LiDAR#
- class stable_pretraining.callbacks.LiDAR(name: str, target: str, queue_length: int, target_shape: int | Iterable[int], n_classes: int = 100, samples_per_class: int = 10, delta: float = 0.0001, epsilon: float = 1e-08)[source]#
Bases:
Callback
LiDAR (Linear Discriminant Analysis Rank) monitor using queue discovery.
LiDAR measures the effective rank of learned representations using Linear Discriminant Analysis (LDA). It computes the exponential of the entropy of the eigenvalue distribution from the LDA transformation, providing a metric between 1 and min(d, n_classes - 1) where d is the feature dimension, indicating how many dimensions are effectively being used.
This implementation is based on Thilak et al. “LiDAR: Sensing Linear Probing Performance in Joint Embedding SSL Architectures” (arXiv:2312.04000).
IMPORTANT: Surrogate Class Formation Requirement#
The LiDAR paper requires that each “surrogate class” consists of q augmented views of the same clean sample. The current implementation chunks the queue sequentially into groups of size samples_per_class. For faithful reproduction of the paper:
Ensure the upstream queue pushes q contiguous augmentations of each clean sample
OR implement ID-based grouping to ensure each group contains views of the same sample
Without proper grouping, the metric may not accurately reflect the paper’s methodology.
The metric helps detect: - Dimensional collapse in self-supervised learning - Loss of representational capacity - Over-regularization effects
- param name:
Unique identifier for this callback instance
- param target:
Key in batch dict containing the feature embeddings to monitor
- param queue_length:
Size of the circular buffer for caching embeddings
- param target_shape:
Shape of the target embeddings (e.g., 768 for 768-dim features)
- param n_classes:
Number of surrogate classes (clean samples) for LDA computation
- param samples_per_class:
Number of augmented samples per class
- param delta:
Regularization constant added to within-class covariance (default: 1e-4)
- param epsilon:
Small constant for numerical stability (default: 1e-8)
- on_validation_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: dict, batch: dict, batch_idx: int, dataloader_idx: int = 0) None [source]#
Compute LiDAR metric on the first validation batch only.