Source code for stable_pretraining.utils.config

"""Configuration utilities and model manipulation helpers."""

import functools

import hydra
import torch
from hydra.core.hydra_config import HydraConfig
from omegaconf import OmegaConf, open_dict
import torch.distributed as dist
from typing import Any


[docs] def is_dist() -> bool: """Returns True if torch.distributed is available and initialized.""" return dist.is_available() and dist.is_initialized()
[docs] def load_hparams_from_ckpt(path: str) -> Any: """Loads a checkpoint safely in both distributed and non-distributed settings. - If torch.distributed is initialized, only src_rank loads from disk, then broadcasts to all other ranks. - If not, loads directly from disk. Args: path: Path to checkpoint file. Returns: The loaded hparams. """ if is_dist(): world_size = dist.get_world_size() rank = dist.get_rank() # If only one process, just load directly if world_size == 1: return torch.load( path, map_location=torch.device("meta"), weights_only=False ) obj = None if rank == 0: obj = torch.load( path, map_location=torch.device("meta"), weights_only=False )["hyper_parameters"] obj_list = [obj] dist.broadcast_object_list(obj_list, src=0) dist.barrier() return obj_list[0] else: return torch.load(path, map_location=torch.device("meta"))["hyper_parameters"]
[docs] def execute_from_config(manager, cfg): """Execute a function with support for submitit job submission. If submitit configuration is present, submits the job to a cluster. Otherwise executes locally. Args: manager: Function or callable to execute cfg: Configuration dictionary Returns: Result of the executed function """ if "submitit" in cfg: assert "hydra" not in cfg hydra_conf = HydraConfig.get() # force_add ignores nodes in struct mode or Structured Configs nodes # and updates anyway, inserting keys as needed. OmegaConf.set_struct(cfg, True) with open_dict(cfg): OmegaConf.update( cfg, "hydra.sweep.dir", hydra_conf.sweep.dir, force_add=True ) with open_dict(cfg): cfg.hydra = {} cfg.hydra.job = OmegaConf.create(hydra_conf.job) cfg.hydra.sweep = OmegaConf.create(hydra_conf.sweep) cfg.hydra.run = OmegaConf.create(hydra_conf.run) executor = hydra.utils.instantiate(cfg.submitit.executor, _convert_="object") executor.update_parameters(**cfg.submitit.get("update_parameters", {})) job = executor.submit(manager) return job.result() else: return manager()
[docs] def adapt_resnet_for_lowres(model): """Adapt a ResNet model for low resolution images. Modifies the first convolution layer to use 3x3 kernels with stride 1 and removes the max pooling layer. Args: model: ResNet model to adapt Returns: Modified model """ model.conv1 = torch.nn.Conv2d( 3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, ) model.maxpool = torch.nn.Identity() return model
[docs] def rsetattr(obj, attr, val): """Recursively set an attribute using dot notation. Args: obj: Object to set attribute on attr: Attribute path (e.g., "module.layer.weight") val: Value to set """ pre, _, post = attr.rpartition(".") parent = rgetattr(obj, pre) if pre else obj if type(parent) is dict: parent[post] = val else: return setattr(parent, post, val)
def _adaptive_getattr(obj, attr): """Get attribute that works with both objects and dictionaries.""" if type(obj) is dict: return obj[attr] else: return getattr(obj, attr)
[docs] def rgetattr(obj, attr): """Recursively get an attribute using dot notation. Args: obj: Object to get attribute from attr: Attribute path (e.g., "module.layer.weight") Returns: The requested attribute value """ return functools.reduce(_adaptive_getattr, [obj] + attr.split("."))
[docs] def find_module(model: torch.nn.Module, module: torch.nn.Module): """Find all instances of a module type in a model. Args: model: Model to search in module: Module class to search for Returns: Tuple of (names, modules) where names are the module paths and modules are the actual module instances """ names = [] values = [] for child_name, child in model.named_modules(): if isinstance(child, module): names.append(child_name) values.append(child) return names, values
[docs] def replace_module(model, replacement_mapping): """Replace modules in a model based on a mapping function. Args: model: PyTorch model to modify replacement_mapping: Function that takes (name, module) and returns the replacement module Returns: Modified model """ if not isinstance(model, torch.nn.Module): raise ValueError("Torch.nn.Module expected as input.") for name, module in model.named_modules(): if name == "": continue replacement = replacement_mapping(name, module) module_names = name.split(".") # we go down the tree up to the parent parent = model for name in module_names[:-1]: parent = getattr(parent, name) setattr(parent, module_names[-1], replacement) return model