stable_pretraining.backbone package#
Submodules#
stable_pretraining.backbone.aggregator module#
Modular tensor aggregation module for feeding multi-scale/multi-layer features to MLPs.
Commonly used for: - SSL linear probes using multiple transformer layers - Multi-scale feature fusion - Combining features from different network stages
- class stable_pretraining.backbone.aggregator.TensorAggregator(input_spec: str | List[str] | Dict[str, str], adaptive_pool_size: int = 1)[source]#
Bases:
ModuleAggregates multi-dimensional tensors into 2D format for MLP input.
Pure aggregation module with NO trainable parameters. Handles various input formats and aggregation strategies.
- Parameters:
input_spec – Specification of input format and aggregation modes: - str: Single aggregation mode for all tensors (e.g., “mean”) - List[str]: Per-tensor aggregation modes for list inputs - Dict[str, str]: Per-key aggregation modes for dict inputs
adaptive_pool_size – Output size for adaptive pooling (default: 1)
- Aggregation Modes:
“mean”: Spatial/temporal mean pooling
“max”: Spatial/temporal max pooling
“cls”: Take first token (for transformers with [CLS] token)
“flatten”: Flatten all dimensions after batch
“adaptive”: Adaptive average pooling to fixed size
Examples
>>> # Single tensor with mean pooling >>> agg = TensorAggregator("mean") >>> x = torch.randn(4, 768, 14, 14) >>> out = agg(x) # Shape: (4, 768)
>>> # SSL: Last 4 transformer layers with CLS token >>> agg = TensorAggregator(["cls", "cls", "cls", "cls"]) >>> layers = [torch.randn(4, 197, 768) for _ in range(4)] >>> out = agg(layers) # Shape: (4, 3072) # 768 * 4
>>> # Multi-scale features >>> agg = TensorAggregator({"layer1": "cls", "layer2": "mean", "conv": "mean"}) >>> out = agg( ... { ... "layer1": torch.randn(4, 197, 768), ... "layer2": torch.randn(4, 197, 768), ... "conv": torch.randn(4, 512, 14, 14), ... } ... ) # Shape: (4, 2048)
- compute_output_dim(input_shapes: tuple | List[tuple] | Dict[str, tuple]) int[source]#
Compute the output dimension given input shapes.
- Parameters:
input_shapes – Shape(s) of input tensor(s) (excluding batch dim)
- Returns:
Total output features
Examples
>>> agg = TensorAggregator(["cls", "mean"]) >>> agg.compute_output_dim([(197, 768), (197, 768)]) 1536
>>> agg = TensorAggregator({"l1": "cls", "conv": "mean"}) >>> agg.compute_output_dim({"l1": (197, 768), "conv": (512, 14, 14)}) 1280
stable_pretraining.backbone.convmixer module#
- class stable_pretraining.backbone.convmixer.ConvMixer(in_channels=3, num_classes=10, dim=64, depth=6, kernel_size=9, patch_size=7)[source]#
Bases:
ModuleConvMixer model.
A simple and efficient convolutional architecture that operates directly on patches.
- Parameters:
in_channels (int, optional) – Number of input channels. Defaults to 3.
num_classes (int, optional) – Number of output classes. Defaults to 10.
dim (int, optional) – Hidden dimension size. Defaults to 64.
depth (int, optional) – Number of ConvMixer blocks. Defaults to 6.
kernel_size (int, optional) – Kernel size for depthwise convolution. Defaults to 9.
patch_size (int, optional) – Patch embedding size. Defaults to 7.
Note
Introduced in [Trockman and Kolter, 2022].
- forward(xb)[source]#
Forward pass through the ConvMixer model.
- Parameters:
xb (torch.Tensor) – Input tensor of shape (batch_size, in_channels, height, width).
- Returns:
Output logits of shape (batch_size, num_classes).
- Return type:
stable_pretraining.backbone.mae module#
- class stable_pretraining.backbone.mae.MaskedAutoencoderViT(img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=24, num_heads=16, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4.0, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, norm_pix_loss=False)[source]#
Bases:
ModuleMasked Autoencoder with VisionTransformer backbone.
- forward(imgs, mask_ratio=0.75)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- patchify(imgs)[source]#
Convert images to patches.
- Parameters:
imgs – (N, 3, H, W)
- Returns:
(N, L, patch_size**2 *3)
- Return type:
x
- random_masking(x, mask_ratio)[source]#
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
- Parameters:
x – [N, L, D], sequence
mask_ratio – ratio of patches to mask
- Returns:
masked sequence mask: binary mask ids_restore: indices to restore original order
- Return type:
x_masked
- stable_pretraining.backbone.mae.get_1d_sincos_pos_embed_from_grid(embed_dim, pos)[source]#
Get 1D sinusoidal positional embedding from grid.
- Parameters:
embed_dim – output dimension for each position
pos – a list of positions to be encoded: size (M,)
- Returns:
(M, D)
- Return type:
out
- stable_pretraining.backbone.mae.get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False)[source]#
Get 2D sinusoidal positional embedding.
- Parameters:
embed_dim – embedding dimension
grid_size – int of the grid height and width
cls_token – whether to include class token
- Returns:
[grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
- Return type:
pos_embed
- stable_pretraining.backbone.mae.vit_base_patch16(**kwargs)#
- stable_pretraining.backbone.mae.vit_huge_patch14(**kwargs)#
- stable_pretraining.backbone.mae.vit_large_patch16(**kwargs)#
stable_pretraining.backbone.mlp module#
- class stable_pretraining.backbone.mlp.MLP(in_channels: int, hidden_channels: list[int], norm_layer: str = None, activation_layer=<class 'torch.nn.modules.activation.ReLU'>, inplace: bool = None, bias: bool = True, dropout: float = 0.0)[source]#
Bases:
SequentialThis block implements the multi-layer perceptron (MLP) module.
- Parameters:
in_channels (int) – Number of channels of the input
hidden_channels (List[int]) – List of the hidden channel dimensions
norm_layer (Callable[..., torch.nn.Module], optional) – Norm layer that will be stacked on top of the linear layer. If
Nonethis layer won’t be used. Default:Noneactivation_layer (Callable[..., torch.nn.Module], optional) – Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the linear layer. If
Nonethis layer won’t be used. Default:torch.nn.ReLUinplace (bool, optional) – Parameter for the activation layer, which can optionally do the operation in-place. Default is
None, which uses the respective default values of theactivation_layerand Dropout layer.bias (bool) – Whether to use bias in the linear layer. Default
Truedropout (float) – The probability for the dropout layer. Default: 0.0
stable_pretraining.backbone.probe module#
- class stable_pretraining.backbone.probe.AutoLinearClassifier(name, embedding_dim, num_classes, pooling=None, weight_decay=[0], lr_scaling=[1], normalization=['none', 'norm', 'bn'], dropout=[0, 0.5], label_smoothing=[0, 1])[source]#
Bases:
ModuleLinear using either CLS token or mean pooling with configurable normalization layer.
- Parameters:
embedding_dim (int) – Dimensionality of the input embeddings.
num_classes (int) – Number of output classes.
pooling (str) – Pooling strategy, either ‘cls’ or ‘mean’.
norm_layer (callable or None) – Normalization layer class (e.g., torch.nn.LayerNorm, torch.nn.BatchNorm1d), or None for no normalization. Should accept a single argument: normalized_shape or num_features.
- norm#
Instantiated normalization layer, or None.
- Type:
nn.Module or None
- fc#
Linear layer mapping pooled representation to class logits.
- Type:
nn.Linear
- Forward Args:
- x (torch.Tensor): Input tensor of shape (N, T, D) or (N, D).
If 3D, pooling and normalization are applied. If 2D, input is used directly (no pooling or normalization).
- Returns:
Output logits of shape (N, num_classes).
- Return type:
Example
>>> probe = LinearProbe( ... embedding_dim=128, ... num_classes=10, ... pooling="mean", ... norm_layer=torch.nn.LayerNorm, ... ) >>> x = torch.randn(32, 20, 128) >>> logits = probe(x) # shape: (32, 10) >>> x2 = torch.randn(32, 128) >>> logits2 = probe(x2) # shape: (32, 10)
- forward(x, y=None, pl_module=None)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class stable_pretraining.backbone.probe.AutoTuneMLP(in_features: int, out_features: int, hidden_features: List[int] | List[List[int]], name: str, loss_fn: Callable, additional_weight_decay: float | List[float] = [0], lr_scaling: float | List[float] = [1], normalization: str | List[str] = ['none'], dropout: float | List[float] = [0], activation: str | List[str] = ['relu'])[source]#
Bases:
ModuleAutomatically creates multiple MLP variants with different hyperparameter combinations.
This module creates a grid of MLPs with different configurations (dropout, normalization, learning rates, architectures, etc.) to enable parallel hyperparameter tuning.
- Parameters:
in_features – Number of input features
out_features – Number of output features
hidden_features – Architecture specification. Can be: - List[int]: Single architecture, e.g., [256, 128] - List[List[int]]: Multiple architectures, e.g., [[256, 128], [512, 256, 128]] - []: Empty list for linear model (no hidden layers)
name – Base name for this AutoTuneMLP instance
loss_fn – Loss function to compute loss
additional_weight_decay – List of weight decay values to try
lr_scaling – List of learning rate scaling factors to try
normalization – List of normalization types [‘none’, ‘norm’, ‘bn’]
dropout – List of dropout rates to try
activation – List of activation functions [‘relu’, ‘leaky_relu’, ‘tanh’]
Examples
>>> # Single architecture >>> model = AutoTuneMLP(128, 10, [256, 128], "clf", nn.CrossEntropyLoss())
>>> # Multiple architectures >>> model = AutoTuneMLP( ... 128, 10, [[256], [256, 128], [512, 256]], "clf", nn.CrossEntropyLoss() ... )
>>> # Linear model (no hidden layers) >>> model = AutoTuneMLP(128, 10, [], "linear_clf", nn.CrossEntropyLoss())
- forward(x: Tensor, y: Tensor | None = None) Dict[str, Tensor][source]#
Forward pass through all MLP variants.
- Parameters:
x – Input tensor of shape (batch_size, in_features)
y – Optional target tensor for loss computation
- Returns:
Dictionary with predictions and losses for each variant Format: {‘pred/{variant_id}’: tensor, ‘loss/{variant_id}’: tensor}
- get_best_variant(metric_dict: Dict[str, float], lower_is_better: bool = True) str[source]#
Get the best performing variant based on metrics.
- Parameters:
metric_dict – Dictionary mapping variant_id to metric values
lower_is_better – If True, lower metric is better (e.g., loss). If False, higher is better (e.g., accuracy)
- Returns:
ID of the best performing variant
- get_variant(key: str) Module[source]#
Get a specific MLP variant by key.
- Parameters:
key – Variant ID
- Returns:
The MLP module
- Raises:
KeyError – If key doesn’t exist
- class stable_pretraining.backbone.probe.LinearProbe(embedding_dim, num_classes, pooling='cls', norm_layer=None)[source]#
Bases:
ModuleLinear using either CLS token or mean pooling with configurable normalization layer.
- Parameters:
embedding_dim (int) – Dimensionality of the input embeddings.
num_classes (int) – Number of output classes.
pooling (str) – Pooling strategy, either ‘cls’ or ‘mean’.
norm_layer (callable or None) – Normalization layer class (e.g., torch.nn.LayerNorm, torch.nn.BatchNorm1d), or None for no normalization. Should accept a single argument: normalized_shape or num_features.
- norm#
Instantiated normalization layer, or None.
- Type:
nn.Module or None
- fc#
Linear layer mapping pooled representation to class logits.
- Type:
nn.Linear
- Forward Args:
- x (torch.Tensor): Input tensor of shape (N, T, D) or (N, D).
If 3D, pooling and normalization are applied. If 2D, input is used directly (no pooling or normalization).
- Returns:
Output logits of shape (N, num_classes).
- Return type:
Example
>>> probe = LinearProbe( ... embedding_dim=128, ... num_classes=10, ... pooling="mean", ... norm_layer=torch.nn.LayerNorm, ... ) >>> x = torch.randn(32, 20, 128) >>> logits = probe(x) # shape: (32, 10) >>> x2 = torch.randn(32, 128) >>> logits2 = probe(x2) # shape: (32, 10)
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class stable_pretraining.backbone.probe.MultiHeadAttentiveProbe(embedding_dim: int, num_classes: int, num_heads: int = 4)[source]#
Bases:
ModuleA multi-head attentive probe for sequence representations.
This module applies multiple attention heads to a sequence of embeddings, pools the sequence into a fixed-size representation per head, concatenates the results, and projects to a set of output classes.
- Parameters:
- ln#
Layer normalization applied to the input.
- Type:
- attn_vectors#
Learnable attention vectors for each head, shape (num_heads, embedding_dim).
- Type:
torch.nn.Parameter
- fc#
Final linear layer mapping concatenated head outputs to class logits.
- Type:
- Forward Args:
- x (torch.Tensor): Input tensor of shape (N, T, D), where
N = batch size, T = sequence length, D = embedding_dim.
- Returns:
Output logits of shape (N, num_classes).
- Return type:
Example
>>> probe = MultiHeadAttentiveProbe( ... embedding_dim=128, num_classes=10, num_heads=4 ... ) >>> x = torch.randn(32, 20, 128) # batch of 32, sequence length 20 >>> logits = probe(x) # shape: (32, 10)
- forward(x: Tensor)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
stable_pretraining.backbone.resnet9 module#
- class stable_pretraining.backbone.resnet9.MLP(in_channels: int, hidden_channels: list[int], norm_layer: str = None, activation_layer=<class 'torch.nn.modules.activation.ReLU'>, inplace: bool = None, bias: bool = True, dropout: float = 0.0)[source]#
Bases:
SequentialThis block implements the multi-layer perceptron (MLP) module.
- Parameters:
in_channels (int) – Number of channels of the input
hidden_channels (List[int]) – List of the hidden channel dimensions
norm_layer (Callable[..., torch.nn.Module], optional) – Norm layer that will be stacked on top of the linear layer. If
Nonethis layer won’t be used. Default:Noneactivation_layer (Callable[..., torch.nn.Module], optional) – Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the linear layer. If
Nonethis layer won’t be used. Default:torch.nn.ReLUinplace (bool, optional) – Parameter for the activation layer, which can optionally do the operation in-place. Default is
None, which uses the respective default values of theactivation_layerand Dropout layer.bias (bool) – Whether to use bias in the linear layer. Default
Truedropout (float) – The probability for the dropout layer. Default: 0.0
- class stable_pretraining.backbone.resnet9.ResidualBlock(in_channels, out_channels, kernel_size, padding, stride)[source]#
Bases:
ModuleA residual block as defined by He et al.
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class stable_pretraining.backbone.resnet9.Resnet9(num_classes, num_channels, *args, **kwargs)[source]#
Bases:
ModuleA Residual network.
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
stable_pretraining.backbone.utils module#
- class stable_pretraining.backbone.utils.EfficientMaskedTimmViT(vit: Module)[source]#
Bases:
ModuleOptimized Vision Transformer wrapper that efficiently handles NaN patches.
This module is designed to work with timm ViT models and provides: - Per-sample NaN masking (different NaN patterns per image in batch) - Fast path for same masking pattern across batch - Support for class tokens (cls_token), distillation tokens (dist_token), and register tokens - Compatibility with various timm ViT architectures (vit_*, deit_*, beit_*, etc.) - Minimal overhead when no masking is present
Key Optimizations: - Early exit when no NaN patches detected - Simpler indexing for same masking patterns - Cached batch indices for repeated operations - Zero-copy operations where possible
- Parameters:
vit – A timm Vision Transformer model instance
- Raises:
ValueError – If samples have different numbers of NaN patches
ValueError – If all patches are NaN
RuntimeError – If the model structure is incompatible
Example
>>> import timm >>> vit = timm.create_model( ... "vit_base_patch16_224", pretrained=False, reg_tokens=4 ... ) >>> masked_vit = EfficientMaskedTimmViT(vit) >>> >>> # Create input with some NaN patches >>> x = torch.randn(4, 3, 224, 224) >>> output = masked_vit(x)
- Performance:
Same pattern masking: ~0-5% overhead vs different patterns
No masking: <2% overhead vs original model
50% masking: ~1.5x speedup
90% masking: ~2.5-3x speedup
Note
All samples in a batch must have the same NUMBER of NaN patches, but the LOCATION of NaN patches can differ per sample.
Register tokens (DINOv2 style) do NOT receive positional embeddings.
- clear_cache()[source]#
Clear the cached batch indices.
Useful if you want to free memory after processing different batch sizes. The cache will be rebuilt as needed during forward passes.
- forward(x: Tensor) Tensor[source]#
Forward pass through the masked ViT.
This method implements an optimized forward pass with the following features: - Early exit for inputs without NaN patches (fast path) - Optimized indexing for same masking patterns across batch - Per-sample masking support with advanced indexing - Automatic NaN replacement for partial NaN patches - Support for register tokens (DINOv2 style)
- Parameters:
x – Input tensor, either:
images (- Raw) – shape (B, C, H, W)
Pre-patchified (-) – shape (B, N, D) where N is number of patches
- Returns:
Model output (logits if head exists, features otherwise)
- Return type:
- Raises:
ValueError – If samples have different numbers of NaN patches
ValueError – If all patches are NaN
- Performance Notes:
No NaN patches: Uses fast path with <2% overhead
Same pattern: Optimized indexing, ~0-5% overhead vs different patterns
Different patterns: Uses advanced indexing, ~10-35% slower at high masking
- class stable_pretraining.backbone.utils.EvalOnly(backbone: Module)[source]#
Bases:
ModuleWrapper that forces a module to remain in evaluation mode.
- forward(*args, **kwargs)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class stable_pretraining.backbone.utils.FeaturesConcat(agg: callable, names: str | Iterable[str] = None)[source]#
Bases:
ModuleAggregates and concatenates features from a dictionary input, then classifies.
- Parameters:
names (List[str]) – Keys to extract from the input dictionary. if not given then we aggregate everything from dict/list
- forward(inputs: dict | Iterable)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class stable_pretraining.backbone.utils.ReturnEmbedding(backbone: Module, module_names: list[str])[source]#
Bases:
ModuleCache embedding from a module given their names.
Example: stable_pretraining.backbone.utils.ReturnEmbedding(
torchvision.models.swin_v2_s(), stable_pretraining.static.EMBEDDINGS[“swin_v2_s”] )
Args: module_names (list of str): List of module names to hook (e.g., [‘layer1’, ‘encoder.block1’]). add_to_forward_output (bool): If True, enables merging cached outputs into the dict returned by forward.
- forward(*args, **kwargs)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class stable_pretraining.backbone.utils.TeacherStudentWrapper(student: Module, warm_init: bool = True, base_ema_coefficient: float = 0.996, final_ema_coefficient: float = 1)[source]#
Bases:
ModuleBackbone wrapper that implements teacher-student distillation via EMA.
This is a wrapper for backbones that creates a teacher model as an exponential moving average (EMA) of the student model. It should be passed as the backbone to stable_pretraining.Module and accessed via forward_student() and forward_teacher() methods in your custom forward function.
The teacher model is updated by taking a running average of the student’s parameters and buffers. When ema_coefficient == 0.0, the teacher and student are literally the same object, saving memory but forward passes through the teacher will not produce any gradients.
- Usage example:
backbone = ResNet18() wrapped_backbone = TeacherStudentWrapper(backbone) module = ssl.Module(
backbone=wrapped_backbone, projector=projector, forward=forward_with_teacher_student, …
)
- Parameters:
student (torch.nn.Module) – The student model whose parameters will be tracked.
warm_init (bool, optional) – If True, performs an initialization step to match the student’s parameters immediately. Default is True.
base_ema_coefficient (float, optional) – EMA decay factor at the start of training. This value will be updated following a cosine schedule. Should be in [0, 1]. A value of 0.0 means the teacher is fully updated to the student’s parameters on every step, while a value of 1.0 means the teacher remains unchanged. Default is 0.996.
final_ema_coefficient (float, optional) – EMA decay factor at the end of training. Default is 1.
- forward(*args, **kwargs)[source]#
Forward pass through either the student or teacher network.
You can choose which model to run in the default forward. Commonly the teacher is evaluated, so we default to that.
- forward_student(*args, **kwargs)[source]#
Forward pass through the student network. Gradients will flow normally.
- forward_teacher(*args, **kwargs)[source]#
Forward pass through the teacher network.
By default, the teacher network does not require grad. If ema_coefficient == 0, then teacher==student, so we wrap in torch.no_grad() to ensure no gradients flow.
- update_ema_coefficient(epoch: int, total_epochs: int)[source]#
Update the EMA coefficient following a cosine schedule.
- The EMA coefficient is updated following a cosine schedule:
ema_coefficient = final_ema_coefficient - 0.5 * (final_ema_coefficient - base_ema_coefficient) * (1 + cos(epoch / total_epochs * pi))
- update_teacher()[source]#
Perform one EMA update step on the teacher’s parameters.
- The update rule is:
teacher_param = ema_coefficient * teacher_param + (1 - ema_coefficient) * student_param
This is done in a no_grad context to ensure the teacher’s parameters do not accumulate gradients, but the student remains fully trainable.
Everything is updated, including buffers (e.g. batch norm running averages).
- stable_pretraining.backbone.utils.from_huggingface(model_name, pretrained, attn_implementation='sdpa', **kwargs)[source]#
Loads a Hugging Face Transformers base model, optionally with pretrained weights, and returns the backbone model.
This function wraps the Hugging Face transformers library to load a model specified by model_name. It supports loading either pretrained weights or initializing from configuration only. The returned object is the model’s backbone (model.base_model), which is useful for extracting the core architecture without task-specific heads.
- Parameters:
model_name (str) – The Hugging Face model repository identifier or local path. Examples include “bert-base-uncased”, “facebook/opt-1.3b”, or a local directory containing model files.
pretrained (bool) – If True, loads pretrained weights via AutoModel.from_pretrained. If False, initializes the model from configuration only via AutoConfig.from_pretrained and AutoModel.from_config.
attn_implementation (str, optional) – The attention backend to use. Supported values include “sdpa” (default), “eager”, “flash_attention_2”, etc., as supported by the installed version of transformers and your hardware. This is forwarded to the underlying model constructor.
**kwargs – Additional keyword arguments forwarded to AutoModel.from_pretrained or AutoConfig.from_pretrained. Common options include: - revision (str): Model version or branch to use. - cache_dir (str): Directory to cache downloaded models. - trust_remote_code (bool): Allow loading custom code from model repo. - torch_dtype (str or torch.dtype): Data type for model weights. - device_map (str or dict): Device placement for model parameters. - And others supported by Hugging Face Transformers.
- Returns:
The base (backbone) model instance, typically accessible via model.base_model. For some architectures, this may be the model itself.
- Return type:
transformers.PreTrainedModel
- Raises:
ImportError – If the transformers library is not installed.
OSError – If the model or configuration cannot be found or downloaded.
ValueError – If invalid arguments are provided.
Exception – Propagates any other exceptions raised by Hugging Face Transformers.
Notes
The returned base_model may differ depending on the architecture. For some models, base_model is the same as the full model.
The availability of certain attention implementations (e.g., “flash_attention_2”) depends on your hardware, installed libraries, and the version of transformers.
Ensure that your environment meets the requirements for the selected attention backend.
Examples
>>> # Load a pretrained BERT model with default attention >>> model = from_huggingface("bert-base-uncased", pretrained=True) >>> # Initialize a model from config only, specifying a revision and device >>> model = from_huggingface( ... "facebook/opt-1.3b", ... pretrained=False, ... revision="main", ... device_map="auto", ... ) >>> # Load a pretrained model using flash attention (if supported) >>> model = from_huggingface( ... "meta-llama/Llama-2-7b-hf", ... pretrained=True, ... attn_implementation="flash_attention_2", ... )
- stable_pretraining.backbone.utils.from_torchvision(model_name, low_resolution=False, **kwargs)[source]#
Load a backbone model.
If num_classes is provided, the last layer is replaced by a linear layer of output size num_classes. Otherwise, the last layer is replaced by an identity layer.
- Parameters:
model_name (str) – Name of the backbone model. Supported models are: - Any model from torchvision.models - “Resnet9” - “ConvMixer”
low_resolution (bool, optional) – Whether to adapt the resolution of the model (for CIFAR typically). By default False.
**kwargs –
Additional keyword arguments for the model. Special handling: - in_channels (int): Number of input channels. If provided for ResNet models, the first
conv layer will be modified to accept this many channels. Default is 3.
- Returns:
The neural network model.
- Return type:
- stable_pretraining.backbone.utils.get_children_modules(model: Module, parent_name: str, L: int = 1, partial_match: bool = False) List[str][source]#
Extracts unique module names matching a given parent_name and L submodules.
- Parameters:
model – The root nn.Module.
parent_name – The string or path component to match (e.g., ‘blocks’).
L – Number of levels after the parent_name to include in the result.
partial_match – whether to check with == or in
- Returns:
Sorted list of unique qualified module names at depth L after the parent_name.
- stable_pretraining.backbone.utils.get_output_shape(model: Module, *inputs, **kwargs) Any[source]#
Infers the output shapes of a PyTorch nn.Module by forwarding fake inputs on the ‘meta’ device using FakeTensorMode.
Handles arbitrary nested output structures (lists, dicts, tuples, sets, namedtuples, dataclasses), preserving their structure but replacing torch.Tensor objects with their .shape. This function temporarily replaces the model’s parameters and buffers with fake tensors on the ‘meta’ device, converts all tensor inputs and keyword arguments to ‘meta’, and runs the forward pass under FakeTensorMode. After execution, the original parameters and buffers are restored. No real computation or memory allocation occurs.
- Parameters:
model (torch.nn.Module) – The PyTorch module to evaluate. Must be on a real device (e.g., CPU).
*inputs – Positional arguments to pass to the model’s forward method. All torch.Tensor inputs are converted to ‘meta’.
**kwargs – Keyword arguments to pass to the model’s forward method. All torch.Tensor values are converted to ‘meta’.
- Returns:
- The output structure from the model’s forward pass, with all torch.Tensor objects replaced by their .shape.
Non-tensor objects are left unchanged.
- Return type:
Any
Notes
Supports nested output structures: dict, list, tuple, set, namedtuple, and dataclasses.
No real memory is allocated; all tensors are on the ‘meta’ device.
Not thread-safe: concurrent calls may interfere with parameter/buffer swapping.
Requires PyTorch 1.11+ for FakeTensorMode.
If the model contains custom buffers or state, ensure they are handled appropriately.
Raises exceptions if model forward fails or if parameters/buffers cannot be swapped.
Non-tensor outputs are returned unchanged.
Example
shapes = get_output_shape_multi_input(model, input1, input2, key1=kwarg1) # shapes will have the same structure as the model’s output, but with torch.Size in place of tensors.
- stable_pretraining.backbone.utils.register_lr_scale_hook(module, lr_scale, weight_decay=0.0)[source]#
Registers a hook that scales gradients and applies weight decay during backward pass.
- Parameters:
module – PyTorch module/layer
lr_scale – Scaling factor for the learning rate (scales gradients)
weight_decay – L2 penalty coefficient (default: 0.0)
- Returns:
The same module (for chaining)
- Return type:
module
- stable_pretraining.backbone.utils.set_embedding_dim(module, dim, bias=True, expected_input_shape: tuple | list | None = None, expected_output_shape: tuple | list | None = None)[source]#
- stable_pretraining.backbone.utils.vit_hf(size: str = 'tiny', patch_size: int = 16, image_size: int = 224, pretrained: bool = False, use_mask_token: bool = True, **kwargs) Module[source]#
Create a Vision Transformer using HuggingFace transformers.
This provides a clean, well-maintained ViT implementation with native support for: - Masking via bool_masked_pos parameter - Learnable mask token - Easy access to CLS and patch tokens
- Parameters:
size – Model size - “tiny”, “small”, “base”, or “large”
patch_size – Patch size (default: 16)
image_size – Input image size (default: 224)
pretrained – Load pretrained weights from HuggingFace Hub
use_mask_token – Whether to include learnable mask token (needed for iBOT)
**kwargs – Additional ViTConfig parameters
- Returns:
HuggingFace ViTModel
Example
>>> backbone = vit_hf("tiny", use_mask_token=True) >>> x = torch.randn(2, 3, 224, 224) >>> >>> # Without masking >>> output = backbone(x) >>> cls_token = output.last_hidden_state[:, 0, :] >>> patch_tokens = output.last_hidden_state[:, 1:, :] >>> >>> # With masking (for iBOT student) >>> masks = torch.zeros(2, 196, dtype=torch.bool) >>> masks[:, :59] = True # Mask 30% >>> output = backbone(x, bool_masked_pos=masks)