Source code for stable_ssl.monitors

"""Training/evaluation metrics that are computed at the end of each step."""

#
# Author: @sami-bg
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
from collections import deque

import torch
import torch.distributed as dist

from stable_ssl import (
    BaseTrainer,
    JointEmbeddingPredictiveTrainer,
    JointEmbeddingTrainer,
)
from stable_ssl.utils import broadcast, gather, reduce, warn_once


def get_num_devices() -> int:
    """Return the number of devices used in this run."""
    num_devices = 1
    if dist.is_available() and dist.is_initialized():
        num_devices = dist.get_world_size()
    return num_devices


[docs] class Monitor: """Base class for metrics that are monitored at the end of each step. Inheritors must implement a `compute` method, that calculates the metric, and a `name` attribute for logging. """ name: str = "monitor"
[docs] def compute(self, trainer: BaseTrainer): """Abstract method that calculates a score given a model.""" pass
[docs] class RankMe(Monitor): """RankMe (effective rank) monitor from :cite:`garrido2023rankme`.""" name = "rankme" def __init__(self, limit: int = 8, epsilon: float = 1e-7): super().__init__() self.global_limit = limit self.num_devices = get_num_devices() assert self.global_limit % self.num_devices == 0, ( f"RankMe {limit=} must be divisible by {self.num_devices=}" ) self.device_limit = self.global_limit // self.num_devices self.epsilon = epsilon self.bounded_queue = deque(maxlen=self.device_limit) def rankme(self, encoding: torch.Tensor, epsilon: float) -> float: (batch_size, *_), device = encoding.shape, encoding.device encoding = encoding.reshape(batch_size, -1) self.bounded_queue.append(encoding) encoding = torch.cat(list(self.bounded_queue), dim=0) encoding = gather(encoding, rank=0) # NOTE torch.linalg.svd only supports torch.float32 for now if encoding.dtype != torch.float32: warn_once( f"RankMe expected tensors of type {torch.float32}, " f"but received {encoding.dtype}, will convert " f"{encoding.dtype}->{torch.float32}" ) encoding = encoding.to(torch.float32) _u, s, _vh = torch.linalg.svd(encoding, full_matrices=False) p = (s / torch.sum(s, axis=0)) + epsilon entropy = -torch.sum(p * torch.log(p)) rankme = float(torch.exp(entropy)) return broadcast(torch.tensor([rankme], device=device), src_rank=0).item()
[docs] def compute(self, trainer: BaseTrainer) -> float: if not isinstance( trainer, (JointEmbeddingTrainer, JointEmbeddingPredictiveTrainer) ): raise NotImplementedError( f"RankMe only implemented for JointEmbeddingTrainer, JointEmbeddingPredictiveTrainer " f"and not yet implemented for type {type(trainer)}" ) encoding: list | torch.Tensor = trainer.latest_representations if isinstance(encoding, list): # assume a list is of views, where each view is batch_size on the 0th dim # (as per JointEmbeddng) return [self.compute(batch) for batch in encoding][-1] return self.rankme(encoding, self.epsilon)
[docs] class LiDAR(Monitor): """LiDAR (Linear Discriminant Analysis Rank) monitor from :cite`thilak2023lidar`.""" name = "LiDAR" def __init__(self, n: int = 1000, epsilon: float = 1e-7, delta: float = 1e-3): super().__init__() self.n = n self.global_limit = n self.epsilon = epsilon self.delta = delta self.num_devices = get_num_devices() self.queue = None self.device_limit = self.global_limit // self.num_devices def _init_bounded_queue(self, batch_size: int) -> None: # NOTE Dynamically create queue, rounding it's capacity to # the nearest batch size. self.device_limit = (self.device_limit // batch_size) * batch_size self.global_limit = self.device_limit * self.num_devices self.queue = deque(maxlen=self.device_limit) if self.global_limit != self.n: warn_once( f"Received n={self.n} but rounded to {self.global_limit}. " f"To avoid this, make sure n={self.n} and your batch size " f"({batch_size}) are divisible." ) logging.info(f"Initialized LiDAR with n={self.global_limit}") return def lidar(self, batch_embeddings: list[torch.Tensor]) -> float: if not self.queue: batch_size = len(batch_embeddings) self._init_bounded_queue(batch_size) self.queue.extend(batch_embeddings) embeddings: torch.Tensor = torch.stack(list(self.queue), dim=0) (local_n, q, d), device = embeddings.shape, embeddings.device n_total_tensor = torch.tensor([local_n], device=device) n_total_tensor = reduce(n_total_tensor, rank=0, op=dist.ReduceOp.SUM) n_total_tensor = broadcast(n_total_tensor, src_rank=0) if (n_total := n_total_tensor.item()) == 1: warn_once("LiDAR cannot compute within-class scatter with only one class!") return 1.0 class_means = embeddings.mean(dim=1) grand_mean_local = class_means.mean(dim=0) local_Sb = torch.zeros(d, d, device=device) local_Sw = torch.zeros(d, d, device=device) for i in range(local_n): diff_b = (class_means[i] - grand_mean_local).unsqueeze(1) local_Sb += diff_b @ diff_b.T for j in range(q): diff_w = (embeddings[i, j] - class_means[i]).unsqueeze(1) local_Sw += diff_w @ diff_w.T S_b = reduce(local_Sb, rank=0, op=dist.ReduceOp.SUM) / (n_total - 1) S_w = reduce(local_Sw, rank=0, op=dist.ReduceOp.SUM) / (n_total * (q - 1)) S_w += self.delta * torch.eye(d, device=device) eigvals_w, eigvecs_w = torch.linalg.eigh(S_w) eigvals_w = torch.clamp(eigvals_w, min=self.epsilon) invsqrt_w = (eigvecs_w * (1.0 / torch.sqrt(eigvals_w))) @ eigvecs_w.transpose( -1, -2 ) Sigma_lidar = invsqrt_w @ S_b @ invsqrt_w lam, _ = torch.linalg.eigh(Sigma_lidar) lam = torch.clamp(lam, min=0.0) lam_sum = lam.sum() + self.epsilon p = lam / lam_sum p_log_p = p * torch.log(p + self.epsilon) lidar = float(torch.exp(-p_log_p.sum())) return broadcast(torch.tensor([lidar], device=device), src_rank=0).item()
[docs] def compute(self, trainer: BaseTrainer) -> float: if not isinstance( trainer, (JointEmbeddingTrainer, JointEmbeddingPredictiveTrainer) ): raise NotImplementedError( f"LiDAR only implemented for JointEmbeddingTrainer, JointEmbeddingPredictiveTrainer " f"and not yet implemented for type {type(trainer)}" ) trainer: JointEmbeddingTrainer | JointEmbeddingPredictiveTrainer embeddings: list[torch.Tensor] = trainer.latest_embeddings return self.lidar(embeddings)