Source code for stable_ssl.callbacks.checkpoint_sklearn

from typing import Optional

import numpy as np
from lightning.pytorch import Callback, LightningModule, Trainer
from loguru import logger as logging
from sklearn.base import ClassifierMixin, RegressorMixin
from tabulate import tabulate


[docs] class SklearnCheckpoint(Callback): """Callback for saving and loading sklearn models in PyTorch Lightning checkpoints."""
[docs] def setup( self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None ) -> None: sklearn_modules = _get_sklearn_modules(pl_module) stats = [] for name, module in sklearn_modules.items(): stats.append((name, module.__str__(), type(module))) headers = ["Module", "Name", "Type"] logging.info("Sklearn Modules:") logging.info(f"\n{tabulate(stats, headers, tablefmt='heavy_outline')}")
[docs] def on_save_checkpoint(self, trainer, pl_module, checkpoint): # Modify the checkpoint dictionary before saving print("\tChecking for non PyTorch modules to save... 🔧", flush=True) modules = _get_sklearn_modules(pl_module) for name, module in modules.items(): if name in checkpoint: raise RuntimeError( f"Can't pickle {name}, already present in checkpoint" ) checkpoint[name] = module print(f"\t\tsaving non PyTorch system: {name} 🔧", flush=True)
[docs] def on_load_checkpoint(self, trainer, pl_module, checkpoint): # Access and use data from the loaded checkpoint print("\tChecking for non PyTorch modules to load... 🔧", flush=True) for name, item in checkpoint.items(): if isinstance(item, RegressorMixin) or isinstance(item, ClassifierMixin): setattr(pl_module, name, item) print(f"\t\tloading non PyTorch system: {name} 🔧", flush=True)
def _contains_sklearn_module(item): if isinstance(item, RegressorMixin) or isinstance(item, ClassifierMixin): return True if isinstance(item, list): return np.any([_contains_sklearn_module(m) for m in item]) if isinstance(item, dict): return np.any([_contains_sklearn_module(m) for m in item.values()]) return False def _get_sklearn_modules(module): modules = dict() for name in dir(module): if name[0] == "_": continue item = getattr(module, name) if _contains_sklearn_module(item): modules[name] = item return modules