05 · Define the Ray Train Loop (DDP per-worker)#
This is the per-worker training function that Ray executes on each process/GPU. It keeps your PyTorch code intact while Ray handles process launch, device placement, and data sharding.
Key points:
Inputs via
config: we pass hyperparameters likenum_epochsand aglobal_batch_size.Model & optimizer:
load_model_ray_train()returns a model already wrapped by Ray Train (DDP + correct device). We useAdamandCrossEntropyLossfor MNIST.Batch sizing: we split the global batch across workers:
per_worker_batch = global_batch_size // world_size.Data sharding:
build_data_loader_ray_train(...)returns a DataLoader wrapped with a DistributedSampler; each worker sees a disjoint shard.Epoch control:
data_loader.sampler.set_epoch(epoch)ensures proper shuffling across epochs in distributed mode.Training step: standard PyTorch loop—forward → loss → zero_grad → backward → step.
Metrics & checkpointing:
print_metrics_ray_train(...)logs loss;save_checkpoint_and_metrics_ray_train(...)callsray.train.report(...)(rank-0 saves the checkpoint).
This function is passed to TorchTrainer, which runs it concurrently on all workers.
Let’s see how this data-parallel training loop will look like with Ray Train and PyTorch.
# 05. Define the Ray Train per-worker training loop
def train_loop_ray_train(config: dict): # pass in hyperparameters in config
# config holds hyperparameters passed from TorchTrainer (e.g. num_epochs, global_batch_size)
# Define loss function for MNIST classification
criterion = CrossEntropyLoss()
# Build and prepare the model for distributed training.
# load_model_ray_train() calls ray.train.torch.prepare_model()
# → moves model to GPU and wraps it in DistributedDataParallel (DDP).
model = load_model_ray_train()
# Standard optimizer (learning rate fixed for demo)
optimizer = Adam(model.parameters(), lr=1e-5)
# Calculate the batch size for each worker
global_batch_size = config["global_batch_size"]
world_size = ray.train.get_context().get_world_size() # total # of workers in the job
batch_size = global_batch_size // world_size # split global batch evenly
print(f"{world_size=}\n{batch_size=}")
# Wrap DataLoader with prepare_data_loader()
# → applies DistributedSampler (shards data across workers)
# → ensures batches are automatically moved to correct device
data_loader = build_data_loader_ray_train(batch_size=batch_size)
# ----------------------- Training loop ----------------------- #
for epoch in range(config["num_epochs"]):
# Ensure each worker shuffles its shard differently every epoch
data_loader.sampler.set_epoch(epoch)
# Iterate over batches (sharded across workers)
for images, labels in data_loader:
outputs = model(images) # forward pass
loss = criterion(outputs, labels) # compute loss
optimizer.zero_grad() # reset gradients
loss.backward() # backward pass (grads averaged across workers via DDP)
optimizer.step() # update model weights
# After each epoch: report loss and log metrics
metrics = print_metrics_ray_train(loss, epoch)
# Save checkpoint (only rank-0 worker persists the model)
save_checkpoint_and_metrics_ray_train(model, metrics)
Main training loop
- global_batch_size: the total number of samples processed in a single training step of the entire training job.
- It's estimated like this:
batch size * DDP workers * gradient accumulation steps.
- It's estimated like this:
- Notice that images and labels are no longer manually moved to device (
images.to("cuda")). This is done by prepare_data_loader() . - Config that will be passed here, is defined below. It will be passed to the Ray Train's TorchTrainer.
- TrainContext lets users get useful information about the training i.e. node rank, world size, world rank, experiment name.
load_model_ray_trainandbuild_data_loader_ray_trainare implemented below.