OnlineKNN#
- class stable_ssl.callbacks.OnlineKNN(name: str, input: str, target: str, queue_length: int, metrics: Dict, input_dim: Tuple[int, ...] | List[int] | int | None = None, target_dim: int | None = None, k: int = 5, temperature: float = 0.07, chunk_size: int = -1, distance_metric: Literal['euclidean', 'squared_euclidean', 'cosine', 'manhattan'] = 'euclidean')[source]#
Bases:
Callback
Weighted KNN online evaluator using queue discovery.
This callback finds OnlineQueue callbacks that track the required features and labels, then uses that data for KNN evaluation during validation.
- Parameters:
name – Unique identifier for this callback instance
input – Key in batch dict containing input features
target – Key in batch dict containing target labels
queue_length – Required queue length for both input and target
metrics – Dictionary of metrics to compute during validation
input_dim – Dimensionality of input features (None to accept any)
target_dim – Dimensionality of targets (None to accept any)
k – Number of nearest neighbors to consider
temperature – Temperature parameter for distance weighting
chunk_size – Batch size for memory-efficient distance computation (-1 for no chunking)
distance_metric – Distance metric to use for KNN computation