18 · Load a Checkpoint for Inference#
After training, we often want to reload the model and use it for predictions.
Here we define a Ray actor (ModelWorker) that loads the checkpointed ResNet-18 onto a GPU and serves inference requests.
Initialization (
__init__):Reads the checkpoint directory using
checkpoint.as_directory().Loads the model weights into a fresh ResNet-18.
Moves the model to GPU and sets it to evaluation mode.
Prediction (
predict):Accepts either a single image (
[C,H,W]) or a batch ([B,C,H,W]).Ensures the tensor is correctly shaped and moved to GPU.
Runs inference in
torch.inference_mode()for efficiency.Returns the predicted class indices as a Python list.
Finally, we launch the actor with ModelWorker.remote(result.checkpoint).
This spawns a dedicated process with 1 GPU attached that can serve predictions using the trained model.
# 18. Define a Ray actor to load the trained model and run inference
@ray.remote(num_gpus=1) # allocate 1 GPU to this actor
class ModelWorker:
def __init__(self, checkpoint):
# Load model weights from the Ray checkpoint (on CPU first)
with checkpoint.as_directory() as ckpt_dir:
model_path = os.path.join(ckpt_dir, "model.pt")
state_dict = torch.load(
model_path,
map_location=torch.device("cpu"),
weights_only=True,
)
# Rebuild the model, load weights, move to GPU, and set to eval mode
self.model = build_resnet18()
self.model.load_state_dict(state_dict)
self.model.to("cuda")
self.model.eval()
@torch.inference_mode() # disable autograd for faster inference
def predict(self, batch):
"""
batch: torch.Tensor or numpy array with shape [B,C,H,W] or [C,H,W]
returns: list[int] predicted class indices
"""
x = torch.as_tensor(batch)
if x.ndim == 3: # single image → add batch dimension
x = x.unsqueeze(0) # shape becomes [1,C,H,W]
x = x.to("cuda", non_blocking=True)
logits = self.model(x)
preds = torch.argmax(logits, dim=1)
return preds.detach().cpu().tolist()
# Create a fresh actor instance (avoid naming conflicts)
worker = ModelWorker.remote(result.checkpoint)
19 · Run Inference and Visualize Predictions#
With the ModelWorker actor running on GPU, we can now generate predictions on random samples from the MNIST dataset and plot them.
Steps in this cell:
Normalization on CPU
Convert each image to a tensor with
ToTensor().Apply channel-specific normalization (
0.5mean / std).Keep this preprocessing on CPU for efficiency.
Prediction on GPU via Actor
Each normalized image is expanded to shape
[1, C, H, W].The tensor is sent to the remote
ModelWorkerfor inference.ray.get(worker.predict.remote(x))retrieves the predicted class index.
Plot Results
Display a 3×3 grid of random MNIST samples.
Each subplot shows the true label and the predicted label from the trained ResNet-18.
This demonstrates a simple but practical workflow: CPU-based preprocessing + GPU-based inference in a Ray actor.
# 19. CPU preprocessing + GPU inference via Ray actor
to_tensor = ToTensor()
def normalize_cpu(img):
# Convert image (PIL) to tensor on CPU → shape [C,H,W]
t = to_tensor(img) # [C,H,W] on CPU
C = t.shape[0]
# Apply channel-wise normalization (grayscale vs RGB)
if C == 3:
norm = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
else:
norm = Normalize((0.5,), (0.5,))
return norm(t)
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
# Plot a 3x3 grid of random MNIST samples with predictions
for i in range(1, cols * rows + 1):
idx = np.random.randint(0, len(dataset))
img, label = dataset[idx]
# Preprocess on CPU, add batch dim → [1,C,H,W]
x = normalize_cpu(img).unsqueeze(0)
# Run inference on GPU via Ray actor, fetch result
pred = ray.get(worker.predict.remote(x))[0] # int
# Plot image with true label and predicted label
figure.add_subplot(rows, cols, i)
plt.title(f"label: {label}; pred: {int(pred)}")
plt.axis("off")
arr = np.array(img)
plt.imshow(arr, cmap="gray" if arr.ndim == 2 else None)
plt.tight_layout()
plt.show()
20 · Clean Up the Ray Actor#
Once you’re done running inference, it’s a good practice to free up resources:
ray.kill(worker, no_restart=True)→ stops theModelWorkeractor and releases its GPU.del worker+gc.collect()→ drop local references so Python’s garbage collector can clean up.
This ensures the GPU is no longer pinned by the actor and can be reused for other jobs.
# 20.
# stop the actor process and free its GPU
ray.kill(worker, no_restart=True)
# drop local references so nothing pins it
del worker
# Forcing garbage collection is optional:
# - Cluster resources are already freed by ray.kill()
# - Python will clean up the local handle eventually
# - gc.collect() is usually unnecessary unless debugging memory issues
gc.collect()