10 · Report Training Metrics#
During training, it’s important to log metrics like loss values so you can monitor progress.
This helper function prints metrics from every worker:
Collects the current loss and epoch into a dictionary.
Uses
ray.train.get_context().get_world_rank()to identify which worker is reporting.Prints the metrics along with the worker’s rank for debugging and visibility.
# 10. Report training metrics from each worker
def print_metrics_ray_train(loss: torch.Tensor, epoch: int) -> None:
metrics = {"loss": loss.item(), "epoch": epoch}
world_rank = ray.train.get_context().get_world_rank() # report from all workers
print(f"{metrics=} {world_rank=}")
return metrics
If you want to log only from the rank 0 worker, use this code:
def print_metrics_ray_train(loss: torch.Tensor, epoch: int) -> None:
metrics = {"loss": loss.item(), "epoch": epoch}
if ray.train.get_context().get_world_rank() == 0: # report only from the rank 0 worker
print(f"{metrics=} {world_rank=}")
return metrics