ConvMixer#
- class stable_pretraining.backbone.ConvMixer(in_channels=3, num_classes=10, dim=64, depth=6, kernel_size=9, patch_size=7)[source]#
Bases:
ModuleConvMixer model.
A simple and efficient convolutional architecture that operates directly on patches.
- Parameters:
in_channels (int, optional) – Number of input channels. Defaults to 3.
num_classes (int, optional) – Number of output classes. Defaults to 10.
dim (int, optional) – Hidden dimension size. Defaults to 64.
depth (int, optional) – Number of ConvMixer blocks. Defaults to 6.
kernel_size (int, optional) – Kernel size for depthwise convolution. Defaults to 9.
patch_size (int, optional) – Patch embedding size. Defaults to 7.
Note
Introduced in [Trockman and Kolter, 2022].
- forward(xb)[source]#
Forward pass through the ConvMixer model.
- Parameters:
xb (torch.Tensor) – Input tensor of shape (batch_size, in_channels, height, width).
- Returns:
Output logits of shape (batch_size, num_classes).
- Return type: