tglite.TBatch.split_data

TBatch.split_data(data: Tensor) Tuple[Tensor, Tensor, Tensor | None]

Splits the data into multiple arrays, with each array containing a number of rows equal to the batch size.

Parameters:

data (Tensor) – The source data to be split.

Raises:

TError – If the length of data is not three times the batch size when negative nodes are included or two times otherwise.

Returns:

A tuple (src, dst, neg), where neg is None if no negative nodes are specified.