OnlineKNN#

class stable_ssl.callbacks.OnlineKNN(name: str, input: str, target: str, queue_length: int, metrics: Dict, input_dim: Tuple[int, ...] | List[int] | int | None = None, target_dim: int | None = None, k: int = 5, temperature: float = 0.07, chunk_size: int = -1, distance_metric: Literal['euclidean', 'squared_euclidean', 'cosine', 'manhattan'] = 'euclidean')[source]#

Bases: Callback

Weighted KNN online evaluator using queue discovery.

This callback finds OnlineQueue callbacks that track the required features and labels, then uses that data for KNN evaluation during validation.

Parameters:
  • name – Unique identifier for this callback instance

  • input – Key in batch dict containing input features

  • target – Key in batch dict containing target labels

  • queue_length – Required queue length for both input and target

  • metrics – Dictionary of metrics to compute during validation

  • input_dim – Dimensionality of input features (None to accept any)

  • target_dim – Dimensionality of targets (None to accept any)

  • k – Number of nearest neighbors to consider

  • temperature – Temperature parameter for distance weighting

  • chunk_size – Batch size for memory-efficient distance computation (-1 for no chunking)

  • distance_metric – Distance metric to use for KNN computation

on_validation_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: Dict, batch: Dict, batch_idx: int, dataloader_idx: int = 0) None[source]#

Compute KNN predictions during validation.

setup(trainer: Trainer, pl_module: LightningModule, stage: str) None[source]#

Find or create queue callbacks and setup metrics.