Evaluation

Evaluation#

This tutorial concludes by evaluating the trained model on the test dataset. Evaluation is essentially the same as the batch inference workload where you apply the model on batches of data and then calculate metrics using the predictions versus true labels. Ray Data is hyper optimized for throughput so preserving order isn’t a priority. But for evaluation, this approach is crucial. Achieve this approach by preserving the entire row and adding the predicted label as another column to each row.

from urllib.parse import urlparse
from sklearn.metrics import multilabel_confusion_matrix
class TorchPredictor:
    def __init__(self, preprocessor, model):
        self.preprocessor = preprocessor
        self.model = model
        self.model.eval()

    def __call__(self, batch, device="cuda"):
        self.model.to(device)
        batch["prediction"] = self.model.predict(collate_fn(batch))
        return batch

    def predict_probabilities(self, batch, device="cuda"):
        self.model.to(device)
        predicted_probabilities = self.model.predict_probabilities(collate_fn(batch))
        batch["probabilities"] = [
            {
                self.preprocessor.label_to_class[i]: float(prob)
                for i, prob in enumerate(probabilities)
            }
            for probabilities in predicted_probabilities
        ]
        return batch

    @classmethod
    def from_artifacts_dir(cls, artifacts_dir):
        with open(os.path.join(artifacts_dir, "class_to_label.json"), "r") as fp:
            class_to_label = json.load(fp)
        preprocessor = Preprocessor(class_to_label=class_to_label)
        model = ClassificationModel.load(
            args_fp=os.path.join(artifacts_dir, "args.json"),
            state_dict_fp=os.path.join(artifacts_dir, "model.pt"),
        )
        return cls(preprocessor=preprocessor, model=model)
# Load and preproces eval dataset.
artifacts_dir = urlparse(best_run.artifact_uri).path
predictor = TorchPredictor.from_artifacts_dir(artifacts_dir=artifacts_dir)
test_ds = ray.data.read_images("s3://doggos-dataset/test", include_paths=True)
test_ds = test_ds.map(add_class)
test_ds = predictor.preprocessor.transform(ds=test_ds)
# y_pred (batch inference).
pred_ds = test_ds.map_batches(
    predictor,
    concurrency=4,
    batch_size=64,
    num_gpus=1,
    accelerator_type="T4",
)
pred_ds.take(1)
(autoscaler +8m20s) [autoscaler] [1xT4:8CPU-32GB] Attempting to add 1 node to the cluster (increasing from 0 to 1).
(autoscaler +8m25s) [autoscaler] [1xT4:8CPU-32GB|g4dn.2xlarge] [us-west-2a] [on-demand] Launched 1 instance.
(autoscaler +8m25s) [autoscaler] [4xT4:48CPU-192GB] Attempting to add 1 node to the cluster (increasing from 1 to 2).
(autoscaler +8m30s) [autoscaler] [4xT4:48CPU-192GB|g4dn.12xlarge] [us-west-2a] [on-demand] Launched 1 instance.
[{'path': 'doggos-dataset/test/basset/basset_10005.jpg',
  'class': 'basset',
  'label': 2,
  'embedding': array([ 8.86104554e-02, -5.89382686e-02,  1.15464866e-01,  2.15815112e-01,
         -3.43266308e-01, -3.35150540e-01,  1.48883224e-01, -1.02369718e-01,
         -1.69915810e-01,  4.34856862e-03,  2.41593361e-01,  1.79200619e-01,
          4.34402555e-01,  4.59785998e-01,  1.59284808e-02,  4.16959971e-01,
          5.20779848e-01,  1.86366066e-01, -3.43496174e-01, -4.00813907e-01,
         -1.15213782e-01, -3.04853529e-01,  1.77998394e-01,  1.82090014e-01,
         -3.56360346e-01, -2.30711952e-01,  1.69025257e-01,  3.78455579e-01,
          8.37044120e-02, -4.81875241e-02,  3.17967087e-01, -1.40099749e-01,
         -2.15949178e-01, -4.72761095e-01, -3.01893711e-01,  7.59940967e-02,
         -2.64865339e-01,  5.89084566e-01, -3.75831634e-01,  3.11807573e-01,
         -3.82964134e-01, -1.86417520e-01,  1.07007243e-01,  4.81416702e-01,
         -3.70819569e-01,  9.12090182e-01,  3.13470632e-01, -3.69494259e-02,
         -2.21142501e-01,  3.32214013e-02,  8.51379186e-02,  3.64337176e-01,
         -3.90754700e-01,  4.39904258e-02,  5.39945886e-02, -5.02359867e-01,
         -4.76054996e-02,  3.87604594e-01, -3.71239424e-01, -8.79095644e-02,
          5.62141061e-01,  1.96927994e-01,  3.54419112e-01, -6.80974126e-03,
          2.86425143e-01, -3.24660867e-01, -4.56204057e-01,  6.41017914e-01,
         -1.67037442e-01, -2.29641497e-01,  4.71122622e-01,  5.03865302e-01,
         -9.06585157e-03, -1.23926058e-01, -3.32888782e-01,  1.59683321e-02,
         -5.00816345e-01, -3.53796408e-02, -1.60535276e-01, -2.88702995e-01,
          5.51706925e-02, -3.47863048e-01, -3.01085338e-02, -6.00592375e-01,
          2.04530790e-01, -1.17298350e-01,  8.88321698e-01, -3.18641007e-01,
          2.02193573e-01, -1.50856599e-01, -2.96603352e-01, -5.45758486e-01,
         -7.55531311e+00, -3.07271361e-01, -7.33374238e-01,  2.76708573e-01,
         -3.76666151e-02, -4.25825119e-01, -5.56892097e-01,  7.15545475e-01,
          1.02834240e-01, -1.19939610e-01,  1.94998607e-01, -2.46950224e-01,
          2.61530429e-01, -4.19263542e-01,  1.31001920e-01, -2.49398082e-01,
         -3.26750994e-01, -3.92482489e-01,  3.30219358e-01, -5.78646958e-01,
          1.53134540e-01, -3.10127169e-01, -3.67199332e-01, -7.94161111e-02,
         -2.93402106e-01,  2.62198240e-01,  2.91103810e-01,  1.32868871e-01,
         -5.78317158e-02, -4.26885992e-01,  2.99195677e-01,  4.23972368e-01,
          2.30407149e-01, -2.98300147e-01, -1.55886114e-01, -1.24661736e-01,
         -1.17139973e-01, -4.21351314e-01, -1.45010501e-02, -3.06388348e-01,
          2.89572328e-01,  9.73405361e-01, -5.52814901e-01,  2.36222595e-01,
         -2.13898420e-01, -1.00043082e+00, -3.57041806e-01, -1.50843680e-01,
          4.69288528e-02,  2.08646134e-01, -2.70194232e-01,  2.63797104e-01,
          1.31332219e-01,  2.82329589e-01,  2.69341841e-02, -1.21627375e-01,
          3.80910456e-01,  2.65330970e-01, -3.01948935e-01, -6.39178753e-02,
         -3.13922286e-01, -4.14075851e-01, -2.19056532e-01,  2.22424790e-01,
          8.13730657e-02, -3.03519934e-01,  9.32400897e-02, -3.76873404e-01,
          8.34950879e-02,  1.01878762e-01,  2.87054926e-01,  2.09415853e-02,
         -1.22204229e-01,  1.64302550e-02, -2.41174936e-01,  1.78844824e-01,
          9.15416703e-03,  1.66462481e-01, -1.45732313e-01, -5.85511327e-04,
          2.25536823e-01,  3.30472469e-01, -1.25101686e-01,  1.13093004e-01,
          1.52094781e-01,  4.37459409e-01,  3.22061956e-01,  1.37893021e-01,
         -2.53650725e-01, -1.94988877e-01, -2.72130489e-01, -2.57504702e-01,
          1.92389667e-01, -2.07393348e-01,  1.73574477e-01,  2.59756446e-02,
          2.20320046e-01,  6.48344308e-02,  3.96853566e-01,  1.11773282e-01,
         -4.38930988e-01, -5.10937572e-02,  5.92644155e-01,  6.10140711e-03,
         -3.97206768e-02,  7.65584633e-02, -7.68468618e-01,  1.23042464e-01,
          3.48037392e-01,  1.49242997e-01,  2.86662281e-02,  2.79642552e-01,
         -2.26151049e-01, -6.73239648e-01, -8.07924390e-01,  8.62701386e-02,
          4.94999364e-02,  1.61207989e-02, -1.30242959e-01,  1.77768275e-01,
          3.62961054e-01, -3.20745975e-01,  3.67820978e-01, -9.77848917e-02,
         -2.64019221e-01,  6.74475431e-01,  9.26629007e-01, -4.54470068e-02,
          9.59405363e-01,  3.02993000e-01, -5.81385851e-01,  3.98850322e-01,
          7.40434751e-02,  1.79926023e-01,  9.12196040e-02,  2.77938917e-02,
         -2.20950916e-02, -1.98561847e-01, -4.33019698e-01,  1.35872006e-01,
         -3.84440348e-02,  1.63487554e-01,  5.38927615e-02,  8.52212310e-01,
         -8.64772916e-01, -3.00439209e-01,  1.66039094e-02, -4.84181255e-01,
         -2.57156193e-01,  4.46582437e-01,  3.71635705e-02, -7.58354291e-02,
         -1.38248950e-02,  1.01295078e+00,  2.14489758e-01, -1.17217854e-01,
         -2.82662451e-01,  7.08411038e-01,  2.08262652e-01, -1.69240460e-02,
          1.02334268e-01,  4.20059741e-01,  1.07706316e-01, -3.89203757e-01,
         -5.91410846e-02, -1.77690476e-01, -1.26772380e+00,  1.75859511e-01,
         -2.49499828e-01,  1.60166726e-01,  8.72884393e-02, -4.53421593e-01,
          1.96858853e-01, -2.25365251e-01, -1.31235719e-02, -4.58204031e-01,
         -1.54087022e-01, -1.87472761e-01,  2.73187131e-01,  4.14693624e-01,
          6.00348413e-01,  5.16499318e-02, -2.52319247e-01, -2.08351701e-01,
         -3.85643661e-01, -6.44139796e-02, -2.70672083e-01, -5.09124994e-02,
         -1.17392734e-01, -1.16136428e-02, -1.69710606e-01,  2.30101690e-01,
         -6.31506741e-02,  2.20495850e-01,  4.81231391e-01,  3.76428038e-01,
         -2.14597031e-01, -4.70009223e-02,  4.38644290e-01,  2.72557199e-01,
         -1.89499091e-02,  6.36664629e-02, -4.86765429e-02, -6.02428794e-01,
          5.40002957e-02, -9.60005671e-02,  4.63560931e-02, -3.55034113e-01,
          2.27724269e-01, -1.30642965e-01, -5.17771959e-01,  7.08835796e-02,
         -2.57462114e-01, -4.82860744e-01,  1.13421358e-01,  9.88648832e-02,
          6.21988237e-01,  2.64641732e-01, -9.67874378e-03,  1.94528699e-01,
          9.72453296e-01, -4.36969042e-01, -5.50681949e-02,  1.42934144e-01,
          1.37221038e-01,  5.63952804e-01, -3.20022464e-01, -5.56031644e-01,
          9.09894407e-01,  1.02216589e+00, -2.79887915e-01,  1.69066399e-01,
          6.48921371e-01,  1.68456510e-02, -2.58911937e-01,  4.62736428e-01,
          8.00172612e-03,  1.66315883e-01, -5.30062854e-01, -3.96020412e-01,
          4.43380117e-01, -4.35658276e-01, -1.11912012e-01, -5.91614306e-01,
         -7.02220649e-02,  1.41544282e-01, -5.65246567e-02, -1.19229007e+00,
         -1.00026041e-01,  1.35173336e-01, -1.37986809e-01,  4.58395988e-01,
          2.99769610e-01,  1.13845997e-01, -3.23149785e-02,  4.82394725e-01,
         -6.13934547e-03,  3.68614852e-01, -4.91497517e-01, -4.97332066e-01,
          8.73729736e-02,  3.60586494e-01, -2.91166097e-01,  1.89481646e-01,
          2.87948608e-01,  1.90306157e-01,  4.15048778e-01,  3.93784940e-01,
          6.75817132e-02,  1.18251920e-01,  2.03508779e-01,  3.09830695e-01,
         -1.03927016e+00,  1.00612268e-01, -3.46988708e-01, -7.09752440e-01,
          2.20241398e-01, -3.74946982e-01, -1.48783788e-01, -1.31232068e-01,
          3.87498319e-01,  1.67044029e-01, -2.79640555e-01,  3.40543866e-01,
          1.28378880e+00,  4.47215438e-01, -5.00054121e-01,  6.85076341e-02,
          1.93691164e-01, -4.66935217e-01, -3.24348718e-01,  4.53348368e-01,
          6.36629641e-01, -5.52294970e-01, -3.59640062e-01,  2.45728597e-01,
          4.48195577e-01, -1.36022663e+00, -6.26060665e-01, -4.96963590e-01,
         -2.55071461e-01, -2.31453001e-01, -4.22013104e-01,  5.81141561e-02,
          1.66424632e-01, -1.81557357e-01, -2.85358205e-02, -1.10628068e+00,
         -2.42026821e-01, -4.49676067e-03,  5.53836450e-02,  4.92810488e-01,
          5.83105981e-01,  6.97781667e-02, -1.33217961e-01, -1.25093237e-01,
          1.17499933e-01, -5.19634366e-01,  1.42042309e-01,  2.34404474e-01,
         -2.55929470e-01,  3.23758684e-02, -2.34450802e-01, -7.54091814e-02,
          1.83672294e-01, -2.25883007e-01, -4.76478487e-02, -4.84889567e-01,
          1.12959743e-03,  1.80705532e-01, -5.87785244e-02,  4.82457250e-01,
         -1.88920692e-01,  1.47517592e-01,  1.10182568e-01, -2.28278339e-02,
          8.62778306e-01,  4.46689427e-02,  4.16403189e-02, -1.07179873e-01,
         -1.42522454e+00, -2.31161788e-02,  3.05959303e-02, -6.58722073e-02,
         -3.69132429e-01,  3.49290550e-01, -1.39178723e-01, -3.51127565e-01,
          5.00785351e-01,  2.31236637e-01,  6.77590072e-02, -3.59323025e-02,
          2.69076526e-01, -3.60533416e-01,  1.48107335e-01, -1.11518174e-01,
          1.65307403e-01, -1.74086124e-01,  6.01880312e-01, -5.95235109e-01,
          5.29538319e-02,  3.12422097e-01, -1.14403330e-01,  2.30422497e-01,
         -9.48345065e-02,  3.76421027e-02,  4.77573276e-02,  3.89954895e-01,
         -1.91829026e-01, -6.26232028e-01,  1.29549801e-01, -2.84714490e-01,
          2.88834363e-01,  6.25569642e-01, -2.44193405e-01,  3.08956832e-01,
         -4.79587227e-01,  1.59115836e-01, -1.07442781e-01,  1.57203451e-01,
         -8.51369202e-02, -1.20136715e-01, -2.91232206e-02,  1.08408488e-01,
         -5.97195402e-02, -1.21715315e-01, -5.79822421e-01,  3.90639007e-01,
         -2.83878148e-01, -2.72939146e-01,  3.87672335e-04, -2.62640566e-01,
         -1.67415068e-01,  1.97720259e-01,  3.60535234e-01, -1.85247302e-01,
         -2.80813038e-01,  3.32875013e-01, -3.98125350e-01, -3.53022516e-02,
          5.48863769e-01, -1.35882646e-01,  2.50048220e-01, -1.27448589e-01,
         -3.03174406e-01,  3.85489166e-02, -7.27320850e-01,  5.22592783e-01,
         -1.97360516e-01, -1.98229402e-01, -1.42074719e-01,  4.11824808e-02,
         -2.92105675e-01,  2.07964912e-01,  4.97746691e-02,  1.48062438e-01,
         -2.94304550e-01,  7.31720269e-01,  1.14105418e-02,  5.50758056e-02],
        dtype=float32),
  'prediction': 8}]
def batch_metric(batch):
    labels = batch["label"]
    preds = batch["prediction"]
    mcm = multilabel_confusion_matrix(labels, preds)
    tn, fp, fn, tp = [], [], [], []
    for i in range(mcm.shape[0]):
        tn.append(mcm[i, 0, 0])  # True negatives
        fp.append(mcm[i, 0, 1])  # False positives
        fn.append(mcm[i, 1, 0])  # False negatives
        tp.append(mcm[i, 1, 1])  # True positives
    return {"TN": tn, "FP": fp, "FN": fn, "TP": tp}
# Aggregated metrics after processing all batches.
metrics_ds = pred_ds.map_batches(batch_metric)
aggregate_metrics = metrics_ds.sum(["TN", "FP", "FN", "TP"])

# Aggregate the confusion matrix components across all batches.
tn = aggregate_metrics["sum(TN)"]
fp = aggregate_metrics["sum(FP)"]
fn = aggregate_metrics["sum(FN)"]
tp = aggregate_metrics["sum(TP)"]

# Calculate metrics.
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
accuracy = (tp + tn) / (tp + tn + fp + fn)
(autoscaler +9m10s) [autoscaler] Cluster upscaled to {120 CPU, 9 GPU}.
(autoscaler +9m15s) [autoscaler] Cluster upscaled to {168 CPU, 13 GPU}.
print(f"Precision: {precision:.2f}")
print(f"Recall: {recall:.2f}")
print(f"F1: {f1:.2f}")
print(f"Accuracy: {accuracy:.2f}")
Precision: 0.84
Recall: 0.84
F1: 0.84
Accuracy: 0.98
(autoscaler +13m0s) [autoscaler] Downscaling node i-0ffe5abae6e899f5a (node IP: 10.0.60.138) due to node idle termination.
(autoscaler +13m5s) [autoscaler] Cluster resized to {120 CPU, 9 GPU}.
(autoscaler +16m0s) [autoscaler] Downscaling node i-0aa72cef9b8921af5 (node IP: 10.0.31.199) due to node idle termination.
(autoscaler +16m5s) [autoscaler] Cluster resized to {112 CPU, 8 GPU}.

🚨 Note: Reset this notebook using the “🔄 Restart” button location at the notebook’s menu bar. This way we can free up all the variables, utils, etc. used in this notebook.