Source code for stable_ssl.config

"""Configuration classes specifying default parameters for stable-SSL."""

# # Author: Hugues Van Assel <>
# #         Randall Balestriero <>
# #
# # This source code is licensed under the license found in the
# # LICENSE file in the root directory of this source tree.

import logging
import lzma
import pickle
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Union

import hydra
import omegaconf
from hydra.core.hydra_config import HydraConfig

def collapse_nested_dict(
    cfg: Union[dict, object],
    level_separator: str = ".",
    _base_name: str = None,
    _flat_cfg: dict = None,
) -> dict:
    """Parse a Hydra config and make it readable for wandb (flatten).

    cfg: Union[dict, object]
        The original (Hydra) nested dict.
    level_separator: str, optional
        The string to separate level names. Defaults to ".".
    _base_name: str, optional
        The parent string, used for recursion only, users should ignore.
        Defaults to None.
    _flat_cfg: dict, optional
        The flattened config, used for recursion only, users should ignore.
        Defaults to None.

        dict: Flat config.
    # INIT
    if _flat_cfg is None:
        _flat_cfg = {}
    if _base_name is None:
        _base_name = ""
    if isinstance(cfg, list) or isinstance(cfg, tuple):
        for i in range(len(cfg)):
                _base_name=_base_name + f"{level_separator}{i}",
    elif isinstance(cfg, dict) or isinstance(cfg, omegaconf.dictconfig.DictConfig):
        for key in cfg:
                _base_name=_base_name + f"{level_separator}{key}",
        if _base_name.startswith(level_separator):
            _base_name = _base_name[len(level_separator) :]
        _flat_cfg[_base_name] = cfg
    return _flat_cfg

def instanciate_config(cfg=None, debug_hash=None) -> object:
    """Instantiate the config and debug hash."""
    if debug_hash is None:
        assert cfg is not None
        print("Your debugging hash:", lzma.compress(pickle.dumps(cfg)))
        print("Using debugging hash")
        cfg = pickle.loads(lzma.decompress(debug_hash))
    trainer = hydra.utils.instantiate(
        cfg.trainer, _convert_="object", _recursive_=False
    for key, value in cfg.items():
        if key == "trainer":
            continue"\t=> Adding user arg {key} to Trainer")
        if hasattr(trainer, key):
            raise ValueError(f"User arg {key} already exists in the Trainer {trainer}")
        setattr(trainer, key, value)
    return trainer

[docs] @dataclass class HardwareConfig: """Configuration for the hardware parameters. Parameters ---------- seed : int, optional Random seed for reproducibility. Default is None. float16 : bool, optional Whether to use mixed precision (float16) for training. Default is False. world_size : int, optional Number of processes participating in distributed training. Default is 1. device : str, optional The device to use for training. Default is "cuda" if available, else "cpu". """ seed: Optional[int] = None float16: bool = False world_size: int = 1 device: str = "cuda"
[docs] @dataclass class LoggerConfig: """Configuration for logging and checkpointing during training or evaluation. Parameters ---------- level : int, optional The logging level. Determines the threshold for what gets logged. Default is 20. metric : dict, optional A dictionary to store and log various metrics. Default is an empty dict. monitor : dict, optional A dictionary to store and log various monitoring statistics. Default is an empty dict save_final_model : str or bool, optional Specifies whether to save the final trained model. If a name is provided, the final model will be saved with that name. Default is False. eval_every_epoch : int, optional The frequency (in epochs) at which the model will be evaluated. For example, if set to 1, evaluation occurs every epoch. Default is 1. log_every_step : int, optional The frequency (in training steps) at which to log intermediate metrics. For example, if set to 1, logs occur every step. Default is 1. checkpoint_frequency : int, optional The frequency (in epochs) at which model checkpoints are saved. For example, if set to 10, a checkpoint is saved every 10 epochs. Default is None. checkpoint_model_only : bool, optional Whether to save only the model weights (True) or save additional training state (False) during checkpointing. Default is True. dump_path : pathlib.Path, optional The path where output is dumped. Defaults to Hydra's runtime output directory. wandb : bool or dict or None, optional Configuration for Weights & Biases logging. If `True`, it will be converted to an empty dictionary and default keys will be filled in if `rank == 0`. Default is None. See :mod:`stable_ssl.config.WandbConfig` for the full list of parameters and their defaults. """ level: int = 20 metric: dict = field(default_factory=dict) monitor: dict = field(default_factory=dict) save_final_model: Union[str, bool] = False eval_every_epoch: int = 1 log_every_step: int = 1 checkpoint_frequency: Optional[int] = None checkpoint_model_only: bool = True dump_path: Path = field( default_factory=lambda: Path(HydraConfig.get().runtime.output_dir) ) wandb: Union[bool, dict, None] = None
[docs] @dataclass class WandbConfig: """Configuration for the Weights & Biases logging. Parameters ---------- dir : pathlib.Path, optional The path where output is dumped. Defaults to Hydra's runtime output directory. entity : str, optional Name of the (Weights & Biases) entity. Default is None. project : str, optional Name of the (Weights & Biases) project. Default is None. name : str, optional Name of the Weights & Biases run. Default is None. id : str, optional ID of the Weights & Biases run. Default is None. tags : list, optional List of tags for the Weights & Biases run. Default is None. group : str, optional Group for the Weights & Biases run. Default is None. """ dir: str = field( default_factory=lambda: str(Path(HydraConfig.get().runtime.output_dir)) ) entity: Optional[str] = None project: Optional[str] = None name: Optional[str] = None id: Optional[str] = None tags: Optional[list] = None group: Optional[str] = None
[docs] @dataclass class OptimConfig: """Configuration for the optimization parameters. Parameters ---------- optimizer : dict Configuration for the optimizer. scheduler : dict Configuration for the learning rate scheduler. epochs : int, optional Number of epochs to train the model. Default is 1000. max_steps : int, optional Maximum number of steps to train the model. Default is -1. If negative, the models trains on the full dataset. If it is between 0 and 1, it represents the fraction of the dataset to train on. accumulation_steps : int, optional Number of steps to accumulate gradients before updating the model. Default is 1. grad_max_norm : float, optional Maximum norm of the gradients. If None, no clipping is applied. Default is None. """ optimizer: dict scheduler: dict epochs: int = 1000 max_steps: int = -1 accumulation_steps: int = 1 grad_max_norm: Optional[float] = None