18 · Load a Checkpoint for Inference

18 · Load a Checkpoint for Inference#

After training, we often want to reload the model and use it for predictions.
Here we define a Ray actor (ModelWorker) that loads the checkpointed ResNet-18 onto a GPU and serves inference requests.

  • Initialization (__init__):

    • Reads the checkpoint directory using checkpoint.as_directory().

    • Loads the model weights into a fresh ResNet-18.

    • Moves the model to GPU and sets it to evaluation mode.

  • Prediction (predict):

    • Accepts either a single image ([C,H,W]) or a batch ([B,C,H,W]).

    • Ensures the tensor is correctly shaped and moved to GPU.

    • Runs inference in torch.inference_mode() for efficiency.

    • Returns the predicted class indices as a Python list.

Finally, we launch the actor with ModelWorker.remote(result.checkpoint).
This spawns a dedicated process with 1 GPU attached that can serve predictions using the trained model.

# 18. Define a Ray actor to load the trained model and run inference

@ray.remote(num_gpus=1)  # allocate 1 GPU to this actor
class ModelWorker:
    def __init__(self, checkpoint):
        # Load model weights from the Ray checkpoint (on CPU first)
        with checkpoint.as_directory() as ckpt_dir:
            model_path = os.path.join(ckpt_dir, "model.pt")
            state_dict = torch.load(
                model_path,
                map_location=torch.device("cpu"),
                weights_only=True,
            )
        # Rebuild the model, load weights, move to GPU, and set to eval mode
        self.model = build_resnet18()
        self.model.load_state_dict(state_dict)
        self.model.to("cuda")
        self.model.eval()

    @torch.inference_mode()  # disable autograd for faster inference
    def predict(self, batch):
        """
        batch: torch.Tensor or numpy array with shape [B,C,H,W] or [C,H,W]
        returns: list[int] predicted class indices
        """
        x = torch.as_tensor(batch)
        if x.ndim == 3:          # single image → add batch dimension
            x = x.unsqueeze(0)   # shape becomes [1,C,H,W]
        x = x.to("cuda", non_blocking=True)

        logits = self.model(x)
        preds = torch.argmax(logits, dim=1)
        return preds.detach().cpu().tolist()

# Create a fresh actor instance (avoid naming conflicts)
worker = ModelWorker.remote(result.checkpoint)

19 · Run Inference and Visualize Predictions#

With the ModelWorker actor running on GPU, we can now generate predictions on random samples from the MNIST dataset and plot them.

Steps in this cell:

  1. Normalization on CPU

    • Convert each image to a tensor with ToTensor().

    • Apply channel-specific normalization (0.5 mean / std).

    • Keep this preprocessing on CPU for efficiency.

  2. Prediction on GPU via Actor

    • Each normalized image is expanded to shape [1, C, H, W].

    • The tensor is sent to the remote ModelWorker for inference.

    • ray.get(worker.predict.remote(x)) retrieves the predicted class index.

  3. Plot Results

    • Display a 3×3 grid of random MNIST samples.

    • Each subplot shows the true label and the predicted label from the trained ResNet-18.

This demonstrates a simple but practical workflow: CPU-based preprocessing + GPU-based inference in a Ray actor.

# 19. CPU preprocessing + GPU inference via Ray actor

to_tensor = ToTensor()

def normalize_cpu(img):
    # Convert image (PIL) to tensor on CPU → shape [C,H,W]
    t = to_tensor(img)                # [C,H,W] on CPU
    C = t.shape[0]
    # Apply channel-wise normalization (grayscale vs RGB)
    if C == 3:
        norm = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    else:
        norm = Normalize((0.5,), (0.5,))
    return norm(t)

figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3

# Plot a 3x3 grid of random MNIST samples with predictions
for i in range(1, cols * rows + 1):
    idx = np.random.randint(0, len(dataset))
    img, label = dataset[idx]

    # Preprocess on CPU, add batch dim → [1,C,H,W]
    x = normalize_cpu(img).unsqueeze(0)    

    # Run inference on GPU via Ray actor, fetch result   
    pred = ray.get(worker.predict.remote(x))[0]  # int
    
    # Plot image with true label and predicted label
    figure.add_subplot(rows, cols, i)
    plt.title(f"label: {label}; pred: {int(pred)}")
    plt.axis("off")
    arr = np.array(img)
    plt.imshow(arr, cmap="gray" if arr.ndim == 2 else None)

plt.tight_layout()
plt.show()

20 · Clean Up the Ray Actor#

Once you’re done running inference, it’s a good practice to free up resources:

  • ray.kill(worker, no_restart=True) → stops the ModelWorker actor and releases its GPU.

  • del worker + gc.collect() → drop local references so Python’s garbage collector can clean up.

This ensures the GPU is no longer pinned by the actor and can be reused for other jobs.

# 20.

# stop the actor process and free its GPU
ray.kill(worker, no_restart=True)     

# drop local references so nothing pins it
del worker

# Forcing garbage collection is optional:
# - Cluster resources are already freed by ray.kill()
# - Python will clean up the local handle eventually
# - gc.collect() is usually unnecessary unless debugging memory issues
gc.collect()