Source code for stable_worldmodel.spaces

"""Extended Gymnasium spaces with state tracking and constraint support."""

import time

from gymnasium import spaces
from loguru import logger as logging

import stable_worldmodel as swm


[docs] class Discrete(spaces.Discrete): """Extended discrete space with state tracking and constraint support. This class extends ``gymnasium.spaces.Discrete`` to add state management and optional constraint validation. Unlike the standard discrete space, this version maintains a current value and supports rejection sampling via a custom constraint function. Attributes: init_value (int): The initial value for the space. value (int): The current value of the space. constrain_fn (callable): Optional function that returns True if a value satisfies custom constraints. Example: Create a discrete space that only accepts even numbers:: space = Discrete(n=10, init_value=0, constrain_fn=lambda x: x % 2 == 0) value = space.sample() # Samples even number and updates space.value space.reset() # Resets space.value back to 0 (init_value) Note: The ``sample()`` method uses rejection sampling when a constraint function is provided, which may impact performance for difficult constraints. """
[docs] def __init__(self, *args, init_value=None, constrain_fn=None, **kwargs): """Initialize a Discrete space with state tracking. Args: *args: Positional arguments passed to gymnasium.spaces.Discrete. init_value (int, optional): Initial value for the space. Defaults to None. constrain_fn (callable, optional): Function that takes an int and returns True if the value satisfies custom constraints. Defaults to None. **kwargs: Keyword arguments passed to gymnasium.spaces.Discrete. """ super().__init__(*args, **kwargs) self._init_value = init_value self.constrain_fn = constrain_fn or (lambda x: True) self._value = init_value
@property def init_value(self): """int: The initial value of the space, returned by reset().""" return self._init_value @property def value(self): """int: The current value of the space.""" return self._value
[docs] def reset(self): """Reset the space value to its initial value. Sets the current value back to the init_value specified during initialization. """ self._value = self.init_value
[docs] def contains(self, x): """Check if value is valid and satisfies constraints. Args: x (int): The value to check. Returns: bool: True if x is within bounds and satisfies the constraint function, False otherwise. """ return super().contains(x) and self.constrain_fn(x)
[docs] def check(self): """Validate the current space value. Checks if the current value is within the space bounds and satisfies the constraint function. Logs a warning if the constraint fails. Returns: bool: True if the current value is valid, False otherwise. """ if not self.constrain_fn(self.value): logging.warning(f"Discrete: value {self.value} does not satisfy constrain_fn") return False return super().contains(self.value)
[docs] def sample(self, *args, max_tries=1000, warn_after_s=5.0, set_value=True, **kwargs): """Sample a random value using rejection sampling for constraints. Repeatedly samples values until one satisfies the constraint function or max_tries is reached. Optionally updates the space's current value. Args: *args: Positional arguments passed to gymnasium.spaces.Discrete.sample(). max_tries (int, optional): Maximum number of sampling attempts before raising an error. Defaults to 1000. warn_after_s (float, optional): Time threshold in seconds after which to log a warning about slow sampling. Set to None to disable. Defaults to 5.0. set_value (bool, optional): Whether to update the space's current value with the sampled value. Defaults to True. **kwargs: Keyword arguments passed to gymnasium.spaces.Discrete.sample(). Returns: int: A sampled value that satisfies the constraint function. Raises: RuntimeError: If no valid sample is found after max_tries attempts. """ start = time.time() for i in range(max_tries): sample = super().sample(*args, **kwargs) if self.contains(sample): if set_value: self._value = sample return sample if warn_after_s is not None and (time.time() - start) > warn_after_s: logging.warning("rejection sampling: rejection sampling is taking a while...") raise RuntimeError(f"rejection sampling: predicate not satisfied after {max_tries} draws")
[docs] class MultiDiscrete(spaces.MultiDiscrete): """Extended multi-discrete space with state tracking and constraint support. This class extends ``gymnasium.spaces.MultiDiscrete`` to add state management and optional constraint validation. It represents multiple discrete variables with potentially different ranges (nvec), where each variable maintains its own value and can be constrained. Attributes: init_value (np.ndarray): The initial values for all discrete variables. value (np.ndarray): The current values of all discrete variables. constrain_fn (callable): Optional function that returns True if the entire value array satisfies custom constraints. Example: Create a multi-discrete space for game difficulty settings:: import numpy as np space = MultiDiscrete( nvec=[5, 3, 10], # [enemy_count, speed_level, spawn_rate] init_value=np.array([2, 1, 5]), ) settings = space.sample() # Random difficulty configuration space.reset() # Resets to [2, 1, 5] (medium difficulty) Note: Constraints are applied to the entire array, not individual elements. Use a constraint function that validates the complete state. """
[docs] def __init__(self, *args, init_value=None, constrain_fn=None, **kwargs): """Initialize a MultiDiscrete space with state tracking. Args: *args: Positional arguments passed to gymnasium.spaces.MultiDiscrete. init_value (np.ndarray, optional): Initial values for the space. Must match the shape defined by nvec. Defaults to None. constrain_fn (callable, optional): Function that takes a numpy array and returns True if the values satisfy custom constraints. Defaults to None. **kwargs: Keyword arguments passed to gymnasium.spaces.MultiDiscrete. """ super().__init__(*args, **kwargs) self._init_value = init_value self.constrain_fn = constrain_fn or (lambda x: True) self._value = init_value
@property def init_value(self): """np.ndarray: The initial values of the space, returned by reset().""" return self._init_value @property def value(self): """np.ndarray: The current values of the space.""" return self._value
[docs] def reset(self): """Reset the space values to their initial values. Sets the current values back to the init_value specified during initialization. """ self._value = self.init_value
[docs] def contains(self, x): """Check if values are valid and satisfy constraints. Args: x (np.ndarray): The array of values to check. Returns: bool: True if x is within bounds for all elements and satisfies the constraint function, False otherwise. """ return super().contains(x) and self.constrain_fn(x)
[docs] def check(self): """Validate the current space values. Checks if the current values are within the space bounds and satisfy the constraint function. Logs a warning if the constraint fails. Returns: bool: True if the current values are valid, False otherwise. """ if not self.constrain_fn(self.value): logging.warning(f"MultiDiscrete: value {self.value} does not satisfy constrain_fn") return False return super().contains(self.value)
[docs] def sample(self, *args, max_tries=1000, warn_after_s=5.0, set_value=True, **kwargs): """Sample random values using rejection sampling for constraints. Repeatedly samples value arrays until one satisfies the constraint function or max_tries is reached. Optionally updates the space's current values. Args: *args: Positional arguments passed to gymnasium.spaces.MultiDiscrete.sample(). max_tries (int, optional): Maximum number of sampling attempts before raising an error. Defaults to 1000. warn_after_s (float, optional): Time threshold in seconds after which to log a warning about slow sampling. Set to None to disable. Defaults to 5.0. set_value (bool, optional): Whether to update the space's current values with the sampled values. Defaults to True. **kwargs: Keyword arguments passed to gymnasium.spaces.MultiDiscrete.sample(). Returns: np.ndarray: A sampled array that satisfies the constraint function. Raises: RuntimeError: If no valid sample is found after max_tries attempts. """ start = time.time() for i in range(max_tries): sample = super().sample(*args, **kwargs) if self.contains(sample): if set_value: self._value = sample return sample if warn_after_s is not None and (time.time() - start) > warn_after_s: logging.warning("rejection sampling: rejection sampling is taking a while...") raise RuntimeError(f"rejection sampling: predicate not satisfied after {max_tries} draws")
[docs] class Box(spaces.Box): """Extended continuous box space with state tracking and constraint support. This class extends ``gymnasium.spaces.Box`` to add state management and optional constraint validation. It represents bounded continuous values with configurable shape, dtype, and custom constraints. Attributes: init_value (np.ndarray): The initial value for the space. value (np.ndarray): The current value of the space. constrain_fn (callable): Optional function that returns True if a value satisfies custom constraints beyond the box boundaries. Example: Create a 2D position space constrained to a circle:: import numpy as np def in_circle(pos): return np.linalg.norm(pos) <= 1.0 space = Box( low=np.array([-1.0, -1.0]), high=np.array([1.0, 1.0]), init_value=np.array([0.0, 0.0]), constrain_fn=in_circle, ) position = space.sample() # Only samples within unit circle Note: The constraint function enables complex geometric or relational constraints beyond simple box boundaries. """
[docs] def __init__(self, *args, init_value=None, constrain_fn=None, **kwargs): """Initialize a Box space with state tracking. Args: *args: Positional arguments passed to gymnasium.spaces.Box. init_value (np.ndarray, optional): Initial value for the space. Must match the shape and dtype of the box. Defaults to None. constrain_fn (callable, optional): Function that takes a numpy array and returns True if the value satisfies custom constraints beyond the box boundaries. Defaults to None. **kwargs: Keyword arguments passed to gymnasium.spaces.Box. """ super().__init__(*args, **kwargs) self.constrain_fn = constrain_fn or (lambda x: True) self._init_value = init_value self._value = init_value
@property def init_value(self): """np.ndarray: The initial value of the space, returned by reset().""" return self._init_value @property def value(self): """np.ndarray: The current value of the space.""" return self._value
[docs] def reset(self): """Reset the space value to its initial value. Sets the current value back to the init_value specified during initialization. """ self._value = self.init_value
[docs] def contains(self, x): """Check if value is valid and satisfies constraints. Args: x (np.ndarray): The value to check. Returns: bool: True if x is within box bounds and satisfies the constraint function, False otherwise. """ return super().contains(x) and self.constrain_fn(x)
[docs] def check(self): """Validate the current space value. Checks if the current value is within the box bounds and satisfies the constraint function. Logs a warning if the constraint fails. Returns: bool: True if the current value is valid, False otherwise. """ if not self.constrain_fn(self.value): logging.warning(f"Box: value {self.value} does not satisfy constrain_fn") return False return self.contains(self.value)
[docs] def sample(self, *args, max_tries=1000, warn_after_s=5.0, set_value=True, **kwargs): """Sample a random value using rejection sampling for constraints. Repeatedly samples values until one satisfies the constraint function or max_tries is reached. Optionally updates the space's current value. Args: *args: Positional arguments passed to gymnasium.spaces.Box.sample(). max_tries (int, optional): Maximum number of sampling attempts before raising an error. Defaults to 1000. warn_after_s (float, optional): Time threshold in seconds after which to log a warning about slow sampling. Set to None to disable. Defaults to 5.0. set_value (bool, optional): Whether to update the space's current value with the sampled value. Defaults to True. **kwargs: Keyword arguments passed to gymnasium.spaces.Box.sample(). Returns: np.ndarray: A sampled array that satisfies the constraint function. Raises: RuntimeError: If no valid sample is found after max_tries attempts. """ start = time.time() for i in range(max_tries): sample = super().sample(*args, **kwargs) if self.contains(sample): if set_value: self._value = sample return sample if warn_after_s is not None and (time.time() - start) > warn_after_s: logging.warning("rejection sampling: rejection sampling is taking a while...") raise RuntimeError(f"rejection sampling: predicate not satisfied after {max_tries} draws")
[docs] class RGBBox(Box): """Specialized box space for RGB image data with automatic constraints. This class extends ``Box`` to provide a convenient space for RGB images, automatically enforcing uint8 dtype and [0, 255] value ranges. It validates that the shape includes exactly 3 channels for RGB data. Args: shape (tuple): Shape of the image. Must include a dimension of size 3 for the RGB channels. Common formats: (H, W, 3) or (3, H, W). init_value (np.ndarray, optional): Initial RGB image. Must match shape and be uint8 dtype. *args: Additional positional arguments passed to Box. **kwargs: Additional keyword arguments passed to Box. Attributes: init_value (np.ndarray): The initial RGB image. value (np.ndarray): The current RGB image. Example: Create a space for 64x64 RGB images:: import numpy as np space = RGBBox( shape=(64, 64, 3), init_value=np.zeros((64, 64, 3), dtype=np.uint8) ) image = space.sample() # Random RGB image space.reset() # Returns to black image Raises: AssertionError: If shape does not contain a dimension of size 3. Note: This space is useful for vision-based environments where images need to be sampled or tracked as part of environment configuration. The low, high, and dtype parameters are automatically set and cannot be overridden. """ def __init__(self, shape=(3,), *args, init_value=None, **kwargs): if not any(dim == 3 for dim in shape): raise ValueError("shape must have a channel of size 3") super().__init__( low=0, high=255, shape=shape, dtype="uint8", init_value=init_value, *args, **kwargs, )
[docs] class Dict(spaces.Dict): """Extended dictionary space with ordered sampling and nested support. This class extends ``gymnasium.spaces.Dict`` to add state management, constraint validation, and explicit sampling order control. It composes multiple spaces into a hierarchical structure where dependencies between variables can be handled through ordered sampling. Args: *args: Positional arguments passed to gymnasium.spaces.Dict. init_value (dict, optional): Initial values for the space. If None, derived from init_value of contained spaces. constrain_fn (callable, optional): Function that returns True if the complete dictionary satisfies custom constraints. sampling_order (list, optional): Explicit order for sampling keys. If None, uses insertion order. Missing keys are appended. **kwargs: Additional keyword arguments passed to Dict. Attributes: init_value (dict): Initial values for all contained spaces. value (dict): Current values of all contained spaces. constrain_fn (callable): Constraint validation function. sampling_order (set): Set of dotted paths for all variables in order. Example: Create a nested space with sampling order dependencies:: from stable_worldmodel import spaces import numpy as np config = spaces.Dict( { "difficulty": spaces.Discrete(n=3, init_value=0), "world": spaces.Dict( { "width": spaces.Discrete(n=100, init_value=50), "height": spaces.Discrete(n=100, init_value=50), } ), "player_pos": spaces.Box( low=np.array([0, 0]), high=np.array([99, 99]), init_value=np.array([25, 25]), ), }, sampling_order=["difficulty", "world", "player_pos"], ) # Sample respects order state = config.sample() Note: Sampling order is crucial when variables have dependencies. For example, sample world size before sampling positions within it. Nested Dict spaces recursively apply their own sampling orders. **Accessing values in constraint functions**: When implementing ``constrain_fn`` for Dict spaces, always use ``self.value['key']['key2']`` instead of ``self['key']['key2'].value``. The ``.value`` property recursively builds the complete value dictionary from the top level down, ensuring all nested values are up-to-date and correctly structured. Direct subspace access with ``.value`` only retrieves that specific subspace's value without the full context. Note that direct subspace access (e.g., ``self['key'].value``) is perfectly fine for regular operations outside of constraint functions, such as reading individual subspace values or debugging. The recommendation to use top-level ``.value`` applies specifically to constraint functions where you need the complete, consistent state of all nested spaces. Example of proper constraint function usage:: # Example: In a class with Dict space attribute class Environment: def __init__(self): self.config_space = spaces.Dict({...}) def validate_config(self): # ✓ CORRECT: Access via .value at top level values = self.config_space.value return values["player_pos"][0] < values["world"]["width"] def validate_wrong(self): # ✗ AVOID: Direct subspace access return ( self.config_space["player_pos"].value[0] < self.config_space["world"]["width"].value ) """
[docs] def __init__(self, *args, init_value=None, constrain_fn=None, sampling_order=None, **kwargs): """Initialize a Dict space with state tracking and sampling order. Args: *args: Positional arguments passed to gymnasium.spaces.Dict. init_value (dict, optional): Initial values for the space. If None, derived from init_value of contained spaces. Defaults to None. constrain_fn (callable, optional): Function that takes a dict and returns True if the complete dictionary satisfies custom constraints. Defaults to None. sampling_order (list, optional): Explicit order for sampling keys. If None, uses insertion order. Missing keys are appended with warning. Defaults to None. **kwargs: Keyword arguments passed to gymnasium.spaces.Dict. Raises: ValueError: If sampling_order contains keys not present in spaces. """ super().__init__(*args, **kwargs) self.constrain_fn = constrain_fn or (lambda x: True) self._init_value = init_value self._value = self.init_value # add missing keys if sampling_order is None: self._sampling_order = list(self.spaces.keys()) elif len(sampling_order) != len(self.spaces): missing_keys = set(self.spaces.keys()).difference(set(sampling_order)) logging.warning( f"Dict sampling_order is missing keys {missing_keys}, adding them at the end of the sampling order" ) self._sampling_order = list(sampling_order) + list(missing_keys) else: self._sampling_order = sampling_order if not all(key in self.spaces for key in self._sampling_order): missing = set(self._sampling_order) - set(self.spaces.keys()) raise ValueError(f"sampling_order contains keys not in spaces: {missing}")
@property def init_value(self): """dict: Initial values for all contained spaces. Constructs initial value dictionary from contained spaces' init_value properties. Falls back to sampling if a space lacks init_value. Returns: dict: Dictionary mapping space keys to their initial values. """ init_val = {} for k, v in self.spaces.items(): if hasattr(v, "init_value"): init_val[k] = v.init_value else: logging.warning( f"Space {k} of type {type(v)} does not have init_value property, using default sample instead" ) init_val[k] = v.sample() return init_val @property def value(self): """dict: Current values of all contained spaces. Constructs value dictionary from contained spaces' value properties. Returns: dict: Dictionary mapping space keys to their current values. Raises: ValueError: If a contained space does not have a value property. """ val = {} for k, v in self.spaces.items(): if hasattr(v, "value"): val[k] = v.value else: raise ValueError(f"Space {k} of type {type(v)} does not have value property") return val def _get_sampling_order(self, parts=None): """Yield dotted paths for nested Dict space respecting sampling order. Recursively generates dotted-path strings for all variables in this Dict space and any nested Dict spaces, honoring the explicit sampling order when available. Args: parts (tuple, optional): Parent path components for recursion. Defaults to empty tuple. Yields: str: Dotted path strings like 'parent.child.key' for each variable. """ if parts is None: parts = () # Prefer an explicit sampling order; otherwise preserve insertion order. keys = getattr(self, "_sampling_order", None) or self.spaces.keys() for key in keys: # Skip if the key isn't in the mapping (defensive against stale order lists). if key not in self.spaces: continue key_str = str(key) # ensure joinable path = parts + (key_str,) yield ".".join(path) subspace = self.spaces[key] if isinstance(subspace, spaces.Dict): # Recurse into nested Dict spaces yield from subspace._get_sampling_order(path) @property def sampling_order(self): """set: Set of dotted paths for all variables in sampling order. Returns: set: Set of strings representing dotted paths (e.g., 'parent.child.key') for all variables including nested Dict spaces. """ return list(self._get_sampling_order())
[docs] def reset(self): """Reset all contained spaces to their initial values. Calls reset() on all contained spaces that have a reset method, then sets this space's value to init_value. """ for v in self.spaces.values(): if hasattr(v, "reset"): v.reset() self._value = self.init_value
[docs] def contains(self, x) -> bool: """Check if value is a valid member of this space. Validates that x is a dictionary containing all required keys with values that satisfy each subspace's constraints and the overall constraint function. Args: x: The value to check. Returns: bool: True if x is a valid dict with all keys present, all values within subspace bounds, and satisfies the constraint function. False otherwise. """ if not isinstance(x, dict): return False for key in self.spaces.keys(): if key not in x: return False if not self.spaces[key].contains(x[key]): return False if not self.constrain_fn(x): return False return True
[docs] def check(self, debug=False): """Validate all contained spaces' current values. Checks each contained space using its check() method if available, or falls back to contains(value). Optionally logs warnings for failed checks. Args: debug (bool, optional): If True, logs warnings for spaces that fail validation. Defaults to False. Returns: bool: True if all contained spaces have valid values, False otherwise. """ for k, v in self.spaces.items(): if hasattr(v, "check"): if not v.check(): if debug: logging.warning(f"Dict: space {k} failed check()") return False return True
[docs] def names(self): """Return all space keys including nested ones. Returns: list: A list of all keys in the Dict space, with nested keys using dot notation. For example, a nested dict with key "a" containing subspace "b" would produce "a.b". """ def _key_generator(d, parent_key=""): for k, v in d.items(): new_key = f"{parent_key}.{k}" if parent_key else k if isinstance(v, spaces.Dict): yield from _key_generator(v.spaces, new_key) else: yield new_key return list(_key_generator(self.spaces))
[docs] def sample(self, *args, max_tries=1000, warn_after_s=5.0, set_value=True, **kwargs): """Sample a random element from the Dict space. Samples each subspace in the sampling order and ensures the result satisfies any constraint functions. Uses rejection sampling if constraints are present. Args: *args: Positional arguments passed to each subspace's sample method. max_tries (int, optional): Maximum number of rejection sampling attempts. Defaults to 1000. warn_after_s (float, optional): Issue a warning if sampling takes longer than this many seconds. Set to None to disable warnings. Defaults to 5.0. set_value (bool, optional): Whether to set the internal value to the sampled value. Defaults to True. **kwargs: Additional keyword arguments passed to each subspace's sample method. Returns: dict: A dictionary with keys matching the space definition and values sampled from their respective subspaces. Raises: RuntimeError: If a valid sample is not found within max_tries attempts. """ start = time.time() for i in range(max_tries): sample = {} for k in self._sampling_order: sample[k] = self.spaces[k].sample(*args, **kwargs, set_value=set_value) if self.contains(sample): if set_value: self._value = sample return sample if warn_after_s is not None and (time.time() - start) > warn_after_s: logging.warning("rejection sampling is taking a while...") raise RuntimeError(f"constrain_fn not satisfied after {max_tries} draws")
[docs] def update(self, keys): """Update specific keys in the Dict space by resampling them. Samples new values for the specified keys while maintaining the sampling order. Uses dot notation for nested keys (e.g., "a.b" for nested dict). Args: keys (container): A container (list, set, etc.) of key names to resample. Keys should use dot notation for nested spaces. Raises: ValueError: If a specified key is not found in the Dict space. AssertionError: If the updated values violate the space constraints. """ keys = set(keys) order = self.sampling_order if len(keys) == 1 and "all" in keys: self.sample() else: for v in filter(keys.__contains__, order): try: var_path = v.split(".") swm.utils.get_in(self, var_path).sample() except (KeyError, TypeError): raise ValueError(f"Key {v} not found in Dict space") assert self.check(debug=True), "Values must be within space!"