Batch embeddings

Batch embeddings#

The previous section applied a mapping operation using a function to each row in the dataset. Now you’re ready to generate embeddings from the data and using Ray Data’s map_batches to apply an operation across batches of the data. The operation is in the form of a callable, which is a function or a class with a __call__ method.

import numpy as np
from PIL import Image
import torch
from transformers import CLIPModel, CLIPProcessor
class EmbedImages(object):
    def __init__(self, model_id, device):
        # Load CLIP model and processor
        self.processor = CLIPProcessor.from_pretrained(model_id)
        self.model = CLIPModel.from_pretrained(model_id)
        self.model.to(device)
        self.device = device

    def __call__(self, batch):
        # Load and preprocess images
        images = [
            Image.fromarray(np.uint8(img)).convert("RGB") for img in batch["image"]
        ]
        inputs = self.processor(images=images, return_tensors="pt", padding=True).to(
            self.device
        )

        # Generate embeddings
        with torch.inference_mode():
            batch["embedding"] = self.model.get_image_features(**inputs).cpu().numpy()

        return batch
Ray object store references

Instead of initializing the same model for each instance of the class above, we can instead use references to Ray’s shared memory object store. We can load the model once, store it inside the default object store and then have each instance of our class refer to it.

model = load_model(...)
model_ref = ray.put(model) 

class Foo:
    def __init__(self, model_ref):
        self.model = ray.get(model_ref)
        ...
# Generate batch embeddings
embeddings_ds = ds.map_batches(
    EmbedImages,
    fn_constructor_kwargs={
        "model_id": "openai/clip-vit-base-patch32",
        "device": "cuda",
    },  # class kwargs
    fn_kwargs={},  # __call__ kwargs
    concurrency=4,
    batch_size=64,
    num_gpus=1,
    accelerator_type="T4",
)
embeddings_ds = embeddings_ds.drop_columns(["image"])  # remove image column