14 · Create the TorchTrainer

14 · Create the TorchTrainer#

Now we bring everything together with a TorchTrainer.

The TorchTrainer is the high-level Ray Train class that:

  • Launches the per-worker training loop (train_loop_ray_train) across the cluster.

  • Applies the scaling setup from scaling_config (number of workers, GPUs/CPUs).

  • Uses run_config to decide where results and checkpoints are stored.

  • Passes train_loop_config (hyperparameters like num_epochs and global_batch_size) into the training loop.

This object encapsulates the distributed orchestration, so you can start training with a simple call to trainer.fit().

# 14. Set up the TorchTrainer

trainer = TorchTrainer(
    train_loop_ray_train,          # training loop to run on each worker
    scaling_config=scaling_config, # number of workers and resource config
    run_config=run_config,         # storage path + run name for artifacts
    train_loop_config=train_loop_config,  # hyperparameters passed to the loop
)

15 · Launch Training with trainer.fit()#

Calling trainer.fit() starts the distributed training job and blocks until it completes.

When the job launches, you’ll see logs that confirm:

  • Process group setup → Ray initializes a distributed worker group and assigns ranks (e.g., world_rank=0 and world_rank=1).

  • Worker placement → Each worker is launched on a specific node and device. The logs show IP addresses, process IDs, and rank assignments.

  • Model preparation → Each worker moves the model to its GPU (cuda:0) and wraps it in DistributedDataParallel (DDP).

These logs are a quick sanity check that Ray Train is correctly orchestrating multi-GPU training across your cluster.

# 15. Launch distributed training

# trainer.fit() starts the training job:
# - Spawns workers according to scaling_config
# - Runs train_loop_ray_train() on each worker
# - Collects metrics and checkpoints into result
result = trainer.fit()