Source code for stable_pretraining.backbone.utils

import copy
import math
from typing import Union, Iterable, List, Optional, Any, Dict

import torch
import torchvision
from loguru import logger as logging
from torch import nn

# Try to import optional dependencies
try:
    from timm.layers.classifier import ClassifierHead

    _TIMM_AVAILABLE = True
except ImportError:
    ClassifierHead = None
    _TIMM_AVAILABLE = False

try:
    from transformers import TimmWrapperModel, ViTConfig, ViTModel

    _TRANSFORMERS_AVAILABLE = True
except ImportError:
    TimmWrapperModel = None
    ViTConfig = None
    ViTModel = None
    _TRANSFORMERS_AVAILABLE = False


[docs] def register_lr_scale_hook(module, lr_scale, weight_decay=0.0): """Registers a hook that scales gradients and applies weight decay during backward pass. Args: module: PyTorch module/layer lr_scale: Scaling factor for the learning rate (scales gradients) weight_decay: L2 penalty coefficient (default: 0.0) Returns: module: The same module (for chaining) """ def make_hook(param): def gradient_scaling_hook(grad): # Add weight decay (L2 regularization) if weight_decay != 0.0: grad = grad + weight_decay * param.data # Scale gradient (equivalent to scaling learning rate) return grad * lr_scale return gradient_scaling_hook for param in module.parameters(): param.register_hook(make_hook(param)) return module
[docs] def vit_hf( size: str = "tiny", patch_size: int = 16, image_size: int = 224, pretrained: bool = False, use_mask_token: bool = True, **kwargs, ) -> nn.Module: """Create a Vision Transformer using HuggingFace transformers. This provides a clean, well-maintained ViT implementation with native support for: - Masking via bool_masked_pos parameter - Learnable mask token - Easy access to CLS and patch tokens Args: size: Model size - "tiny", "small", "base", or "large" patch_size: Patch size (default: 16) image_size: Input image size (default: 224) pretrained: Load pretrained weights from HuggingFace Hub use_mask_token: Whether to include learnable mask token (needed for iBOT) **kwargs: Additional ViTConfig parameters Returns: HuggingFace ViTModel Example: >>> backbone = vit_hf("tiny", use_mask_token=True) >>> x = torch.randn(2, 3, 224, 224) >>> >>> # Without masking >>> output = backbone(x) >>> cls_token = output.last_hidden_state[:, 0, :] >>> patch_tokens = output.last_hidden_state[:, 1:, :] >>> >>> # With masking (for iBOT student) >>> masks = torch.zeros(2, 196, dtype=torch.bool) >>> masks[:, :59] = True # Mask 30% >>> output = backbone(x, bool_masked_pos=masks) """ if not _TRANSFORMERS_AVAILABLE: raise ImportError( "transformers library is required for vit_hf. " "Install with: pip install transformers" ) # ViT size configurations (matching timm/DINOv3) size_configs = { "tiny": {"hidden_size": 192, "num_hidden_layers": 12, "num_attention_heads": 3}, "small": { "hidden_size": 384, "num_hidden_layers": 12, "num_attention_heads": 6, }, "base": { "hidden_size": 768, "num_hidden_layers": 12, "num_attention_heads": 12, }, "large": { "hidden_size": 1024, "num_hidden_layers": 24, "num_attention_heads": 16, }, "huge": { "hidden_size": 1280, "num_hidden_layers": 32, "num_attention_heads": 16, }, } if size not in size_configs: raise ValueError( f"Invalid size '{size}'. Choose from {list(size_configs.keys())}" ) config_params = size_configs[size] config_params["intermediate_size"] = config_params["hidden_size"] * 4 config_params["image_size"] = image_size config_params["patch_size"] = patch_size config_params.update(kwargs) if pretrained: # Try to load pretrained model from HF Hub model_name = f"google/vit-{size}-patch{patch_size}-{image_size}" logging.info(f"Loading pretrained ViT from {model_name}") model = ViTModel.from_pretrained( model_name, add_pooling_layer=False, use_mask_token=use_mask_token ) else: config = ViTConfig(**config_params) model = ViTModel(config, add_pooling_layer=False, use_mask_token=use_mask_token) logging.info(f"Created ViT-{size} from scratch with config: {config_params}") # IMPORTANT: Set model to always interpolate position encodings for dynamic input sizes # This allows processing images of different sizes (e.g., 224x224 global + 96x96 local views) # Must be set as instance attribute, not in config model.config.interpolate_pos_encoding = True return model
[docs] class EvalOnly(nn.Module): """Wrapper that forces a module to remain in evaluation mode.""" def __init__(self, backbone: nn.Module): super().__init__() self.backbone = backbone self.backbone.train(False) self.requires_grad_(False) assert not self.backbone.training
[docs] def train(self, mode): return self
[docs] def forward(self, *args, **kwargs): if self.backbone.training: raise RuntimeError("EvalOnly module is in training mode") return self.backbone.forward(*args, **kwargs)
[docs] class FeaturesConcat(nn.Module): """Aggregates and concatenates features from a dictionary input, then classifies. Args: names (List[str]): Keys to extract from the input dictionary. if not given then we aggregate everything from dict/list """ def __init__(self, agg: callable, names: Union[str, Iterable[str]] = None): super().__init__() if type(names) is str: names = [names] self.names = names self.agg = agg
[docs] def forward(self, inputs: Union[dict, Iterable]): if type(inputs) is dict: assert self.names is not None tensors = [inputs[n] for n in self.names] else: tensors = inputs reps = [] for t in tensors: reps.append(self.agg(t)) concat = torch.cat(reps, dim=1) return concat
[docs] @staticmethod def get_output_shape( agg: callable, shapes: Union[list[str], Dict[str, Iterable[int]]] ): """Given a list of shapes (tuples), returns the expected concatenated shape. Assumes all shapes have the same batch size (shapes[0][0]). Args: shapes (List[Tuple[int]]): List of shapes after aggregation. agg (callable): How to aggregate, can be None. Returns: Tuple[int]: The concatenated shape. """ if not shapes: raise ValueError("Shape list is empty.") if type(shapes) is dict: shapes = list(shapes.values()) x = [torch.empty(shape, device="meta") for shape in shapes] obj = FeaturesConcat(agg) out = obj(x) return out.shape
[docs] class ReturnEmbedding(nn.Module): """Cache embedding from a module given their names. Example: stable_pretraining.backbone.utils.ReturnEmbedding( torchvision.models.swin_v2_s(), stable_pretraining.static.EMBEDDINGS["swin_v2_s"] ) Args: module_names (list of str): List of module names to hook (e.g., ['layer1', 'encoder.block1']). add_to_forward_output (bool): If True, enables merging cached outputs into the dict returned by forward. """ def __init__(self, backbone: nn.Module, module_names: list[str]): super().__init__() logging.info("Init of ReturnEmbedding module") logging.info(f"\t - {len(module_names)} module names") self.backbone = backbone self.module_names = module_names self.hooks = [] self.embedding_cache = {} for name in self.module_names: module = self._get_module_by_name(backbone, name) if module is None: raise ValueError(f"Module '{name}' not found in backbone.") hook = module.register_forward_hook(self._make_hook(name, backbone)) self.hooks.append(hook)
[docs] def forward(self, *args, **kwargs): return self.backbone(*args, **kwargs), self.embedding_cache
def _make_hook(self, name, pl_module): def hook(module, input, output): self.embedding_cache[name] = output return hook def _get_module_by_name(self, pl_module, name): module = pl_module for attr in name.split("."): if not hasattr(module, attr): return None module = getattr(module, attr) return module
[docs] class TeacherStudentWrapper(nn.Module): """Backbone wrapper that implements teacher-student distillation via EMA. This is a wrapper for backbones that creates a teacher model as an exponential moving average (EMA) of the student model. It should be passed as the backbone to stable_pretraining.Module and accessed via forward_student() and forward_teacher() methods in your custom forward function. The teacher model is updated by taking a running average of the student's parameters and buffers. When `ema_coefficient == 0.0`, the teacher and student are literally the same object, saving memory but forward passes through the teacher will not produce any gradients. Usage example: backbone = ResNet18() wrapped_backbone = TeacherStudentWrapper(backbone) module = ssl.Module( backbone=wrapped_backbone, projector=projector, forward=forward_with_teacher_student, ... ) Args: student (torch.nn.Module): The student model whose parameters will be tracked. warm_init (bool, optional): If True, performs an initialization step to match the student's parameters immediately. Default is True. base_ema_coefficient (float, optional): EMA decay factor at the start of training. This value will be updated following a cosine schedule. Should be in [0, 1]. A value of 0.0 means the teacher is fully updated to the student's parameters on every step, while a value of 1.0 means the teacher remains unchanged. Default is 0.996. final_ema_coefficient (float, optional): EMA decay factor at the end of training. Default is 1. """ def __init__( self, student: nn.Module, warm_init: bool = True, base_ema_coefficient: float = 0.996, final_ema_coefficient: float = 1, ): if not (0.0 <= base_ema_coefficient <= 1.0) or not ( 0.0 <= final_ema_coefficient <= 1.0 ): error_msg = ( f"ema_coefficient must be in [0, 1]. Found: " f"base_ema_coefficient={base_ema_coefficient}, " f"final_ema_coefficient={final_ema_coefficient}." ) logging.error(error_msg) raise ValueError(error_msg) super().__init__() self.student = student # Register EMA coefficients as buffers so they persist through checkpointing self.register_buffer("base_ema_coefficient", torch.tensor(base_ema_coefficient)) self.register_buffer( "final_ema_coefficient", torch.tensor(final_ema_coefficient) ) if self.base_ema_coefficient == 0.0 and self.final_ema_coefficient == 0.0: # No need to create a teacher network if the EMA coefficient is 0.0. self.teacher = student # Even when teacher == student, register the buffer for consistency self.register_buffer("ema_coefficient", self.base_ema_coefficient.clone()) else: # Create a teacher network with the same architecture as the student. if isinstance(student, ReturnEmbedding): self.teacher = ReturnEmbedding( copy.deepcopy(student.backbone), student.module_names ) else: self.teacher = copy.deepcopy(student) self.teacher.requires_grad_(False) # Teacher should not require gradients. if warm_init: # Initialization step to match the student's parameters. # Temporarily set ema_coefficient to 0 for warm init self.register_buffer("ema_coefficient", torch.zeros(())) self.update_teacher() # Now set to base value after warm init self.ema_coefficient.copy_(self.base_ema_coefficient) else: self.register_buffer( "ema_coefficient", self.base_ema_coefficient.clone() )
[docs] @torch.no_grad def update_teacher(self): """Perform one EMA update step on the teacher’s parameters. The update rule is: teacher_param = ema_coefficient * teacher_param + (1 - ema_coefficient) * student_param This is done in a `no_grad` context to ensure the teacher’s parameters do not accumulate gradients, but the student remains fully trainable. Everything is updated, including buffers (e.g. batch norm running averages). """ if not self.training: return # We don't update in eval elif self.ema_coefficient.item() == 0.0: return # Nothing to update when the teacher is the student. elif self.ema_coefficient.item() == 1.0: return # No need to update when the teacher is fixed. for teacher_group, student_group in [ (self.teacher.parameters(), self.student.parameters()), (self.teacher.buffers(), self.student.buffers()), ]: for t, s in zip(teacher_group, student_group): ty = t.dtype t.mul_(self.ema_coefficient.to(dtype=ty)) t.add_((1.0 - self.ema_coefficient).to(dtype=ty) * s)
[docs] @torch.no_grad def update_ema_coefficient(self, epoch: int, total_epochs: int): """Update the EMA coefficient following a cosine schedule. The EMA coefficient is updated following a cosine schedule: ema_coefficient = final_ema_coefficient - 0.5 * (final_ema_coefficient - base_ema_coefficient) * (1 + cos(epoch / total_epochs * pi)) Args: epoch (int): Current epoch in the training loop. total_epochs (int): Total number of epochs in the training loop. """ new_value = self.final_ema_coefficient - 0.5 * ( self.final_ema_coefficient - self.base_ema_coefficient ) * (1 + math.cos(epoch / total_epochs * math.pi)) # Update the buffer in-place to maintain persistence self.ema_coefficient.copy_(new_value)
[docs] def forward_student(self, *args, **kwargs): """Forward pass through the student network. Gradients will flow normally.""" return self.student(*args, **kwargs)
[docs] def forward_teacher(self, *args, **kwargs): """Forward pass through the teacher network. By default, the teacher network does not require grad. If ema_coefficient == 0, then teacher==student, so we wrap in torch.no_grad() to ensure no gradients flow. """ with torch.no_grad(): return self.teacher(*args, **kwargs)
[docs] def forward(self, *args, **kwargs): """Forward pass through either the student or teacher network. You can choose which model to run in the default forward. Commonly the teacher is evaluated, so we default to that. """ return self.forward_teacher(*args, **kwargs)
[docs] def from_torchvision(model_name, low_resolution=False, **kwargs): """Load a backbone model. If num_classes is provided, the last layer is replaced by a linear layer of output size num_classes. Otherwise, the last layer is replaced by an identity layer. Args: model_name (str): Name of the backbone model. Supported models are: - Any model from torchvision.models - "Resnet9" - "ConvMixer" low_resolution (bool, optional): Whether to adapt the resolution of the model (for CIFAR typically). By default False. **kwargs: Additional keyword arguments for the model. Special handling: - in_channels (int): Number of input channels. If provided for ResNet models, the first conv layer will be modified to accept this many channels. Default is 3. Returns: torch.nn.Module: The neural network model. """ # Extract in_channels before passing to torchvision (which doesn't accept it) in_channels = kwargs.pop("in_channels", 3) try: model = torchvision.models.__dict__[model_name](**kwargs) except KeyError: raise ValueError(f"Unknown model: {model_name}.") # Modify conv1 for custom number of input channels and/or low resolution if "resnet" in model_name and (in_channels != 3 or low_resolution): if low_resolution: # Low resolution: smaller kernel, stride=1, no maxpool (for CIFAR) model.conv1 = nn.Conv2d( in_channels, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, ) model.maxpool = nn.Identity() else: # Full resolution: keep original kernel/stride, just change in_channels model.conv1 = nn.Conv2d( in_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False, ) elif low_resolution and "resnet" not in model_name: logging.warning(f"Cannot adapt resolution for model: {model_name}.") # Handle num_classes parameter as documented num_classes = kwargs.get("num_classes", None) if num_classes is not None: # Replace the last layer with a linear layer of the specified size if hasattr(model, "fc"): in_features = model.fc.in_features model.fc = nn.Linear(in_features, num_classes) elif hasattr(model, "classifier"): if isinstance(model.classifier, (nn.ModuleList, nn.Sequential)): in_features = model.classifier[-1].in_features model.classifier[-1] = nn.Linear(in_features, num_classes) else: in_features = model.classifier.in_features model.classifier = nn.Linear(in_features, num_classes) else: # Replace the last layer with an identity layer for feature extraction if hasattr(model, "fc"): model.fc = nn.Identity() elif hasattr(model, "classifier"): if isinstance(model.classifier, (nn.ModuleList, nn.Sequential)): model.classifier[-1] = nn.Identity() else: model.classifier = nn.Identity() return model
[docs] def from_huggingface(model_name, pretrained, attn_implementation="sdpa", **kwargs): """Loads a Hugging Face Transformers base model, optionally with pretrained weights, and returns the backbone model. This function wraps the Hugging Face `transformers` library to load a model specified by `model_name`. It supports loading either pretrained weights or initializing from configuration only. The returned object is the model's backbone (`model.base_model`), which is useful for extracting the core architecture without task-specific heads. Args: model_name (str): The Hugging Face model repository identifier or local path. Examples include "bert-base-uncased", "facebook/opt-1.3b", or a local directory containing model files. pretrained (bool): If True, loads pretrained weights via `AutoModel.from_pretrained`. If False, initializes the model from configuration only via `AutoConfig.from_pretrained` and `AutoModel.from_config`. attn_implementation (str, optional): The attention backend to use. Supported values include "sdpa" (default), "eager", "flash_attention_2", etc., as supported by the installed version of `transformers` and your hardware. This is forwarded to the underlying model constructor. **kwargs: Additional keyword arguments forwarded to `AutoModel.from_pretrained` or `AutoConfig.from_pretrained`. Common options include: - `revision` (str): Model version or branch to use. - `cache_dir` (str): Directory to cache downloaded models. - `trust_remote_code` (bool): Allow loading custom code from model repo. - `torch_dtype` (str or torch.dtype): Data type for model weights. - `device_map` (str or dict): Device placement for model parameters. - And others supported by Hugging Face Transformers. Returns: transformers.PreTrainedModel: The base (backbone) model instance, typically accessible via `model.base_model`. For some architectures, this may be the model itself. Raises: ImportError: If the `transformers` library is not installed. OSError: If the model or configuration cannot be found or downloaded. ValueError: If invalid arguments are provided. Exception: Propagates any other exceptions raised by Hugging Face Transformers. Notes: - The returned `base_model` may differ depending on the architecture. For some models, `base_model` is the same as the full model. - The availability of certain attention implementations (e.g., "flash_attention_2") depends on your hardware, installed libraries, and the version of `transformers`. - Ensure that your environment meets the requirements for the selected attention backend. Examples: >>> # Load a pretrained BERT model with default attention >>> model = from_huggingface("bert-base-uncased", pretrained=True) >>> # Initialize a model from config only, specifying a revision and device >>> model = from_huggingface( ... "facebook/opt-1.3b", ... pretrained=False, ... revision="main", ... device_map="auto", ... ) >>> # Load a pretrained model using flash attention (if supported) >>> model = from_huggingface( ... "meta-llama/Llama-2-7b-hf", ... pretrained=True, ... attn_implementation="flash_attention_2", ... ) """ from transformers import AutoModel, AutoConfig if pretrained: model = AutoModel.from_pretrained( model_name, attn_implementation=attn_implementation, **kwargs ) else: config = AutoConfig.from_pretrained(model_name, **kwargs) model = AutoModel.from_config( config, attn_implementation=attn_implementation, ) return model.base_model
[docs] def from_timm(model_name, low_resolution=False, **kwargs): import timm model = timm.create_model(model_name, **kwargs) if low_resolution: # reduce resolution, for instance for CIFAR if "resnet" in model_name: in_channels = kwargs.get("in_channels", 3) model.conv1 = nn.Conv2d( in_channels, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, ) model.maxpool = nn.Identity() else: logging.warning(f"Cannot adapt resolution for model: {model_name}.") return model
def _map_shapes(obj: Any) -> Any: """Recursively maps a nested structure, replacing torch.Tensor objects with their .shape. We preserve the original structure for lists, tuples, dicts, sets, namedtuples, and dataclasses. Non-tensor objects are left unchanged. """ import dataclasses if isinstance(obj, torch.Tensor): return obj.shape elif isinstance(obj, dict): return {k: _map_shapes(v) for k, v in obj.items()} elif isinstance(obj, list): return [_map_shapes(v) for v in obj] elif isinstance(obj, tuple) and hasattr(obj, "_fields"): # namedtuple return type(obj)(*(_map_shapes(v) for v in obj)) elif isinstance(obj, tuple): return tuple(_map_shapes(v) for v in obj) elif isinstance(obj, set): return {_map_shapes(v) for v in obj} elif dataclasses.is_dataclass(obj) and not isinstance(obj, type): return dataclasses.replace( obj, **{ f.name: _map_shapes(getattr(obj, f.name)) for f in dataclasses.fields(obj) }, ) else: return obj
[docs] def get_output_shape(model: torch.nn.Module, *inputs, **kwargs) -> Any: """Infers the output shapes of a PyTorch nn.Module by forwarding fake inputs on the 'meta' device using FakeTensorMode. Handles arbitrary nested output structures (lists, dicts, tuples, sets, namedtuples, dataclasses), preserving their structure but replacing torch.Tensor objects with their .shape. This function temporarily replaces the model's parameters and buffers with fake tensors on the 'meta' device, converts all tensor inputs and keyword arguments to 'meta', and runs the forward pass under FakeTensorMode. After execution, the original parameters and buffers are restored. No real computation or memory allocation occurs. Args: model (torch.nn.Module): The PyTorch module to evaluate. Must be on a real device (e.g., CPU). *inputs: Positional arguments to pass to the model's forward method. All torch.Tensor inputs are converted to 'meta'. **kwargs: Keyword arguments to pass to the model's forward method. All torch.Tensor values are converted to 'meta'. Returns: Any: The output structure from the model's forward pass, with all torch.Tensor objects replaced by their .shape. Non-tensor objects are left unchanged. Notes: - Supports nested output structures: dict, list, tuple, set, namedtuple, and dataclasses. - No real memory is allocated; all tensors are on the 'meta' device. - Not thread-safe: concurrent calls may interfere with parameter/buffer swapping. - Requires PyTorch 1.11+ for FakeTensorMode. - If the model contains custom buffers or state, ensure they are handled appropriately. - Raises exceptions if model forward fails or if parameters/buffers cannot be swapped. - Non-tensor outputs are returned unchanged. Example: shapes = get_output_shape_multi_input(model, input1, input2, key1=kwarg1) # shapes will have the same structure as the model's output, but with torch.Size in place of tensors. """ from torch.func import functional_call import dataclasses # Try to use FakeTensorConverter if available (PyTorch 2.x+) try: from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensorConverter fake_mode = FakeTensorMode() converter = FakeTensorConverter() def to_fake(t): return converter.from_real_tensor(fake_mode, t) except ImportError: # Fallback: just use .to('meta') inside FakeTensorMode from torch._subclasses.fake_tensor import FakeTensorMode fake_mode = FakeTensorMode() def to_fake(t): return t.to("meta") # Prepare fake params and buffers params_and_buffers = dict(model.named_parameters()) params_and_buffers.update(model.named_buffers()) fake_params_and_buffers = {k: to_fake(v) for k, v in params_and_buffers.items()} # Recursively convert all tensor inputs/kwargs to fake/meta def convert_inputs(obj): if isinstance(obj, torch.Tensor): return to_fake(obj) elif isinstance(obj, dict): return {k: convert_inputs(v) for k, v in obj.items()} elif isinstance(obj, list): return [convert_inputs(v) for v in obj] elif isinstance(obj, tuple) and hasattr(obj, "_fields"): # namedtuple return type(obj)(*(convert_inputs(v) for v in obj)) elif isinstance(obj, tuple): return tuple(convert_inputs(v) for v in obj) elif isinstance(obj, set): return {convert_inputs(v) for v in obj} elif dataclasses.is_dataclass(obj) and not isinstance(obj, type): return dataclasses.replace( obj, **{ f.name: convert_inputs(getattr(obj, f.name)) for f in dataclasses.fields(obj) }, ) else: return obj fake_inputs = [convert_inputs(inp) for inp in inputs] fake_kwargs = {k: convert_inputs(v) for k, v in kwargs.items()} with fake_mode: output = functional_call( model, fake_params_and_buffers, tuple(fake_inputs), fake_kwargs ) return _map_shapes(output)
[docs] def set_embedding_dim( module, dim, bias=True, expected_input_shape: Optional[Union[tuple, list]] = None, expected_output_shape: Optional[Union[tuple, list]] = None, ): if isinstance(module, TimmWrapperModel): module = module.timm_model def embedder(in_features): return nn.Sequential( nn.Flatten(), nn.Linear(in_features, out_features=dim, bias=bias) ) # For models like ResNet. if hasattr(module, "fc"): in_features = module.fc.in_features module.fc = embedder(in_features) # For modules like VGG or AlexNet. elif hasattr(module, "classifier"): if isinstance(module.classifier, nn.ModuleList) or isinstance( module.classifier, nn.Sequential ): in_features = module.classifier[-1].in_features module.classifier[-1] = embedder(in_features) else: in_features = module.classifier.in_features module.classifier = embedder(in_features) # For modules like ViT. elif hasattr(module, "heads"): in_features = module.heads.head.in_features module.heads.head = embedder(in_features) # For modules like Swin Transformer. elif hasattr(module, "head") and ( ClassifierHead is None or not isinstance(module.head, ClassifierHead) ): in_features = module.head.in_features module.head = embedder(in_features) else: logging.warning( f"Unknown module structure for : '{module}'.\n\n" "We will use the default's output and attach a " "linear module on top." ) if expected_input_shape is None: logging.error("Can't do that without `expected_input_shape`") raise ValueError("Can't do that without `expected_input_shape`") test_input = torch.empty(expected_input_shape, device="meta") out_shape = module.to("meta")(test_input) in_features = out_shape.flatten(1).size(1) embedder = nn.Sequential( nn.Flatten(), nn.Linear(in_features, out_features=dim, bias=bias) ) return nn.Sequential(module, embedder) if expected_input_shape is None: logging.warning( "No `expected_input_shape` provided, can't verify" "the behavior of `set_emebdding_dim`" ) else: assert expected_output_shape is not None x = torch.empty(expected_input_shape, device="meta") # Save original device before moving to meta original_device = next(module.parameters()).device out = module.to("meta")(x) if isinstance(out, tuple): assert out[0].shape == expected_output_shape elif hasattr(out, "logits"): assert out["logits"].shape == expected_output_shape else: assert out.shape == expected_output_shape # Move module back to original device # Use to_empty() for meta tensors which have no data module = module.to_empty(device=original_device) return module
[docs] def get_children_modules( model: nn.Module, parent_name: str, L: int = 1, partial_match: bool = False ) -> List[str]: """Extracts unique module names matching a given parent_name and L submodules. Args: model: The root nn.Module. parent_name: The string or path component to match (e.g., 'blocks'). L: Number of levels after the parent_name to include in the result. partial_match: whether to check with == or in Returns: Sorted list of unique qualified module names at depth L after the parent_name. """ result: List[str] = [] for name, _ in model.named_modules(): parts = name.split(".") matches = [ i for i, p in enumerate(parts) if (parent_name in p if partial_match else parent_name == p) ] if not matches: continue for idx in matches: target_idx = idx + L if target_idx < len(parts): truncated = ".".join(parts[: target_idx + 1]) if truncated in result: continue # Ensure this is a valid submodule try: model.get_submodule(truncated) result.append(truncated) except AttributeError: continue elif L == 0: truncated = ".".join(parts[: idx + 1]) try: model.get_submodule(truncated) result.append(truncated) except AttributeError: continue return result
[docs] class EfficientMaskedTimmViT(nn.Module): """Optimized Vision Transformer wrapper that efficiently handles NaN patches. This module is designed to work with timm ViT models and provides: - Per-sample NaN masking (different NaN patterns per image in batch) - Fast path for same masking pattern across batch - Support for class tokens (cls_token), distillation tokens (dist_token), and register tokens - Compatibility with various timm ViT architectures (vit_*, deit_*, beit_*, etc.) - Minimal overhead when no masking is present Key Optimizations: - Early exit when no NaN patches detected - Simpler indexing for same masking patterns - Cached batch indices for repeated operations - Zero-copy operations where possible Args: vit: A timm Vision Transformer model instance Raises: ValueError: If samples have different numbers of NaN patches ValueError: If all patches are NaN RuntimeError: If the model structure is incompatible Example: >>> import timm >>> vit = timm.create_model( ... "vit_base_patch16_224", pretrained=False, reg_tokens=4 ... ) >>> masked_vit = EfficientMaskedTimmViT(vit) >>> >>> # Create input with some NaN patches >>> x = torch.randn(4, 3, 224, 224) >>> output = masked_vit(x) Performance: - Same pattern masking: ~0-5% overhead vs different patterns - No masking: <2% overhead vs original model - 50% masking: ~1.5x speedup - 90% masking: ~2.5-3x speedup Note: All samples in a batch must have the same NUMBER of NaN patches, but the LOCATION of NaN patches can differ per sample. Register tokens (DINOv2 style) do NOT receive positional embeddings. """ def __init__(self, vit: nn.Module): super().__init__() self.vit = vit # Cache for batch indices to avoid repeated allocation self._batch_indices_cache = {} # Validate model has required components if not hasattr(vit, "patch_embed"): raise RuntimeError( "Model must have 'patch_embed' attribute. " "This wrapper only supports patch-based ViT models." ) if not hasattr(vit, "blocks"): raise RuntimeError( "Model must have 'blocks' attribute containing transformer blocks." ) def nan_gradient_hook(grad): """Replace NaN gradients with zeros.""" if torch.isnan(grad).any(): return torch.nan_to_num(grad) return grad # Register hook for all parameters for name, param in self.vit.patch_embed.named_parameters(): if param.requires_grad: param.register_hook(nan_gradient_hook) logging.debug(f"Registered NaN hook for: {name}") def _get_num_extra_tokens(self) -> int: """Determine the number of extra tokens (cls, dist, register) the model uses. Returns: int: Number of extra tokens (cls + dist + register) Note: This counts ALL extra tokens that occupy sequence positions. Register tokens don't receive positional embeddings but do occupy positions. """ num_extra = 0 # CLS token if hasattr(self.vit, "cls_token") and self.vit.cls_token is not None: num_extra += 1 # Distillation token (DeiT) if hasattr(self.vit, "dist_token") and self.vit.dist_token is not None: num_extra += 1 # Register tokens (DINOv2 style) if hasattr(self.vit, "reg_token") and self.vit.reg_token is not None: num_extra += self.vit.reg_token.shape[1] elif hasattr(self.vit, "num_reg_tokens"): num_extra += self.vit.num_reg_tokens return num_extra def _get_num_pos_tokens(self) -> int: """Get the number of tokens that RECEIVE positional embeddings. Returns: int: Number of tokens with positional embeddings Note: With timm's dynamic_img_size=True, register tokens ARE included in pos_embed. This method returns CLS + DIST (not register) for non-dynamic models, but we need to check pos_embed.shape to know the actual structure. """ num_pos = 0 # CLS token gets positional embedding if hasattr(self.vit, "cls_token") and self.vit.cls_token is not None: num_pos += 1 # Distillation token gets positional embedding if hasattr(self.vit, "dist_token") and self.vit.dist_token is not None: num_pos += 1 # Note: Register tokens may or may not be in pos_embed depending on timm config # This is checked dynamically in _interpolate_pos_embed return num_pos def _add_extra_tokens(self, x: torch.Tensor) -> torch.Tensor: """Add cls_token, dist_token, and/or register tokens to the sequence. Args: x: Input tensor of shape (B, N, D) containing patch embeddings Returns: torch.Tensor: Tensor with extra tokens prepended Note: Token order: [cls_token, dist_token (if present), register_tokens (if present), patches] This matches the timm convention for ViTs with register tokens. """ B = x.shape[0] # Add cls_token if present if hasattr(self.vit, "cls_token") and self.vit.cls_token is not None: cls_tokens = self.vit.cls_token.expand(B, -1, -1) x = torch.cat([cls_tokens, x], dim=1) # Add dist_token if present (for DeiT models) if hasattr(self.vit, "dist_token") and self.vit.dist_token is not None: dist_tokens = self.vit.dist_token.expand(B, -1, -1) if hasattr(self.vit, "cls_token") and self.vit.cls_token is not None: x = torch.cat([x[:, :1, :], dist_tokens, x[:, 1:, :]], dim=1) else: x = torch.cat([dist_tokens, x], dim=1) # Add register tokens if present (DINOv2 style) if hasattr(self.vit, "reg_token") and self.vit.reg_token is not None: reg_tokens = self.vit.reg_token.expand(B, -1, -1) # Register tokens come after cls/dist but before patches num_prefix = 0 if hasattr(self.vit, "cls_token") and self.vit.cls_token is not None: num_prefix += 1 if hasattr(self.vit, "dist_token") and self.vit.dist_token is not None: num_prefix += 1 if num_prefix > 0: x = torch.cat( [x[:, :num_prefix, :], reg_tokens, x[:, num_prefix:, :]], dim=1 ) else: x = torch.cat([reg_tokens, x], dim=1) return x def _get_batch_indices( self, B: int, num_keep: int, device: torch.device ) -> torch.Tensor: """Get or create cached batch indices for gathering operations. Args: B: Batch size num_keep: Number of patches to keep device: Device for the tensor Returns: torch.Tensor: Batch indices of shape (B, num_keep) for advanced indexing Note: Results are cached to avoid repeated allocations for common batch sizes. """ key = (B, num_keep, device) if key not in self._batch_indices_cache: batch_idx = torch.arange(B, device=device).unsqueeze(1).expand(-1, num_keep) self._batch_indices_cache[key] = batch_idx return self._batch_indices_cache[key] def _subsample_pos_embed_same_pattern( self, keep_idx: torch.Tensor, B: int, N: int ) -> torch.Tensor: """Subsample positional embeddings when all samples have the same mask pattern.""" pos_embed = self.vit.pos_embed num_pos_tokens = self._get_num_pos_tokens() # Check if model has register tokens num_register_tokens = 0 if hasattr(self.vit, "reg_token") and self.vit.reg_token is not None: num_register_tokens = self.vit.reg_token.shape[1] elif hasattr(self.vit, "num_reg_tokens"): num_register_tokens = self.vit.num_reg_tokens # Interpolate if needed for dynamic image sizes pos_embed = self._interpolate_pos_embed(pos_embed, N) # Determine positional embedding structure # With dynamic_img_size=True, pos_embed may include register tokens # Check both: with and without register tokens if pos_embed.shape[1] == N + num_pos_tokens + num_register_tokens: # pos_embed includes register tokens: [CLS, REG, PATCHES] extra_tokens_pos = pos_embed[:, : num_pos_tokens + num_register_tokens, :] patch_pos_embed = pos_embed[:, num_pos_tokens + num_register_tokens :, :] # Subsample patch positions patch_pos_embed = patch_pos_embed[:, keep_idx, :] pos_embed = torch.cat( [ extra_tokens_pos.expand(B, -1, -1), patch_pos_embed.expand(B, -1, -1), ], dim=1, ) elif pos_embed.shape[1] == N + num_pos_tokens: # pos_embed doesn't include register tokens: [CLS, PATCHES] extra_tokens_pos = pos_embed[:, :num_pos_tokens, :] patch_pos_embed = pos_embed[:, num_pos_tokens:, :] # Subsample patch positions patch_pos_embed = patch_pos_embed[:, keep_idx, :] if num_pos_tokens > 0: pos_embed = torch.cat( [ extra_tokens_pos.expand(B, -1, -1), patch_pos_embed.expand(B, -1, -1), ], dim=1, ) else: pos_embed = patch_pos_embed.expand(B, -1, -1) elif pos_embed.shape[1] == N: # No extra tokens at all patch_pos_embed = pos_embed[:, keep_idx, :] pos_embed = patch_pos_embed.expand(B, -1, -1) else: raise RuntimeError( f"Unexpected pos_embed shape after interpolation: {pos_embed.shape}. " f"Expected shape[1] to be {N + num_pos_tokens + num_register_tokens}, " f"{N + num_pos_tokens}, or {N}" ) return pos_embed def _subsample_pos_embed_different_patterns( self, keep_indices: torch.Tensor, B: int, N: int, num_keep: int ) -> torch.Tensor: """Subsample positional embeddings when samples have different mask patterns.""" pos_embed = self.vit.pos_embed num_pos_tokens = self._get_num_pos_tokens() # Check if model has register tokens num_register_tokens = 0 if hasattr(self.vit, "reg_token") and self.vit.reg_token is not None: num_register_tokens = self.vit.reg_token.shape[1] elif hasattr(self.vit, "num_reg_tokens"): num_register_tokens = self.vit.num_reg_tokens # Interpolate if needed for dynamic image sizes pos_embed = self._interpolate_pos_embed(pos_embed, N) # Determine positional embedding structure # With dynamic_img_size=True, pos_embed may include register tokens if pos_embed.shape[1] == N + num_pos_tokens + num_register_tokens: # pos_embed includes register tokens: [CLS, REG, PATCHES] extra_tokens_pos = pos_embed[:, : num_pos_tokens + num_register_tokens, :] patch_pos_embed = pos_embed[:, num_pos_tokens + num_register_tokens :, :] elif pos_embed.shape[1] == N + num_pos_tokens: # pos_embed doesn't include register tokens: [CLS, PATCHES] extra_tokens_pos = pos_embed[:, :num_pos_tokens, :] patch_pos_embed = pos_embed[:, num_pos_tokens:, :] elif pos_embed.shape[1] == N: # No extra tokens extra_tokens_pos = None patch_pos_embed = pos_embed else: raise RuntimeError( f"Unexpected pos_embed shape after interpolation: {pos_embed.shape}. " f"Expected shape[1] to be {N + num_pos_tokens + num_register_tokens}, " f"{N + num_pos_tokens}, or {N}" ) # Subsample patch positional embeddings per sample patch_pos_embed = patch_pos_embed.expand(B, -1, -1) batch_idx = self._get_batch_indices(B, num_keep, keep_indices.device) patch_pos_embed = patch_pos_embed[batch_idx, keep_indices, :] if extra_tokens_pos is not None: extra_tokens_pos = extra_tokens_pos.expand(B, -1, -1) pos_embed = torch.cat([extra_tokens_pos, patch_pos_embed], dim=1) else: pos_embed = patch_pos_embed return pos_embed def _apply_head(self, x: torch.Tensor) -> torch.Tensor: """Apply the classification head to the transformer output. Args: x: Output from transformer blocks, shape (B, N, D) Returns: torch.Tensor: Classification logits or features Note: Handles multiple head types used by different timm models. """ # Try different head application methods used by timm models if hasattr(self.vit, "forward_head"): # Newer timm models with forward_head method return self.vit.forward_head(x) elif hasattr(self.vit, "head"): # Standard ViT: use cls token (first token) if hasattr(self.vit, "fc_norm") and self.vit.fc_norm is not None: # Some models apply additional norm before head x = self.vit.fc_norm(x[:, 0]) return self.vit.head(x) else: return self.vit.head(x[:, 0]) elif hasattr(self.vit, "head_dist"): # DeiT with distillation - has two heads x_cls = self.vit.head(x[:, 0]) x_dist = self.vit.head_dist(x[:, 1]) if self.training: # Return both during training return x_cls, x_dist else: # Average predictions during inference return (x_cls + x_dist) / 2 else: # No head - return raw features return x
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass through the masked ViT. This method implements an optimized forward pass with the following features: - Early exit for inputs without NaN patches (fast path) - Optimized indexing for same masking patterns across batch - Per-sample masking support with advanced indexing - Automatic NaN replacement for partial NaN patches - Support for register tokens (DINOv2 style) Args: x: Input tensor, either: - Raw images: shape (B, C, H, W) - Pre-patchified: shape (B, N, D) where N is number of patches Returns: torch.Tensor: Model output (logits if head exists, features otherwise) Raises: ValueError: If samples have different numbers of NaN patches ValueError: If all patches are NaN Performance Notes: - No NaN patches: Uses fast path with <2% overhead - Same pattern: Optimized indexing, ~0-5% overhead vs different patterns - Different patterns: Uses advanced indexing, ~10-35% slower at high masking """ # Detect if this is a FakeTensor (used for shape inference/tracing) is_fake_tensor = ( x.__class__.__name__ == "FakeTensor" or hasattr(x, "fake_mode") or "Fake" in type(x).__name__ ) if is_fake_tensor or not torch.isnan(x).any(): return self.vit(x) # Patchify if needed if x.ndim == 4: # (B, C, H, W) - raw image # Apply patch embedding x = self.vit.patch_embed(x) # Ensure 3D output (B, N, C) if x.ndim == 4: # Dynamic: (B, H, W, C) -> (B, H*W, C) x = x.flatten(1, 2) elif x.ndim != 3: raise ValueError( f"Expected patch_embed output to be 3D or 4D, got {x.ndim}D with shape {x.shape}" ) elif x.ndim == 3: # (B, N, D) - already patchified pass else: raise ValueError( f"Input must be 4D (B, C, H, W) image or 3D (B, N, D) patches. " f"Got shape: {x.shape}" ) B, N, D = x.shape device = x.device nan_mask = torch.isnan(x).any(dim=2) # (B, N) # Verify same number of NaN patches across batch num_nans = nan_mask.sum(dim=1) if not (num_nans == num_nans[0]).all(): raise ValueError( f"All samples must have the same number of NaN patches. " f"Got counts: {num_nans.tolist()}" ) num_keep = N - num_nans[0].item() if num_keep == 0: raise ValueError("All patches are NaN - cannot process input") # Check if all samples have the same masking pattern same_pattern = (nan_mask == nan_mask[0]).all().item() if same_pattern: # OPTIMIZED PATH: Same pattern for all samples keep_idx = (~nan_mask[0]).nonzero(as_tuple=True)[0] # (num_keep,) x = x[:, keep_idx, :] # Simple indexing - faster # Subsample positional embeddings (optimized) pos_embed = self._subsample_pos_embed_same_pattern(keep_idx, B, N) else: # GENERAL PATH: Different patterns per sample keep_indices = self._get_keep_indices_vectorized(nan_mask, num_keep) # Gather non-NaN patches (advanced indexing) batch_idx = self._get_batch_indices(B, num_keep, device) x = x[batch_idx, keep_indices, :] # Subsample positional embeddings pos_embed = self._subsample_pos_embed_different_patterns( keep_indices, B, N, num_keep ) # Replace any remaining NaNs with zeros (partial NaNs in patches) if torch.isnan(x).any(): x = torch.nan_to_num(x, nan=0.0) # Add cls_token, dist_token, and/or register tokens x = self._add_extra_tokens(x) # Add positional embeddings # The subsample methods ensure pos_embed matches x in length # (includes register tokens when using dynamic_img_size=True) x = x + pos_embed # Apply positional dropout if it exists if hasattr(self.vit, "pos_drop") and self.vit.pos_drop is not None: x = self.vit.pos_drop(x) # Apply patch dropout if it exists (some models have this) if hasattr(self.vit, "patch_drop") and self.vit.patch_drop is not None: x = self.vit.patch_drop(x) # Forward through transformer blocks for blk in self.vit.blocks: x = blk(x) # Apply final norm if hasattr(self.vit, "norm") and self.vit.norm is not None: x = self.vit.norm(x) # Apply head and return return self._apply_head(x)
[docs] def clear_cache(self): """Clear the cached batch indices. Useful if you want to free memory after processing different batch sizes. The cache will be rebuilt as needed during forward passes. """ self._batch_indices_cache.clear()
def _get_keep_indices_vectorized( self, nan_mask: torch.Tensor, num_keep: int ) -> torch.Tensor: """Get keep indices for all samples without Python loops (faster). This vectorized approach is ~2-3x faster than iterating over the batch. Args: nan_mask: Boolean mask indicating NaN patches, shape (B, N) num_keep: Number of patches to keep per sample Returns: torch.Tensor: Keep indices per sample, shape (B, num_keep) Note: Uses topk instead of nonzero to avoid Python loops. The indices are sorted in ascending order. """ B, N = nan_mask.shape device = nan_mask.device # Create index tensor for all samples indices = torch.arange(N, device=device).unsqueeze(0).expand(B, -1) # (B, N) # Mask out NaN positions by setting them to a large value indices_masked = indices.float() indices_masked[nan_mask] = float(N + 1) # Larger than any valid index # Use topk to get smallest indices (non-NaN positions) keep_indices, _ = torch.topk( indices_masked, k=num_keep, dim=1, largest=False, # Get smallest values sorted=True, # Keep sorted for cache friendliness ) return keep_indices.long() def _interpolate_pos_embed(self, pos_embed: torch.Tensor, N: int) -> torch.Tensor: """Interpolate positional embeddings to match the number of patches. This is needed when dynamic_image_size=True and the input size differs from the default/training size. Args: pos_embed: Original positional embeddings, shape (1, N_orig, D) N: Target number of patches Returns: torch.Tensor: Interpolated positional embeddings Note: When using timm with dynamic_img_size=True and reg_tokens, the pos_embed INCLUDES register tokens: [CLS_pos, REG_pos, PATCH_pos] """ num_pos_tokens = self._get_num_pos_tokens() # Check if model has register tokens num_register_tokens = 0 if hasattr(self.vit, "reg_token") and self.vit.reg_token is not None: num_register_tokens = self.vit.reg_token.shape[1] elif hasattr(self.vit, "num_reg_tokens"): num_register_tokens = self.vit.num_reg_tokens N_orig = pos_embed.shape[1] # If already correct size, return as-is # Check both possibilities: with and without register tokens if ( N_orig == N + num_pos_tokens + num_register_tokens or N_orig == N + num_pos_tokens or N_orig == N ): return pos_embed # Determine structure: timm may include register tokens in pos_embed when dynamic_img_size=True # Structure can be: [CLS_pos, REG_pos, PATCH_pos] or [CLS_pos, PATCH_pos] # Calculate expected position with register tokens expected_with_reg = num_pos_tokens + num_register_tokens if N_orig > expected_with_reg and num_register_tokens > 0: # pos_embed includes register tokens: [CLS, REG, PATCHES] extra_tokens_pos = pos_embed[:, :expected_with_reg, :] patch_pos_embed = pos_embed[:, expected_with_reg:, :] elif num_pos_tokens > 0 and N_orig > num_pos_tokens: # pos_embed doesn't include register tokens: [CLS, PATCHES] extra_tokens_pos = pos_embed[:, :num_pos_tokens, :] patch_pos_embed = pos_embed[:, num_pos_tokens:, :] else: # No extra tokens extra_tokens_pos = None patch_pos_embed = pos_embed # Calculate grid sizes N_orig_patches = patch_pos_embed.shape[1] gs_orig = int(N_orig_patches**0.5) gs_new = int(N**0.5) if gs_orig * gs_orig != N_orig_patches: raise RuntimeError( f"Original positional embeddings ({N_orig_patches}) don't form a square grid. " f"Non-square grids require custom interpolation." ) if gs_new * gs_new != N: raise RuntimeError( f"Target number of patches ({N}) doesn't form a square grid. " f"Non-square grids require custom interpolation." ) # Reshape to 2D grid: (1, N_orig, D) -> (1, D, H_orig, W_orig) D = patch_pos_embed.shape[2] patch_pos_embed = patch_pos_embed.reshape(1, gs_orig, gs_orig, D).permute( 0, 3, 1, 2 ) # Interpolate using bicubic (same as timm) patch_pos_embed = torch.nn.functional.interpolate( patch_pos_embed, size=(gs_new, gs_new), mode="bicubic", align_corners=False, antialias=False, ) # Reshape back: (1, D, H_new, W_new) -> (1, N, D) patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).reshape(1, N, D) # Recombine with extra token positions if extra_tokens_pos is not None: pos_embed = torch.cat([extra_tokens_pos, patch_pos_embed], dim=1) else: pos_embed = patch_pos_embed return pos_embed