Source code for stable_ssl.backbone.convmixer
from torch import nn
[docs]
class ConvMixer(nn.Module):
"""ConvMixer model from :cite:`trockman2022patches`."""
def __init__(
self,
in_channels=3,
num_classes=10,
dim=64,
depth=6,
kernel_size=9,
patch_size=7,
):
super().__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size),
nn.BatchNorm2d(dim),
nn.ReLU(),
)
self.blocks_a = nn.ModuleList(
[
nn.Sequential(
nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"),
nn.BatchNorm2d(dim),
nn.ReLU(),
)
for _ in range(depth)
]
)
self.blocks_b = nn.ModuleList(
[
nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=1), nn.BatchNorm2d(dim), nn.ReLU()
)
for _ in range(depth)
]
)
self.pool = nn.Sequential(nn.AdaptiveMaxPool2d(1), nn.Flatten())
self.fc = nn.Linear(dim, num_classes)
[docs]
def forward(self, xb):
"""Forward pass."""
out = self.conv1(xb)
for a, b in zip(self.blocks_a, self.blocks_b):
out = out + a(out)
out = b(out)
out = self.fc(self.pool(out))
return out