Source code for stable_pretraining.utils.lightning_patch

"""Monkey-patch for PyTorch Lightning to support manual optimization with Trainer parameters.

This patch modifies Lightning's validation to transfer gradient_clip_val and
accumulate_grad_batches to alternative attributes instead of raising errors.
"""

import logging
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    import lightning.pytorch as pl


[docs] def apply_manual_optimization_patch(): """Apply the monkey-patch to Lightning's manual optimization validation. This patch modifies the __verify_manual_optimization_support function to: 1. Transfer gradient_clip_val to gradient_clip_val_ 2. Transfer accumulate_grad_batches to accumulate_grad_batches_ 3. Clear the original values to avoid Lightning's error This allows users to use standard Trainer parameters even with manual optimization. """ try: # Try new import path first (lightning.pytorch) try: import lightning.pytorch.trainer.configuration_validator as validator_module except ImportError: # Fall back to old import path (pytorch_lightning) import pytorch_lightning.trainer.configuration_validator as validator_module # Store the original function for potential restoration original_verify = validator_module.__verify_manual_optimization_support def patched_verify_manual_optimization_support( trainer: "pl.Trainer", model: "pl.LightningModule" ) -> None: """Patched version that transfers parameters instead of raising errors.""" # Only process if manual optimization is enabled if not model.automatic_optimization: # Transfer gradient clipping parameters if ( trainer.gradient_clip_val is not None and trainer.gradient_clip_val > 0 ): # Save to alternative attributes trainer.gradient_clip_val_ = trainer.gradient_clip_val trainer.gradient_clip_algorithm_ = trainer.gradient_clip_algorithm # Clear originals to avoid Lightning's validation error trainer.gradient_clip_val = None logging.debug( f"Manual optimization patch: Transferred gradient_clip_val={trainer.gradient_clip_val_} " f"to trainer.gradient_clip_val_ (algorithm={trainer.gradient_clip_algorithm_})" ) # Transfer gradient accumulation parameters if trainer.accumulate_grad_batches != 1: # Save to alternative attribute trainer.accumulate_grad_batches_ = trainer.accumulate_grad_batches # Reset to 1 to avoid Lightning's validation error trainer.accumulate_grad_batches = 1 logging.debug( f"Manual optimization patch: Transferred accumulate_grad_batches={trainer.accumulate_grad_batches_} " f"to trainer.accumulate_grad_batches_" ) # No need to call the original since we handled the problematic cases # The original would only raise errors that we're avoiding return None # Apply the monkey-patch validator_module.__verify_manual_optimization_support = ( patched_verify_manual_optimization_support ) # Store reference to original for potential restoration validator_module.__original_verify_manual_optimization_support = original_verify logging.debug( "Successfully applied manual optimization parameter patch for PyTorch Lightning" ) except ImportError as e: logging.warning(f"Could not apply Lightning patch: {e}") except Exception as e: logging.warning(f"Error applying Lightning patch: {e}")
[docs] def restore_original_validation(): """Restore the original Lightning validation function (for testing/debugging).""" try: # Try new import path first try: import lightning.pytorch.trainer.configuration_validator as validator_module except ImportError: import pytorch_lightning.trainer.configuration_validator as validator_module if hasattr(validator_module, "__original_verify_manual_optimization_support"): validator_module.__verify_manual_optimization_support = ( validator_module.__original_verify_manual_optimization_support ) delattr(validator_module, "__original_verify_manual_optimization_support") logging.debug("Restored original Lightning validation function") else: logging.warning("No original validation function found to restore") except ImportError: pass