SklearnCheckpoint#

class stable_ssl.callbacks.SklearnCheckpoint[source]#

Bases: Callback

Callback for saving and loading sklearn models in PyTorch Lightning checkpoints.

on_load_checkpoint(trainer, pl_module, checkpoint)[source]#

Called when loading a model checkpoint, use to reload state.

Parameters:
  • trainer – the current Trainer instance.

  • pl_module – the current LightningModule instance.

  • checkpoint – the full checkpoint dictionary that got loaded by the Trainer.

on_save_checkpoint(trainer, pl_module, checkpoint)[source]#

Called when saving a checkpoint to give you a chance to store anything else you might want to save.

Parameters:
  • trainer – the current Trainer instance.

  • pl_module – the current LightningModule instance.

  • checkpoint – the checkpoint dictionary that will be saved.

setup(trainer: Trainer, pl_module: LightningModule, stage: str | None = None) None[source]#

Called when fit, validate, test, predict, or tune begins.