05 · Define the Ray Train Loop (DDP per-worker)

05 · Define the Ray Train Loop (DDP per-worker)#

This is the per-worker training function that Ray executes on each process/GPU. It keeps your PyTorch code intact while Ray handles process launch, device placement, and data sharding.

Key points:

  • Inputs via config: we pass hyperparameters like num_epochs and a global_batch_size.

  • Model & optimizer: load_model_ray_train() returns a model already wrapped by Ray Train (DDP + correct device). We use Adam and CrossEntropyLoss for MNIST.

  • Batch sizing: we split the global batch across workers:
    per_worker_batch = global_batch_size // world_size.

  • Data sharding: build_data_loader_ray_train(...) returns a DataLoader wrapped with a DistributedSampler; each worker sees a disjoint shard.

  • Epoch control: data_loader.sampler.set_epoch(epoch) ensures proper shuffling across epochs in distributed mode.

  • Training step: standard PyTorch loop—forward → loss → zero_grad → backward → step.

  • Metrics & checkpointing: print_metrics_ray_train(...) logs loss; save_checkpoint_and_metrics_ray_train(...) calls ray.train.report(...) (rank-0 saves the checkpoint).

This function is passed to TorchTrainer, which runs it concurrently on all workers.

Let’s see how this data-parallel training loop will look like with Ray Train and PyTorch.

# 05. Define the Ray Train per-worker training loop

def train_loop_ray_train(config: dict):  # pass in hyperparameters in config
    # config holds hyperparameters passed from TorchTrainer (e.g. num_epochs, global_batch_size)

    # Define loss function for MNIST classification
    criterion = CrossEntropyLoss()

    # Build and prepare the model for distributed training.
    # load_model_ray_train() calls ray.train.torch.prepare_model()
    # → moves model to GPU and wraps it in DistributedDataParallel (DDP).
    model = load_model_ray_train()

    # Standard optimizer (learning rate fixed for demo)
    optimizer = Adam(model.parameters(), lr=1e-5)

    # Calculate the batch size for each worker
    global_batch_size = config["global_batch_size"]
    world_size = ray.train.get_context().get_world_size()  # total # of workers in the job
    batch_size = global_batch_size // world_size  # split global batch evenly
    print(f"{world_size=}\n{batch_size=}")

    # Wrap DataLoader with prepare_data_loader()
    # → applies DistributedSampler (shards data across workers)
    # → ensures batches are automatically moved to correct device
    data_loader = build_data_loader_ray_train(batch_size=batch_size)

    # ----------------------- Training loop ----------------------- #
    for epoch in range(config["num_epochs"]):

        # Ensure each worker shuffles its shard differently every epoch
        data_loader.sampler.set_epoch(epoch)

        # Iterate over batches (sharded across workers)
        for images, labels in data_loader:
            outputs = model(images)            # forward pass
            loss = criterion(outputs, labels)  # compute loss
            optimizer.zero_grad()              # reset gradients

            loss.backward()   # backward pass (grads averaged across workers via DDP)
            optimizer.step()  # update model weights

        # After each epoch: report loss and log metrics
        metrics = print_metrics_ray_train(loss, epoch)

        # Save checkpoint (only rank-0 worker persists the model)
        save_checkpoint_and_metrics_ray_train(model, metrics)

Main training loop

  • global_batch_size: the total number of samples processed in a single training step of the entire training job.
    • It's estimated like this: batch size * DDP workers * gradient accumulation steps.
  • Notice that images and labels are no longer manually moved to device (images.to("cuda")). This is done by prepare_data_loader() .
  • Config that will be passed here, is defined below. It will be passed to the Ray Train's TorchTrainer.
  • TrainContext lets users get useful information about the training i.e. node rank, world size, world rank, experiment name.
  • load_model_ray_train and build_data_loader_ray_train are implemented below.