Source code for stable_ssl.callbacks.image_retrieval

import types
from typing import List, Union

import numpy as np
import torch
from lightning.pytorch import Callback, LightningModule, Trainer
from loguru import logger as logging
from torchmetrics.retrieval.base import RetrievalMetric

from .utils import format_metrics_as_dict


def wrap_validation_step(fn, input, name):
    def ffn(
        self,
        batch,
        batch_idx,
        fn=fn,
        name=name,
        input=input,
    ):
        batch = fn(batch, batch_idx)

        with torch.no_grad():
            norm = self._callbacks_modules[name]["normalizer"](batch[input])
            norm = torch.nn.functional.normalize(norm, dim=1, p=2)

        idx = self.all_gather(batch["sample_idx"])
        norm = self.all_gather(norm)

        if self.local_rank == 0:
            self.embeds[idx] = norm

        return batch

    return ffn


[docs] class ImageRetrieval(Callback): """Image Retrieval evaluator for self-supervised learning. The implementation follows: 1. https://github.com/facebookresearch/dino/blob/main/eval_image_retrieval.py """ NAME = "ImageRetrieval" def __init__( self, pl_module, name: str, input: str, query_col: str, retrieval_col: str | List[str], metrics, features_dim: Union[tuple[int], list[int], int], normalizer: str = None, ) -> None: logging.info(f"Setting up callback ({self.NAME})") logging.info(f"\t- {input=}") logging.info(f"\t- {query_col=}") logging.info("\t- caching modules into `_callbacks_modules`") if name in pl_module._callbacks_modules: raise ValueError(f"{name=} already used in callbacks") if type(features_dim) in [list, tuple]: features_dim = np.prod(features_dim) if normalizer is not None and normalizer not in ["batch_norm", "layer_norm"]: raise ValueError( "`normalizer` has to be one of `batch_norm` or `layer_norm`" ) if normalizer == "batch_norm": normalizer = torch.nn.BatchNorm1d(features_dim, affine=False) elif normalizer == "layer_norm": normalizer = torch.nn.LayerNorm( features_dim, elementwise_affine=False, bias=False ) else: normalizer = torch.nn.Identity() pl_module._callbacks_modules[name] = torch.nn.ModuleDict( { "normalizer": normalizer, } ) logging.info( f"`_callbacks_modules` now contains ({list(pl_module._callbacks_modules.keys())})" ) if not isinstance(retrieval_col, list): retrieval_col = [retrieval_col] for k, metric in metrics.items(): if not isinstance(metric, RetrievalMetric): raise ValueError( f"Only `RetrievalMetric` is supported for {self.NAME} callback, but got {metric} for {k}" ) logging.info("\t- caching metrics into `_callbacks_metrics`") pl_module._callbacks_metrics[name] = format_metrics_as_dict(metrics) logging.info("\t- wrapping the `validation_step`") fn = wrap_validation_step(pl_module.validation_step, input, name) pl_module.validation_step = types.MethodType(fn, pl_module) self.name = name self.features_dim = features_dim self.query_col = query_col self.retrieval_col = retrieval_col # pl_module.embeds = None
[docs] def on_validation_epoch_start( self, trainer: Trainer, pl_module: LightningModule ) -> None: # register buffer on rank 0 val_dataset = pl_module.trainer.datamodule.val.dataset dataset_size = len(val_dataset) if pl_module.local_rank == 0: device = pl_module.device pl_module.embeds = torch.zeros( (dataset_size, self.features_dim), device=device ) return
[docs] def on_validation_epoch_end( self, trainer: Trainer, pl_module: LightningModule ) -> None: if pl_module.local_rank == 0: logging.info(f"Computing results for {self.name} callback") val_dataset = pl_module.trainer.datamodule.val.dataset.dataset if len(pl_module.embeds) != len(val_dataset): logging.warning( f"Expected {len(val_dataset)} embeddings, but got {len(pl_module.embeds)}. Skipping evaluation." ) return is_query = torch.tensor( val_dataset[self.query_col], device=pl_module.device ).squeeze() query_idx = torch.nonzero(is_query) query = pl_module.embeds[is_query] gallery = pl_module.embeds[~is_query] score = query @ gallery.t() # ranks = torch.argsort(-score, dim=1) preds = [] targets = [] indexes = [] for idx, q_idx in enumerate(query_idx): # add query idx to the indexes indexes.append(q_idx.repeat(len(gallery))) # build target for query target = torch.zeros( len(gallery), dtype=torch.bool, device=pl_module.device ) for col in self.retrieval_col: ret_idx = val_dataset[q_idx][col] if ret_idx: target[ret_idx] = True targets.append(target) preds.append(score[idx]) preds = torch.cat(preds) targets = torch.cat(targets) indexes = torch.cat(indexes) logs = {} for k, metric in pl_module._callbacks_metrics[self.name]["_val"].items(): res = metric(preds, targets, indexes=indexes) logs[f"eval/{self.name}_{k}"] = res.item() * 100 self.log_dict(logs, on_epoch=True, rank_zero_only=True) logging.info(f"Finished computing results for {self.name} callback") if torch.distributed.is_initialized(): torch.distributed.barrier() pl_module.embeds = None