Source code for stable_pretraining.callbacks.embedding_cache

import lightning as pl
from loguru import logger as logging


[docs] class EmbeddingCache(pl.pytorch.Callback): """Cache embedding from a module given their names. Args: module_names (list of str): List of module names to hook (e.g., ['layer1', 'encoder.block1']). add_to_forward_output (bool): If True, enables merging cached outputs into the dict returned by forward. """ def __init__(self, module_names: list, add_to_forward_output: bool = True): super().__init__() logging.info("Init of EmbeddingCache callback with") logging.info(f"\t - {len(module_names)} module names") logging.info(f"\t - {add_to_forward_output}") self.module_names = module_names self.add_to_forward_output = add_to_forward_output self.hooks = []
[docs] def setup(self, trainer, pl_module, stage=None): logging.info("Setup of EmbeddingCache") if hasattr(pl_module, "embedding_cache"): raise RuntimeError("A embedding_cache is already present") pl_module.embedding_cache = {} for name in self.module_names: module = self._get_module_by_name(pl_module, name) if module is None: raise ValueError(f"Module '{name}' not found in LightningModule.") hook = module.register_forward_hook(self._make_hook(name, pl_module)) self.hooks.append(hook) logging.info("\t - adding forward hook") pl_module.register_forward_hook(self.forward_hook_fn)
[docs] def teardown(self, trainer, pl_module, stage=None): for hook in self.hooks: hook.remove() self.hooks.clear() if hasattr(pl_module, "embedding_cache"): del pl_module.embedding_cache if hasattr(pl_module, "_addembedding_cache_to_forward"): del pl_module._addembedding_cache_to_forward
[docs] def on_train_batch_start( self, trainer, pl_module, batch, batch_idx, dataloader_idx=0 ): pl_module.embedding_cache.clear()
[docs] def on_validation_batch_start( self, trainer, pl_module, batch, batch_idx, dataloader_idx=0 ): pl_module.embedding_cache.clear()
[docs] def on_test_batch_start( self, trainer, pl_module, batch, batch_idx, dataloader_idx=0 ): pl_module.embedding_cache.clear()
def _make_hook(self, name, pl_module): def hook(module, input, output): pl_module.embedding_cache[name] = output return hook def _get_module_by_name(self, pl_module, name): module = pl_module for attr in name.split("."): if not hasattr(module, attr): return None module = getattr(module, attr) return module
[docs] def forward_hook_fn(self, pl_module, args, outputs) -> None: """Perform probe training step.""" # Extract batch from args tuple (it's the first argument to forward) if not self.add_to_forward_output: return outputs outputs.update(pl_module.embedding_cache) return outputs