Batching

Batching#

Take a look at a sample batch of data and ensure that tensors have the proper data type.

from ray.train.torch import get_device
def collate_fn(batch):
    dtypes = {"embedding": torch.float32, "label": torch.int64}
    tensor_batch = {}
    for key in dtypes.keys():
        if key in batch:
            tensor_batch[key] = torch.as_tensor(
                batch[key],
                dtype=dtypes[key],
                device=get_device(),
            )
    return tensor_batch
# Sample batch
sample_batch = train_ds.take_batch(batch_size=3)
collate_fn(batch=sample_batch)
{'embedding': tensor([[ 0.4219,  0.3688, -0.1833,  ...,  0.6288,  0.2298, -0.3989],
         [ 0.0385,  0.3297,  0.2076,  ...,  0.3434, -0.5492,  0.0362],
         [ 0.1881,  0.1737, -0.3069,  ...,  0.3336,  0.1783, -0.0299]]),
 'label': tensor([11, 34,  7])}