Source code for stable_pretraining.callbacks.clip_zero_shot

import torch
import torch.nn.functional as F
from typing import Optional, Callable

from lightning.pytorch import Callback, LightningModule, Trainer
from loguru import logger as logging
import torchmetrics

from ..utils import get_data_from_batch_or_outputs
from .utils import format_metrics_as_dict


[docs] class CLIPZeroShot(Callback): """Zero-shot classification evaluator for CLIP-style models. This callback computes zero-shot predictions by computing the similarity between the image embeddings and the class embeddings. Args: name: Unique identifier for this callback instance (used as log prefix and registry key). image_key: Key in batch or outputs containing input images or precomputed image features. tokens_key: Key in batch containing tokenized text. class_key: Key in batch containing ground-truth class indices (0..C-1, aligned with class_names order). class_names: List of class names in index order. image_backbone: Module/callable to encode images into embeddings. text_backbone: Module/callable to encode tokenized text into embeddings. tokenizer_fn: Callable that maps str | list[str] -> tensor of shape (T,). metrics: Dict of torchmetrics to compute on validation (e.g., {"top1": MulticlassAccuracy(...)}). """ def __init__( self, name: str, image_key: str, class_key: str, class_names: list[str], image_backbone: torch.nn.Module, text_backbone: torch.nn.Module, tokenizer_fn: Callable[[str | list[str]], torch.Tensor], metrics: Optional[dict | tuple | list | torchmetrics.Metric] = None, ) -> None: super().__init__() self.name = name self.image_key = image_key self.class_key = class_key self.class_names = class_names self.class_map = {i: c for i, c in enumerate(class_names)} self.image_backbone = image_backbone self.text_backbone = text_backbone self.tokenizer_fn = tokenizer_fn self._train_metrics = None self._val_metrics = None # Format metrics self.metrics_config = metrics logging.info(f"Initialized CLIPZeroShot callback: {name}") logging.info(f" - Image key: {image_key}") logging.info(f" - Number of classes: {len(class_names)}") logging.info(f" - Class names: [{', '.join(class_names[:5])}...]") logging.info(f" - Image backbone: {image_backbone.__class__.__name__}")
[docs] def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: """Initialize optimizer, scheduler, and metrics.""" # Call parent setup for module/optimizer/scheduler super().setup(trainer, pl_module, stage) # Setup metrics logging.info(f"{self.name}: Setting up metrics") pl_module.callbacks_metrics[self.name] = format_metrics_as_dict( self.metrics_config ) self.image_backbone = self.image_backbone.to(device=pl_module.device) self.text_backbone = self.text_backbone.to(device=pl_module.device) self._train_metrics = pl_module.callbacks_metrics[self.name]["_train"] self._val_metrics = pl_module.callbacks_metrics[self.name]["_val"] self.class_tokens = self.tokenizer_fn(self.class_names).to( device=pl_module.device ) self.class_embeds = self.text_backbone(input_ids=self.class_tokens).text_embeds self.class_embeds = F.normalize(self.class_embeds, dim=-1)
[docs] def on_validation_batch_end( self, trainer: Trainer, pl_module: LightningModule, outputs: dict, batch: dict, batch_idx: int, ) -> None: image = get_data_from_batch_or_outputs( self.image_key, batch, outputs, caller_name=self.name ) classes = get_data_from_batch_or_outputs( self.class_key, batch, outputs, caller_name=self.name ) if image is None: return image = image.to(device=pl_module.device) with torch.no_grad(): image_features = self.image_backbone(image).image_embeds image_features = F.normalize(image_features, dim=-1) logits = image_features @ self.class_embeds.T prediction_key = f"{self.name}_preds" if prediction_key not in batch: batch[prediction_key] = logits.detach() logs = {} for metric_name, metric in pl_module.callbacks_metrics[self.name][ "_val" ].items(): metric( logits.detach(), torch.tensor(classes) if isinstance(classes, list) else classes, ) logs[f"val/{self.name}_{metric_name}"] = metric pl_module.log_dict(logs, on_step=False, on_epoch=True, sync_dist=True)