Source code for stable_pretraining.losses.reconstruction
"""Reconstruction-based SSL losses.
This module contains reconstruction-based self-supervised learning losses
such as Masked Autoencoder (MAE).
"""
[docs]
def mae(target, pred, mask, norm_pix_loss=False):
"""Compute masked autoencoder loss.
Args:
target: [N, L, p*p*3] target images
pred: [N, L, p*p*3] predicted images
mask: [N, L], 0 is keep, 1 is remove
norm_pix_loss: whether to normalize pixels
Returns:
loss: mean loss value
"""
if norm_pix_loss:
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
target = (target - mean) / (var + 1.0e-6) ** 0.5
loss = (pred - target) ** 2
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
return loss