Source code for stable_ssl.callbacks.checkpoint_sklearn
from typing import Optional
import numpy as np
from lightning.pytorch import Callback, LightningModule, Trainer
from loguru import logger as logging
from sklearn.base import ClassifierMixin, RegressorMixin
from tabulate import tabulate
[docs]
class SklearnCheckpoint(Callback):
"""Callback for saving and loading sklearn models in PyTorch Lightning checkpoints."""
[docs]
def setup(
self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None
) -> None:
sklearn_modules = _get_sklearn_modules(pl_module)
stats = []
for name, module in sklearn_modules.items():
stats.append((name, module.__str__(), type(module)))
headers = ["Module", "Name", "Type"]
logging.info("Sklearn Modules:")
logging.info(f"\n{tabulate(stats, headers, tablefmt='heavy_outline')}")
[docs]
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
# Modify the checkpoint dictionary before saving
print("\tChecking for non PyTorch modules to save... 🔧", flush=True)
modules = _get_sklearn_modules(pl_module)
for name, module in modules.items():
if name in checkpoint:
raise RuntimeError(
f"Can't pickle {name}, already present in checkpoint"
)
checkpoint[name] = module
print(f"\t\tsaving non PyTorch system: {name} 🔧", flush=True)
[docs]
def on_load_checkpoint(self, trainer, pl_module, checkpoint):
# Access and use data from the loaded checkpoint
print("\tChecking for non PyTorch modules to load... 🔧", flush=True)
for name, item in checkpoint.items():
if isinstance(item, RegressorMixin) or isinstance(item, ClassifierMixin):
setattr(pl_module, name, item)
print(f"\t\tloading non PyTorch system: {name} 🔧", flush=True)
def _contains_sklearn_module(item):
if isinstance(item, RegressorMixin) or isinstance(item, ClassifierMixin):
return True
if isinstance(item, list):
return np.any([_contains_sklearn_module(m) for m in item])
if isinstance(item, dict):
return np.any([_contains_sklearn_module(m) for m in item.values()])
return False
def _get_sklearn_modules(module):
modules = dict()
for name in dir(module):
if name[0] == "_":
continue
item = getattr(module, name)
if _contains_sklearn_module(item):
modules[name] = item
return modules