Source code for stable_pretraining.optim.utils
"""Shared utilities for optimizer and scheduler configuration."""
import inspect
from functools import partial
from typing import Union
import torch
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from .. import optim as ssl_optim
[docs]
def create_optimizer(
params,
optimizer_config: Union[str, dict, partial, type],
) -> torch.optim.Optimizer:
"""Create an optimizer from flexible configuration.
This function provides a unified way to create optimizers from various configuration formats,
used by both Module and OnlineProbe for consistency.
Args:
params: Parameters to optimize (e.g., model.parameters())
optimizer_config: Can be:
- str: optimizer name from torch.optim or stable_pretraining.optim (e.g., "AdamW", "LARS")
- dict: {"type": "AdamW", "lr": 1e-3, ...}
- partial: pre-configured optimizer factory
- class: optimizer class (e.g., torch.optim.AdamW)
Returns:
Configured optimizer instance
Examples:
>>> # String name (uses default parameters)
>>> opt = create_optimizer(model.parameters(), "AdamW")
>>> # Dict with parameters
>>> opt = create_optimizer(
... model.parameters(), {"type": "SGD", "lr": 0.1, "momentum": 0.9}
... )
>>> # Using partial
>>> from functools import partial
>>> opt = create_optimizer(
... model.parameters(), partial(torch.optim.Adam, lr=1e-3)
... )
>>> # Direct class
>>> opt = create_optimizer(model.parameters(), torch.optim.RMSprop)
"""
# Handle Hydra config objects
if hasattr(optimizer_config, "_target_"):
return instantiate(optimizer_config, params=params, _convert_="object")
# partial -> call with params
if isinstance(optimizer_config, partial):
return optimizer_config(params)
# callable (including optimizer factories, but not classes)
if callable(optimizer_config) and not isinstance(optimizer_config, type):
return optimizer_config(params)
# dict -> extract type and kwargs
if isinstance(optimizer_config, (dict, DictConfig)):
# Convert DictConfig to dict if needed
if isinstance(optimizer_config, DictConfig):
config_copy = OmegaConf.to_container(optimizer_config, resolve=True)
else:
config_copy = optimizer_config.copy()
opt_type = config_copy.pop("type", "AdamW")
kwargs = config_copy
else:
opt_type = optimizer_config
kwargs = {}
# resolve class
if isinstance(opt_type, str):
if hasattr(torch.optim, opt_type):
opt_class = getattr(torch.optim, opt_type)
elif hasattr(ssl_optim, opt_type):
opt_class = getattr(ssl_optim, opt_type)
else:
torch_opts = [n for n in dir(torch.optim) if n[0].isupper()]
ssl_opts = [n for n in dir(ssl_optim) if n[0].isupper()]
raise ValueError(
f"Optimizer '{opt_type}' not found. Available in torch.optim: "
+ ", ".join(torch_opts)
+ ". Available in stable_pretraining.optim: "
+ ", ".join(ssl_opts)
)
else:
opt_class = opt_type
try:
return opt_class(params, **kwargs)
except TypeError as e:
sig = inspect.signature(opt_class.__init__)
required = [
p.name
for p in sig.parameters.values()
if p.default == inspect.Parameter.empty and p.name not in ["self", "params"]
]
raise TypeError(
f"Failed to create {opt_class.__name__}. Required parameters: {required}. "
f"Provided: {list(kwargs.keys())}. Original error: {e}"
)