Source code for stable_worldmodel.wrappers

import re
import time
from collections.abc import Callable, Iterable

import gymnasium as gym
import numpy as np
from gymnasium.spaces.utils import is_space_dtype_shape_equiv
from gymnasium.vector import VectorWrapper
from gymnasium.vector.utils import (
    batch_differing_spaces,
    batch_space,
)

from stable_worldmodel.utils import get_in


[docs] class EnsureInfoKeysWrapper(gym.Wrapper): """Validates that required keys are present in the info dict after reset and step. Supports regex patterns for flexible key matching. Raises RuntimeError if any required pattern has no matching key. Args: env: The Gymnasium environment to wrap. required_keys: Iterable of regex patterns as strings. Each pattern must match at least one key in the info dict. Raises: RuntimeError: If any required pattern has no matching key in info dict. """ def __init__(self, env, required_keys: Iterable[str]): super().__init__(env) self._patterns: list[re.Pattern] = [] for k in required_keys: self._patterns.append(re.compile(k)) # else: # # exact match # self._patterns.append(re.compile(rf"^{re.escape(k)}$")) def _check(self, info: dict, where: str): keys = list(info.keys()) missing = [p.pattern for p in self._patterns if not any(p.fullmatch(k) for k in keys)] if missing: raise RuntimeError( f"{where}: required info keys missing (patterns with no match): {missing}. Present keys: {keys}" )
[docs] def step(self, action): obs, reward, terminated, truncated, info = self.env.step(action) self._check(info, "step()") return obs, reward, terminated, truncated, info
[docs] def reset(self, *args, **kwargs): obs, info = self.env.reset(*args, **kwargs) self._check(info, "reset()") return obs, info
[docs] class EnsureImageShape(gym.Wrapper): """Validates that an image in the info dict has the expected spatial dimensions. Args: env: The Gymnasium environment to wrap. image_key: Key in info dict containing the image to validate. image_shape: Expected (height, width) tuple for the image. Raises: RuntimeError: If the image shape doesn't match the expected dimensions. """ def __init__(self, env, image_key, image_shape): super().__init__(env) self.image_key = image_key self.image_shape = image_shape # (height, width)
[docs] def step(self, action): obs, reward, terminated, truncated, info = self.env.step(action) if info[self.image_key].shape[:-1] != self.image_shape: raise RuntimeError(f"Image shape {info[self.image_key].shape} should be {self.image_shape}") return obs, reward, terminated, truncated, info
[docs] def reset(self, *args, **kwargs): obs, info = self.env.reset(*args, **kwargs) if info[self.image_key].shape[:-1] != self.image_shape: raise RuntimeError(f"Image shape {info[self.image_key].shape} should be {self.image_shape}") return obs, info
[docs] class EnsureGoalInfoWrapper(gym.Wrapper): """Validates that 'goal' key is present in info dict during reset and/or step. Useful for goal-conditioned environments to ensure goal information is provided. Args: env: The Gymnasium environment to wrap. check_reset: If True, validates 'goal' key is in info after reset(). check_step: If True, validates 'goal' key is in info after step(). Raises: RuntimeError: If 'goal' key is missing when validation is enabled. """ def __init__(self, env, check_reset, check_step: bool = False): super().__init__(env) self.check_reset = check_reset self.check_step = check_step
[docs] def reset(self, *args, **kwargs): obs, info = self.env.reset(*args, **kwargs) if self.check_reset and "goal" not in info: raise RuntimeError("The info dict returned by reset() must contain the key 'goal'.") return obs, info
[docs] def step(self, action): obs, reward, terminated, truncated, info = self.env.step(action) if self.check_step and "goal" not in info: raise RuntimeError("The info dict returned by step() must contain the key 'goal'.") return obs, reward, terminated, truncated, info
[docs] class EverythingToInfoWrapper(gym.Wrapper): """Moves all transition information into the info dict for unified data access. Adds observation, reward, terminated, truncated, action, and step_idx to info. Optionally tracks environment variations when specified in reset options. Args: env: The Gymnasium environment to wrap. Info Keys Added: - observation (or dict keys if obs is dict): Current observation. - reward: Reward value (NaN after reset). - terminated: Episode termination flag. - truncated: Episode truncation flag. - action: Action taken (NaN sample after reset). - step_idx: Current step counter. - variation.{key}: Variation values if requested via reset options. Note: Pass options={"variation": ["key1", "key2"]} or ["all"] to reset() to track variations. """ def __init__(self, env): super().__init__(env) self._variations_watch = []
[docs] def reset(self, *args, **kwargs): self._step_counter = 0 obs, info = self.env.reset(*args, **kwargs) if type(obs) is not dict: _obs = {"observation": obs} else: _obs = obs for key in _obs: assert key not in info info[key] = _obs[key] assert "reward" not in info info["reward"] = np.nan assert "terminated" not in info info["terminated"] = False assert "truncated" not in info info["truncated"] = False assert "action" not in info info["action"] = self.env.action_space.sample() assert "step_idx" not in info info["step_idx"] = self._step_counter # add all variations to info if needed options = kwargs.get("options") or {} if "variation" in options: var_opt = options["variation"] assert isinstance(options["variation"], list | tuple), ( "variation option must be a list or tuple containing variation names to sample" ) if len(var_opt) == 1 and var_opt[0] == "all": self._variations_watch = self.env.unwrapped.variation_space.names() else: self._variations_watch = var_opt for key in self._variations_watch: var_key = f"variation.{key}" assert var_key not in info subvar_space = get_in(self.env.unwrapped.variation_space, key.split(".")) info[var_key] = subvar_space.value if type(info["action"]) is dict: raise NotImplementedError else: info["action"] *= np.nan return obs, info
[docs] def step(self, action): obs, reward, terminated, truncated, info = self.env.step(action) self._step_counter += 1 if type(obs) is not dict: _obs = {"observation": obs} else: _obs = obs for key in _obs: assert key not in info info[key] = _obs[key] assert "reward" not in info info["reward"] = reward assert "terminated" not in info info["terminated"] = bool(terminated) assert "truncated" not in info info["truncated"] = bool(truncated) assert "action" not in info info["action"] = action assert "step_idx" not in info info["step_idx"] = self._step_counter for key in self._variations_watch: var_key = f"variation.{key}" assert var_key not in info subvar_space = get_in(self.env.unwrapped.variation_space, key.split(".")) info[var_key] = subvar_space.value return obs, reward, terminated, truncated, info
[docs] class AddPixelsWrapper(gym.Wrapper): """Adds rendered environment pixels to info dict with optional resizing and transforms. Supports single images, dictionaries of images (multiview), or lists of images. Uses PIL for resizing and optional torchvision transforms. Args: env: The Gymnasium environment to wrap. pixels_shape: Target (height, width) for resized images. Defaults to (84, 84). torchvision_transform: Optional transform to apply to PIL images. Info Keys Added: - pixels: Rendered image (single view). - pixels.{key}: Individual images (multiview dict). - pixels.{idx}: Individual images (multiview list). - render_time: Time taken to render in seconds. """ def __init__( self, env, pixels_shape: tuple[int, int] = (84, 84), # (height, width) torchvision_transform: Callable | None = None, ): super().__init__(env) self.pixels_shape = pixels_shape self.torchvision_transform = torchvision_transform # For resizing, use PIL (required for torchvision transforms) from PIL import Image self.Image = Image def _get_pixels(self): # Render the environment as an RGB array render = getattr(self.env.unwrapped, "render_multiview", None) render = render if callable(render) else self.env.render t0 = time.time() img = render() t1 = time.time() def _process_img(img_array): # Convert to PIL Image for resizing pil_img = self.Image.fromarray(img_array) height, width = self.pixels_shape pil_img = pil_img.resize((width, height), self.Image.BILINEAR) # Optionally apply torchvision transform if self.torchvision_transform is not None: pixels = self.torchvision_transform(pil_img) else: pixels = np.array(pil_img) return pixels if isinstance(img, dict): pixels = {f"pixels.{k}": _process_img(v) for k, v in img.items()} elif isinstance(img, (list | tuple)): pixels = {f"pixels.{i}": _process_img(v) for i, v in enumerate(img)} else: pixels = {"pixels": _process_img(img)} return pixels, t1 - t0
[docs] def reset(self, *args, **kwargs): obs, info = self.env.reset(*args, **kwargs) pixels, info["render_time"] = self._get_pixels() info.update(pixels) return obs, info
[docs] def step(self, action): obs, reward, terminated, truncated, info = self.env.step(action) pixels, info["render_time"] = self._get_pixels() info.update(pixels) return obs, reward, terminated, truncated, info
[docs] class ResizeGoalWrapper(gym.Wrapper): """Resizes goal images in info dict with optional transforms. Applies PIL-based resizing and optional torchvision transforms to the 'goal' image in info dict during both reset and step. Args: env: The Gymnasium environment to wrap. pixels_shape: Target (height, width) for resized goal images. Defaults to (84, 84). torchvision_transform: Optional transform to apply to PIL goal images. """ def __init__( self, env, pixels_shape: tuple[int, int] = (84, 84), # (height, width) torchvision_transform: Callable | None = None, ): super().__init__(env) self.pixels_shape = pixels_shape self.torchvision_transform = torchvision_transform # For resizing, use PIL (required for torchvision transforms) from PIL import Image self.Image = Image def _format(self, img): # Convert to PIL Image for resizing pil_img = self.Image.fromarray(img) height, width = self.pixels_shape pil_img = pil_img.resize((width, height), self.Image.BILINEAR) # Optionally apply torchvision transform if self.torchvision_transform is not None: pixels = self.torchvision_transform(pil_img) else: pixels = np.array(pil_img) return pixels
[docs] def reset(self, *args, **kwargs): obs, info = self.env.reset(*args, **kwargs) info["goal"] = self._format(info["goal"]) return obs, info
[docs] def step(self, action): obs, reward, terminated, truncated, info = self.env.step(action) info["goal"] = self._format(info["goal"]) return obs, reward, terminated, truncated, info
[docs] class MegaWrapper(gym.Wrapper): """Combines multiple wrappers for comprehensive environment preprocessing. Applies in sequence: AddPixelsWrapper → EverythingToInfoWrapper → EnsureInfoKeysWrapper → EnsureGoalInfoWrapper → ResizeGoalWrapper. This provides a complete preprocessing pipeline with rendered pixels, unified info dict, key validation, goal checking, and goal resizing. Args: env: The Gymnasium environment to wrap. image_shape: Target (height, width) for pixels and goal. Defaults to (84, 84). pixels_transform: Optional torchvision transform for rendered pixels. goal_transform: Optional torchvision transform for goal images. required_keys: Additional regex patterns for keys that must be in info. Pattern ``^pixels(?:\\..*)?$`` is always added. separate_goal: If True, validates 'goal' is present in info. Defaults to True. """ def __init__( self, env, image_shape: tuple[int, int] = (84, 84), pixels_transform: Callable | None = None, goal_transform: Callable | None = None, required_keys: Iterable | None = None, separate_goal: Iterable | None = True, ): super().__init__(env) if required_keys is None: required_keys = [] required_keys.append(r"^pixels(?:\..*)?$") # this adds `pixels` key to info with optional transform env = AddPixelsWrapper(env, image_shape, pixels_transform) # this removes the info output, everything is in observation! env = EverythingToInfoWrapper(env) # check that necessary keys are in the observation env = EnsureInfoKeysWrapper(env, required_keys) # check goal is provided env = EnsureGoalInfoWrapper(env, check_reset=separate_goal, check_step=separate_goal) self.env = ResizeGoalWrapper(env, image_shape, goal_transform)
[docs] def reset(self, *args, **kwargs): return self.env.reset(*args, **kwargs)
[docs] def step(self, action): return self.env.step(action)
[docs] class VariationWrapper(VectorWrapper): """Manages variation spaces for vectorized environments. Handles batching of variation spaces across multiple environments, supporting either shared variations (same) or independent variations (different). Args: env: The vectorized Gymnasium environment to wrap. variation_mode: Mode for handling variations across environments: - "same": All environments share the same variation space (batched). - "different": Each environment has independent variation spaces. Raises: ValueError: If variation_mode is invalid or sub-environment spaces don't match. Note: Base environment must have a ``variation_space`` attribute. If missing, variation spaces are set to None. """ def __init__( self, env, variation_mode: str | gym.Space = "same", ): super().__init__(env) base_env = env.envs[0].unwrapped if not hasattr(base_env, "variation_space"): self.single_variation_space = None self.variation_space = None return if variation_mode == "same": self.single_variation_space = base_env.variation_space self.variation_space = batch_space(self.single_variation_space, self.num_envs) elif variation_mode == "different": self.single_variation_space = base_env.variation_space self.variation_space = batch_differing_spaces([sub_env.unwrapped.variation_space for sub_env in env.envs]) else: raise ValueError( f"Invalid `variation_mode`, expected: 'same' or 'different' or tuple of single and batch variation space, actual got {variation_mode}" ) # check sub-environment obs and action spaces for sub_env in env.envs: if variation_mode == "same": if not is_space_dtype_shape_equiv(sub_env.unwrapped.observation_space, self.single_observation_space): raise ValueError( f"VariationWrapper(..., variation_mode='same') however the sub-environments observation spaces do not share a common shape and dtype, single_observation_space={self.single_observation_space}, sub-environment observation_space={sub_env.observation_space}" ) else: if not is_space_dtype_shape_equiv(sub_env.unwrapped.observation_space, self.single_observation_space): raise ValueError( f"VariationWrapper(..., variation_mode='different' or custom space) however the sub-environments observation spaces do not share a common shape and dtype, single_observation_space={self.single_observation_space}, sub-environment observation_space={sub_env.observation_space}" ) @property def envs(self): return getattr(self.env, "envs", None)