Batching#
Take a look at a sample batch of data and ensure that tensors have the proper data type.
from ray.train.torch import get_device
def collate_fn(batch):
dtypes = {"embedding": torch.float32, "label": torch.int64}
tensor_batch = {}
for key in dtypes.keys():
if key in batch:
tensor_batch[key] = torch.as_tensor(
batch[key],
dtype=dtypes[key],
device=get_device(),
)
return tensor_batch
# Sample batch
sample_batch = train_ds.take_batch(batch_size=3)
collate_fn(batch=sample_batch)
{'embedding': tensor([[ 0.4219, 0.3688, -0.1833, ..., 0.6288, 0.2298, -0.3989],
[ 0.0385, 0.3297, 0.2076, ..., 0.3434, -0.5492, 0.0362],
[ 0.1881, 0.1737, -0.3069, ..., 0.3336, 0.1783, -0.0299]]),
'label': tensor([11, 34, 7])}