08 · Wrap the Model with prepare_model()#
Next, we define a helper function to build and prepare the model for Ray Train.
Start by constructing the ResNet-18 model adapted for MNIST using
build_resnet18().Instead of manually calling
model.to("cuda")and wrapping it in DistributedDataParallel (DDP), we useray.train.torch.prepare_model().This automatically:
Moves the model to the correct device (GPU or CPU).
Wraps it in DDP or FSDP.
Ensures gradients are synchronized across workers.
This means the same code works whether you’re training on 1 GPU or 100 GPUs — no manual device placement or DDP boilerplate required.
# 08. Build and prepare the model for Ray Train
def load_model_ray_train() -> torch.nn.Module:
model = build_resnet18()
# prepare_model() → move to correct device + wrap in DDP automatically
model = ray.train.torch.prepare_model(model)
return model
prepare_model()
allows users to specify additional parameters:
parallel_strategy: "ddp", "fsdp" – wrap models inDistributedDataParallelorFullyShardedDataParallelparallel_strategy_kwargs: pass additional arguments to "ddp" or "fsdp"