Source code for stable_ssl.joint_embedding.barlow_twins
# -*- coding: utf-8 -*-
"""BarlowTwins model."""
#
# Author: Randall Balestriero <randallbalestriero@gmail.com>
# Hugues Van Assel <vanasselhugues@gmail.com>
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
import torch
from .base import JEConfig, JETrainer
from stable_ssl.utils import off_diagonal
[docs]
class BarlowTwins(JETrainer):
"""BarlowTwins model from [ZJM+21]_.
Reference
---------
.. [ZJM+21] Zbontar, J., Jing, L., Misra, I., LeCun, Y., & Deny, S. (2021).
Barlow Twins: Self-Supervised Learning via Redundancy Reduction.
In International conference on machine learning (pp. 12310-12320). PMLR.
"""
def compute_ssl_loss(self, z1, z2):
# Empirical cross-correlation matrix.
c = self.bn(z1).T @ self.bn(z2)
# Sum the cross-correlation matrix between all gpus.
c.div_(self.args.batch_size)
torch.distributed.all_reduce(c)
on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
off_diag = off_diagonal(c).pow_(2).sum()
loss = on_diag + self.config.model.lambd * off_diag
return loss
[docs]
@dataclass
class BarlowTwinsConfig(JEConfig):
"""Configuration for the BarlowTwins model parameters.
Parameters
----------
lambd : str
Lambda parameter for the off-diagonal loss. Default is 0.1.
"""
lambd: str = 0.1
def trainer(self):
return BarlowTwins