SklearnCheckpoint#
- class stable_pretraining.callbacks.SklearnCheckpoint[source]#
Bases:
Callback
Callback for saving and loading sklearn models in PyTorch Lightning checkpoints.
This callback automatically detects sklearn models (Regressors and Classifiers) attached to the Lightning module and handles their serialization/deserialization during checkpoint save/load operations. This is necessary because sklearn models are not natively supported by PyTorch’s checkpoint system.
The callback will: 1. Automatically discover sklearn models attached to the Lightning module 2. Save them to the checkpoint dictionary during checkpoint saving 3. Restore them from the checkpoint during checkpoint loading 4. Log information about discovered sklearn modules during setup
Note
Only attributes that are sklearn RegressorMixin or ClassifierMixin instances are saved
Private attributes (starting with ‘_’) are ignored
The callback will raise an error if a sklearn model name conflicts with existing checkpoint keys
- on_load_checkpoint(trainer, pl_module, checkpoint)[source]#
Called when loading a model checkpoint, use to reload state.
- Parameters:
trainer – the current
Trainer
instance.pl_module – the current
LightningModule
instance.checkpoint – the full checkpoint dictionary that got loaded by the Trainer.
- on_save_checkpoint(trainer, pl_module, checkpoint)[source]#
Called when saving a checkpoint to give you a chance to store anything else you might want to save.
- Parameters:
trainer – the current
Trainer
instance.pl_module – the current
LightningModule
instance.checkpoint – the checkpoint dictionary that will be saved.