"""Configuration classes specifying default parameters for stable-SSL."""
from typing import Any, Union
import hydra
import omegaconf
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
[docs]
def collapse_nested_dict(
cfg: Union[dict, object],
level_separator: str = ".",
_base_name: str = None,
_flat_cfg: dict = None,
) -> dict:
"""Parse a Hydra config and make it readable for wandb (flatten).
Args:
cfg (Union[dict, object]): The original (Hydra) nested dict.
level_separator (str, optional): The string to separate level names. Defaults to ".".
_base_name (str, optional): The parent string, used for recursion only, users should ignore.
Defaults to None.
_flat_cfg (dict, optional): The flattened config, used for recursion only, users should ignore.
Defaults to None.
Returns:
dict: Flat config.
"""
# INIT
if _flat_cfg is None:
_flat_cfg = {}
if _base_name is None:
_base_name = ""
if isinstance(cfg, list) or isinstance(cfg, tuple):
for i in range(len(cfg)):
collapse_nested_dict(
cfg[i],
level_separator=level_separator,
_base_name=_base_name + f"{level_separator}{i}",
_flat_cfg=_flat_cfg,
)
elif isinstance(cfg, dict) or isinstance(cfg, omegaconf.dictconfig.DictConfig):
for key in cfg:
collapse_nested_dict(
cfg[key],
level_separator=level_separator,
_base_name=_base_name + f"{level_separator}{key}",
_flat_cfg=_flat_cfg,
)
else:
if _base_name.startswith(level_separator):
_base_name = _base_name[len(level_separator) :]
_flat_cfg[_base_name] = cfg
return _flat_cfg
[docs]
def recursive_instantiate(
cfg: Union[dict, omegaconf.DictConfig], parent_objects: dict = None
) -> dict:
"""Recursively instantiate all components in config with dependency resolution.
Args:
cfg: Configuration dictionary or DictConfig with _target_ fields
parent_objects: Optional dict of already instantiated objects for dependencies
Returns:
Dictionary of instantiated components
"""
if cfg is None:
return {}
instantiated = {}
parent_objects = parent_objects or {}
# Define instantiation order for proper dependency resolution
# Later items can depend on earlier ones
priority_order = ["data", "module", "loss", "callbacks", "logger", "trainer"]
# First pass: instantiate components in priority order
for key in priority_order:
if key in cfg:
try:
if (
isinstance(cfg[key], (dict, omegaconf.DictConfig))
and "_target_" in cfg[key]
):
# Special handling for Module to resolve forward function
if key == "module" and "forward" in cfg[key]:
# Resolve interpolations before converting to dict to handle root-level references
module_cfg = omegaconf.OmegaConf.to_container(
cfg[key], resolve=True
)
# Import the forward function if it's a string reference
if isinstance(module_cfg["forward"], str):
parts = module_cfg["forward"].rsplit(".", 1)
if len(parts) == 2:
import importlib
module = importlib.import_module(parts[0])
module_cfg["forward"] = getattr(module, parts[1])
instantiated[key] = hydra.utils.instantiate(
module_cfg, _recursive_=True
)
else:
# Don't use recursive for DataModule as it handles its own instantiation
if key == "data":
instantiated[key] = hydra.utils.instantiate(
cfg[key], _recursive_=False
)
else:
instantiated[key] = hydra.utils.instantiate(
cfg[key], _recursive_=True
)
else:
instantiated[key] = cfg[key]
except Exception as e:
rank_zero_warn(f"Could not instantiate {key}: {e}")
instantiated[key] = cfg[key]
# Second pass: instantiate remaining components
for key, value in cfg.items():
if key not in instantiated:
try:
if (
isinstance(value, (dict, omegaconf.DictConfig))
and "_target_" in value
):
instantiated[key] = hydra.utils.instantiate(value, _recursive_=True)
else:
instantiated[key] = value
except Exception as e:
rank_zero_warn(f"Could not instantiate {key}: {e}")
instantiated[key] = value
return instantiated
[docs]
def instantiate_from_config(cfg: Union[dict, omegaconf.DictConfig]) -> Any:
"""Main entry point for config-based training.
This function handles the complete instantiation of a training setup from config:
- Recursively instantiates all components
- Creates Manager if trainer/module/data are present
- Returns appropriate object based on config structure
Args:
cfg: Complete configuration dictionary or DictConfig
Returns:
Manager instance if config contains trainer/module/data,
otherwise returns instantiated config dict
"""
from stable_pretraining.manager import Manager
import torch
# Convert to DictConfig if needed
if isinstance(cfg, dict):
cfg = omegaconf.OmegaConf.create(cfg)
# Set matmul precision if specified (must be done before Trainer instantiation)
if "matmul_precision" in cfg and cfg.matmul_precision is not None:
torch.set_float32_matmul_precision(cfg.matmul_precision)
rank_zero_warn(f"Set float32 matmul precision to: {cfg.matmul_precision}")
# Instantiate all components
components = recursive_instantiate(cfg)
# Check if this is a Manager-based config (has trainer, module, data)
if all(k in components for k in ["trainer", "module", "data"]):
# Create Manager for training
manager = Manager(
trainer=components["trainer"],
module=components["module"],
data=components["data"],
seed=components.get("seed", None),
ckpt_path=components.get("ckpt_path", None),
)
return manager
# Otherwise return the instantiated components
return components