Source code for stable_ssl.joint_embedding.vicreg

# -*- coding: utf-8 -*-
"""VICReg model."""
#
# Author: Randall Balestriero <randallbalestriero@gmail.com>
#         Hugues Van Assel <vanasselhugues@gmail.com>
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
import torch

from stable_ssl.utils import FullGatherLayer, off_diagonal
from .base import JEConfig, JETrainer


[docs] class VICReg(JETrainer): """VICReg model from [BPL21]_. Reference --------- .. [BPL21] Bardes, A., Ponce, J., & LeCun, Y. (2021). VICReg: Variance-Invariance-Covariance Regularization For Self-Supervised Learning. International Conference on Learning Representations (ICLR). """ def compute_ssl_loss(self, z1, z2): repr_loss = torch.nn.functional.mse_loss(z1, z2) if self.config.hardware.world_size > 1: x = torch.cat(FullGatherLayer.apply(z1), dim=0) y = torch.cat(FullGatherLayer.apply(z2), dim=0) else: x = z1 y = z2 x = x - x.mean(dim=0) y = y - y.mean(dim=0) std_x = torch.sqrt(x.var(dim=0) + self.config.model.epsilon) std_y = torch.sqrt(y.var(dim=0) + self.config.model.epsilon) std_loss = ( torch.mean(torch.nn.functional.relu(1 - std_x)) / 2 + torch.mean(torch.nn.functional.relu(1 - std_y)) / 2 ) cov_x = (x.T @ x) / (x.size(0) - 1) cov_y = (y.T @ y) / (x.size(0) - 1) cov_loss = off_diagonal(cov_x).pow_(2).sum().div(x.size(1)) + off_diagonal( cov_y ).pow_(2).sum().div(x.size(1)) loss = ( self.config.model.sim_coeff * repr_loss + self.config.model.std_coeff * std_loss + self.config.model.cov_coeff * cov_loss ) return loss
[docs] @dataclass class VICRegConfig(JEConfig): """Configuration for the VICreg model parameters.""" sim_coeff: float = 25 std_coeff: float = 25 cov_coeff: float = 1 epsilon: float = 0.0001 def trainer(self): return VICReg