01 · Imports#
Start by importing all the libraries you’ll need for this tutorial.
Standard utilities:
os,datetime,tempfile,csv,shutil, andgchelp with file paths, checkpointing, cleanup, and general housekeeping.Data and visualization:
pandas,numpy,matplotlib, andPILare used for inspecting the dataset and plotting sample images.PyTorch: core deep learning components (
torch,CrossEntropyLoss,Adam) plustorchvisionfor loading MNIST and building a ResNet-18 model.Ray Train: the key imports for distributed training—
ScalingConfig,RunConfig, andTorchTrainer. These handle cluster scaling, experiment output storage, and execution of your training loop across GPUs.
This notebook assumes Ray is already running (for example, inside an Anyscale cluster), so you don’t need to call ray.init() manually.
# 01. Imports
# --- Standard library: file IO, paths, timestamps, temp dirs, cleanup ---
import csv # Simple CSV logging for metrics in single-GPU section
import datetime # Timestamps for run directories / filenames
import os # Filesystem utilities (paths, env vars)
import tempfile # Ephemeral dirs for checkpoint staging with ray.train.report()
import shutil # Cleanup of artifacts (later cells)
import gc # Manual garbage collection to cleanup after inference
from pathlib import Path # Convenient, cross-platform path handling
# --- Visualization & data wrangling ---
import matplotlib.pyplot as plt # Plot sample digits and metrics curves
from PIL import Image # Image utilities for inspection/debug
import numpy as np # Numeric helpers (random sampling, arrays)
import pandas as pd # Read metrics.csv into a DataFrame
# --- PyTorch & TorchVision (model + dataset) ---
import torch
from torch.nn import CrossEntropyLoss # Classification loss for MNIST
from torch.optim import Adam # Optimizer
from torchvision.models import resnet18 # Baseline CNN (we’ll adapt for 1-channel input)
from torchvision.datasets import MNIST # Dataset
from torchvision.transforms import ToTensor, Normalize, Compose # Preprocessing pipeline
# --- Ray Train (distributed orchestration) ---
import ray
from ray.train import ScalingConfig, RunConfig # Configure scale and storage
from ray.train.torch import TorchTrainer # Multi-GPU PyTorch trainer (DDP/FSDP)
02 · Download MNIST Dataset#
Next, download the MNIST dataset using torchvision.datasets.MNIST.
This will automatically fetch the dataset (if not already present) into a local
./datadirectory.MNIST consists of 60,000 grayscale images of handwritten digits (0–9), each sized 28×28 pixels.
By setting
train=True, we load the training split of the dataset.
Once downloaded, we’ll later wrap this dataset in a DataLoader and apply normalization so it can be used for model training.
# 02. Download MNIST Dataset
dataset = MNIST(root="/mnt/cluster_storage/data", train=True, download=True)
Note about Anyscale storage options
In this example, the MNIST dataset is stored under /mnt/cluster_storage/, which is Anyscale’s persistent cluster storage.
Unlike node-local NVMe volumes, cluster storage is shared across nodes in your cluster.
Data written here will persist across cluster restarts, making it a safe place for datasets, checkpoints, and results.
This is the recommended location for training data and artifacts you want to reuse.
Anyscale also provides each node with its own volume and disk and doesn’t share them with other nodes.
Local storage is very fast - Anyscale supports the Non-Volatile Memory Express (NVMe) interface.
Local storage is not a persisent storage, Anyscale deletes data in the local storage after instances are terminated.
Read more about available storage options.
03 · Visualize Sample Digits#
Before training, let’s take a quick look at the dataset.
We’ll randomly sample 9 images from the MNIST training set.
Each image is a 28×28 grayscale digit, with its ground-truth label shown above the plot.
This simple visualization is a good sanity check to confirm that the dataset downloaded correctly and that labels align with the images.
# 03. Visualize Sample Digits
# Create a square figure for plotting 9 samples (3x3 grid)
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
# Loop through grid slots and plot a random digit each time
for i in range(1, cols * rows + 1):
# Randomly select an index from the dataset
sample_idx = np.random.randint(0, len(dataset.data))
img, label = dataset[sample_idx] # image (PIL) and its digit label
# Add subplot to the figure
figure.add_subplot(rows, cols, i)
plt.title(label) # show the digit label above each subplot
plt.axis("off") # remove axes for cleaner visualization
plt.imshow(img, cmap="gray") # display as grayscale image