SklearnCheckpoint#
- class stable_ssl.callbacks.SklearnCheckpoint[source]#
Bases:
Callback
Callback for saving and loading sklearn models in PyTorch Lightning checkpoints.
- on_load_checkpoint(trainer, pl_module, checkpoint)[source]#
Called when loading a model checkpoint, use to reload state.
- Parameters:
trainer – the current
Trainer
instance.pl_module – the current
LightningModule
instance.checkpoint – the full checkpoint dictionary that got loaded by the Trainer.
- on_save_checkpoint(trainer, pl_module, checkpoint)[source]#
Called when saving a checkpoint to give you a chance to store anything else you might want to save.
- Parameters:
trainer – the current
Trainer
instance.pl_module – the current
LightningModule
instance.checkpoint – the checkpoint dictionary that will be saved.