Batch embeddings#
The previous section applied a mapping operation using a function to each row in the dataset. Now you’re ready to generate embeddings from the data and using Ray Data’s map_batches to apply an operation across batches of the data. The operation is in the form of a callable, which is a function or a class with a __call__ method.
import numpy as np
from PIL import Image
import torch
from transformers import CLIPModel, CLIPProcessor
class EmbedImages(object):
def __init__(self, model_id, device):
# Load CLIP model and processor
self.processor = CLIPProcessor.from_pretrained(model_id)
self.model = CLIPModel.from_pretrained(model_id)
self.model.to(device)
self.device = device
def __call__(self, batch):
# Load and preprocess images
images = [
Image.fromarray(np.uint8(img)).convert("RGB") for img in batch["image"]
]
inputs = self.processor(images=images, return_tensors="pt", padding=True).to(
self.device
)
# Generate embeddings
with torch.inference_mode():
batch["embedding"] = self.model.get_image_features(**inputs).cpu().numpy()
return batch
Ray object store references
Instead of initializing the same model for each instance of the class above, we can instead use references to Ray’s shared memory object store. We can load the model once, store it inside the default object store and then have each instance of our class refer to it.
model = load_model(...)
model_ref = ray.put(model)
class Foo:
def __init__(self, model_ref):
self.model = ray.get(model_ref)
...
# Generate batch embeddings
embeddings_ds = ds.map_batches(
EmbedImages,
fn_constructor_kwargs={
"model_id": "openai/clip-vit-base-patch32",
"device": "cuda",
}, # class kwargs
fn_kwargs={}, # __call__ kwargs
concurrency=4,
batch_size=64,
num_gpus=1,
accelerator_type="T4",
)
embeddings_ds = embeddings_ds.drop_columns(["image"]) # remove image column