stable_worldmodel package

Subpackages

Submodules

stable_worldmodel.cli module

Command-line interface for stable-worldmodel.

This module provides CLI commands for managing and inspecting stable-worldmodel resources including datasets, models, and world environments. The CLI uses the Typer framework with Rich formatting for an enhanced terminal experience.

Available commands:
  • list: List cached models, datasets, or worlds

  • show: Display detailed information about datasets or worlds

  • delete: Remove models or datasets from cache

The CLI can be invoked via:
  • stable-worldmodel <command>

  • swm <command>

  • python -m stable_worldmodel.cli <command>

Typical usage examples:

List all cached datasets:

$ swm list dataset

Show information about a specific world:

$ swm show world swm/SimplePointMaze-v0

Delete a cached dataset:

$ swm delete dataset my-dataset

Show version:

$ swm --version
common(version: ~types.Annotated[bool | None, <typer.models.OptionInfo object at 0x7fd45e97a610>] = None)[source]

Common options for all stable-worldmodel commands.

This callback provides global options that apply to all CLI commands. Currently only supports the –version flag.

Parameters:

version (Optional[bool], optional) – If True, display version and exit. Defaults to None.

delete(kind: ~typing.Annotated[str, <typer.models.ArgumentInfo object at 0x7fd45d8eba10>], names: ~typing.Annotated[list[str], <typer.models.ArgumentInfo object at 0x7fd45d8eba90>])[source]

Delete models or datasets from cache directory.

Permanently removes specified models or datasets from the local cache directory. Requires user confirmation before deletion. This operation cannot be undone.

Parameters:
  • kind (str) – Type of resource to delete. Must be either: - ‘model’: Delete cached world models - ‘dataset’: Delete cached datasets

  • names (List[str]) – One or more names of resources to delete. All specified names must exist in cache.

Raises:

typer.Abort – If kind is invalid, specified names not found in cache, or user cancels the confirmation prompt.

Example

Delete a single dataset:

$ swm delete dataset my-dataset

Delete multiple models:

$ swm delete model model1 model2

Warning

This operation permanently deletes data and cannot be undone. The command will prompt for confirmation before proceeding.

display_dataset_info(info: dict[str, Any]) None[source]

Display dataset information in a formatted panel with tables.

Prints a Rich panel showing dataset metadata including columns, episode count, shapes, and variation information in a two-table layout.

Parameters:

info (Dict[str, Any]) – Dataset metadata dictionary containing keys: - ‘name’: Dataset name - ‘columns’: List of data columns - ‘num_episodes’: Total number of episodes - ‘num_steps’: Total number of steps - ‘obs_shape’: Observation tensor shape - ‘action_shape’: Action tensor shape - ‘goal_shape’: Goal tensor shape - ‘variation’: Variation space structure

Example

>>> from stable_worldmodel import data
>>> info = data.dataset_info("simple-pointmaze")
>>> display_dataset_info(info)
display_world_info(info: dict[str, Any]) None[source]

Display world environment information in a formatted panel.

Prints a Rich panel showing the world’s observation space, action space, and variation space in a hierarchical, color-coded format.

Parameters:

info (Dict[str, Any]) – World metadata dictionary containing keys: - ‘name’: World environment name - ‘observation_space’: Observation space structure - ‘action_space’: Action space structure - ‘variation’: Variation space structure

Example

>>> from stable_worldmodel import data
>>> info = data.world_info("swm/SimplePointMaze-v0")
>>> display_world_info(info)
list_cmd(kind: ~typing.Annotated[str, <typer.models.ArgumentInfo object at 0x7fd45e075490>])[source]

List cached stable-worldmodel resources.

Displays a table of all cached models, datasets, or worlds stored in the stable-worldmodel cache directory. Useful for seeing what resources are available locally.

Parameters:

kind (str) – Type of resource to list. Must be one of: - ‘model’: List cached world models - ‘dataset’: List cached datasets - ‘world’: List registered world environments

Raises:

typer.Abort – If kind is not ‘model’, ‘dataset’, or ‘world’.

Example

List all cached datasets:

$ swm list dataset

List all available worlds:

$ swm list world
show(kind: ~typing.Annotated[str, <typer.models.ArgumentInfo object at 0x7fd45e979650>], names: ~types.Annotated[list[str] | None, <typer.models.ArgumentInfo object at 0x7fd45d8a4f50>] = None, all: ~typing.Annotated[bool, <typer.models.OptionInfo object at 0x7fd45d8ebdd0>] = False)[source]

Show detailed information about datasets or worlds.

Displays comprehensive information about specified datasets or world environments including their structure, spaces, and variations. Information is presented in formatted panels with hierarchical trees and tables.

Parameters:
  • kind (str) – Type of resource to show. Must be either: - ‘dataset’: Show dataset information - ‘world’: Show world environment information

  • names (Optional[List[str]], optional) – Specific names to display. Can provide multiple names. Defaults to None.

  • all (bool, optional) – If True, show information for all cached resources of the specified kind. Defaults to False.

Raises:

typer.Abort – If kind is invalid, no names provided without –all flag, or if specified names are not found in cache.

Example

Show a specific dataset:

$ swm show dataset simple-pointmaze

Show multiple datasets:

$ swm show dataset dataset1 dataset2

Show all cached datasets:

$ swm show dataset --all

Show a specific world:

$ swm show world swm/SimplePointMaze-v0

stable_worldmodel.data module

Data management and dataset utilities for Stable World Model.

This module provides utilities for managing datasets, models, and world information. It includes functionality for loading multi-step trajectories, querying cached data, and retrieving metadata about environments.

The module supports:
  • Multi-step trajectory datasets with frame skipping

  • Dataset and model cache management

  • World environment introspection

  • Gymnasium space metadata extraction

class SpaceInfo[source]

Bases: TypedDict

Type specification for Gymnasium space metadata.

Variables:
  • shape (tuple[int, ...]) – Dimensions of the space.

  • type (str) – Class name of the space (e.g., ‘Box’, ‘Discrete’).

  • dtype (str) – Data type of the space elements.

  • low (Any) – Lower bounds for Box spaces.

  • high (Any) – Upper bounds for Box spaces.

  • n (int) – Number of discrete values for Discrete spaces.

dtype: str
high: Any
low: Any
n: int
shape: tuple[int, ...]
type: str
class StepsDataset(path, *args, num_steps=2, frameskip=1, cache_dir=None, **kwargs)[source]

Bases: HFDataset

Dataset for loading multi-step trajectory sequences.

This dataset loads sequences of consecutive steps from episodic data, supporting frame skipping and automatic image loading from file paths. Inherits from stable_pretraining’s HFDataset.

Variables:
  • data_dir (Path) – Directory containing the dataset.

  • num_steps (int) – Number of steps per sequence.

  • frameskip (int) – Number of steps to skip between frames.

  • episodes (np.ndarray) – Array of unique episode indices.

  • episode_slices (dict) – Mapping from episode index to dataset indices.

  • cum_slices (np.ndarray) – Cumulative sum of valid samples per episode.

  • idx_to_ep (np.ndarray) – Mapping from sample index to episode.

  • img_cols (set) – Set of column names containing image paths.

__init__(path, *args, num_steps=2, frameskip=1, cache_dir=None, **kwargs)[source]

Initialize the StepsDataset.

Parameters:
  • path (str) – Name or path of the dataset within cache directory.

  • *args – Additional arguments passed to parent class.

  • num_steps (int, optional) – Number of consecutive steps per sample. Defaults to 2.

  • frameskip (int, optional) – Number of steps between sampled frames. Defaults to 1.

  • cache_dir (str, optional) – Cache directory path. Defaults to None (uses default).

  • **kwargs – Additional keyword arguments passed to parent class.

Raises:
  • AssertionError – If required columns are missing from dataset.

  • ValueError – If episodes are too short for the requested num_steps and frameskip.

get_episode_slice(episode_idx, episode_indices)[source]

Get dataset indices for a specific episode.

Parameters:
  • episode_idx (int) – Episode index to retrieve.

  • episode_indices (array-like) – Array of episode indices for all steps.

Returns:

Indices of steps belonging to the specified episode.

Return type:

np.ndarray

Raises:

ValueError – If episode is too short for num_steps and frameskip.

infer_img_path_columns()[source]

Infer which dataset columns contain image file paths.

Checks the first dataset element to identify string columns with common image file extensions.

Returns:

Set of column names containing image file paths.

Return type:

set

class VariationInfo[source]

Bases: TypedDict

Type specification for environment variation metadata.

Variables:
  • has_variation (bool) – Whether the environment supports variations.

  • type (str | None) – Class name of the variation space if it exists.

  • names (list[str] | None) – List of variation parameter names.

has_variation: bool
names: list[str] | None
type: str | None
class WorldInfo[source]

Bases: TypedDict

Type specification for world environment information.

Variables:
action_space: SpaceInfo
config: dict[str, Any]
name: str
observation_space: SpaceInfo
variation: VariationInfo
dataset_info(name)[source]

Get metadata about a cached dataset.

Parameters:

name (str) – Name of the dataset.

Returns:

Dictionary containing dataset metadata including:
  • name: Dataset name

  • num_episodes: Number of unique episodes

  • num_steps: Total number of steps

  • columns: List of column names

  • obs_shape: Shape of observation images

  • action_shape: Shape of action vectors

  • goal_shape: Shape of goal images

  • variation: Dict with variation information

Return type:

dict

Raises:
  • ValueError – If dataset is not found in cache.

  • AssertionError – If required columns are missing.

delete_dataset(name)[source]

Delete a cached dataset and its associated files.

Parameters:

name (str) – Name of the dataset to delete.

Note

Prints success or error messages to console.

delete_model(name)[source]

Delete cached model checkpoint files.

Removes all checkpoint files (weights and object files) matching the given model name.

Parameters:

name (str) – Name of the model to delete.

Note

Prints success or error messages to console for each file deleted.

get_cache_dir() Path[source]

Get the cache directory for stable_worldmodel data.

The cache directory can be customized via the STABLEWM_HOME environment variable. If not set, defaults to ~/.stable_worldmodel.

Returns:

Path to the cache directory. Directory is created if it doesn’t exist.

Return type:

Path

is_image(x)[source]

Check if input is a valid image array.

Parameters:

x – Input to check.

Returns:

True if x is a uint8 numpy array with shape (H, W, C) where

C is 1 (grayscale), 3 (RGB), or 4 (RGBA).

Return type:

bool

list_datasets()[source]

List all cached datasets.

Returns:

Names of all dataset directories in the cache.

Return type:

list[str]

list_models()[source]

List all cached model checkpoints.

Searches for files matching the pattern <name>_weights*.ckpt or <name>_object.ckpt and returns unique model names.

Returns:

Sorted list of model names found in cache.

Return type:

list[str]

list_worlds()[source]

List all registered world environments.

Returns:

Sorted list of world environment IDs.

Return type:

list[str]

world_info(name: str, *, image_shape: tuple[int, int] = (224, 224), render_mode: str = 'rgb_array') WorldInfo[source]

Get metadata about a world environment.

Creates a temporary world instance to extract observation space, action space, and variation information. Results are cached for efficiency.

Parameters:
  • name – ID of the world environment.

  • image_shape – Desired image shape for rendering. Defaults to (224, 224).

  • render_mode – Rendering mode for the environment. Defaults to “rgb_array”.

Returns:

Dictionary containing world metadata including spaces and variations.

Return type:

WorldInfo

Raises:

ValueError – If world name is not registered.

stable_worldmodel.policy module

AutoCostModel(model_name, cache_dir=None)[source]
class BasePolicy(**kwargs)[source]

Bases: object

Base class for agent policies.

get_action(obs, **kwargs)[source]

Get action from the policy given the observation.

set_env(env)[source]
class ExpertPolicy(**kwargs)[source]

Bases: BasePolicy

Expert Policy.

get_action(obs, goal_obs, **kwargs)[source]

Get action from the policy given the observation.

class PlanConfig(horizon: int, receding_horizon: int, history_len: int = 1, action_block: int = 1, warm_start: bool = True)[source]

Bases: object

Configuration for the planning process.

action_block: int = 1
history_len: int = 1
horizon: int
property plan_len
receding_horizon: int
warm_start: bool = True
class RandomPolicy(seed=None, **kwargs)[source]

Bases: BasePolicy

Random Policy.

get_action(obs, **kwargs)[source]

Get action from the policy given the observation.

set_seed(seed)[source]
class Transformable(*args, **kwargs)[source]

Bases: Protocol

Protocol for input transformation.

inverse_transform() Tensor[source]

Revert pre-processed

transform() Tensor[source]

Pre-process

class WorldModelPolicy(solver: Solver, config: PlanConfig, process: dict[str, Transformable] | None = None, transform: dict[str, callable] | None = None, **kwargs)[source]

Bases: BasePolicy

World Model Policy using a planning solver.

property flatten_receding_horizon
get_action(info_dict, **kwargs)[source]

Get action from the policy given the observation.

set_env(env)[source]

stable_worldmodel.spaces module

Extended Gymnasium spaces with state tracking and constraint support.

class Box(*args, init_value=None, constrain_fn=None, **kwargs)[source]

Bases: Box

Extended continuous box space with state tracking and constraint support.

This class extends gymnasium.spaces.Box to add state management and optional constraint validation. It represents bounded continuous values with configurable shape, dtype, and custom constraints.

Variables:
  • init_value (np.ndarray) – The initial value for the space.

  • value (np.ndarray) – The current value of the space.

  • constrain_fn (callable) – Optional function that returns True if a value satisfies custom constraints beyond the box boundaries.

Example

Create a 2D position space constrained to a circle:

import numpy as np


def in_circle(pos):
    return np.linalg.norm(pos) <= 1.0


space = Box(
    low=np.array([-1.0, -1.0]),
    high=np.array([1.0, 1.0]),
    init_value=np.array([0.0, 0.0]),
    constrain_fn=in_circle,
)
position = space.sample()  # Only samples within unit circle

Note

The constraint function enables complex geometric or relational constraints beyond simple box boundaries.

__init__(*args, init_value=None, constrain_fn=None, **kwargs)[source]

Initialize a Box space with state tracking.

Parameters:
  • *args – Positional arguments passed to gymnasium.spaces.Box.

  • init_value (np.ndarray, optional) – Initial value for the space. Must match the shape and dtype of the box. Defaults to None.

  • constrain_fn (callable, optional) – Function that takes a numpy array and returns True if the value satisfies custom constraints beyond the box boundaries. Defaults to None.

  • **kwargs – Keyword arguments passed to gymnasium.spaces.Box.

check()[source]

Validate the current space value.

Checks if the current value is within the box bounds and satisfies the constraint function. Logs a warning if the constraint fails.

Returns:

True if the current value is valid, False otherwise.

Return type:

bool

contains(x)[source]

Check if value is valid and satisfies constraints.

Parameters:

x (np.ndarray) – The value to check.

Returns:

True if x is within box bounds and satisfies the constraint

function, False otherwise.

Return type:

bool

property init_value

The initial value of the space, returned by reset().

Type:

np.ndarray

reset()[source]

Reset the space value to its initial value.

Sets the current value back to the init_value specified during initialization.

sample(*args, max_tries=1000, warn_after_s=5.0, set_value=True, **kwargs)[source]

Sample a random value using rejection sampling for constraints.

Repeatedly samples values until one satisfies the constraint function or max_tries is reached. Optionally updates the space’s current value.

Parameters:
  • *args – Positional arguments passed to gymnasium.spaces.Box.sample().

  • max_tries (int, optional) – Maximum number of sampling attempts before raising an error. Defaults to 1000.

  • warn_after_s (float, optional) – Time threshold in seconds after which to log a warning about slow sampling. Set to None to disable. Defaults to 5.0.

  • set_value (bool, optional) – Whether to update the space’s current value with the sampled value. Defaults to True.

  • **kwargs – Keyword arguments passed to gymnasium.spaces.Box.sample().

Returns:

A sampled array that satisfies the constraint function.

Return type:

np.ndarray

Raises:

RuntimeError – If no valid sample is found after max_tries attempts.

property value

The current value of the space.

Type:

np.ndarray

class Dict(*args, init_value=None, constrain_fn=None, sampling_order=None, **kwargs)[source]

Bases: Dict

Extended dictionary space with ordered sampling and nested support.

This class extends gymnasium.spaces.Dict to add state management, constraint validation, and explicit sampling order control. It composes multiple spaces into a hierarchical structure where dependencies between variables can be handled through ordered sampling.

Parameters:
  • *args – Positional arguments passed to gymnasium.spaces.Dict.

  • init_value (dict, optional) – Initial values for the space. If None, derived from init_value of contained spaces.

  • constrain_fn (callable, optional) – Function that returns True if the complete dictionary satisfies custom constraints.

  • sampling_order (list, optional) – Explicit order for sampling keys. If None, uses insertion order. Missing keys are appended.

  • **kwargs – Additional keyword arguments passed to Dict.

Variables:
  • init_value (dict) – Initial values for all contained spaces.

  • value (dict) – Current values of all contained spaces.

  • constrain_fn (callable) – Constraint validation function.

  • sampling_order (set) – Set of dotted paths for all variables in order.

Example

Create a nested space with sampling order dependencies:

from stable_worldmodel import spaces
import numpy as np

config = spaces.Dict(
    {
        "difficulty": spaces.Discrete(n=3, init_value=0),
        "world": spaces.Dict(
            {
                "width": spaces.Discrete(n=100, init_value=50),
                "height": spaces.Discrete(n=100, init_value=50),
            }
        ),
        "player_pos": spaces.Box(
            low=np.array([0, 0]),
            high=np.array([99, 99]),
            init_value=np.array([25, 25]),
        ),
    },
    sampling_order=["difficulty", "world", "player_pos"],
)

# Sample respects order
state = config.sample()

Note

Sampling order is crucial when variables have dependencies. For example, sample world size before sampling positions within it. Nested Dict spaces recursively apply their own sampling orders.

Accessing values in constraint functions: When implementing constrain_fn for Dict spaces, always use self.value['key']['key2'] instead of self['key']['key2'].value. The .value property recursively builds the complete value dictionary from the top level down, ensuring all nested values are up-to-date and correctly structured. Direct subspace access with .value only retrieves that specific subspace’s value without the full context.

Note that direct subspace access (e.g., self['key'].value) is perfectly fine for regular operations outside of constraint functions, such as reading individual subspace values or debugging. The recommendation to use top-level .value applies specifically to constraint functions where you need the complete, consistent state of all nested spaces.

Example of proper constraint function usage:

# Example: In a class with Dict space attribute
class Environment:
    def __init__(self):
        self.config_space = spaces.Dict({...})

    def validate_config(self):
        # ✓ CORRECT: Access via .value at top level
        values = self.config_space.value
        return values["player_pos"][0] < values["world"]["width"]

    def validate_wrong(self):
        # ✗ AVOID: Direct subspace access
        return (
            self.config_space["player_pos"].value[0]
            < self.config_space["world"]["width"].value
        )
__init__(*args, init_value=None, constrain_fn=None, sampling_order=None, **kwargs)[source]

Initialize a Dict space with state tracking and sampling order.

Parameters:
  • *args – Positional arguments passed to gymnasium.spaces.Dict.

  • init_value (dict, optional) – Initial values for the space. If None, derived from init_value of contained spaces. Defaults to None.

  • constrain_fn (callable, optional) – Function that takes a dict and returns True if the complete dictionary satisfies custom constraints. Defaults to None.

  • sampling_order (list, optional) – Explicit order for sampling keys. If None, uses insertion order. Missing keys are appended with warning. Defaults to None.

  • **kwargs – Keyword arguments passed to gymnasium.spaces.Dict.

Raises:

ValueError – If sampling_order contains keys not present in spaces.

check(debug=False)[source]

Validate all contained spaces’ current values.

Checks each contained space using its check() method if available, or falls back to contains(value). Optionally logs warnings for failed checks.

Parameters:

debug (bool, optional) – If True, logs warnings for spaces that fail validation. Defaults to False.

Returns:

True if all contained spaces have valid values, False otherwise.

Return type:

bool

contains(x) bool[source]

Check if value is a valid member of this space.

Validates that x is a dictionary containing all required keys with values that satisfy each subspace’s constraints and the overall constraint function.

Parameters:

x – The value to check.

Returns:

True if x is a valid dict with all keys present, all values

within subspace bounds, and satisfies the constraint function. False otherwise.

Return type:

bool

property init_value

Initial values for all contained spaces.

Constructs initial value dictionary from contained spaces’ init_value properties. Falls back to sampling if a space lacks init_value.

Returns:

Dictionary mapping space keys to their initial values.

Return type:

dict

Type:

dict

names()[source]

Return all space keys including nested ones.

Returns:

A list of all keys in the Dict space, with nested keys using dot notation.

For example, a nested dict with key “a” containing subspace “b” would produce “a.b”.

Return type:

list

reset()[source]

Reset all contained spaces to their initial values.

Calls reset() on all contained spaces that have a reset method, then sets this space’s value to init_value.

sample(*args, max_tries=1000, warn_after_s=5.0, set_value=True, **kwargs)[source]

Sample a random element from the Dict space.

Samples each subspace in the sampling order and ensures the result satisfies any constraint functions. Uses rejection sampling if constraints are present.

Parameters:
  • *args – Positional arguments passed to each subspace’s sample method.

  • max_tries (int, optional) – Maximum number of rejection sampling attempts. Defaults to 1000.

  • warn_after_s (float, optional) – Issue a warning if sampling takes longer than this many seconds. Set to None to disable warnings. Defaults to 5.0.

  • set_value (bool, optional) – Whether to set the internal value to the sampled value. Defaults to True.

  • **kwargs – Additional keyword arguments passed to each subspace’s sample method.

Returns:

A dictionary with keys matching the space definition and values sampled

from their respective subspaces.

Return type:

dict

Raises:

RuntimeError – If a valid sample is not found within max_tries attempts.

property sampling_order

Set of dotted paths for all variables in sampling order.

Returns:

Set of strings representing dotted paths (e.g., ‘parent.child.key’)

for all variables including nested Dict spaces.

Return type:

set

Type:

set

update(keys)[source]

Update specific keys in the Dict space by resampling them.

Samples new values for the specified keys while maintaining the sampling order. Uses dot notation for nested keys (e.g., “a.b” for nested dict).

Parameters:

keys (container) – A container (list, set, etc.) of key names to resample. Keys should use dot notation for nested spaces.

Raises:
  • ValueError – If a specified key is not found in the Dict space.

  • AssertionError – If the updated values violate the space constraints.

property value

Current values of all contained spaces.

Constructs value dictionary from contained spaces’ value properties.

Returns:

Dictionary mapping space keys to their current values.

Return type:

dict

Raises:

ValueError – If a contained space does not have a value property.

Type:

dict

class Discrete(*args, init_value=None, constrain_fn=None, **kwargs)[source]

Bases: Discrete

Extended discrete space with state tracking and constraint support.

This class extends gymnasium.spaces.Discrete to add state management and optional constraint validation. Unlike the standard discrete space, this version maintains a current value and supports rejection sampling via a custom constraint function.

Variables:
  • init_value (int) – The initial value for the space.

  • value (int) – The current value of the space.

  • constrain_fn (callable) – Optional function that returns True if a value satisfies custom constraints.

Example

Create a discrete space that only accepts even numbers:

space = Discrete(n=10, init_value=0, constrain_fn=lambda x: x % 2 == 0)
value = space.sample()  # Samples even number and updates space.value
space.reset()  # Resets space.value back to 0 (init_value)

Note

The sample() method uses rejection sampling when a constraint function is provided, which may impact performance for difficult constraints.

__init__(*args, init_value=None, constrain_fn=None, **kwargs)[source]

Initialize a Discrete space with state tracking.

Parameters:
  • *args – Positional arguments passed to gymnasium.spaces.Discrete.

  • init_value (int, optional) – Initial value for the space. Defaults to None.

  • constrain_fn (callable, optional) – Function that takes an int and returns True if the value satisfies custom constraints. Defaults to None.

  • **kwargs – Keyword arguments passed to gymnasium.spaces.Discrete.

check()[source]

Validate the current space value.

Checks if the current value is within the space bounds and satisfies the constraint function. Logs a warning if the constraint fails.

Returns:

True if the current value is valid, False otherwise.

Return type:

bool

contains(x)[source]

Check if value is valid and satisfies constraints.

Parameters:

x (int) – The value to check.

Returns:

True if x is within bounds and satisfies the constraint

function, False otherwise.

Return type:

bool

property init_value

The initial value of the space, returned by reset().

Type:

int

reset()[source]

Reset the space value to its initial value.

Sets the current value back to the init_value specified during initialization.

sample(*args, max_tries=1000, warn_after_s=5.0, set_value=True, **kwargs)[source]

Sample a random value using rejection sampling for constraints.

Repeatedly samples values until one satisfies the constraint function or max_tries is reached. Optionally updates the space’s current value.

Parameters:
  • *args – Positional arguments passed to gymnasium.spaces.Discrete.sample().

  • max_tries (int, optional) – Maximum number of sampling attempts before raising an error. Defaults to 1000.

  • warn_after_s (float, optional) – Time threshold in seconds after which to log a warning about slow sampling. Set to None to disable. Defaults to 5.0.

  • set_value (bool, optional) – Whether to update the space’s current value with the sampled value. Defaults to True.

  • **kwargs – Keyword arguments passed to gymnasium.spaces.Discrete.sample().

Returns:

A sampled value that satisfies the constraint function.

Return type:

int

Raises:

RuntimeError – If no valid sample is found after max_tries attempts.

property value

The current value of the space.

Type:

int

class MultiDiscrete(*args, init_value=None, constrain_fn=None, **kwargs)[source]

Bases: MultiDiscrete

Extended multi-discrete space with state tracking and constraint support.

This class extends gymnasium.spaces.MultiDiscrete to add state management and optional constraint validation. It represents multiple discrete variables with potentially different ranges (nvec), where each variable maintains its own value and can be constrained.

Variables:
  • init_value (np.ndarray) – The initial values for all discrete variables.

  • value (np.ndarray) – The current values of all discrete variables.

  • constrain_fn (callable) – Optional function that returns True if the entire value array satisfies custom constraints.

Example

Create a multi-discrete space for game difficulty settings:

import numpy as np

space = MultiDiscrete(
    nvec=[5, 3, 10],  # [enemy_count, speed_level, spawn_rate]
    init_value=np.array([2, 1, 5]),
)
settings = space.sample()  # Random difficulty configuration
space.reset()  # Resets to [2, 1, 5] (medium difficulty)

Note

Constraints are applied to the entire array, not individual elements. Use a constraint function that validates the complete state.

__init__(*args, init_value=None, constrain_fn=None, **kwargs)[source]

Initialize a MultiDiscrete space with state tracking.

Parameters:
  • *args – Positional arguments passed to gymnasium.spaces.MultiDiscrete.

  • init_value (np.ndarray, optional) – Initial values for the space. Must match the shape defined by nvec. Defaults to None.

  • constrain_fn (callable, optional) – Function that takes a numpy array and returns True if the values satisfy custom constraints. Defaults to None.

  • **kwargs – Keyword arguments passed to gymnasium.spaces.MultiDiscrete.

check()[source]

Validate the current space values.

Checks if the current values are within the space bounds and satisfy the constraint function. Logs a warning if the constraint fails.

Returns:

True if the current values are valid, False otherwise.

Return type:

bool

contains(x)[source]

Check if values are valid and satisfy constraints.

Parameters:

x (np.ndarray) – The array of values to check.

Returns:

True if x is within bounds for all elements and satisfies

the constraint function, False otherwise.

Return type:

bool

property init_value

The initial values of the space, returned by reset().

Type:

np.ndarray

reset()[source]

Reset the space values to their initial values.

Sets the current values back to the init_value specified during initialization.

sample(*args, max_tries=1000, warn_after_s=5.0, set_value=True, **kwargs)[source]

Sample random values using rejection sampling for constraints.

Repeatedly samples value arrays until one satisfies the constraint function or max_tries is reached. Optionally updates the space’s current values.

Parameters:
  • *args – Positional arguments passed to gymnasium.spaces.MultiDiscrete.sample().

  • max_tries (int, optional) – Maximum number of sampling attempts before raising an error. Defaults to 1000.

  • warn_after_s (float, optional) – Time threshold in seconds after which to log a warning about slow sampling. Set to None to disable. Defaults to 5.0.

  • set_value (bool, optional) – Whether to update the space’s current values with the sampled values. Defaults to True.

  • **kwargs – Keyword arguments passed to gymnasium.spaces.MultiDiscrete.sample().

Returns:

A sampled array that satisfies the constraint function.

Return type:

np.ndarray

Raises:

RuntimeError – If no valid sample is found after max_tries attempts.

property value

The current values of the space.

Type:

np.ndarray

class RGBBox(shape=(3,), *args, init_value=None, **kwargs)[source]

Bases: Box

Specialized box space for RGB image data with automatic constraints.

This class extends Box to provide a convenient space for RGB images, automatically enforcing uint8 dtype and [0, 255] value ranges. It validates that the shape includes exactly 3 channels for RGB data.

Parameters:
  • shape (tuple) – Shape of the image. Must include a dimension of size 3 for the RGB channels. Common formats: (H, W, 3) or (3, H, W).

  • init_value (np.ndarray, optional) – Initial RGB image. Must match shape and be uint8 dtype.

  • *args – Additional positional arguments passed to Box.

  • **kwargs – Additional keyword arguments passed to Box.

Variables:
  • init_value (np.ndarray) – The initial RGB image.

  • value (np.ndarray) – The current RGB image.

Example

Create a space for 64x64 RGB images:

import numpy as np

space = RGBBox(
    shape=(64, 64, 3), init_value=np.zeros((64, 64, 3), dtype=np.uint8)
)
image = space.sample()  # Random RGB image
space.reset()  # Returns to black image
Raises:

AssertionError – If shape does not contain a dimension of size 3.

Note

This space is useful for vision-based environments where images need to be sampled or tracked as part of environment configuration. The low, high, and dtype parameters are automatically set and cannot be overridden.

stable_worldmodel.utils module

Utility functions for stable_worldmodel.

flatten_dict(d, parent_key='', sep='.')[source]

Flatten a nested dictionary into a single-level dictionary with concatenated keys.

The naming convention for the new keys is similar to Hydra’s, using a . separator to denote levels of nesting. Attention is needed when flattening dictionaries with overlapping keys, as this may lead to information loss.

Parameters:
  • d (dict) – The nested dictionary to flatten.

  • parent_key (str, optional) – The base key to use for the flattened keys.

  • sep (str, optional) – The separator to use between levels of nesting. Defaults to ‘.’.

Returns:

A flattened version of the input dictionary.

Return type:

dict

Examples

>>> info = {"a": {"b": {"c": 42, "d": 43}}, "e": 44}
>>> flatten_dict(info)
{'a.b.c': 42, 'a.b.d': 43, 'e': 44}
>>> flatten_dict({"a": {"b": 2}, "a.b": 3})
{'a.b': 3}
get_in(mapping: dict, path: Iterable[str]) Any[source]

Retrieve a value from a nested dictionary using a sequence of keys.

Parameters:
  • mapping (dict) – A nested dictionary.

  • path (Iterable[str]) – An iterable of keys representing the path to the desired value in mapping.

Returns:

The value located at the specified path in the nested dictionary.

Return type:

Any

Raises:

KeyError – If any key in the path does not exist in the mapping dict.

Examples

>>> variations = {"a": {"b": {"c": 42}}}
>>> get_in(variations, ["a", "b", "c"])
42
pretraining(script_path: str, dataset_name: str, output_model_name: str, dump_object: bool = True, args: str = '') int[source]

Run a pretraining script as a subprocess with optional command-line arguments.

This function checks if the specified script exists, constructs a command to run it with the provided arguments, and executes the command in a subprocess.

Parameters:
  • script_path (str) – The path to the pretraining script to be executed.

  • dataset_name (str) – The name of the dataset to be used in pretraining.

  • output_model_name (str) – The name to save the output model.

  • dump_object (bool, optional) – Whether to dump the model object after training. Defaults to

  • args (str, optional) – A string of command-line arguments to pass to the script. Defaults to an empty string.

Returns:

The return code of the subprocess. A return code of 0 indicates success.

Return type:

int

Raises:
  • ValueError – If the specified script does not exist.

  • SystemExit – If the subprocess exits with a non-zero return code.

stable_worldmodel.world module

World environment manager.

This module provides the World class, a high-level manager for vectorized Gymnasium environments with integrated support for data collection, policy evaluation, video recording, and dataset management. It serves as the central orchestration layer for training and evaluating world models with domain randomization support.

The World class handles:
  • Vectorized environment creation and management

  • Policy attachment and execution

  • Episode data collection with automatic sharding

  • Video recording from live rollouts or stored datasets

  • Policy evaluation with comprehensive metrics

  • Visual domain randomization through variation spaces

Example

Basic usage for policy evaluation:

from stable_worldmodel import World
from stable_worldmodel.policy import RandomPolicy

# Create a world with 4 parallel environments
world = World(
    env_name="PushT-v1",
    num_envs=4,
    image_shape=(96, 96),
    max_episode_steps=200
)

# Attach a policy
policy = RandomPolicy()
world.set_policy(policy)

# Evaluate the policy
metrics = world.evaluate(episodes=100, seed=42)
print(f"Success rate: {metrics['success_rate']:.2f}%")

Data collection example:

# Collect demonstration data
world.record_dataset(
    dataset_name="pusht_demos",
    episodes=1000,
    seed=42,
    options={'variation': ['all']}
)
class World(env_name: str, num_envs: int, image_shape: tuple, goal_shape: tuple | None = None, goal_transform: Callable | None = None, image_transform: Callable | None = None, seed: int = 2349867, max_episode_steps: int = 100, verbose: int = 1, **kwargs)[source]

Bases: object

High-level manager for vectorized Gymnasium environments with integrated data collection.

World orchestrates multiple parallel environments, providing a unified interface for policy execution, data collection, evaluation, and visualization. It automatically handles environment resets, batched action execution, and comprehensive episode tracking with support for visual domain randomization.

The World class is the central component of the stable-worldmodel library, designed to streamline the workflow from data collection to policy training and evaluation. It wraps Gymnasium’s vectorized environments with additional functionality for world model research.

Variables:
  • envs (VectorEnv) – Vectorized Gymnasium environment wrapped with MegaWrapper and VariationWrapper for image processing and domain randomization.

  • seed (int) – Base random seed for reproducibility.

  • policy (Policy) – Currently attached policy that generates actions.

  • states (ndarray) – Current observation states for all environments.

  • rewards (ndarray) – Current rewards for all environments.

  • terminateds (ndarray) – Terminal flags indicating episode completion.

  • truncateds (ndarray) – Truncation flags indicating episode timeout.

  • infos (dict) – Dictionary containing major information from environments.

Properties:

num_envs (int): Number of parallel environments. observation_space (Space): Batched observation space. action_space (Space): Batched action space. variation_space (Dict): Variation space for domain randomization (batched). single_variation_space (Dict): Variation space for a single environment. single_action_space (Space): Action space for a single environment. single_observation_space (Space): Observation space for a single environment.

Note

Environments use DISABLED autoreset mode to enable custom reset behavior with seed and options support, which is not provided by Gymnasium’s default autoreset mechanism.

__init__(env_name: str, num_envs: int, image_shape: tuple, goal_shape: tuple | None = None, goal_transform: Callable | None = None, image_transform: Callable | None = None, seed: int = 2349867, max_episode_steps: int = 100, verbose: int = 1, **kwargs)[source]

Initialize the World with vectorized environments.

Creates and configures a vectorized environment with the specified number of parallel instances. Applies image and goal transformations, sets up variation support, and configures autoreset behavior.

Parameters:
  • env_name (str) – Name of the Gymnasium environment to create. Must be a registered environment ID (e.g., ‘PushT-v0’, ‘CubeEnv-v0’).

  • num_envs (int) – Number of parallel environment instances to create. Higher values increase data collection throughput but require more memory.

  • image_shape (tuple) – Target shape for image observations as (height, width) or (height, width, channels). Images are resized to this shape.

  • goal_shape (tuple, optional) – Target shape for goal image observations. If None, goals are processed with the same shape as observations. Defaults to None.

  • goal_transform (Callable, optional) – Function to transform goal observations. Should accept and return numpy arrays. Applied after resizing. Defaults to None.

  • image_transform (Callable, optional) – Function to transform image observations. Should accept and return numpy arrays. Applied after resizing. Defaults to None.

  • seed (int, optional) – Base random seed for environment initialization. Each environment gets an offset seed. Defaults to 2349867.

  • max_episode_steps (int, optional) – Maximum number of steps per episode before truncation. Episodes terminate early on task success. Defaults to 100.

  • verbose (int, optional) – Verbosity level. 0 for silent, 1 for basic info, 2+ for detailed debugging information. Defaults to 1.

  • **kwargs – Additional keyword arguments passed to gym.make_vec() and subsequently to the underlying environment constructor.

Note

The MegaWrapper applies image transformations and resizing. The VariationWrapper enables domain randomization support. Autoreset is disabled to allow custom reset with seeds and options.

Example

Create a world with goal-conditioned observations:

world = World(
    env_name="PushT-v1",
    num_envs=8,
    image_shape=(96, 96),
    goal_shape=(64, 64),
    max_episode_steps=150,
    seed=42
)
property action_space

Batched action space for all environments.

Type:

Space

close(**kwargs)[source]

Close all environments and clean up resources.

Parameters:

**kwargs – Additional keyword arguments passed to the underlying vectorized environment’s close method.

Returns:

Return value from the underlying close method.

Return type:

Any

evaluate(episodes=10, eval_keys=None, seed=None, options=None)[source]

Evaluate the current policy over multiple episodes and return comprehensive metrics.

Runs the attached policy for a specified number of episodes, tracking success rates and optionally other metrics from environment info. Handles episode boundaries and ensures reproducibility through seeding.

Parameters:
  • episodes (int, optional) – Total number of episodes to evaluate. More episodes provide more statistically reliable metrics. Defaults to 10.

  • eval_keys (list of str, optional) – Additional info keys to track across episodes. Must be keys present in self.infos (e.g., ‘reward_total’, ‘steps_to_success’). Defaults to None (track only success rate).

  • seed (int, optional) – Base random seed for reproducible evaluation. Each episode gets an incremental offset. Defaults to None (non-deterministic).

  • options (dict, optional) – Reset options passed to environments (e.g., task selection, variation settings). Defaults to None.

Returns:

Dictionary containing evaluation metrics:
  • ’success_rate’ (float): Percentage of successful episodes (0-100)

  • ’episode_successes’ (ndarray): Boolean array of episode outcomes

  • ’seeds’ (ndarray): Random seeds used for each episode

  • Additional keys from eval_keys if specified (ndarray)

Return type:

dict

Raises:

AssertionError – If eval_key is not found in infos, if episode count mismatch occurs, or if duplicate seeds are detected.

Note

  • Success is determined by the ‘terminateds’ flag (True = success)

  • ‘truncateds’ flag (timeout) is treated as failure

  • Seeds are validated for uniqueness to ensure independent episodes

  • Environments are manually reset with seeds (autoreset is bypassed)

  • All environments are evaluated in parallel for efficiency

Example

Basic evaluation:

metrics = world.evaluate(episodes=100, seed=42)
print(f"Success: {metrics['success_rate']:.1f}%")

Evaluate with additional metrics:

metrics = world.evaluate(
    episodes=100,
    eval_keys=['reward_total', 'episode_length'],
    seed=42,
    options={'task_id': 3}
)
print(f"Success: {metrics['success_rate']:.1f}%")
print(f"Avg reward: {metrics['reward_total'].mean():.2f}")

Evaluate across different variations:

for var_type in ['none', 'color', 'light', 'all']:
    options = {'variation': [var_type]} if var_type != 'none' else None
    metrics = world.evaluate(episodes=50, seed=42, options=options)
    print(f"{var_type}: {metrics['success_rate']:.1f}%")
property num_envs

Number of parallel environment instances.

Type:

int

property observation_space

Batched observation space for all environments.

Type:

Space

record_dataset(dataset_name, episodes=10, seed=None, cache_dir=None, options=None)[source]

Collect episodes with the current policy and save as a HuggingFace Dataset.

Executes the attached policy to collect demonstration or rollout data, automatically managing episode boundaries and saving all observations, actions, rewards, and auxiliary information as a sharded Parquet dataset. Images are stored as JPEG files with paths in the dataset.

The dataset is organized with the following structure:
  • dataset_name/
    • img/
      • {episode_idx}/
        • {step_idx}_{column_name}.jpeg

    • data-{shard}.arrow (Parquet shards)

    • dataset_info.json

    • state.json

Parameters:
  • dataset_name (str) – Name of the dataset. Used as subdirectory name in cache_dir. Should be descriptive (e.g., ‘pusht_expert_demos’).

  • episodes (int, optional) – Total number of complete episodes to collect. Incomplete episodes at the end are discarded. Defaults to 10.

  • seed (int, optional) – Base random seed for reproducibility. Each episode gets an incremental offset. Defaults to None (non-deterministic).

  • cache_dir (str or Path, optional) – Root directory for dataset storage. If None, uses swm.data.get_cache_dir(). Defaults to None.

  • options (dict, optional) – Reset options for environments. Use {‘variation’: [‘all’]} for full domain randomization. Defaults to None.

Dataset Schema:
Each row contains:
  • episode_idx (int32): Episode identifier

  • step_idx (int32): Step within episode (0-indexed)

  • episode_len (int32): Total length of the episode

  • policy (string): Policy type identifier

  • pixels (string): Relative path to observation image

  • goal (string, optional): Relative path to goal image

  • action (float array): Action taken at this step

  • reward (float): Reward received

  • Additional keys from environment infos

Note

  • Actions are shifted: action at step t leads to observation at step t+1

  • Last action in each episode is NaN (no action leads from final state)

  • Images are saved as JPEG for efficiency (may introduce compression artifacts)

  • Dataset is automatically sharded: 1 shard per 50 episodes

  • Only complete episodes are included in the final dataset

Raises:

AssertionError – If required keys (‘pixels’, ‘episode_idx’, ‘step_idx’, ‘episode_len’) are missing from recorded data.

Example

Collect demonstration data with variations:

world.set_policy(expert_policy)
world.record_dataset(
    dataset_name="cube_expert_1k",
    episodes=1000,
    seed=42,
    options={'variation': ['all']}
)

Collect data for specific tasks:

world.record_dataset(
    dataset_name="pusht_task2_data",
    episodes=500,
    options={'task_id': 2}
)
record_video(video_path, max_steps=500, fps=30, viewname='pixels', seed=None, options=None)[source]

Record rollout videos for each environment under the current policy.

Executes policy rollouts in all environments and saves them as MP4 videos, one per environment. Videos show stacked observation and goal images when goals are available.

Parameters:
  • video_path (str or Path) – Directory path where videos will be saved. Created if it doesn’t exist. Videos are named ‘env_0.mp4’, ‘env_1.mp4’, etc.

  • max_steps (int, optional) – Maximum number of steps to record per episode. Recording stops earlier if any environment terminates. Defaults to 500.

  • fps (int, optional) – Frames per second for output videos. Higher values produce smoother playback but larger files. Defaults to 30.

  • seed (int, optional) – Random seed for reproducible rollouts. Defaults to None.

  • options (dict, optional) – Reset options passed to environments (e.g., variation settings). Defaults to None.

Note

  • Videos use libx264 codec for compatibility

  • If ‘goal’ is present in infos, frames show observation stacked above goal

  • All environments are recorded simultaneously

  • Recording stops when ANY environment terminates or truncates

Example

Record evaluation videos:

world.set_policy(trained_policy)
world.record_video(
    video_path="./eval_videos",
    max_steps=200,
    fps=30,
    seed=42
)
record_video_from_dataset(video_path, dataset_name, episode_idx, max_steps=500, fps=30, num_proc=4, viewname: str | list[str] = 'pixels', cache_dir=None)[source]

Replay stored dataset episodes and export them as MP4 videos.

Loads episodes from a previously recorded dataset and renders them as videos, useful for visualization, debugging, and qualitative evaluation of collected data.

Parameters:
  • video_path (str or Path) – Directory where videos will be saved. Videos are named ‘episode_{idx}.mp4’. Directory is created if it doesn’t exist.

  • dataset_name (str) – Name of the dataset to load (must exist in cache_dir).

  • episode_idx (int or list of int) – Episode index or list of episode indices to render. Each episode is saved as a separate video file.

  • max_steps (int, optional) – Maximum number of steps to render per episode. Useful for limiting video length. Defaults to 500.

  • fps (int, optional) – Frames per second for output videos. Defaults to 30.

  • num_proc (int, optional) – Number of processes for parallel dataset filtering. Higher values speed up loading for large datasets. Defaults to 4.

  • cache_dir (str or Path, optional) – Root directory where dataset is stored. If None, uses swm.data.get_cache_dir(). Defaults to None.

Raises:

AssertionError – If dataset doesn’t exist in cache_dir, or if episode length inconsistencies are detected in the data.

Note

  • Images are loaded from JPEG files stored in the dataset

  • If ‘goal’ column exists, observation and goal are stacked vertically

  • Videos use libx264 codec for broad compatibility

  • Episodes are validated for consistency (length matches metadata)

Example

Render specific episodes from a dataset:

world.record_video_from_dataset(
    video_path="./visualizations",
    dataset_name="cube_expert_1k",
    episode_idx=[0, 5, 10, 99],
    fps=30
)

Render a single episode:

world.record_video_from_dataset(
    video_path="./debug",
    dataset_name="failed_episodes",
    episode_idx=42,
    max_steps=100
)
reset(seed=None, options=None)[source]

Reset all environments to initial states.

Parameters:
  • seed (int, optional) – Random seed for reproducible resets. If None, uses non-deterministic seeding. Defaults to None.

  • options (dict, optional) – Dictionary of reset options passed to environments. Common keys include ‘variation’ for domain randomization. Defaults to None.

Updates:
  • self.states: Initial observations from all environments

  • self.infos: Initial auxiliary information

Example

Reset with domain randomization:

world.reset(seed=42, options={'variation': ['all']})

Reset specific variations:

world.reset(options={'variation': ['cube.color', 'light.intensity']})
set_policy(policy)[source]

Attach a policy to the world and configure it with environment context.

The policy will be used for action generation during step(), record_video(), record_dataset(), and evaluate() calls.

Parameters:

policy (Policy) – Policy instance that implements get_action(infos) method. The policy receives environment context through set_env() and optional seeding through set_seed().

Note

If the policy has a ‘seed’ attribute, it will be applied via set_seed(). The policy’s set_env() method receives the wrapped vectorized environment.

property single_action_space

Action space for a single environment instance.

Type:

Space

property single_observation_space

Observation space for a single environment instance.

Type:

Space

property single_variation_space

Variation space for a single environment instance.

Type:

Dict

step()[source]

Advance all environments by one step using the current policy.

Queries the attached policy for actions based on current info, executes those actions in all environments, and updates internal state with the results.

Updates:
  • self.states: New observations from all environments

  • self.rewards: Rewards received from the step

  • self.terminateds: Episode termination flags (task success)

  • self.truncateds: Episode truncation flags (timeout)

  • self.infos: Auxiliary information dictionaries

Note

Requires a policy to be attached via set_policy() before calling. The policy’s get_action() method receives the current infos dict.

Raises:

AttributeError – If no policy has been set via set_policy().

property variation_space

Batched variation space for domain randomization across all environments.

Type:

Dict

stable_worldmodel.wrappers module

class AddPixelsWrapper(env, pixels_shape: tuple[int, int] = (84, 84), torchvision_transform: Callable | None = None)[source]

Bases: Wrapper

Adds rendered environment pixels to info dict with optional resizing and transforms.

Supports single images, dictionaries of images (multiview), or lists of images. Uses PIL for resizing and optional torchvision transforms.

Parameters:
  • env – The Gymnasium environment to wrap.

  • pixels_shape – Target (height, width) for resized images. Defaults to (84, 84).

  • torchvision_transform – Optional transform to apply to PIL images.

Info Keys Added:
  • pixels: Rendered image (single view).

  • pixels.{key}: Individual images (multiview dict).

  • pixels.{idx}: Individual images (multiview list).

  • render_time: Time taken to render in seconds.

reset(*args, **kwargs)[source]

Uses the reset() of the env that can be overwritten to change the returned data.

step(action)[source]

Uses the step() of the env that can be overwritten to change the returned data.

class EnsureGoalInfoWrapper(env, check_reset, check_step: bool = False)[source]

Bases: Wrapper

Validates that ‘goal’ key is present in info dict during reset and/or step.

Useful for goal-conditioned environments to ensure goal information is provided.

Parameters:
  • env – The Gymnasium environment to wrap.

  • check_reset – If True, validates ‘goal’ key is in info after reset().

  • check_step – If True, validates ‘goal’ key is in info after step().

Raises:

RuntimeError – If ‘goal’ key is missing when validation is enabled.

reset(*args, **kwargs)[source]

Uses the reset() of the env that can be overwritten to change the returned data.

step(action)[source]

Uses the step() of the env that can be overwritten to change the returned data.

class EnsureImageShape(env, image_key, image_shape)[source]

Bases: Wrapper

Validates that an image in the info dict has the expected spatial dimensions.

Parameters:
  • env – The Gymnasium environment to wrap.

  • image_key – Key in info dict containing the image to validate.

  • image_shape – Expected (height, width) tuple for the image.

Raises:

RuntimeError – If the image shape doesn’t match the expected dimensions.

reset(*args, **kwargs)[source]

Uses the reset() of the env that can be overwritten to change the returned data.

step(action)[source]

Uses the step() of the env that can be overwritten to change the returned data.

class EnsureInfoKeysWrapper(env, required_keys: Iterable[str])[source]

Bases: Wrapper

Validates that required keys are present in the info dict after reset and step.

Supports regex patterns for flexible key matching. Raises RuntimeError if any required pattern has no matching key.

Parameters:
  • env – The Gymnasium environment to wrap.

  • required_keys – Iterable of regex patterns as strings. Each pattern must match at least one key in the info dict.

Raises:

RuntimeError – If any required pattern has no matching key in info dict.

reset(*args, **kwargs)[source]

Uses the reset() of the env that can be overwritten to change the returned data.

step(action)[source]

Uses the step() of the env that can be overwritten to change the returned data.

class EverythingToInfoWrapper(env)[source]

Bases: Wrapper

Moves all transition information into the info dict for unified data access.

Adds observation, reward, terminated, truncated, action, and step_idx to info. Optionally tracks environment variations when specified in reset options.

Parameters:

env – The Gymnasium environment to wrap.

Info Keys Added:
  • observation (or dict keys if obs is dict): Current observation.

  • reward: Reward value (NaN after reset).

  • terminated: Episode termination flag.

  • truncated: Episode truncation flag.

  • action: Action taken (NaN sample after reset).

  • step_idx: Current step counter.

  • variation.{key}: Variation values if requested via reset options.

Note

Pass options={“variation”: [“key1”, “key2”]} or [“all”] to reset() to track variations.

reset(*args, **kwargs)[source]

Uses the reset() of the env that can be overwritten to change the returned data.

step(action)[source]

Uses the step() of the env that can be overwritten to change the returned data.

class MegaWrapper(env, image_shape: tuple[int, int] = (84, 84), pixels_transform: Callable | None = None, goal_transform: Callable | None = None, required_keys: Iterable | None = None, separate_goal: Iterable | None = True)[source]

Bases: Wrapper

Combines multiple wrappers for comprehensive environment preprocessing.

Applies in sequence: AddPixelsWrapper → EverythingToInfoWrapper → EnsureInfoKeysWrapper → EnsureGoalInfoWrapper → ResizeGoalWrapper.

This provides a complete preprocessing pipeline with rendered pixels, unified info dict, key validation, goal checking, and goal resizing.

Parameters:
  • env – The Gymnasium environment to wrap.

  • image_shape – Target (height, width) for pixels and goal. Defaults to (84, 84).

  • pixels_transform – Optional torchvision transform for rendered pixels.

  • goal_transform – Optional torchvision transform for goal images.

  • required_keys – Additional regex patterns for keys that must be in info. Pattern ^pixels(?:\..*)?$ is always added.

  • separate_goal – If True, validates ‘goal’ is present in info. Defaults to True.

reset(*args, **kwargs)[source]

Uses the reset() of the env that can be overwritten to change the returned data.

step(action)[source]

Uses the step() of the env that can be overwritten to change the returned data.

class ResizeGoalWrapper(env, pixels_shape: tuple[int, int] = (84, 84), torchvision_transform: Callable | None = None)[source]

Bases: Wrapper

Resizes goal images in info dict with optional transforms.

Applies PIL-based resizing and optional torchvision transforms to the ‘goal’ image in info dict during both reset and step.

Parameters:
  • env – The Gymnasium environment to wrap.

  • pixels_shape – Target (height, width) for resized goal images. Defaults to (84, 84).

  • torchvision_transform – Optional transform to apply to PIL goal images.

reset(*args, **kwargs)[source]

Uses the reset() of the env that can be overwritten to change the returned data.

step(action)[source]

Uses the step() of the env that can be overwritten to change the returned data.

class VariationWrapper(env, variation_mode: str | Space = 'same')[source]

Bases: VectorWrapper

Manages variation spaces for vectorized environments.

Handles batching of variation spaces across multiple environments, supporting either shared variations (same) or independent variations (different).

Parameters:
  • env – The vectorized Gymnasium environment to wrap.

  • variation_mode – Mode for handling variations across environments: - “same”: All environments share the same variation space (batched). - “different”: Each environment has independent variation spaces.

Raises:

ValueError – If variation_mode is invalid or sub-environment spaces don’t match.

Note

Base environment must have a variation_space attribute. If missing, variation spaces are set to None.

property envs

Module contents

class PlanConfig(horizon: int, receding_horizon: int, history_len: int = 1, action_block: int = 1, warm_start: bool = True)[source]

Bases: object

Configuration for the planning process.

action_block: int = 1
history_len: int = 1
horizon: int
property plan_len
receding_horizon: int
warm_start: bool = True
class World(env_name: str, num_envs: int, image_shape: tuple, goal_shape: tuple | None = None, goal_transform: Callable | None = None, image_transform: Callable | None = None, seed: int = 2349867, max_episode_steps: int = 100, verbose: int = 1, **kwargs)[source]

Bases: object

High-level manager for vectorized Gymnasium environments with integrated data collection.

World orchestrates multiple parallel environments, providing a unified interface for policy execution, data collection, evaluation, and visualization. It automatically handles environment resets, batched action execution, and comprehensive episode tracking with support for visual domain randomization.

The World class is the central component of the stable-worldmodel library, designed to streamline the workflow from data collection to policy training and evaluation. It wraps Gymnasium’s vectorized environments with additional functionality for world model research.

Variables:
  • envs (VectorEnv) – Vectorized Gymnasium environment wrapped with MegaWrapper and VariationWrapper for image processing and domain randomization.

  • seed (int) – Base random seed for reproducibility.

  • policy (Policy) – Currently attached policy that generates actions.

  • states (ndarray) – Current observation states for all environments.

  • rewards (ndarray) – Current rewards for all environments.

  • terminateds (ndarray) – Terminal flags indicating episode completion.

  • truncateds (ndarray) – Truncation flags indicating episode timeout.

  • infos (dict) – Dictionary containing major information from environments.

Properties:

num_envs (int): Number of parallel environments. observation_space (Space): Batched observation space. action_space (Space): Batched action space. variation_space (Dict): Variation space for domain randomization (batched). single_variation_space (Dict): Variation space for a single environment. single_action_space (Space): Action space for a single environment. single_observation_space (Space): Observation space for a single environment.

Note

Environments use DISABLED autoreset mode to enable custom reset behavior with seed and options support, which is not provided by Gymnasium’s default autoreset mechanism.

__init__(env_name: str, num_envs: int, image_shape: tuple, goal_shape: tuple | None = None, goal_transform: Callable | None = None, image_transform: Callable | None = None, seed: int = 2349867, max_episode_steps: int = 100, verbose: int = 1, **kwargs)[source]

Initialize the World with vectorized environments.

Creates and configures a vectorized environment with the specified number of parallel instances. Applies image and goal transformations, sets up variation support, and configures autoreset behavior.

Parameters:
  • env_name (str) – Name of the Gymnasium environment to create. Must be a registered environment ID (e.g., ‘PushT-v0’, ‘CubeEnv-v0’).

  • num_envs (int) – Number of parallel environment instances to create. Higher values increase data collection throughput but require more memory.

  • image_shape (tuple) – Target shape for image observations as (height, width) or (height, width, channels). Images are resized to this shape.

  • goal_shape (tuple, optional) – Target shape for goal image observations. If None, goals are processed with the same shape as observations. Defaults to None.

  • goal_transform (Callable, optional) – Function to transform goal observations. Should accept and return numpy arrays. Applied after resizing. Defaults to None.

  • image_transform (Callable, optional) – Function to transform image observations. Should accept and return numpy arrays. Applied after resizing. Defaults to None.

  • seed (int, optional) – Base random seed for environment initialization. Each environment gets an offset seed. Defaults to 2349867.

  • max_episode_steps (int, optional) – Maximum number of steps per episode before truncation. Episodes terminate early on task success. Defaults to 100.

  • verbose (int, optional) – Verbosity level. 0 for silent, 1 for basic info, 2+ for detailed debugging information. Defaults to 1.

  • **kwargs – Additional keyword arguments passed to gym.make_vec() and subsequently to the underlying environment constructor.

Note

The MegaWrapper applies image transformations and resizing. The VariationWrapper enables domain randomization support. Autoreset is disabled to allow custom reset with seeds and options.

Example

Create a world with goal-conditioned observations:

world = World(
    env_name="PushT-v1",
    num_envs=8,
    image_shape=(96, 96),
    goal_shape=(64, 64),
    max_episode_steps=150,
    seed=42
)
property action_space

Batched action space for all environments.

Type:

Space

close(**kwargs)[source]

Close all environments and clean up resources.

Parameters:

**kwargs – Additional keyword arguments passed to the underlying vectorized environment’s close method.

Returns:

Return value from the underlying close method.

Return type:

Any

evaluate(episodes=10, eval_keys=None, seed=None, options=None)[source]

Evaluate the current policy over multiple episodes and return comprehensive metrics.

Runs the attached policy for a specified number of episodes, tracking success rates and optionally other metrics from environment info. Handles episode boundaries and ensures reproducibility through seeding.

Parameters:
  • episodes (int, optional) – Total number of episodes to evaluate. More episodes provide more statistically reliable metrics. Defaults to 10.

  • eval_keys (list of str, optional) – Additional info keys to track across episodes. Must be keys present in self.infos (e.g., ‘reward_total’, ‘steps_to_success’). Defaults to None (track only success rate).

  • seed (int, optional) – Base random seed for reproducible evaluation. Each episode gets an incremental offset. Defaults to None (non-deterministic).

  • options (dict, optional) – Reset options passed to environments (e.g., task selection, variation settings). Defaults to None.

Returns:

Dictionary containing evaluation metrics:
  • ’success_rate’ (float): Percentage of successful episodes (0-100)

  • ’episode_successes’ (ndarray): Boolean array of episode outcomes

  • ’seeds’ (ndarray): Random seeds used for each episode

  • Additional keys from eval_keys if specified (ndarray)

Return type:

dict

Raises:

AssertionError – If eval_key is not found in infos, if episode count mismatch occurs, or if duplicate seeds are detected.

Note

  • Success is determined by the ‘terminateds’ flag (True = success)

  • ‘truncateds’ flag (timeout) is treated as failure

  • Seeds are validated for uniqueness to ensure independent episodes

  • Environments are manually reset with seeds (autoreset is bypassed)

  • All environments are evaluated in parallel for efficiency

Example

Basic evaluation:

metrics = world.evaluate(episodes=100, seed=42)
print(f"Success: {metrics['success_rate']:.1f}%")

Evaluate with additional metrics:

metrics = world.evaluate(
    episodes=100,
    eval_keys=['reward_total', 'episode_length'],
    seed=42,
    options={'task_id': 3}
)
print(f"Success: {metrics['success_rate']:.1f}%")
print(f"Avg reward: {metrics['reward_total'].mean():.2f}")

Evaluate across different variations:

for var_type in ['none', 'color', 'light', 'all']:
    options = {'variation': [var_type]} if var_type != 'none' else None
    metrics = world.evaluate(episodes=50, seed=42, options=options)
    print(f"{var_type}: {metrics['success_rate']:.1f}%")
property num_envs

Number of parallel environment instances.

Type:

int

property observation_space

Batched observation space for all environments.

Type:

Space

record_dataset(dataset_name, episodes=10, seed=None, cache_dir=None, options=None)[source]

Collect episodes with the current policy and save as a HuggingFace Dataset.

Executes the attached policy to collect demonstration or rollout data, automatically managing episode boundaries and saving all observations, actions, rewards, and auxiliary information as a sharded Parquet dataset. Images are stored as JPEG files with paths in the dataset.

The dataset is organized with the following structure:
  • dataset_name/
    • img/
      • {episode_idx}/
        • {step_idx}_{column_name}.jpeg

    • data-{shard}.arrow (Parquet shards)

    • dataset_info.json

    • state.json

Parameters:
  • dataset_name (str) – Name of the dataset. Used as subdirectory name in cache_dir. Should be descriptive (e.g., ‘pusht_expert_demos’).

  • episodes (int, optional) – Total number of complete episodes to collect. Incomplete episodes at the end are discarded. Defaults to 10.

  • seed (int, optional) – Base random seed for reproducibility. Each episode gets an incremental offset. Defaults to None (non-deterministic).

  • cache_dir (str or Path, optional) – Root directory for dataset storage. If None, uses swm.data.get_cache_dir(). Defaults to None.

  • options (dict, optional) – Reset options for environments. Use {‘variation’: [‘all’]} for full domain randomization. Defaults to None.

Dataset Schema:
Each row contains:
  • episode_idx (int32): Episode identifier

  • step_idx (int32): Step within episode (0-indexed)

  • episode_len (int32): Total length of the episode

  • policy (string): Policy type identifier

  • pixels (string): Relative path to observation image

  • goal (string, optional): Relative path to goal image

  • action (float array): Action taken at this step

  • reward (float): Reward received

  • Additional keys from environment infos

Note

  • Actions are shifted: action at step t leads to observation at step t+1

  • Last action in each episode is NaN (no action leads from final state)

  • Images are saved as JPEG for efficiency (may introduce compression artifacts)

  • Dataset is automatically sharded: 1 shard per 50 episodes

  • Only complete episodes are included in the final dataset

Raises:

AssertionError – If required keys (‘pixels’, ‘episode_idx’, ‘step_idx’, ‘episode_len’) are missing from recorded data.

Example

Collect demonstration data with variations:

world.set_policy(expert_policy)
world.record_dataset(
    dataset_name="cube_expert_1k",
    episodes=1000,
    seed=42,
    options={'variation': ['all']}
)

Collect data for specific tasks:

world.record_dataset(
    dataset_name="pusht_task2_data",
    episodes=500,
    options={'task_id': 2}
)
record_video(video_path, max_steps=500, fps=30, viewname='pixels', seed=None, options=None)[source]

Record rollout videos for each environment under the current policy.

Executes policy rollouts in all environments and saves them as MP4 videos, one per environment. Videos show stacked observation and goal images when goals are available.

Parameters:
  • video_path (str or Path) – Directory path where videos will be saved. Created if it doesn’t exist. Videos are named ‘env_0.mp4’, ‘env_1.mp4’, etc.

  • max_steps (int, optional) – Maximum number of steps to record per episode. Recording stops earlier if any environment terminates. Defaults to 500.

  • fps (int, optional) – Frames per second for output videos. Higher values produce smoother playback but larger files. Defaults to 30.

  • seed (int, optional) – Random seed for reproducible rollouts. Defaults to None.

  • options (dict, optional) – Reset options passed to environments (e.g., variation settings). Defaults to None.

Note

  • Videos use libx264 codec for compatibility

  • If ‘goal’ is present in infos, frames show observation stacked above goal

  • All environments are recorded simultaneously

  • Recording stops when ANY environment terminates or truncates

Example

Record evaluation videos:

world.set_policy(trained_policy)
world.record_video(
    video_path="./eval_videos",
    max_steps=200,
    fps=30,
    seed=42
)
record_video_from_dataset(video_path, dataset_name, episode_idx, max_steps=500, fps=30, num_proc=4, viewname: str | list[str] = 'pixels', cache_dir=None)[source]

Replay stored dataset episodes and export them as MP4 videos.

Loads episodes from a previously recorded dataset and renders them as videos, useful for visualization, debugging, and qualitative evaluation of collected data.

Parameters:
  • video_path (str or Path) – Directory where videos will be saved. Videos are named ‘episode_{idx}.mp4’. Directory is created if it doesn’t exist.

  • dataset_name (str) – Name of the dataset to load (must exist in cache_dir).

  • episode_idx (int or list of int) – Episode index or list of episode indices to render. Each episode is saved as a separate video file.

  • max_steps (int, optional) – Maximum number of steps to render per episode. Useful for limiting video length. Defaults to 500.

  • fps (int, optional) – Frames per second for output videos. Defaults to 30.

  • num_proc (int, optional) – Number of processes for parallel dataset filtering. Higher values speed up loading for large datasets. Defaults to 4.

  • cache_dir (str or Path, optional) – Root directory where dataset is stored. If None, uses swm.data.get_cache_dir(). Defaults to None.

Raises:

AssertionError – If dataset doesn’t exist in cache_dir, or if episode length inconsistencies are detected in the data.

Note

  • Images are loaded from JPEG files stored in the dataset

  • If ‘goal’ column exists, observation and goal are stacked vertically

  • Videos use libx264 codec for broad compatibility

  • Episodes are validated for consistency (length matches metadata)

Example

Render specific episodes from a dataset:

world.record_video_from_dataset(
    video_path="./visualizations",
    dataset_name="cube_expert_1k",
    episode_idx=[0, 5, 10, 99],
    fps=30
)

Render a single episode:

world.record_video_from_dataset(
    video_path="./debug",
    dataset_name="failed_episodes",
    episode_idx=42,
    max_steps=100
)
reset(seed=None, options=None)[source]

Reset all environments to initial states.

Parameters:
  • seed (int, optional) – Random seed for reproducible resets. If None, uses non-deterministic seeding. Defaults to None.

  • options (dict, optional) – Dictionary of reset options passed to environments. Common keys include ‘variation’ for domain randomization. Defaults to None.

Updates:
  • self.states: Initial observations from all environments

  • self.infos: Initial auxiliary information

Example

Reset with domain randomization:

world.reset(seed=42, options={'variation': ['all']})

Reset specific variations:

world.reset(options={'variation': ['cube.color', 'light.intensity']})
set_policy(policy)[source]

Attach a policy to the world and configure it with environment context.

The policy will be used for action generation during step(), record_video(), record_dataset(), and evaluate() calls.

Parameters:

policy (Policy) – Policy instance that implements get_action(infos) method. The policy receives environment context through set_env() and optional seeding through set_seed().

Note

If the policy has a ‘seed’ attribute, it will be applied via set_seed(). The policy’s set_env() method receives the wrapped vectorized environment.

property single_action_space

Action space for a single environment instance.

Type:

Space

property single_observation_space

Observation space for a single environment instance.

Type:

Space

property single_variation_space

Variation space for a single environment instance.

Type:

Dict

step()[source]

Advance all environments by one step using the current policy.

Queries the attached policy for actions based on current info, executes those actions in all environments, and updates internal state with the results.

Updates:
  • self.states: New observations from all environments

  • self.rewards: Rewards received from the step

  • self.terminateds: Episode termination flags (task success)

  • self.truncateds: Episode truncation flags (timeout)

  • self.infos: Auxiliary information dictionaries

Note

Requires a policy to be attached via set_policy() before calling. The policy’s get_action() method receives the current infos dict.

Raises:

AttributeError – If no policy has been set via set_policy().

property variation_space

Batched variation space for domain randomization across all environments.

Type:

Dict

pretraining(script_path: str, dataset_name: str, output_model_name: str, dump_object: bool = True, args: str = '') int[source]

Run a pretraining script as a subprocess with optional command-line arguments.

This function checks if the specified script exists, constructs a command to run it with the provided arguments, and executes the command in a subprocess.

Parameters:
  • script_path (str) – The path to the pretraining script to be executed.

  • dataset_name (str) – The name of the dataset to be used in pretraining.

  • output_model_name (str) – The name to save the output model.

  • dump_object (bool, optional) – Whether to dump the model object after training. Defaults to

  • args (str, optional) – A string of command-line arguments to pass to the script. Defaults to an empty string.

Returns:

The return code of the subprocess. A return code of 0 indicates success.

Return type:

int

Raises:
  • ValueError – If the specified script does not exist.

  • SystemExit – If the subprocess exits with a non-zero return code.