Source code for stable_pretraining.utils.inspection_utils

"""Function inspection and general helper utilities."""

import inspect
from typing import Any, List


[docs] def get_required_fn_parameters(fn): """Get the list of required parameters for a function. Args: fn: The function to inspect Returns: List of parameter names that don't have default values """ sig = inspect.signature(fn) required = [] for name, param in sig.parameters.items(): if param.default is inspect.Parameter.empty: required.append(name) return required
[docs] def dict_values(**kwargs): """Convert keyword arguments to a list of values. Returns: List of values from the provided keyword arguments """ return list(kwargs.values())
[docs] def broadcast_param_to_list( param: Any, target_length: int, param_name: str ) -> List[Any]: """Broadcast a parameter value to create a list of specified length. This function handles the common pattern of accepting either: - None: creates a list of None values - A single value: broadcasts to all positions - A single-element list/tuple: broadcasts the element to all positions - A list/tuple of correct length: returns as-is Args: param: The parameter to broadcast (can be None, single value, or list/tuple) target_length: The desired length of the output list param_name: Name of the parameter for error messages Returns: List of values with length matching target_length Raises: ValueError: If param is a list/tuple with length > 1 that doesn't match target_length Examples: >>> broadcast_param_to_list(None, 3, "dims") [None, None, None] >>> broadcast_param_to_list(5, 3, "dims") [5, 5, 5] >>> broadcast_param_to_list([5], 3, "dims") [5, 5, 5] >>> broadcast_param_to_list([1, 2, 3], 3, "dims") [1, 2, 3] """ if param is None: return [None] * target_length if not isinstance(param, (list, tuple)): # Single value provided for all elements return [param] * target_length if len(param) == 1 and target_length > 1: # Single value in list, broadcast to all elements return list(param) * target_length if len(param) != target_length: raise ValueError( f"Length of {param_name} ({len(param)}) must match target length ({target_length})" ) return list(param)