Source code for stable_ssl.manager

import copy
import json
import signal
from datetime import timedelta
from pathlib import Path
from typing import Union

import hydra
import lightning
import lightning as pl
import pandas as pd
import submitit
import wandb
from lightning.pytorch.utilities.rank_zero import rank_zero_only
from loguru import logger as logging
from omegaconf import DictConfig, OmegaConf, open_dict

from .utils import get_required_fn_parameters


[docs] class Manager(submitit.helpers.Checkpointable): """Manages training with logging, scheduling, and checkpointing support. Args: trainer (Union[dict, DictConfig, pl.Trainer]): PyTorch Lightning trainer configuration or instance. module (Union[dict, DictConfig, pl.LightningModule]): Lightning module configuration or instance. data (Union[dict, DictConfig, pl.LightningDataModule]): Data module configuration or instance. seed (int, optional): Random seed for reproducibility. Defaults to None. ckpt_path (str, optional): Path to checkpoint for resuming training. Defaults to None. """ def __init__( self, trainer: Union[dict, DictConfig, pl.Trainer], module: Union[dict, DictConfig, pl.LightningModule], data: Union[dict, DictConfig, pl.LightningDataModule], seed: int = None, ckpt_path: str = "last", ): # This is the state that will be saved by `checkpoint` # we do deepcopy in case the user changes things after # padding the dicts (since it will be a copy by reference) if seed is None: logging.warning( "User didn't specify seed, runs won't be exactly reproducible!" ) if type(trainer) is dict: trainer = OmegaConf.create(trainer) if type(trainer) is DictConfig: self.trainer: DictConfig = copy.deepcopy(trainer) logging.debug("\t● trainer config saved ✅") elif isinstance(trainer, pl.Trainer): self.trainer = trainer logging.debug("\t● trainer already instantiated ✅") else: raise ValueError( f"`trainer` must be a dict, DictConfig or pl.Trainer, not {type(trainer)}" ) if type(module) is dict: module = OmegaConf.create(module) if type(module) is DictConfig: self.module: DictConfig = copy.deepcopy(module) logging.debug("\t● module config saved ✅") elif isinstance(module, pl.LightningModule): self.module = module logging.debug("\t● module already instantiated ✅") else: raise ValueError( f"`module` must be a dict, DictConfig or pl.LightningModule, not {type(module)}" ) if type(data) is dict: data = OmegaConf.create(data) if type(data) is DictConfig: self.data: DictConfig = copy.deepcopy(data) logging.debug("\t● data config saved ✅") elif isinstance(data, pl.LightningDataModule): self.data = data logging.debug("\t● data already instantiated ✅") else: raise ValueError( f"`data` must be a dict, DictConfig or pl.LightningDataModule, not {type(data)}" ) self.seed = seed self.ckpt_path = ckpt_path # self.slurm_requeue_signal = slurm_requeue_signal @rank_zero_only def init_and_sync_wandb(self): # only useful with wandb # to set any non-given variable to the DictConfig # so that we can resume on requeue # this override is only useful if receiving parameters # from wandb e.g. using wandb sweep. We retrieve them and # requeue with those instead of user args self.override = [] if isinstance( self.trainer.logger, lightning.pytorch.loggers.logger.DummyLogger ): logging.info("📈📈📈 DummyLogger already setup, skipping init 📈📈📈") return elif isinstance( self.trainer.logger, lightning.pytorch.loggers.tensorboard.TensorBoardLogger ): logging.info("📈📈📈 TensorBoardLogger already setup, skipping init 📈📈📈") logging.info(f"📈📈📈 root_dir={self.trainer.logger.root_dir} 📈📈📈") logging.info(f"📈📈📈 save_dir={self.trainer.logger.save_dir} 📈📈📈") logging.info(f"📈📈📈 log_dir={self.trainer.logger.log_dir} 📈📈📈") return elif isinstance( self.trainer.logger, lightning.pytorch.loggers.csv_logs.CSVLogger ): logging.info("📈📈📈 CSVLogger already setup, skipping init 📈📈📈") logging.info(f"📈📈📈 root_dir={self.trainer.logger.root_dir} 📈📈📈") logging.info(f"📈📈📈 save_dir={self.trainer.logger.save_dir} 📈📈📈") logging.info(f"📈📈📈 log_dir={self.trainer.logger.log_dir} 📈📈📈") return elif isinstance( self.trainer.logger, lightning.pytorch.loggers.wandb.WandbLogger ): logging.info("📈📈📈 WandbLogger already setup, skipping init 📈📈📈") logging.info(f"📈📈📈 init={self.trainer.logger._wandb_init} 📈📈📈") return elif self.trainer.logger is None: logging.warning("📈📈📈 No logger used! 📈📈📈") return if "wandb" not in self.trainer.logger._target_.lower(): return logging.info("📈📈📈 Using Wandb 📈📈📈") exp = self._trainer.logger.experiment with open_dict(self.trainer): prefix = "\t\t● config's " for name in ["entity", "project", "name", "id"]: cfg_value = self.trainer.logger.get(name, None) w_value = getattr(exp, name) if cfg_value == w_value: logging.info(f"{prefix}{name} ({cfg_value}) left as-is ✅") continue self.override.append(f"++manager.trainer.logger.{name}={w_value}") logging.info(f"{prefix}{name} ({cfg_value}) updated to {w_value} ✅") # setattr(self.trainer.logger, name, w_value) self.trainer.logger[name] = w_value if self.trainer.logger.get("resume", None): logging.info(f"{prefix}`resume` already set to `allow`! ✅") else: self.override.append("++manager.trainer.logger.resume=allow") self.trainer.logger.resume = "allow" logging.info(f"{prefix}`resume` set to `allow` for subsequent runs ✅") # we defer adding the config to later # to make sure we use the possibly given # sweep config # if self.logger.get("config", None) is not None: # # this will be a nested dict # config = OmegaConf.to_container(self.logger.config, resolve=True) # # now we flatten # config = pd.json_normalize(config, sep="_") # config = config.to_dict(orient="records")[0] # logging.info(f"\tflattening Hydra's config for Wandb ✅") # self.logger.config = None # else: # config = None if exp.offline: previous_run = self._wandb_previous_dir() logging.info(f"\t\tFound a previous run ({previous_run}), reusing config") with open(previous_run / "files/wandb-config.json", "r") as f: last_config = json.load(f) # at most last_config has an extra `ckpt_path` exp.config.update(last_config) logging.info("\t\treloaded!") elif len(wandb.config.keys()): logging.info("\t\ta Wandb™ config is provided, not uploading Hydra's:") # TODO: make Wandb parameters the trainer one # for key, value in wandb.config.items(): # # need to handle the fact that our base configs have a _ # # and users wouldn't provide that # accessor = key.split(".") # if accessor[0] == "trainer": # accessor = accessor[1:] # if accessor[0] in [ # "data", # "module", # "hardware", # "loss", # "metric", # "optim", # ]: # if "_" != accessor[0][0]: # accessor[0] = "_" + accessor[0] # key = ".".join(accessor) # try: # original = rgetattr(self, key) # rsetattr(self, key, value) # assert rgetattr(self, key) == value # logging.info( # f"\t\t\toverriding: {key} from {original} to {value} ✅" # ) # except Exception as e: # logging.error(f"❌ Error while trying to override {key} ❌") # raise e else: logging.info("\tWandb's config is empty, using Hydra's 📤") config = dict( trainer=OmegaConf.to_container(self.trainer, resolve=True), module=OmegaConf.to_container(self.module, resolve=True), data=OmegaConf.to_container(self.data, resolve=True), ) config = pd.json_normalize(config, sep=".") config = config.to_dict(orient="records")[0] while True: logging.info("\t\tflattening one level of Hydra's config) 📤") valid = True for k in list(config.keys()): if type(config[k]) is list: valid = False for i, j in enumerate(config[k]): config[f"{k}.{i}"] = j del config[k] config = pd.json_normalize(config, sep=".") config = config.to_dict(orient="records")[0] if valid: break logging.info(f"\tFinal Hydra's config has {len(config)} items) 📤") wandb.config.update(config) # TODO: should we updated the config to the DictConfig too for next run to check? # with open_dict(self.logger): # self.trainer.logger.config = config @property def instantiated_module(self): if not isinstance(self.module, pl.LightningModule): logging.info("\t● instantiating pl_module...") # with self._trainer.init_module(): self._instantiated_module = hydra.utils.instantiate( self.module, _convert_="object" ) logging.info("\t● module instantiated ✅") else: self._instantiated_module = self.module return self._instantiated_module @property def instantiated_data(self): if not isinstance(self.data, pl.LightningDataModule): self._instantiated_data = hydra.utils.instantiate( self.data, _convert_="object", _recursive_=False ) logging.info("\t● data instantiated ✅") else: self._instantiated_data = self.data return self._instantiated_data def __call__(self): # self._setup_logging() logging.info(f"📁📁📁 CURRENT WORKING DIR: {Path().resolve()} 📁📁📁") # if "SLURM_JOB_ID" in os.environ: # # single-node and multi-node distributed training on SLURM cluster # # requeue job on SLURM preemption # self.submitit_signal = signal.getsignal( # signal.__dict__[self.slurm_requeue_signal] # ) # logging.info(f"\t● saved signal {self.submitit_signal} ✅") # logging.info( # f"\t● setting up checkpoint and requeue on {self.slurm_requeue_signal} ✅" # ) # signal.signal( # signal.__dict__[self.slurm_requeue_signal], self.checkpoint_and_requeue # ) logging.info(f"🌱🌱🌱 SEEDING EVERYTHING with {self.seed=} 🌱🌱🌱") pl.seed_everything(self.seed, workers=True, verbose=False) if isinstance(self.trainer, pl.Trainer): self._trainer = self.trainer else: if "callbacks" in self.trainer: logging.info("\t● instantiating callbacks...") callbacks = hydra.utils.instantiate( self.trainer.callbacks, _convert_="object" ) for i, callback in enumerate(callbacks): if not callable(callback): continue assert ["pl_module"] == get_required_fn_parameters(callback) callbacks[i] = callback(pl_module=self.instantiated_module) logging.info("\t● callbacks instantiated ✅") del self.trainer.callbacks else: callbacks = [] # we use the following partial to give our init callbacks manually since otherwise # hydra instantiate throws an error self._trainer = hydra.utils.instantiate( self.trainer, _convert_="object", _partial_=True ) self._trainer = self._trainer(callbacks=callbacks) if not isinstance(self._trainer, pl.Trainer): raise ValueError("`trainer` should be a Trainer") logging.info("\t● trainer instantiated ✅") self.init_and_sync_wandb() logging.info("\t● logger updated accordingly ✅") logging.info("\t● 👂👂👂 SIGNALS HANDLERS 👂👂👂") logging.info(f"\t\t- SIGUSR1: `{signal.getsignal(signal.SIGUSR1)}`") logging.info(f"\t\t- SIGUSR2: `{signal.getsignal(signal.SIGUSR2)}`") logging.info(f"\t\t- SIGCONT: `{signal.getsignal(signal.SIGCONT)}`") logging.info(f"\t\t- SIGTERM: `{signal.getsignal(signal.SIGTERM)}`") # when using submitit launcher, Hydra uses its own checkpoint method # https://github.com/facebookresearch/hydra/blob/main/plugins/hydra_submitit_launcher/hydra_plugins/hydra_submitit_launcher/submitit_launcher.py#L78 # we replace it with ours # https://github.com/facebookresearch/hydra/issues/2042 # https://github.com/facebookincubator/submitit/blob/main/submitit/core/job_environment.py#L212 fn = signal.getsignal(signal.SIGUSR2) if hasattr(fn, "__self__"): self._hydra_self = fn.__self__._delayed.function.checkpoint.__self__ fn.__self__._delayed.function.checkpoint = self.checkpoint else: self._hydra_self = self # logging.info(f"\t● Searching for checkpoint to warm restart...") # ckpt_path = None # if wandb.run and not wandb.run.offline: # logging.info( # f"\t\t● Wandb is online... searching for `requeue_checkpoint` in Artifacts..." # ) # r = wandb.Api().run(wandb.run.path) # artifacts = r.logged_artifacts() # logging.info(f"\t\t● wandb run artifacts:") # for artifact in artifacts: # logging.info( # f"\t\t\t● {artifact.name}, {artifact.type}, {artifact.created_at}" # ) # if artifact.name.split(":")[0] == "requeue_checkpoint": # logging.info(f"\t\t● Checkpoint found! 🔥") # ckpt_path = artifact # # we wait the end to download to make sure we get the # # latest version # if ckpt_path: # datadir = Path(ckpt_path.download()) / "checkpoint.ckpt" # logging.info(f"\t● Checkpoint downloaded ({datadir})! 🔥") # ckpt_path = datadir # else: # logging.info( # f"\t\t● No Checkpoint artifact found in Wandb artifacts... searching in config ❌" # ) # if "ckpt_path" in wandb.run.config: # logging.info( # f"\t\t● `ckpt_path` found in Wandb config: {ckpt_path} 🔥" # ) # ckpt_path = wandb.run.config["ckpt_path"] # if ckpt_path is not None and Path(ckpt_path).is_file(): # logging.info( # f"\t\t● `ckpt_path` found in Wandb config: {ckpt_path} 🔥" # ) # else: # logging.info( # f"\t\t● `{ckpt_path=}` is not a valid file... not using it ❌" # ) # ckpt_path = None # else: # logging.info( # f"\t\t● `ckpt_path` not found in online Wandb config ...!" # ) # else: # logging.info( # f"\t\t● Wandb is offline... searching in local logger's config..." # ) # cfg = self.trainer.logger.get("config", {}) # if cfg and cfg.get("ckpt_path", None): # ckpt_path = cfg["ckpt_path"] # if ckpt_path is not None and Path(ckpt_path).is_file(): # logging.info( # f"\t\t● `ckpt_path` found in local config: {ckpt_path} 🔥" # ) # else: # logging.info( # f"\t\t● `{ckpt_path=}` is not a valid file... not using it ❌" # ) # ckpt_path = None # if ckpt_path is None: # logging.error(f"\t\t● No checkpoint found! ❌") logging.info("\t● 📞📞📞 CALLBACKS 📞📞📞") for c in self._trainer.checkpoint_callbacks: logging.info(c) for c in self._trainer.early_stopping_callbacks: logging.info(c) logging.info(f"📣📣📣 CALLING trainer.fit with {self.ckpt_path=} 📣📣📣") self._trainer.fit( self.instantiated_module, datamodule=self.instantiated_data, ckpt_path=self.ckpt_path, ) self._dump_wandb_data() def validate(self): logging.info("📣📣📣 CALLING trainer.validate 📣📣📣") self._trainer.validate( self.instantiated_module, datamodule=self.instantiated_data ) self._dump_wandb_data() def predict(self): logging.info("📣📣📣 CALLING trainer.predict 📣📣📣") self._trainer.predict( self.instantiated_module, datamodule=self.instantiated_data ) self._dump_wandb_data() def test(self): logging.info("📣📣📣 CALLING trainer.test 📣📣📣") self._trainer.test(self.instantiated_module, datamodule=self.instantiated_data) self._dump_wandb_data() # wandb.finish() # logging.info(f"closing wandb 🗑️") # cfg = wandb.run.config.as_dict() # return cfg, module.info @rank_zero_only def _dump_wandb_data(self): if wandb.run is None or not wandb.run.offline: return # Print the summary logging.info("Summary:") summary_dict = wandb.run.summary._as_dict() logging.info(json.dumps(summary_dict, indent=2)) fname = Path(wandb.run.dir) / "wandb-summary.json" if fname.is_file(): raise RuntimeError(f"Summary file already exists {fname}") with open(fname, "w") as f: json.dump(summary_dict, f) logging.info(f"\t● Saved summary at {fname} ✅") fname = Path(wandb.run.dir) / "wandb-config.json" if fname.is_file(): raise RuntimeError(f"Config file already exists {fname}") with open(fname, "w") as f: json.dump(wandb.run.config.as_dict(), f) logging.info(f"\t● Saved config at {fname} ✅") def _wandb_previous_dir(self): # to remove the /files path = Path(wandb.run.dir).parent logging.info(f"\t\t● fetching previous Wandb runs from {path.parent} ✅") # this will be of the form # offline-run-20250413_025716-p8117tgi runs = list(path.parent.glob(f"offline-run-*-{wandb.run.id}")) logging.info(f"\t\t● found {len(runs)} run(s):") runs = sorted(runs) for run in runs: logging.info(f"\t\t\t{run.name}") assert runs[-1] == path if len(runs) == 1: return None return runs[-2] def save_checkpoint(self, path=None): # TODO: figure out how to flush logging in subprocess print("Entering checkpoint method", flush=True) if path is None: path = (Path() / "checkpoint.ckpt").resolve() print(f"\t● saving checkpoint to local path {path} ⏳", flush=True) else: path = Path(path) if not path.parent.is_dir(): path.parent.mkdir(parents=True) print(f"\t● saving checkpoint to user's path {path} ⏳", flush=True) self._trainer.save_checkpoint(str(path)) print("\t● checkpoint saved ✅", flush=True) self._upload_checkpoint_for_requeue(path) @rank_zero_only def _upload_checkpoint_for_requeue(self, ckpt_path): # if "ckpt_path" in wandb.run.config: # ckpt_path = Path(wandb.run.config["ckpt_path"]) # print(f"\t● `ckpt_path` already in config, updating it!", flush=True) # else: # ckpt_path = Path(wandb.run.dir) / "checkpoint.ckpt" # print(f"\t● `ckpt_path` set to {ckpt_path}!", flush=True) if not wandb.run.offline: print("\t● Wandb used and online:", flush=True) artifact = wandb.Artifact("requeue_checkpoint", "model") artifact.add_file(str(ckpt_path)) artifact.ttl = timedelta(days=30) print("\t\t● artifact created ✅", flush=True) wandb.run.log_artifact(artifact) print("\t\t● artifact logged ✅", flush=True) ckpt_path.unlink() print("\t\t● local checkpoint deleted ✅", flush=True) else: print("\t● Wandb used and offline:", flush=True) wandb.run.config.update({"ckpt_path": str(ckpt_path.resolve())}) print("\t● `ckpt_path` added to Wandb config ✅", flush=True) # for offline case self._dump_wandb_data() # @rank_zero_only # def requeue(self, *args, **kwargs): # print(f"\t● requeing! 🔥", flush=True) # print(args, kwargs) # self.submitit_signal(*args, **kwargs)
[docs] @rank_zero_only def checkpoint(self, *args, **kwargs): print( "⚠️⚠️⚠️ only SLURM should use this function," "users should use `save_checkpoint` ⚠️⚠️⚠️", flush=True, ) assert len(kwargs) == 0 assert type(args[0]) is list print("Original Hydra's overrides: ", args[0], flush=True) # TODO add a check to make sure we don't double override anything print("Adding our overrides: ", self.override, flush=True) args[0].extend(self.override) inst = self._hydra_self.__class__(**self._hydra_self.params) inst.config = self._hydra_self.config inst.hydra_context = self._hydra_self.hydra_context inst.task_function = self._hydra_self.task_function # print(self._hydra_checkpoint, flush=True) # return self._hydra_checkpoint(*args, **kwargs) # child = self.__class__( # self.trainer, self.module, self.data, self.seed, self.ckpt_path # ) print("Requeing with: ", int, args, kwargs, flush=True) return submitit.helpers.DelayedSubmission(inst, *args, **kwargs)
# def checkpoint_and_requeue(self, *args, **kwargs): # self.save_checkpoint() # self.requeue(*args, **kwargs)