14 · Create the TorchTrainer#
Now we bring everything together with a TorchTrainer.
The TorchTrainer is the high-level Ray Train class that:
Launches the per-worker training loop (
train_loop_ray_train) across the cluster.Applies the scaling setup from
scaling_config(number of workers, GPUs/CPUs).Uses
run_configto decide where results and checkpoints are stored.Passes
train_loop_config(hyperparameters likenum_epochsandglobal_batch_size) into the training loop.
This object encapsulates the distributed orchestration, so you can start training with a simple call to trainer.fit().
# 14. Set up the TorchTrainer
trainer = TorchTrainer(
train_loop_ray_train, # training loop to run on each worker
scaling_config=scaling_config, # number of workers and resource config
run_config=run_config, # storage path + run name for artifacts
train_loop_config=train_loop_config, # hyperparameters passed to the loop
)
15 · Launch Training with trainer.fit()#
Calling trainer.fit() starts the distributed training job and blocks until it completes.
When the job launches, you’ll see logs that confirm:
Process group setup → Ray initializes a distributed worker group and assigns ranks (e.g.,
world_rank=0andworld_rank=1).Worker placement → Each worker is launched on a specific node and device. The logs show IP addresses, process IDs, and rank assignments.
Model preparation → Each worker moves the model to its GPU (
cuda:0) and wraps it in DistributedDataParallel (DDP).
These logs are a quick sanity check that Ray Train is correctly orchestrating multi-GPU training across your cluster.
|
|---|
# 15. Launch distributed training
# trainer.fit() starts the training job:
# - Spawns workers according to scaling_config
# - Runs train_loop_ray_train() on each worker
# - Collects metrics and checkpoints into result
result = trainer.fit()
