OnlineKNN#

class stable_pretraining.callbacks.OnlineKNN(name: str, input: str, target: str, queue_length: int, metrics: Dict, input_dim: Tuple[int, ...] | List[int] | int | None = None, target_dim: int | None = None, k: int = 5, temperature: float = 0.07, chunk_size: int = -1, distance_metric: Literal['euclidean', 'squared_euclidean', 'cosine', 'manhattan'] = 'euclidean')[source]#

Bases: Callback

Weighted K-Nearest Neighbors online evaluator using queue discovery.

This callback implements a weighted KNN classifier that evaluates the quality of learned representations during training. It automatically discovers or creates OnlineQueue callbacks to maintain circular buffers of features and labels, then uses this cached data to compute KNN predictions during validation.

The KNN evaluation is performed by: 1. Finding k nearest neighbors in the feature space 2. Weighting neighbors by inverse distance with temperature scaling 3. Using weighted voting to produce class predictions 4. Computing specified metrics on the predictions

Parameters:
  • name – Unique identifier for this callback instance. Used for logging and storing metrics.

  • input – Key in batch dict containing input features to evaluate.

  • target – Key in batch dict containing ground truth target labels.

  • queue_length – Size of the circular buffer for caching features and labels. Larger values provide more representative samples but use more memory.

  • metrics – Dictionary of metrics to compute during validation. Keys are metric names, values are metric instances (e.g., torchmetrics.Accuracy).

  • input_dim – Expected dimensionality of input features. Can be int, tuple/list (will be flattened to product), or None to accept any dimension.

  • target_dim – Expected dimensionality of targets. None accepts any dimension.

  • k – Number of nearest neighbors to consider for voting. Default is 5.

  • temperature – Temperature parameter for distance weighting. Lower values give more weight to closer neighbors. Default is 0.07.

  • chunk_size – Batch size for memory-efficient distance computation. Set to -1 to compute all distances at once. Default is -1.

  • distance_metric – Distance metric for finding nearest neighbors. Options are ‘euclidean’, ‘squared_euclidean’, ‘cosine’, ‘manhattan’. Default is ‘euclidean’.

Raises:

ValueError – If k <= 0, temperature <= 0, or chunk_size is invalid.

Note

  • The callback automatically handles distributed training by gathering data

  • Mixed precision is supported through automatic dtype conversion

  • Predictions are stored in batch dict with key ‘{name}_preds’

  • Metrics are logged with prefix ‘eval/{name}_’

on_validation_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: Dict, batch: Dict, batch_idx: int, dataloader_idx: int = 0) None[source]#

Compute KNN predictions during validation.

setup(trainer: Trainer, pl_module: LightningModule, stage: str) None[source]#

Find or create queue callbacks and setup metrics.

property state_key: str#

Unique identifier for this callback’s state during checkpointing.