LiDAR#

class stable_pretraining.callbacks.LiDAR(name: str, target: str, queue_length: int, target_shape: int | Iterable[int], n_classes: int = 100, samples_per_class: int = 10, delta: float = 0.0001, epsilon: float = 1e-08)[source]#

Bases: Callback

LiDAR (Linear Discriminant Analysis Rank) monitor using queue discovery.

LiDAR measures the effective rank of learned representations using Linear Discriminant Analysis (LDA). It computes the exponential of the entropy of the eigenvalue distribution from the LDA transformation, providing a metric between 1 and min(d, n_classes - 1) where d is the feature dimension, indicating how many dimensions are effectively being used.

This implementation is based on Thilak et al. “LiDAR: Sensing Linear Probing Performance in Joint Embedding SSL Architectures” (arXiv:2312.04000).

IMPORTANT: Surrogate Class Formation Requirement#

The LiDAR paper requires that each “surrogate class” consists of q augmented views of the same clean sample. The current implementation chunks the queue sequentially into groups of size samples_per_class. For faithful reproduction of the paper:

  • Ensure the upstream queue pushes q contiguous augmentations of each clean sample

  • OR implement ID-based grouping to ensure each group contains views of the same sample

Without proper grouping, the metric may not accurately reflect the paper’s methodology.

The metric helps detect: - Dimensional collapse in self-supervised learning - Loss of representational capacity - Over-regularization effects

param name:

Unique identifier for this callback instance

param target:

Key in batch dict containing the feature embeddings to monitor

param queue_length:

Size of the circular buffer for caching embeddings

param target_shape:

Shape of the target embeddings (e.g., 768 for 768-dim features)

param n_classes:

Number of surrogate classes (clean samples) for LDA computation

param samples_per_class:

Number of augmented samples per class

param delta:

Regularization constant added to within-class covariance (default: 1e-4)

param epsilon:

Small constant for numerical stability (default: 1e-8)

on_validation_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: dict, batch: dict, batch_idx: int, dataloader_idx: int = 0) None[source]#

Compute LiDAR metric on the first validation batch only.

setup(trainer: Trainer, pl_module: LightningModule, stage: str) None[source]#

Find or create the queue callback for target features.

property state_key: str#

Unique identifier for this callback’s state during checkpointing.