08 · Wrap the Model with prepare_model()

08 · Wrap the Model with prepare_model()#

Next, we define a helper function to build and prepare the model for Ray Train.

  • Start by constructing the ResNet-18 model adapted for MNIST using build_resnet18().

  • Instead of manually calling model.to("cuda") and wrapping it in DistributedDataParallel (DDP), we use ray.train.torch.prepare_model().

    • This automatically:

      • Moves the model to the correct device (GPU or CPU).

      • Wraps it in DDP or FSDP.

      • Ensures gradients are synchronized across workers.

This means the same code works whether you’re training on 1 GPU or 100 GPUs — no manual device placement or DDP boilerplate required.

# 08. Build and prepare the model for Ray Train

def load_model_ray_train() -> torch.nn.Module:
    model = build_resnet18()
    # prepare_model() → move to correct device + wrap in DDP automatically
    model = ray.train.torch.prepare_model(model)
    return model
prepare_model() allows users to specify additional parameters:
  • parallel_strategy: "ddp", "fsdp" – wrap models in DistributedDataParallel or FullyShardedDataParallel
  • parallel_strategy_kwargs: pass additional arguments to "ddp" or "fsdp"