Source code for stable_ssl.backbone.convmixer

from torch import nn


[docs] class ConvMixer(nn.Module): """ConvMixer model from :cite:`trockman2022patches`.""" def __init__( self, in_channels=3, num_classes=10, dim=64, depth=6, kernel_size=9, patch_size=7, ): super().__init__() self.conv1 = nn.Sequential( nn.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size), nn.BatchNorm2d(dim), nn.ReLU(), ) self.blocks_a = nn.ModuleList( [ nn.Sequential( nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"), nn.BatchNorm2d(dim), nn.ReLU(), ) for _ in range(depth) ] ) self.blocks_b = nn.ModuleList( [ nn.Sequential( nn.Conv2d(dim, dim, kernel_size=1), nn.BatchNorm2d(dim), nn.ReLU() ) for _ in range(depth) ] ) self.pool = nn.Sequential(nn.AdaptiveMaxPool2d(1), nn.Flatten()) self.fc = nn.Linear(dim, num_classes)
[docs] def forward(self, xb): """Forward pass.""" out = self.conv1(xb) for a, b in zip(self.blocks_a, self.blocks_b): out = out + a(out) out = b(out) out = self.fc(self.pool(out)) return out