Source code for stable_ssl.callbacks.trainer_info
import lightning.pytorch as pl
import torch
from lightning.pytorch import Callback
from loguru import logger as logging
from prettytable import PrettyTable
from pytorch_lightning.utilities import rank_zero_only
from ..data.module import DataModule
[docs]
class ModuleSummary(pl.Callback):
"""Callback for logging module summaries in a formatted table."""
[docs]
@rank_zero_only
def setup(self, trainer, pl_module, stage):
headers = [
"Module",
"Trainable parameters",
"Non Trainable parameters",
"Uninitialized parameters",
"Buffers",
]
table = PrettyTable()
table.field_names = headers
table.align["Module"] = "l"
table.align["Trainable parameters"] = "r"
table.align["Non Trainable parameters"] = "r"
table.align["Uninitialized parameters"] = "r"
table.align["Buffers"] = "r"
logging.info("PyTorch Modules:")
for name, module in pl_module.named_modules():
num_trainable = 0
num_nontrainable = 0
num_buffer = 0
num_uninitialized = 0
for p in module.parameters():
if isinstance(p, torch.nn.parameter.UninitializedParameter):
n = 0
num_uninitialized += 1
else:
n = p.numel()
if p.requires_grad:
num_trainable += n
else:
num_nontrainable += n
for p in module.buffers():
if isinstance(p, torch.nn.parameter.UninitializedBuffer):
n = 0
num_uninitialized += 1
else:
n = p.numel()
num_buffer += n
table.add_row(
[name, num_trainable, num_nontrainable, num_uninitialized, num_buffer]
)
print(table)
return super().setup(trainer, pl_module, stage)
[docs]
class LoggingCallback(pl.Callback):
"""Callback for logging validation metrics in a formatted table."""
[docs]
@rank_zero_only
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
metrics = trainer.callback_metrics
table = PrettyTable()
table.field_names = ["Metric", "Value"]
for key in sorted(metrics):
if key not in ["log", "progress_bar"]:
table.add_row(
[
"\033[0;34;40m" + key + "\033[0m",
"\033[0;32;40m" + str(metrics[key].item()) + "\033[0m",
]
)
print(table)
[docs]
class TrainerInfo(Callback):
"""Callback for linking trainer to DataModule and providing extra information."""
[docs]
def setup(self, trainer, pl_module, stage):
logging.info("\t linking trainer to DataModule! 🔧")
if not isinstance(trainer.datamodule, DataModule):
logging.warning("Using a custom DataModule, won't have extra info!")
return
trainer.datamodule.set_pl_trainer(trainer)
return super().setup(trainer, pl_module, stage)