SklearnCheckpoint#

class stable_pretraining.callbacks.SklearnCheckpoint[source]#

Bases: Callback

Callback for saving and loading sklearn models in PyTorch Lightning checkpoints.

This callback automatically detects sklearn models (Regressors and Classifiers) attached to the Lightning module and handles their serialization/deserialization during checkpoint save/load operations. This is necessary because sklearn models are not natively supported by PyTorch’s checkpoint system.

The callback will: 1. Automatically discover sklearn models attached to the Lightning module 2. Save them to the checkpoint dictionary during checkpoint saving 3. Restore them from the checkpoint during checkpoint loading 4. Log information about discovered sklearn modules during setup

Note

  • Only attributes that are sklearn RegressorMixin or ClassifierMixin instances are saved

  • Private attributes (starting with ‘_’) are ignored

  • The callback will raise an error if a sklearn model name conflicts with existing checkpoint keys

on_load_checkpoint(trainer, pl_module, checkpoint)[source]#

Called when loading a model checkpoint, use to reload state.

Parameters:
  • trainer – the current Trainer instance.

  • pl_module – the current LightningModule instance.

  • checkpoint – the full checkpoint dictionary that got loaded by the Trainer.

on_save_checkpoint(trainer, pl_module, checkpoint)[source]#

Called when saving a checkpoint to give you a chance to store anything else you might want to save.

Parameters:
  • trainer – the current Trainer instance.

  • pl_module – the current LightningModule instance.

  • checkpoint – the checkpoint dictionary that will be saved.

setup(trainer: Trainer, pl_module: LightningModule, stage: str | None = None) None[source]#

Called when fit, validate, test, predict, or tune begins.