Source code for stable_pretraining.callbacks.lidar

"""LiDAR (Linear Discriminant Analysis Rank) callback for monitoring representation quality.

Based on:
    Thilak et al. "LiDAR: Sensing Linear Probing Performance in Joint Embedding SSL Architectures"
    arXiv:2312.04000 (2023)
"""

from typing import Iterable, Optional, Union

import torch
from lightning.pytorch import Callback, LightningModule, Trainer
from loguru import logger as logging

from .queue import find_or_create_queue_callback


[docs] class LiDAR(Callback): """LiDAR (Linear Discriminant Analysis Rank) monitor using queue discovery. LiDAR measures the effective rank of learned representations using Linear Discriminant Analysis (LDA). It computes the exponential of the entropy of the eigenvalue distribution from the LDA transformation, providing a metric between 1 and min(d, n_classes - 1) where d is the feature dimension, indicating how many dimensions are effectively being used. This implementation is based on Thilak et al. "LiDAR: Sensing Linear Probing Performance in Joint Embedding SSL Architectures" (arXiv:2312.04000). IMPORTANT: Surrogate Class Formation Requirement ------------------------------------------------- The LiDAR paper requires that each "surrogate class" consists of q augmented views of the same clean sample. The current implementation chunks the queue sequentially into groups of size samples_per_class. For faithful reproduction of the paper: - Ensure the upstream queue pushes q contiguous augmentations of each clean sample - OR implement ID-based grouping to ensure each group contains views of the same sample Without proper grouping, the metric may not accurately reflect the paper's methodology. The metric helps detect: - Dimensional collapse in self-supervised learning - Loss of representational capacity - Over-regularization effects Args: name: Unique identifier for this callback instance target: Key in batch dict containing the feature embeddings to monitor queue_length: Size of the circular buffer for caching embeddings target_shape: Shape of the target embeddings (e.g., 768 for 768-dim features) n_classes: Number of surrogate classes (clean samples) for LDA computation samples_per_class: Number of augmented samples per class delta: Regularization constant added to within-class covariance (default: 1e-4) epsilon: Small constant for numerical stability (default: 1e-8) """ def __init__( self, name: str, target: str, queue_length: int, target_shape: Union[int, Iterable[int]], n_classes: int = 100, samples_per_class: int = 10, delta: float = 1e-4, epsilon: float = 1e-8, ) -> None: super().__init__() # Convert target_shape to int if needed if isinstance(target_shape, (list, tuple)): if len(target_shape) == 1: target_shape = target_shape[0] else: target_shape = int(torch.prod(torch.tensor(target_shape))) self.name = name self.target = target self.queue_length = queue_length self.target_shape = target_shape self.n_classes = n_classes self.samples_per_class = samples_per_class self.delta = delta self.epsilon = epsilon self._target_queue = None # Validate queue length adequacy min_required_samples = n_classes * samples_per_class if queue_length < min_required_samples: logging.warning( f"{name}: Queue length ({queue_length}) is less than required " f"samples ({min_required_samples} = {n_classes} classes × {samples_per_class} samples/class). " f"LiDAR computation may use fewer classes than specified." ) logging.info(f"Initialized LiDAR callback: {name}") logging.info(f" - Target: {target}") logging.info(f" - Queue length: {queue_length}") logging.info(f" - Feature dimension: {target_shape}") logging.info( f" - N classes: {n_classes}, Samples per class: {samples_per_class}" ) @property def state_key(self) -> str: """Unique identifier for this callback's state during checkpointing.""" return f"LiDAR[name={self.name}]"
[docs] def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: """Find or create the queue callback for target features.""" if self._target_queue is None: self._target_queue = find_or_create_queue_callback( trainer, self.target, self.queue_length, self.target_shape, torch.float32, gather_distributed=True, create_if_missing=True, ) logging.info(f"{self.name}: Using queue for target '{self.target}'")
[docs] def on_validation_batch_end( self, trainer: Trainer, pl_module: LightningModule, outputs: dict, batch: dict, batch_idx: int, dataloader_idx: int = 0, ) -> None: """Compute LiDAR metric on the first validation batch only.""" if batch_idx > 0: return logging.info(f"{self.name}: Computing LiDAR on first validation batch") embeddings = self._target_queue.data if embeddings is None: logging.warning(f"{self.name}: Queue data not available") return if embeddings.numel() == 0: logging.warning(f"{self.name}: Queue data is empty") return # The queue already handles gathering across GPUs if gather_distributed=True # So embeddings here already contains data from all GPUs lidar_value = self._compute_lidar(embeddings) if lidar_value is not None: pl_module.log( self.name, lidar_value, rank_zero_only=True, # Only log from rank 0 to avoid duplicates sync_dist=False, # No need to sync since we compute same value on all ranks ) if trainer.global_rank == 0: logging.info(f"{self.name}: LiDAR = {lidar_value:.4f}")
def _compute_lidar(self, embeddings: torch.Tensor) -> Optional[float]: """Compute the LiDAR metric from embeddings. Args: embeddings: Tensor of shape (n_samples, feature_dim) Returns: LiDAR value or None if computation fails """ n_samples, d = embeddings.shape # Determine how many classes we can form actual_n_classes = min(self.n_classes, n_samples // self.samples_per_class) if actual_n_classes < 2: logging.warning( f"{self.name}: Not enough samples for LiDAR computation. " f"Need at least {2 * self.samples_per_class} samples, got {n_samples}" ) return None # Reshape embeddings to (n_classes, samples_per_class, feature_dim) # WARNING: This assumes the queue contains contiguous groups of augmentations # from the same clean sample. If the queue mixes samples randomly, the # surrogate classes won't match the paper's methodology. # Take only the samples we need n_used = actual_n_classes * self.samples_per_class embeddings = embeddings[:n_used].view( actual_n_classes, self.samples_per_class, d ) with torch.no_grad(): class_means = embeddings.mean(dim=1) # (n_classes, d) grand_mean = class_means.mean(dim=0) # (d,) device = embeddings.device # Sb = sum((mu_i - mu) @ (mu_i - mu)^T) / (n_classes - 1) centered_means = class_means - grand_mean.unsqueeze(0) Sb = (centered_means.T @ centered_means) / (actual_n_classes - 1) # First center all samples by their class means class_means_expanded = class_means.unsqueeze(1).expand_as(embeddings) centered_samples = embeddings - class_means_expanded centered_samples_flat = centered_samples.reshape(-1, d) # Sw = sum((x_ij - mu_i) @ (x_ij - mu_i)^T) / (n_classes * (samples_per_class - 1)) # This is the unbiased estimate of class-averaged within-class covariance # as described in the LiDAR paper (arXiv:2312.04000) Sw = (centered_samples_flat.T @ centered_samples_flat) / ( actual_n_classes * (self.samples_per_class - 1) ) # Add regularization to within-class covariance Sw = Sw + self.delta * torch.eye(d, device=device) # Compute Sw^(-1/2) using eigendecomposition eigvals_w, eigvecs_w = torch.linalg.eigh(Sw) eigvals_w = torch.clamp(eigvals_w, min=self.epsilon) # Sw^(-1/2) = V * D^(-1/2) * V^T sqrt_inv_eigvals = 1.0 / torch.sqrt(eigvals_w) Sw_invsqrt = (eigvecs_w * sqrt_inv_eigvals.unsqueeze(0)) @ eigvecs_w.T # Compute LiDAR matrix: Σ_lidar = Sw^(-1/2) * Sb * Sw^(-1/2) Sigma_lidar = Sw_invsqrt @ Sb @ Sw_invsqrt # Handle numerical errors by ensuring symmetry Sigma_lidar = 0.5 * (Sigma_lidar + Sigma_lidar.T) # Compute eigenvalues of LiDAR matrix eigvals_lidar = torch.linalg.eigvalsh(Sigma_lidar) eigvals_lidar = torch.clamp(eigvals_lidar, min=0.0) # Normalize eigenvalues to get probability distribution # Following the paper: p_i = (lambda_i + epsilon) / sum_j(lambda_j + epsilon) eigvals_with_eps = eigvals_lidar + self.epsilon eigvals_sum = eigvals_with_eps.sum() if eigvals_sum <= 0: logging.warning(f"{self.name}: All eigenvalues are zero or negative") return 1.0 # Return minimum rank p = eigvals_with_eps / eigvals_sum # Compute entropy and LiDAR metric entropy = -(p * torch.log(p)).sum() lidar = torch.exp(entropy).item() return lidar