Source code for stable_ssl.callbacks.lidar

import types
from typing import Iterable

import torch
import torch.distributed as dist
from loguru import logger as logging
from torch.distributed import broadcast, reduce

from .queue import OnlineQueue


def wrap_validation_step(fn, target, input, name):
    def ffn(self, batch, batch_idx, fn=fn, target=target, input=input, name=name):
        raise NotImplementedError
        batch = fn(batch, batch_idx)
        if batch_idx > 0:
            return batch
        embeddings = getattr(self, f"_cached_{name}_X")
        if self.trainer.global_rank == 0:
            class_means = embeddings.mean(dim=1)
            grand_mean_local = class_means.mean(dim=0)

            d = embeddings.shape[-1]
            device = embeddings.device
            local_n = class_means.shape[0]
            q = embeddings.shape[1]
            n_total = local_n * q

            local_Sb = torch.zeros(d, d, device=device)
            local_Sw = torch.zeros(d, d, device=device)

            for i in range(local_n):
                diff_b = (class_means[i] - grand_mean_local).unsqueeze(1)
                local_Sb += diff_b @ diff_b.T
                for j in range(q):
                    diff_w = (embeddings[i, j] - class_means[i]).unsqueeze(1)
                    local_Sw += diff_w @ diff_w.T
            S_b = reduce(local_Sb, rank=0, op=dist.ReduceOp.SUM) / (n_total - 1)
            S_w = reduce(local_Sw, rank=0, op=dist.ReduceOp.SUM) / (n_total * (q - 1))
            S_w += self.delta * torch.eye(d, device=device)

            eigvals_w, eigvecs_w = torch.linalg.eigh(S_w)
            eigvals_w = torch.clamp(eigvals_w, min=self.epsilon)

            invsqrt_w = (
                eigvecs_w * (1.0 / torch.sqrt(eigvals_w))
            ) @ eigvecs_w.transpose(-1, -2)
            Sigma_lidar = invsqrt_w @ S_b @ invsqrt_w

            lam, _ = torch.linalg.eigh(Sigma_lidar)
            lam = torch.clamp(lam, min=0.0)

            lam_sum = lam.sum() + self.epsilon
            p = lam / lam_sum

            p_log_p = p * torch.log(p + self.epsilon)

            lidar = float(torch.exp(-p_log_p.sum()))
            return broadcast(torch.tensor([lidar], device=device), src_rank=0).item()

        return batch

    return ffn


[docs] class LiDAR(OnlineQueue): """LiDAR (Linear Discriminant Analysis Rank) monitor from :cite`thilak2023lidar`.""" def __init__( self, pl_module, name: str, target: str, queue_length: int, target_shape: Iterable[int], ) -> None: super().__init__( pl_module, name=name, to_save=[target], queue_length=queue_length, shapes=[target_shape], dtypes=[torch.float], ) logging.info("\t- wrapping the `validation_step`") fn = wrap_validation_step(pl_module.validation_step, target, input, name) pl_module.validation_step = types.MethodType(fn, pl_module)