stable_worldmodel package
Subpackages
- stable_worldmodel.envs package
- Submodules
- stable_worldmodel.envs.image_positioning module
- stable_worldmodel.envs.ogbench_cube module
CubeEnvCubeEnv.__init__()CubeEnv.add_object_info()CubeEnv.add_objects()CubeEnv.compute_observation()CubeEnv.compute_oracle_observation()CubeEnv.compute_reward()CubeEnv.get_reset_info()CubeEnv.get_step_info()CubeEnv.initialize_episode()CubeEnv.modify_mjcf_model()CubeEnv.post_compilation_objects()CubeEnv.post_step()CubeEnv.render()CubeEnv.render_multiview()CubeEnv.reset()CubeEnv.set_new_target()CubeEnv.set_tasks()
- stable_worldmodel.envs.ogbench_scene module
SceneEnvSceneEnv.__init__()SceneEnv.add_object_info()SceneEnv.add_objects()SceneEnv.compute_observation()SceneEnv.compute_oracle_observation()SceneEnv.compute_reward()SceneEnv.get_reset_info()SceneEnv.get_step_info()SceneEnv.initialize_episode()SceneEnv.modify_mjcf_model()SceneEnv.post_compilation_objects()SceneEnv.post_step()SceneEnv.pre_step()SceneEnv.render()SceneEnv.reset()SceneEnv.set_new_target()SceneEnv.set_state()SceneEnv.set_tasks()
- stable_worldmodel.envs.pusht module
PushTPushT.add_I()PushT.add_L()PushT.add_Z()PushT.add_box()PushT.add_circle()PushT.add_plus()PushT.add_shape()PushT.add_small_tee()PushT.add_square()PushT.add_tee()PushT.close()PushT.eval_state()PushT.fix_action_sample()PushT.metadataPushT.render()PushT.reset()PushT.reward_rangePushT.step()
- stable_worldmodel.envs.simple_point_maze module
- stable_worldmodel.envs.two_room module
- stable_worldmodel.envs.utils module
- stable_worldmodel.envs.voidrun module
- Module contents
- stable_worldmodel.solver package
- Submodules
- stable_worldmodel.solver.cem module
- stable_worldmodel.solver.gd module
- stable_worldmodel.solver.mppi module
- stable_worldmodel.solver.nevergrad module
- stable_worldmodel.solver.old_mppi module
- stable_worldmodel.solver.random module
- stable_worldmodel.solver.solver module
- Module contents
- stable_worldmodel.wm package
- Submodules
- stable_worldmodel.wm.dinowm module
- stable_worldmodel.wm.dreamer module
- stable_worldmodel.wm.dummy module
- stable_worldmodel.wm.frame module
- stable_worldmodel.wm.tdmpc module
- Module contents
Submodules
stable_worldmodel.cli module
Command-line interface for stable-worldmodel.
This module provides CLI commands for managing and inspecting stable-worldmodel resources including datasets, models, and world environments. The CLI uses the Typer framework with Rich formatting for an enhanced terminal experience.
- Available commands:
list: List cached models, datasets, or worlds
show: Display detailed information about datasets or worlds
delete: Remove models or datasets from cache
- The CLI can be invoked via:
stable-worldmodel <command>
swm <command>
python -m stable_worldmodel.cli <command>
Typical usage examples:
List all cached datasets:
$ swm list datasetShow information about a specific world:
$ swm show world swm/SimplePointMaze-v0Delete a cached dataset:
$ swm delete dataset my-datasetShow version:
$ swm --version
- common(version: ~types.Annotated[bool | None, <typer.models.OptionInfo object at 0x7fd45e97a610>] = None)[source]
Common options for all stable-worldmodel commands.
This callback provides global options that apply to all CLI commands. Currently only supports the –version flag.
- Parameters:
version (Optional[bool], optional) – If True, display version and exit. Defaults to None.
- delete(kind: ~typing.Annotated[str, <typer.models.ArgumentInfo object at 0x7fd45d8eba10>], names: ~typing.Annotated[list[str], <typer.models.ArgumentInfo object at 0x7fd45d8eba90>])[source]
Delete models or datasets from cache directory.
Permanently removes specified models or datasets from the local cache directory. Requires user confirmation before deletion. This operation cannot be undone.
- Parameters:
kind (str) – Type of resource to delete. Must be either: - ‘model’: Delete cached world models - ‘dataset’: Delete cached datasets
names (List[str]) – One or more names of resources to delete. All specified names must exist in cache.
- Raises:
typer.Abort – If kind is invalid, specified names not found in cache, or user cancels the confirmation prompt.
Example
Delete a single dataset:
$ swm delete dataset my-datasetDelete multiple models:
$ swm delete model model1 model2Warning
This operation permanently deletes data and cannot be undone. The command will prompt for confirmation before proceeding.
- display_dataset_info(info: dict[str, Any]) None[source]
Display dataset information in a formatted panel with tables.
Prints a Rich panel showing dataset metadata including columns, episode count, shapes, and variation information in a two-table layout.
- Parameters:
info (Dict[str, Any]) – Dataset metadata dictionary containing keys: - ‘name’: Dataset name - ‘columns’: List of data columns - ‘num_episodes’: Total number of episodes - ‘num_steps’: Total number of steps - ‘obs_shape’: Observation tensor shape - ‘action_shape’: Action tensor shape - ‘goal_shape’: Goal tensor shape - ‘variation’: Variation space structure
Example
>>> from stable_worldmodel import data >>> info = data.dataset_info("simple-pointmaze") >>> display_dataset_info(info)
- display_world_info(info: dict[str, Any]) None[source]
Display world environment information in a formatted panel.
Prints a Rich panel showing the world’s observation space, action space, and variation space in a hierarchical, color-coded format.
- Parameters:
info (Dict[str, Any]) – World metadata dictionary containing keys: - ‘name’: World environment name - ‘observation_space’: Observation space structure - ‘action_space’: Action space structure - ‘variation’: Variation space structure
Example
>>> from stable_worldmodel import data >>> info = data.world_info("swm/SimplePointMaze-v0") >>> display_world_info(info)
- list_cmd(kind: ~typing.Annotated[str, <typer.models.ArgumentInfo object at 0x7fd45e075490>])[source]
List cached stable-worldmodel resources.
Displays a table of all cached models, datasets, or worlds stored in the stable-worldmodel cache directory. Useful for seeing what resources are available locally.
- Parameters:
kind (str) – Type of resource to list. Must be one of: - ‘model’: List cached world models - ‘dataset’: List cached datasets - ‘world’: List registered world environments
- Raises:
typer.Abort – If kind is not ‘model’, ‘dataset’, or ‘world’.
Example
List all cached datasets:
$ swm list datasetList all available worlds:
$ swm list world
- show(kind: ~typing.Annotated[str, <typer.models.ArgumentInfo object at 0x7fd45e979650>], names: ~types.Annotated[list[str] | None, <typer.models.ArgumentInfo object at 0x7fd45d8a4f50>] = None, all: ~typing.Annotated[bool, <typer.models.OptionInfo object at 0x7fd45d8ebdd0>] = False)[source]
Show detailed information about datasets or worlds.
Displays comprehensive information about specified datasets or world environments including their structure, spaces, and variations. Information is presented in formatted panels with hierarchical trees and tables.
- Parameters:
kind (str) – Type of resource to show. Must be either: - ‘dataset’: Show dataset information - ‘world’: Show world environment information
names (Optional[List[str]], optional) – Specific names to display. Can provide multiple names. Defaults to None.
all (bool, optional) – If True, show information for all cached resources of the specified kind. Defaults to False.
- Raises:
typer.Abort – If kind is invalid, no names provided without –all flag, or if specified names are not found in cache.
Example
Show a specific dataset:
$ swm show dataset simple-pointmazeShow multiple datasets:
$ swm show dataset dataset1 dataset2Show all cached datasets:
$ swm show dataset --allShow a specific world:
$ swm show world swm/SimplePointMaze-v0
stable_worldmodel.data module
Data management and dataset utilities for Stable World Model.
This module provides utilities for managing datasets, models, and world information. It includes functionality for loading multi-step trajectories, querying cached data, and retrieving metadata about environments.
- The module supports:
Multi-step trajectory datasets with frame skipping
Dataset and model cache management
World environment introspection
Gymnasium space metadata extraction
- class SpaceInfo[source]
Bases:
TypedDictType specification for Gymnasium space metadata.
- Variables:
shape (tuple[int, ...]) – Dimensions of the space.
type (str) – Class name of the space (e.g., ‘Box’, ‘Discrete’).
dtype (str) – Data type of the space elements.
low (Any) – Lower bounds for Box spaces.
high (Any) – Upper bounds for Box spaces.
n (int) – Number of discrete values for Discrete spaces.
- class StepsDataset(path, *args, num_steps=2, frameskip=1, cache_dir=None, **kwargs)[source]
Bases:
HFDatasetDataset for loading multi-step trajectory sequences.
This dataset loads sequences of consecutive steps from episodic data, supporting frame skipping and automatic image loading from file paths. Inherits from stable_pretraining’s HFDataset.
- Variables:
data_dir (Path) – Directory containing the dataset.
num_steps (int) – Number of steps per sequence.
frameskip (int) – Number of steps to skip between frames.
episodes (np.ndarray) – Array of unique episode indices.
episode_slices (dict) – Mapping from episode index to dataset indices.
cum_slices (np.ndarray) – Cumulative sum of valid samples per episode.
idx_to_ep (np.ndarray) – Mapping from sample index to episode.
img_cols (set) – Set of column names containing image paths.
- __init__(path, *args, num_steps=2, frameskip=1, cache_dir=None, **kwargs)[source]
Initialize the StepsDataset.
- Parameters:
path (str) – Name or path of the dataset within cache directory.
*args – Additional arguments passed to parent class.
num_steps (int, optional) – Number of consecutive steps per sample. Defaults to 2.
frameskip (int, optional) – Number of steps between sampled frames. Defaults to 1.
cache_dir (str, optional) – Cache directory path. Defaults to None (uses default).
**kwargs – Additional keyword arguments passed to parent class.
- Raises:
AssertionError – If required columns are missing from dataset.
ValueError – If episodes are too short for the requested num_steps and frameskip.
- get_episode_slice(episode_idx, episode_indices)[source]
Get dataset indices for a specific episode.
- Parameters:
episode_idx (int) – Episode index to retrieve.
episode_indices (array-like) – Array of episode indices for all steps.
- Returns:
Indices of steps belonging to the specified episode.
- Return type:
np.ndarray
- Raises:
ValueError – If episode is too short for num_steps and frameskip.
- infer_img_path_columns()[source]
Infer which dataset columns contain image file paths.
Checks the first dataset element to identify string columns with common image file extensions.
- Returns:
Set of column names containing image file paths.
- Return type:
set
- class VariationInfo[source]
Bases:
TypedDictType specification for environment variation metadata.
- Variables:
has_variation (bool) – Whether the environment supports variations.
type (str | None) – Class name of the variation space if it exists.
names (list[str] | None) – List of variation parameter names.
- class WorldInfo[source]
Bases:
TypedDictType specification for world environment information.
- Variables:
name (str) – Name/ID of the world environment.
observation_space (stable_worldmodel.data.SpaceInfo) – Metadata about the observation space.
action_space (stable_worldmodel.data.SpaceInfo) – Metadata about the action space.
variation (stable_worldmodel.data.VariationInfo) – Information about environment variations.
config (dict[str, Any]) – Additional configuration parameters.
- action_space: SpaceInfo
- observation_space: SpaceInfo
- variation: VariationInfo
- dataset_info(name)[source]
Get metadata about a cached dataset.
- Parameters:
name (str) – Name of the dataset.
- Returns:
- Dictionary containing dataset metadata including:
name: Dataset name
num_episodes: Number of unique episodes
num_steps: Total number of steps
columns: List of column names
obs_shape: Shape of observation images
action_shape: Shape of action vectors
goal_shape: Shape of goal images
variation: Dict with variation information
- Return type:
dict
- Raises:
ValueError – If dataset is not found in cache.
AssertionError – If required columns are missing.
- delete_dataset(name)[source]
Delete a cached dataset and its associated files.
- Parameters:
name (str) – Name of the dataset to delete.
Note
Prints success or error messages to console.
- delete_model(name)[source]
Delete cached model checkpoint files.
Removes all checkpoint files (weights and object files) matching the given model name.
- Parameters:
name (str) – Name of the model to delete.
Note
Prints success or error messages to console for each file deleted.
- get_cache_dir() Path[source]
Get the cache directory for stable_worldmodel data.
The cache directory can be customized via the STABLEWM_HOME environment variable. If not set, defaults to ~/.stable_worldmodel.
- Returns:
Path to the cache directory. Directory is created if it doesn’t exist.
- Return type:
Path
- is_image(x)[source]
Check if input is a valid image array.
- Parameters:
x – Input to check.
- Returns:
- True if x is a uint8 numpy array with shape (H, W, C) where
C is 1 (grayscale), 3 (RGB), or 4 (RGBA).
- Return type:
bool
- list_datasets()[source]
List all cached datasets.
- Returns:
Names of all dataset directories in the cache.
- Return type:
list[str]
- list_models()[source]
List all cached model checkpoints.
Searches for files matching the pattern <name>_weights*.ckpt or <name>_object.ckpt and returns unique model names.
- Returns:
Sorted list of model names found in cache.
- Return type:
list[str]
- list_worlds()[source]
List all registered world environments.
- Returns:
Sorted list of world environment IDs.
- Return type:
list[str]
- world_info(name: str, *, image_shape: tuple[int, int] = (224, 224), render_mode: str = 'rgb_array') WorldInfo[source]
Get metadata about a world environment.
Creates a temporary world instance to extract observation space, action space, and variation information. Results are cached for efficiency.
- Parameters:
name – ID of the world environment.
image_shape – Desired image shape for rendering. Defaults to (224, 224).
render_mode – Rendering mode for the environment. Defaults to “rgb_array”.
- Returns:
Dictionary containing world metadata including spaces and variations.
- Return type:
- Raises:
ValueError – If world name is not registered.
stable_worldmodel.policy module
- AutoCostModel(model_name, cache_dir=None)[source]
- class BasePolicy(**kwargs)[source]
Bases:
objectBase class for agent policies.
- get_action(obs, **kwargs)[source]
Get action from the policy given the observation.
- set_env(env)[source]
- class ExpertPolicy(**kwargs)[source]
Bases:
BasePolicyExpert Policy.
- get_action(obs, goal_obs, **kwargs)[source]
Get action from the policy given the observation.
- class PlanConfig(horizon: int, receding_horizon: int, history_len: int = 1, action_block: int = 1, warm_start: bool = True)[source]
Bases:
objectConfiguration for the planning process.
- class RandomPolicy(seed=None, **kwargs)[source]
Bases:
BasePolicyRandom Policy.
- get_action(obs, **kwargs)[source]
Get action from the policy given the observation.
- set_seed(seed)[source]
- class Transformable(*args, **kwargs)[source]
Bases:
ProtocolProtocol for input transformation.
- inverse_transform() Tensor[source]
Revert pre-processed
- transform() Tensor[source]
Pre-process
- class WorldModelPolicy(solver: Solver, config: PlanConfig, process: dict[str, Transformable] | None = None, transform: dict[str, callable] | None = None, **kwargs)[source]
Bases:
BasePolicyWorld Model Policy using a planning solver.
- get_action(info_dict, **kwargs)[source]
Get action from the policy given the observation.
- set_env(env)[source]
stable_worldmodel.spaces module
Extended Gymnasium spaces with state tracking and constraint support.
- class Box(*args, init_value=None, constrain_fn=None, **kwargs)[source]
Bases:
BoxExtended continuous box space with state tracking and constraint support.
This class extends
gymnasium.spaces.Boxto add state management and optional constraint validation. It represents bounded continuous values with configurable shape, dtype, and custom constraints.- Variables:
init_value (np.ndarray) – The initial value for the space.
value (np.ndarray) – The current value of the space.
constrain_fn (callable) – Optional function that returns True if a value satisfies custom constraints beyond the box boundaries.
Example
Create a 2D position space constrained to a circle:
import numpy as np def in_circle(pos): return np.linalg.norm(pos) <= 1.0 space = Box( low=np.array([-1.0, -1.0]), high=np.array([1.0, 1.0]), init_value=np.array([0.0, 0.0]), constrain_fn=in_circle, ) position = space.sample() # Only samples within unit circleNote
The constraint function enables complex geometric or relational constraints beyond simple box boundaries.
- __init__(*args, init_value=None, constrain_fn=None, **kwargs)[source]
Initialize a Box space with state tracking.
- Parameters:
*args – Positional arguments passed to gymnasium.spaces.Box.
init_value (np.ndarray, optional) – Initial value for the space. Must match the shape and dtype of the box. Defaults to None.
constrain_fn (callable, optional) – Function that takes a numpy array and returns True if the value satisfies custom constraints beyond the box boundaries. Defaults to None.
**kwargs – Keyword arguments passed to gymnasium.spaces.Box.
- check()[source]
Validate the current space value.
Checks if the current value is within the box bounds and satisfies the constraint function. Logs a warning if the constraint fails.
- Returns:
True if the current value is valid, False otherwise.
- Return type:
bool
- contains(x)[source]
Check if value is valid and satisfies constraints.
- Parameters:
x (np.ndarray) – The value to check.
- Returns:
- True if x is within box bounds and satisfies the constraint
function, False otherwise.
- Return type:
bool
- reset()[source]
Reset the space value to its initial value.
Sets the current value back to the init_value specified during initialization.
- sample(*args, max_tries=1000, warn_after_s=5.0, set_value=True, **kwargs)[source]
Sample a random value using rejection sampling for constraints.
Repeatedly samples values until one satisfies the constraint function or max_tries is reached. Optionally updates the space’s current value.
- Parameters:
*args – Positional arguments passed to gymnasium.spaces.Box.sample().
max_tries (int, optional) – Maximum number of sampling attempts before raising an error. Defaults to 1000.
warn_after_s (float, optional) – Time threshold in seconds after which to log a warning about slow sampling. Set to None to disable. Defaults to 5.0.
set_value (bool, optional) – Whether to update the space’s current value with the sampled value. Defaults to True.
**kwargs – Keyword arguments passed to gymnasium.spaces.Box.sample().
- Returns:
A sampled array that satisfies the constraint function.
- Return type:
np.ndarray
- Raises:
RuntimeError – If no valid sample is found after max_tries attempts.
- class Dict(*args, init_value=None, constrain_fn=None, sampling_order=None, **kwargs)[source]
Bases:
DictExtended dictionary space with ordered sampling and nested support.
This class extends
gymnasium.spaces.Dictto add state management, constraint validation, and explicit sampling order control. It composes multiple spaces into a hierarchical structure where dependencies between variables can be handled through ordered sampling.- Parameters:
*args – Positional arguments passed to gymnasium.spaces.Dict.
init_value (dict, optional) – Initial values for the space. If None, derived from init_value of contained spaces.
constrain_fn (callable, optional) – Function that returns True if the complete dictionary satisfies custom constraints.
sampling_order (list, optional) – Explicit order for sampling keys. If None, uses insertion order. Missing keys are appended.
**kwargs – Additional keyword arguments passed to Dict.
- Variables:
init_value (dict) – Initial values for all contained spaces.
value (dict) – Current values of all contained spaces.
constrain_fn (callable) – Constraint validation function.
sampling_order (set) – Set of dotted paths for all variables in order.
Example
Create a nested space with sampling order dependencies:
from stable_worldmodel import spaces import numpy as np config = spaces.Dict( { "difficulty": spaces.Discrete(n=3, init_value=0), "world": spaces.Dict( { "width": spaces.Discrete(n=100, init_value=50), "height": spaces.Discrete(n=100, init_value=50), } ), "player_pos": spaces.Box( low=np.array([0, 0]), high=np.array([99, 99]), init_value=np.array([25, 25]), ), }, sampling_order=["difficulty", "world", "player_pos"], ) # Sample respects order state = config.sample()Note
Sampling order is crucial when variables have dependencies. For example, sample world size before sampling positions within it. Nested Dict spaces recursively apply their own sampling orders.
Accessing values in constraint functions: When implementing
constrain_fnfor Dict spaces, always useself.value['key']['key2']instead ofself['key']['key2'].value. The.valueproperty recursively builds the complete value dictionary from the top level down, ensuring all nested values are up-to-date and correctly structured. Direct subspace access with.valueonly retrieves that specific subspace’s value without the full context.Note that direct subspace access (e.g.,
self['key'].value) is perfectly fine for regular operations outside of constraint functions, such as reading individual subspace values or debugging. The recommendation to use top-level.valueapplies specifically to constraint functions where you need the complete, consistent state of all nested spaces.Example of proper constraint function usage:
# Example: In a class with Dict space attribute class Environment: def __init__(self): self.config_space = spaces.Dict({...}) def validate_config(self): # ✓ CORRECT: Access via .value at top level values = self.config_space.value return values["player_pos"][0] < values["world"]["width"] def validate_wrong(self): # ✗ AVOID: Direct subspace access return ( self.config_space["player_pos"].value[0] < self.config_space["world"]["width"].value )- __init__(*args, init_value=None, constrain_fn=None, sampling_order=None, **kwargs)[source]
Initialize a Dict space with state tracking and sampling order.
- Parameters:
*args – Positional arguments passed to gymnasium.spaces.Dict.
init_value (dict, optional) – Initial values for the space. If None, derived from init_value of contained spaces. Defaults to None.
constrain_fn (callable, optional) – Function that takes a dict and returns True if the complete dictionary satisfies custom constraints. Defaults to None.
sampling_order (list, optional) – Explicit order for sampling keys. If None, uses insertion order. Missing keys are appended with warning. Defaults to None.
**kwargs – Keyword arguments passed to gymnasium.spaces.Dict.
- Raises:
ValueError – If sampling_order contains keys not present in spaces.
- check(debug=False)[source]
Validate all contained spaces’ current values.
Checks each contained space using its check() method if available, or falls back to contains(value). Optionally logs warnings for failed checks.
- Parameters:
debug (bool, optional) – If True, logs warnings for spaces that fail validation. Defaults to False.
- Returns:
True if all contained spaces have valid values, False otherwise.
- Return type:
bool
- contains(x) bool[source]
Check if value is a valid member of this space.
Validates that x is a dictionary containing all required keys with values that satisfy each subspace’s constraints and the overall constraint function.
- Parameters:
x – The value to check.
- Returns:
- True if x is a valid dict with all keys present, all values
within subspace bounds, and satisfies the constraint function. False otherwise.
- Return type:
bool
- property init_value
Initial values for all contained spaces.
Constructs initial value dictionary from contained spaces’ init_value properties. Falls back to sampling if a space lacks init_value.
- Returns:
Dictionary mapping space keys to their initial values.
- Return type:
dict
- Type:
dict
- names()[source]
Return all space keys including nested ones.
- Returns:
- A list of all keys in the Dict space, with nested keys using dot notation.
For example, a nested dict with key “a” containing subspace “b” would produce “a.b”.
- Return type:
list
- reset()[source]
Reset all contained spaces to their initial values.
Calls reset() on all contained spaces that have a reset method, then sets this space’s value to init_value.
- sample(*args, max_tries=1000, warn_after_s=5.0, set_value=True, **kwargs)[source]
Sample a random element from the Dict space.
Samples each subspace in the sampling order and ensures the result satisfies any constraint functions. Uses rejection sampling if constraints are present.
- Parameters:
*args – Positional arguments passed to each subspace’s sample method.
max_tries (int, optional) – Maximum number of rejection sampling attempts. Defaults to 1000.
warn_after_s (float, optional) – Issue a warning if sampling takes longer than this many seconds. Set to None to disable warnings. Defaults to 5.0.
set_value (bool, optional) – Whether to set the internal value to the sampled value. Defaults to True.
**kwargs – Additional keyword arguments passed to each subspace’s sample method.
- Returns:
- A dictionary with keys matching the space definition and values sampled
from their respective subspaces.
- Return type:
dict
- Raises:
RuntimeError – If a valid sample is not found within max_tries attempts.
- property sampling_order
Set of dotted paths for all variables in sampling order.
- Returns:
- Set of strings representing dotted paths (e.g., ‘parent.child.key’)
for all variables including nested Dict spaces.
- Return type:
set
- Type:
set
- update(keys)[source]
Update specific keys in the Dict space by resampling them.
Samples new values for the specified keys while maintaining the sampling order. Uses dot notation for nested keys (e.g., “a.b” for nested dict).
- Parameters:
keys (container) – A container (list, set, etc.) of key names to resample. Keys should use dot notation for nested spaces.
- Raises:
ValueError – If a specified key is not found in the Dict space.
AssertionError – If the updated values violate the space constraints.
- class Discrete(*args, init_value=None, constrain_fn=None, **kwargs)[source]
Bases:
DiscreteExtended discrete space with state tracking and constraint support.
This class extends
gymnasium.spaces.Discreteto add state management and optional constraint validation. Unlike the standard discrete space, this version maintains a current value and supports rejection sampling via a custom constraint function.- Variables:
init_value (int) – The initial value for the space.
value (int) – The current value of the space.
constrain_fn (callable) – Optional function that returns True if a value satisfies custom constraints.
Example
Create a discrete space that only accepts even numbers:
space = Discrete(n=10, init_value=0, constrain_fn=lambda x: x % 2 == 0) value = space.sample() # Samples even number and updates space.value space.reset() # Resets space.value back to 0 (init_value)Note
The
sample()method uses rejection sampling when a constraint function is provided, which may impact performance for difficult constraints.- __init__(*args, init_value=None, constrain_fn=None, **kwargs)[source]
Initialize a Discrete space with state tracking.
- Parameters:
*args – Positional arguments passed to gymnasium.spaces.Discrete.
init_value (int, optional) – Initial value for the space. Defaults to None.
constrain_fn (callable, optional) – Function that takes an int and returns True if the value satisfies custom constraints. Defaults to None.
**kwargs – Keyword arguments passed to gymnasium.spaces.Discrete.
- check()[source]
Validate the current space value.
Checks if the current value is within the space bounds and satisfies the constraint function. Logs a warning if the constraint fails.
- Returns:
True if the current value is valid, False otherwise.
- Return type:
bool
- contains(x)[source]
Check if value is valid and satisfies constraints.
- Parameters:
x (int) – The value to check.
- Returns:
- True if x is within bounds and satisfies the constraint
function, False otherwise.
- Return type:
bool
- reset()[source]
Reset the space value to its initial value.
Sets the current value back to the init_value specified during initialization.
- sample(*args, max_tries=1000, warn_after_s=5.0, set_value=True, **kwargs)[source]
Sample a random value using rejection sampling for constraints.
Repeatedly samples values until one satisfies the constraint function or max_tries is reached. Optionally updates the space’s current value.
- Parameters:
*args – Positional arguments passed to gymnasium.spaces.Discrete.sample().
max_tries (int, optional) – Maximum number of sampling attempts before raising an error. Defaults to 1000.
warn_after_s (float, optional) – Time threshold in seconds after which to log a warning about slow sampling. Set to None to disable. Defaults to 5.0.
set_value (bool, optional) – Whether to update the space’s current value with the sampled value. Defaults to True.
**kwargs – Keyword arguments passed to gymnasium.spaces.Discrete.sample().
- Returns:
A sampled value that satisfies the constraint function.
- Return type:
int
- Raises:
RuntimeError – If no valid sample is found after max_tries attempts.
- class MultiDiscrete(*args, init_value=None, constrain_fn=None, **kwargs)[source]
Bases:
MultiDiscreteExtended multi-discrete space with state tracking and constraint support.
This class extends
gymnasium.spaces.MultiDiscreteto add state management and optional constraint validation. It represents multiple discrete variables with potentially different ranges (nvec), where each variable maintains its own value and can be constrained.- Variables:
init_value (np.ndarray) – The initial values for all discrete variables.
value (np.ndarray) – The current values of all discrete variables.
constrain_fn (callable) – Optional function that returns True if the entire value array satisfies custom constraints.
Example
Create a multi-discrete space for game difficulty settings:
import numpy as np space = MultiDiscrete( nvec=[5, 3, 10], # [enemy_count, speed_level, spawn_rate] init_value=np.array([2, 1, 5]), ) settings = space.sample() # Random difficulty configuration space.reset() # Resets to [2, 1, 5] (medium difficulty)Note
Constraints are applied to the entire array, not individual elements. Use a constraint function that validates the complete state.
- __init__(*args, init_value=None, constrain_fn=None, **kwargs)[source]
Initialize a MultiDiscrete space with state tracking.
- Parameters:
*args – Positional arguments passed to gymnasium.spaces.MultiDiscrete.
init_value (np.ndarray, optional) – Initial values for the space. Must match the shape defined by nvec. Defaults to None.
constrain_fn (callable, optional) – Function that takes a numpy array and returns True if the values satisfy custom constraints. Defaults to None.
**kwargs – Keyword arguments passed to gymnasium.spaces.MultiDiscrete.
- check()[source]
Validate the current space values.
Checks if the current values are within the space bounds and satisfy the constraint function. Logs a warning if the constraint fails.
- Returns:
True if the current values are valid, False otherwise.
- Return type:
bool
- contains(x)[source]
Check if values are valid and satisfy constraints.
- Parameters:
x (np.ndarray) – The array of values to check.
- Returns:
- True if x is within bounds for all elements and satisfies
the constraint function, False otherwise.
- Return type:
bool
- reset()[source]
Reset the space values to their initial values.
Sets the current values back to the init_value specified during initialization.
- sample(*args, max_tries=1000, warn_after_s=5.0, set_value=True, **kwargs)[source]
Sample random values using rejection sampling for constraints.
Repeatedly samples value arrays until one satisfies the constraint function or max_tries is reached. Optionally updates the space’s current values.
- Parameters:
*args – Positional arguments passed to gymnasium.spaces.MultiDiscrete.sample().
max_tries (int, optional) – Maximum number of sampling attempts before raising an error. Defaults to 1000.
warn_after_s (float, optional) – Time threshold in seconds after which to log a warning about slow sampling. Set to None to disable. Defaults to 5.0.
set_value (bool, optional) – Whether to update the space’s current values with the sampled values. Defaults to True.
**kwargs – Keyword arguments passed to gymnasium.spaces.MultiDiscrete.sample().
- Returns:
A sampled array that satisfies the constraint function.
- Return type:
np.ndarray
- Raises:
RuntimeError – If no valid sample is found after max_tries attempts.
- class RGBBox(shape=(3,), *args, init_value=None, **kwargs)[source]
Bases:
BoxSpecialized box space for RGB image data with automatic constraints.
This class extends
Boxto provide a convenient space for RGB images, automatically enforcing uint8 dtype and [0, 255] value ranges. It validates that the shape includes exactly 3 channels for RGB data.- Parameters:
shape (tuple) – Shape of the image. Must include a dimension of size 3 for the RGB channels. Common formats: (H, W, 3) or (3, H, W).
init_value (np.ndarray, optional) – Initial RGB image. Must match shape and be uint8 dtype.
*args – Additional positional arguments passed to Box.
**kwargs – Additional keyword arguments passed to Box.
- Variables:
init_value (np.ndarray) – The initial RGB image.
value (np.ndarray) – The current RGB image.
Example
Create a space for 64x64 RGB images:
import numpy as np space = RGBBox( shape=(64, 64, 3), init_value=np.zeros((64, 64, 3), dtype=np.uint8) ) image = space.sample() # Random RGB image space.reset() # Returns to black image- Raises:
AssertionError – If shape does not contain a dimension of size 3.
Note
This space is useful for vision-based environments where images need to be sampled or tracked as part of environment configuration. The low, high, and dtype parameters are automatically set and cannot be overridden.
stable_worldmodel.utils module
Utility functions for stable_worldmodel.
- flatten_dict(d, parent_key='', sep='.')[source]
Flatten a nested dictionary into a single-level dictionary with concatenated keys.
The naming convention for the new keys is similar to Hydra’s, using a . separator to denote levels of nesting. Attention is needed when flattening dictionaries with overlapping keys, as this may lead to information loss.
- Parameters:
d (dict) – The nested dictionary to flatten.
parent_key (str, optional) – The base key to use for the flattened keys.
sep (str, optional) – The separator to use between levels of nesting. Defaults to ‘.’.
- Returns:
A flattened version of the input dictionary.
- Return type:
dict
Examples
>>> info = {"a": {"b": {"c": 42, "d": 43}}, "e": 44} >>> flatten_dict(info) {'a.b.c': 42, 'a.b.d': 43, 'e': 44}>>> flatten_dict({"a": {"b": 2}, "a.b": 3}) {'a.b': 3}
- get_in(mapping: dict, path: Iterable[str]) Any[source]
Retrieve a value from a nested dictionary using a sequence of keys.
- Parameters:
mapping (dict) – A nested dictionary.
path (Iterable[str]) – An iterable of keys representing the path to the desired value in mapping.
- Returns:
The value located at the specified path in the nested dictionary.
- Return type:
Any
- Raises:
KeyError – If any key in the path does not exist in the mapping dict.
Examples
>>> variations = {"a": {"b": {"c": 42}}} >>> get_in(variations, ["a", "b", "c"]) 42
- pretraining(script_path: str, dataset_name: str, output_model_name: str, dump_object: bool = True, args: str = '') int[source]
Run a pretraining script as a subprocess with optional command-line arguments.
This function checks if the specified script exists, constructs a command to run it with the provided arguments, and executes the command in a subprocess.
- Parameters:
script_path (str) – The path to the pretraining script to be executed.
dataset_name (str) – The name of the dataset to be used in pretraining.
output_model_name (str) – The name to save the output model.
dump_object (bool, optional) – Whether to dump the model object after training. Defaults to
args (str, optional) – A string of command-line arguments to pass to the script. Defaults to an empty string.
- Returns:
The return code of the subprocess. A return code of 0 indicates success.
- Return type:
int
- Raises:
ValueError – If the specified script does not exist.
SystemExit – If the subprocess exits with a non-zero return code.
stable_worldmodel.world module
World environment manager.
This module provides the World class, a high-level manager for vectorized Gymnasium environments with integrated support for data collection, policy evaluation, video recording, and dataset management. It serves as the central orchestration layer for training and evaluating world models with domain randomization support.
- The World class handles:
Vectorized environment creation and management
Policy attachment and execution
Episode data collection with automatic sharding
Video recording from live rollouts or stored datasets
Policy evaluation with comprehensive metrics
Visual domain randomization through variation spaces
Example
Basic usage for policy evaluation:
from stable_worldmodel import World
from stable_worldmodel.policy import RandomPolicy
# Create a world with 4 parallel environments
world = World(
env_name="PushT-v1",
num_envs=4,
image_shape=(96, 96),
max_episode_steps=200
)
# Attach a policy
policy = RandomPolicy()
world.set_policy(policy)
# Evaluate the policy
metrics = world.evaluate(episodes=100, seed=42)
print(f"Success rate: {metrics['success_rate']:.2f}%")
Data collection example:
# Collect demonstration data
world.record_dataset(
dataset_name="pusht_demos",
episodes=1000,
seed=42,
options={'variation': ['all']}
)
- class World(env_name: str, num_envs: int, image_shape: tuple, goal_shape: tuple | None = None, goal_transform: Callable | None = None, image_transform: Callable | None = None, seed: int = 2349867, max_episode_steps: int = 100, verbose: int = 1, **kwargs)[source]
Bases:
objectHigh-level manager for vectorized Gymnasium environments with integrated data collection.
World orchestrates multiple parallel environments, providing a unified interface for policy execution, data collection, evaluation, and visualization. It automatically handles environment resets, batched action execution, and comprehensive episode tracking with support for visual domain randomization.
The World class is the central component of the stable-worldmodel library, designed to streamline the workflow from data collection to policy training and evaluation. It wraps Gymnasium’s vectorized environments with additional functionality for world model research.
- Variables:
envs (VectorEnv) – Vectorized Gymnasium environment wrapped with MegaWrapper and VariationWrapper for image processing and domain randomization.
seed (int) – Base random seed for reproducibility.
policy (Policy) – Currently attached policy that generates actions.
states (ndarray) – Current observation states for all environments.
rewards (ndarray) – Current rewards for all environments.
terminateds (ndarray) – Terminal flags indicating episode completion.
truncateds (ndarray) – Truncation flags indicating episode timeout.
infos (dict) – Dictionary containing major information from environments.
- Properties:
num_envs (int): Number of parallel environments. observation_space (Space): Batched observation space. action_space (Space): Batched action space. variation_space (Dict): Variation space for domain randomization (batched). single_variation_space (Dict): Variation space for a single environment. single_action_space (Space): Action space for a single environment. single_observation_space (Space): Observation space for a single environment.
Note
Environments use DISABLED autoreset mode to enable custom reset behavior with seed and options support, which is not provided by Gymnasium’s default autoreset mechanism.
- __init__(env_name: str, num_envs: int, image_shape: tuple, goal_shape: tuple | None = None, goal_transform: Callable | None = None, image_transform: Callable | None = None, seed: int = 2349867, max_episode_steps: int = 100, verbose: int = 1, **kwargs)[source]
Initialize the World with vectorized environments.
Creates and configures a vectorized environment with the specified number of parallel instances. Applies image and goal transformations, sets up variation support, and configures autoreset behavior.
- Parameters:
env_name (str) – Name of the Gymnasium environment to create. Must be a registered environment ID (e.g., ‘PushT-v0’, ‘CubeEnv-v0’).
num_envs (int) – Number of parallel environment instances to create. Higher values increase data collection throughput but require more memory.
image_shape (tuple) – Target shape for image observations as (height, width) or (height, width, channels). Images are resized to this shape.
goal_shape (tuple, optional) – Target shape for goal image observations. If None, goals are processed with the same shape as observations. Defaults to None.
goal_transform (Callable, optional) – Function to transform goal observations. Should accept and return numpy arrays. Applied after resizing. Defaults to None.
image_transform (Callable, optional) – Function to transform image observations. Should accept and return numpy arrays. Applied after resizing. Defaults to None.
seed (int, optional) – Base random seed for environment initialization. Each environment gets an offset seed. Defaults to 2349867.
max_episode_steps (int, optional) – Maximum number of steps per episode before truncation. Episodes terminate early on task success. Defaults to 100.
verbose (int, optional) – Verbosity level. 0 for silent, 1 for basic info, 2+ for detailed debugging information. Defaults to 1.
**kwargs – Additional keyword arguments passed to gym.make_vec() and subsequently to the underlying environment constructor.
Note
The MegaWrapper applies image transformations and resizing. The VariationWrapper enables domain randomization support. Autoreset is disabled to allow custom reset with seeds and options.
Example
Create a world with goal-conditioned observations:
world = World( env_name="PushT-v1", num_envs=8, image_shape=(96, 96), goal_shape=(64, 64), max_episode_steps=150, seed=42 )
- close(**kwargs)[source]
Close all environments and clean up resources.
- Parameters:
**kwargs – Additional keyword arguments passed to the underlying vectorized environment’s close method.
- Returns:
Return value from the underlying close method.
- Return type:
Any
- evaluate(episodes=10, eval_keys=None, seed=None, options=None)[source]
Evaluate the current policy over multiple episodes and return comprehensive metrics.
Runs the attached policy for a specified number of episodes, tracking success rates and optionally other metrics from environment info. Handles episode boundaries and ensures reproducibility through seeding.
- Parameters:
episodes (int, optional) – Total number of episodes to evaluate. More episodes provide more statistically reliable metrics. Defaults to 10.
eval_keys (list of str, optional) – Additional info keys to track across episodes. Must be keys present in self.infos (e.g., ‘reward_total’, ‘steps_to_success’). Defaults to None (track only success rate).
seed (int, optional) – Base random seed for reproducible evaluation. Each episode gets an incremental offset. Defaults to None (non-deterministic).
options (dict, optional) – Reset options passed to environments (e.g., task selection, variation settings). Defaults to None.
- Returns:
- Dictionary containing evaluation metrics:
’success_rate’ (float): Percentage of successful episodes (0-100)
’episode_successes’ (ndarray): Boolean array of episode outcomes
’seeds’ (ndarray): Random seeds used for each episode
Additional keys from eval_keys if specified (ndarray)
- Return type:
dict
- Raises:
AssertionError – If eval_key is not found in infos, if episode count mismatch occurs, or if duplicate seeds are detected.
Note
Success is determined by the ‘terminateds’ flag (True = success)
‘truncateds’ flag (timeout) is treated as failure
Seeds are validated for uniqueness to ensure independent episodes
Environments are manually reset with seeds (autoreset is bypassed)
All environments are evaluated in parallel for efficiency
Example
Basic evaluation:
metrics = world.evaluate(episodes=100, seed=42) print(f"Success: {metrics['success_rate']:.1f}%")Evaluate with additional metrics:
metrics = world.evaluate( episodes=100, eval_keys=['reward_total', 'episode_length'], seed=42, options={'task_id': 3} ) print(f"Success: {metrics['success_rate']:.1f}%") print(f"Avg reward: {metrics['reward_total'].mean():.2f}")Evaluate across different variations:
for var_type in ['none', 'color', 'light', 'all']: options = {'variation': [var_type]} if var_type != 'none' else None metrics = world.evaluate(episodes=50, seed=42, options=options) print(f"{var_type}: {metrics['success_rate']:.1f}%")
- record_dataset(dataset_name, episodes=10, seed=None, cache_dir=None, options=None)[source]
Collect episodes with the current policy and save as a HuggingFace Dataset.
Executes the attached policy to collect demonstration or rollout data, automatically managing episode boundaries and saving all observations, actions, rewards, and auxiliary information as a sharded Parquet dataset. Images are stored as JPEG files with paths in the dataset.
- The dataset is organized with the following structure:
- dataset_name/
- img/
- {episode_idx}/
{step_idx}_{column_name}.jpeg
data-{shard}.arrow (Parquet shards)
dataset_info.json
state.json
- Parameters:
dataset_name (str) – Name of the dataset. Used as subdirectory name in cache_dir. Should be descriptive (e.g., ‘pusht_expert_demos’).
episodes (int, optional) – Total number of complete episodes to collect. Incomplete episodes at the end are discarded. Defaults to 10.
seed (int, optional) – Base random seed for reproducibility. Each episode gets an incremental offset. Defaults to None (non-deterministic).
cache_dir (str or Path, optional) – Root directory for dataset storage. If None, uses swm.data.get_cache_dir(). Defaults to None.
options (dict, optional) – Reset options for environments. Use {‘variation’: [‘all’]} for full domain randomization. Defaults to None.
- Dataset Schema:
- Each row contains:
episode_idx (int32): Episode identifier
step_idx (int32): Step within episode (0-indexed)
episode_len (int32): Total length of the episode
policy (string): Policy type identifier
pixels (string): Relative path to observation image
goal (string, optional): Relative path to goal image
action (float array): Action taken at this step
reward (float): Reward received
Additional keys from environment infos
Note
Actions are shifted: action at step t leads to observation at step t+1
Last action in each episode is NaN (no action leads from final state)
Images are saved as JPEG for efficiency (may introduce compression artifacts)
Dataset is automatically sharded: 1 shard per 50 episodes
Only complete episodes are included in the final dataset
- Raises:
AssertionError – If required keys (‘pixels’, ‘episode_idx’, ‘step_idx’, ‘episode_len’) are missing from recorded data.
Example
Collect demonstration data with variations:
world.set_policy(expert_policy) world.record_dataset( dataset_name="cube_expert_1k", episodes=1000, seed=42, options={'variation': ['all']} )Collect data for specific tasks:
world.record_dataset( dataset_name="pusht_task2_data", episodes=500, options={'task_id': 2} )
- record_video(video_path, max_steps=500, fps=30, viewname='pixels', seed=None, options=None)[source]
Record rollout videos for each environment under the current policy.
Executes policy rollouts in all environments and saves them as MP4 videos, one per environment. Videos show stacked observation and goal images when goals are available.
- Parameters:
video_path (str or Path) – Directory path where videos will be saved. Created if it doesn’t exist. Videos are named ‘env_0.mp4’, ‘env_1.mp4’, etc.
max_steps (int, optional) – Maximum number of steps to record per episode. Recording stops earlier if any environment terminates. Defaults to 500.
fps (int, optional) – Frames per second for output videos. Higher values produce smoother playback but larger files. Defaults to 30.
seed (int, optional) – Random seed for reproducible rollouts. Defaults to None.
options (dict, optional) – Reset options passed to environments (e.g., variation settings). Defaults to None.
Note
Videos use libx264 codec for compatibility
If ‘goal’ is present in infos, frames show observation stacked above goal
All environments are recorded simultaneously
Recording stops when ANY environment terminates or truncates
Example
Record evaluation videos:
world.set_policy(trained_policy) world.record_video( video_path="./eval_videos", max_steps=200, fps=30, seed=42 )
- record_video_from_dataset(video_path, dataset_name, episode_idx, max_steps=500, fps=30, num_proc=4, viewname: str | list[str] = 'pixels', cache_dir=None)[source]
Replay stored dataset episodes and export them as MP4 videos.
Loads episodes from a previously recorded dataset and renders them as videos, useful for visualization, debugging, and qualitative evaluation of collected data.
- Parameters:
video_path (str or Path) – Directory where videos will be saved. Videos are named ‘episode_{idx}.mp4’. Directory is created if it doesn’t exist.
dataset_name (str) – Name of the dataset to load (must exist in cache_dir).
episode_idx (int or list of int) – Episode index or list of episode indices to render. Each episode is saved as a separate video file.
max_steps (int, optional) – Maximum number of steps to render per episode. Useful for limiting video length. Defaults to 500.
fps (int, optional) – Frames per second for output videos. Defaults to 30.
num_proc (int, optional) – Number of processes for parallel dataset filtering. Higher values speed up loading for large datasets. Defaults to 4.
cache_dir (str or Path, optional) – Root directory where dataset is stored. If None, uses swm.data.get_cache_dir(). Defaults to None.
- Raises:
AssertionError – If dataset doesn’t exist in cache_dir, or if episode length inconsistencies are detected in the data.
Note
Images are loaded from JPEG files stored in the dataset
If ‘goal’ column exists, observation and goal are stacked vertically
Videos use libx264 codec for broad compatibility
Episodes are validated for consistency (length matches metadata)
Example
Render specific episodes from a dataset:
world.record_video_from_dataset( video_path="./visualizations", dataset_name="cube_expert_1k", episode_idx=[0, 5, 10, 99], fps=30 )Render a single episode:
world.record_video_from_dataset( video_path="./debug", dataset_name="failed_episodes", episode_idx=42, max_steps=100 )
- reset(seed=None, options=None)[source]
Reset all environments to initial states.
- Parameters:
seed (int, optional) – Random seed for reproducible resets. If None, uses non-deterministic seeding. Defaults to None.
options (dict, optional) – Dictionary of reset options passed to environments. Common keys include ‘variation’ for domain randomization. Defaults to None.
- Updates:
self.states: Initial observations from all environments
self.infos: Initial auxiliary information
Example
Reset with domain randomization:
world.reset(seed=42, options={'variation': ['all']})Reset specific variations:
world.reset(options={'variation': ['cube.color', 'light.intensity']})
- set_policy(policy)[source]
Attach a policy to the world and configure it with environment context.
The policy will be used for action generation during step(), record_video(), record_dataset(), and evaluate() calls.
- Parameters:
policy (Policy) – Policy instance that implements get_action(infos) method. The policy receives environment context through set_env() and optional seeding through set_seed().
Note
If the policy has a ‘seed’ attribute, it will be applied via set_seed(). The policy’s set_env() method receives the wrapped vectorized environment.
- step()[source]
Advance all environments by one step using the current policy.
Queries the attached policy for actions based on current info, executes those actions in all environments, and updates internal state with the results.
- Updates:
self.states: New observations from all environments
self.rewards: Rewards received from the step
self.terminateds: Episode termination flags (task success)
self.truncateds: Episode truncation flags (timeout)
self.infos: Auxiliary information dictionaries
Note
Requires a policy to be attached via set_policy() before calling. The policy’s get_action() method receives the current infos dict.
- Raises:
AttributeError – If no policy has been set via set_policy().
stable_worldmodel.wrappers module
- class AddPixelsWrapper(env, pixels_shape: tuple[int, int] = (84, 84), torchvision_transform: Callable | None = None)[source]
Bases:
WrapperAdds rendered environment pixels to info dict with optional resizing and transforms.
Supports single images, dictionaries of images (multiview), or lists of images. Uses PIL for resizing and optional torchvision transforms.
- Parameters:
env – The Gymnasium environment to wrap.
pixels_shape – Target (height, width) for resized images. Defaults to (84, 84).
torchvision_transform – Optional transform to apply to PIL images.
- Info Keys Added:
pixels: Rendered image (single view).
pixels.{key}: Individual images (multiview dict).
pixels.{idx}: Individual images (multiview list).
render_time: Time taken to render in seconds.
- class EnsureGoalInfoWrapper(env, check_reset, check_step: bool = False)[source]
Bases:
WrapperValidates that ‘goal’ key is present in info dict during reset and/or step.
Useful for goal-conditioned environments to ensure goal information is provided.
- Parameters:
env – The Gymnasium environment to wrap.
check_reset – If True, validates ‘goal’ key is in info after reset().
check_step – If True, validates ‘goal’ key is in info after step().
- Raises:
RuntimeError – If ‘goal’ key is missing when validation is enabled.
- class EnsureImageShape(env, image_key, image_shape)[source]
Bases:
WrapperValidates that an image in the info dict has the expected spatial dimensions.
- Parameters:
env – The Gymnasium environment to wrap.
image_key – Key in info dict containing the image to validate.
image_shape – Expected (height, width) tuple for the image.
- Raises:
RuntimeError – If the image shape doesn’t match the expected dimensions.
- class EnsureInfoKeysWrapper(env, required_keys: Iterable[str])[source]
Bases:
WrapperValidates that required keys are present in the info dict after reset and step.
Supports regex patterns for flexible key matching. Raises RuntimeError if any required pattern has no matching key.
- Parameters:
env – The Gymnasium environment to wrap.
required_keys – Iterable of regex patterns as strings. Each pattern must match at least one key in the info dict.
- Raises:
RuntimeError – If any required pattern has no matching key in info dict.
- class EverythingToInfoWrapper(env)[source]
Bases:
WrapperMoves all transition information into the info dict for unified data access.
Adds observation, reward, terminated, truncated, action, and step_idx to info. Optionally tracks environment variations when specified in reset options.
- Parameters:
env – The Gymnasium environment to wrap.
- Info Keys Added:
observation (or dict keys if obs is dict): Current observation.
reward: Reward value (NaN after reset).
terminated: Episode termination flag.
truncated: Episode truncation flag.
action: Action taken (NaN sample after reset).
step_idx: Current step counter.
variation.{key}: Variation values if requested via reset options.
Note
Pass options={“variation”: [“key1”, “key2”]} or [“all”] to reset() to track variations.
- class MegaWrapper(env, image_shape: tuple[int, int] = (84, 84), pixels_transform: Callable | None = None, goal_transform: Callable | None = None, required_keys: Iterable | None = None, separate_goal: Iterable | None = True)[source]
Bases:
WrapperCombines multiple wrappers for comprehensive environment preprocessing.
Applies in sequence: AddPixelsWrapper → EverythingToInfoWrapper → EnsureInfoKeysWrapper → EnsureGoalInfoWrapper → ResizeGoalWrapper.
This provides a complete preprocessing pipeline with rendered pixels, unified info dict, key validation, goal checking, and goal resizing.
- Parameters:
env – The Gymnasium environment to wrap.
image_shape – Target (height, width) for pixels and goal. Defaults to (84, 84).
pixels_transform – Optional torchvision transform for rendered pixels.
goal_transform – Optional torchvision transform for goal images.
required_keys – Additional regex patterns for keys that must be in info. Pattern
^pixels(?:\..*)?$is always added.separate_goal – If True, validates ‘goal’ is present in info. Defaults to True.
- class ResizeGoalWrapper(env, pixels_shape: tuple[int, int] = (84, 84), torchvision_transform: Callable | None = None)[source]
Bases:
WrapperResizes goal images in info dict with optional transforms.
Applies PIL-based resizing and optional torchvision transforms to the ‘goal’ image in info dict during both reset and step.
- Parameters:
env – The Gymnasium environment to wrap.
pixels_shape – Target (height, width) for resized goal images. Defaults to (84, 84).
torchvision_transform – Optional transform to apply to PIL goal images.
- class VariationWrapper(env, variation_mode: str | Space = 'same')[source]
Bases:
VectorWrapperManages variation spaces for vectorized environments.
Handles batching of variation spaces across multiple environments, supporting either shared variations (same) or independent variations (different).
- Parameters:
env – The vectorized Gymnasium environment to wrap.
variation_mode – Mode for handling variations across environments: - “same”: All environments share the same variation space (batched). - “different”: Each environment has independent variation spaces.
- Raises:
ValueError – If variation_mode is invalid or sub-environment spaces don’t match.
Note
Base environment must have a
variation_spaceattribute. If missing, variation spaces are set to None.
Module contents
- class PlanConfig(horizon: int, receding_horizon: int, history_len: int = 1, action_block: int = 1, warm_start: bool = True)[source]
Bases:
objectConfiguration for the planning process.
- class World(env_name: str, num_envs: int, image_shape: tuple, goal_shape: tuple | None = None, goal_transform: Callable | None = None, image_transform: Callable | None = None, seed: int = 2349867, max_episode_steps: int = 100, verbose: int = 1, **kwargs)[source]
Bases:
objectHigh-level manager for vectorized Gymnasium environments with integrated data collection.
World orchestrates multiple parallel environments, providing a unified interface for policy execution, data collection, evaluation, and visualization. It automatically handles environment resets, batched action execution, and comprehensive episode tracking with support for visual domain randomization.
The World class is the central component of the stable-worldmodel library, designed to streamline the workflow from data collection to policy training and evaluation. It wraps Gymnasium’s vectorized environments with additional functionality for world model research.
- Variables:
envs (VectorEnv) – Vectorized Gymnasium environment wrapped with MegaWrapper and VariationWrapper for image processing and domain randomization.
seed (int) – Base random seed for reproducibility.
policy (Policy) – Currently attached policy that generates actions.
states (ndarray) – Current observation states for all environments.
rewards (ndarray) – Current rewards for all environments.
terminateds (ndarray) – Terminal flags indicating episode completion.
truncateds (ndarray) – Truncation flags indicating episode timeout.
infos (dict) – Dictionary containing major information from environments.
- Properties:
num_envs (int): Number of parallel environments. observation_space (Space): Batched observation space. action_space (Space): Batched action space. variation_space (Dict): Variation space for domain randomization (batched). single_variation_space (Dict): Variation space for a single environment. single_action_space (Space): Action space for a single environment. single_observation_space (Space): Observation space for a single environment.
Note
Environments use DISABLED autoreset mode to enable custom reset behavior with seed and options support, which is not provided by Gymnasium’s default autoreset mechanism.
- __init__(env_name: str, num_envs: int, image_shape: tuple, goal_shape: tuple | None = None, goal_transform: Callable | None = None, image_transform: Callable | None = None, seed: int = 2349867, max_episode_steps: int = 100, verbose: int = 1, **kwargs)[source]
Initialize the World with vectorized environments.
Creates and configures a vectorized environment with the specified number of parallel instances. Applies image and goal transformations, sets up variation support, and configures autoreset behavior.
- Parameters:
env_name (str) – Name of the Gymnasium environment to create. Must be a registered environment ID (e.g., ‘PushT-v0’, ‘CubeEnv-v0’).
num_envs (int) – Number of parallel environment instances to create. Higher values increase data collection throughput but require more memory.
image_shape (tuple) – Target shape for image observations as (height, width) or (height, width, channels). Images are resized to this shape.
goal_shape (tuple, optional) – Target shape for goal image observations. If None, goals are processed with the same shape as observations. Defaults to None.
goal_transform (Callable, optional) – Function to transform goal observations. Should accept and return numpy arrays. Applied after resizing. Defaults to None.
image_transform (Callable, optional) – Function to transform image observations. Should accept and return numpy arrays. Applied after resizing. Defaults to None.
seed (int, optional) – Base random seed for environment initialization. Each environment gets an offset seed. Defaults to 2349867.
max_episode_steps (int, optional) – Maximum number of steps per episode before truncation. Episodes terminate early on task success. Defaults to 100.
verbose (int, optional) – Verbosity level. 0 for silent, 1 for basic info, 2+ for detailed debugging information. Defaults to 1.
**kwargs – Additional keyword arguments passed to gym.make_vec() and subsequently to the underlying environment constructor.
Note
The MegaWrapper applies image transformations and resizing. The VariationWrapper enables domain randomization support. Autoreset is disabled to allow custom reset with seeds and options.
Example
Create a world with goal-conditioned observations:
world = World( env_name="PushT-v1", num_envs=8, image_shape=(96, 96), goal_shape=(64, 64), max_episode_steps=150, seed=42 )
- close(**kwargs)[source]
Close all environments and clean up resources.
- Parameters:
**kwargs – Additional keyword arguments passed to the underlying vectorized environment’s close method.
- Returns:
Return value from the underlying close method.
- Return type:
Any
- evaluate(episodes=10, eval_keys=None, seed=None, options=None)[source]
Evaluate the current policy over multiple episodes and return comprehensive metrics.
Runs the attached policy for a specified number of episodes, tracking success rates and optionally other metrics from environment info. Handles episode boundaries and ensures reproducibility through seeding.
- Parameters:
episodes (int, optional) – Total number of episodes to evaluate. More episodes provide more statistically reliable metrics. Defaults to 10.
eval_keys (list of str, optional) – Additional info keys to track across episodes. Must be keys present in self.infos (e.g., ‘reward_total’, ‘steps_to_success’). Defaults to None (track only success rate).
seed (int, optional) – Base random seed for reproducible evaluation. Each episode gets an incremental offset. Defaults to None (non-deterministic).
options (dict, optional) – Reset options passed to environments (e.g., task selection, variation settings). Defaults to None.
- Returns:
- Dictionary containing evaluation metrics:
’success_rate’ (float): Percentage of successful episodes (0-100)
’episode_successes’ (ndarray): Boolean array of episode outcomes
’seeds’ (ndarray): Random seeds used for each episode
Additional keys from eval_keys if specified (ndarray)
- Return type:
dict
- Raises:
AssertionError – If eval_key is not found in infos, if episode count mismatch occurs, or if duplicate seeds are detected.
Note
Success is determined by the ‘terminateds’ flag (True = success)
‘truncateds’ flag (timeout) is treated as failure
Seeds are validated for uniqueness to ensure independent episodes
Environments are manually reset with seeds (autoreset is bypassed)
All environments are evaluated in parallel for efficiency
Example
Basic evaluation:
metrics = world.evaluate(episodes=100, seed=42) print(f"Success: {metrics['success_rate']:.1f}%")Evaluate with additional metrics:
metrics = world.evaluate( episodes=100, eval_keys=['reward_total', 'episode_length'], seed=42, options={'task_id': 3} ) print(f"Success: {metrics['success_rate']:.1f}%") print(f"Avg reward: {metrics['reward_total'].mean():.2f}")Evaluate across different variations:
for var_type in ['none', 'color', 'light', 'all']: options = {'variation': [var_type]} if var_type != 'none' else None metrics = world.evaluate(episodes=50, seed=42, options=options) print(f"{var_type}: {metrics['success_rate']:.1f}%")
- record_dataset(dataset_name, episodes=10, seed=None, cache_dir=None, options=None)[source]
Collect episodes with the current policy and save as a HuggingFace Dataset.
Executes the attached policy to collect demonstration or rollout data, automatically managing episode boundaries and saving all observations, actions, rewards, and auxiliary information as a sharded Parquet dataset. Images are stored as JPEG files with paths in the dataset.
- The dataset is organized with the following structure:
- dataset_name/
- img/
- {episode_idx}/
{step_idx}_{column_name}.jpeg
data-{shard}.arrow (Parquet shards)
dataset_info.json
state.json
- Parameters:
dataset_name (str) – Name of the dataset. Used as subdirectory name in cache_dir. Should be descriptive (e.g., ‘pusht_expert_demos’).
episodes (int, optional) – Total number of complete episodes to collect. Incomplete episodes at the end are discarded. Defaults to 10.
seed (int, optional) – Base random seed for reproducibility. Each episode gets an incremental offset. Defaults to None (non-deterministic).
cache_dir (str or Path, optional) – Root directory for dataset storage. If None, uses swm.data.get_cache_dir(). Defaults to None.
options (dict, optional) – Reset options for environments. Use {‘variation’: [‘all’]} for full domain randomization. Defaults to None.
- Dataset Schema:
- Each row contains:
episode_idx (int32): Episode identifier
step_idx (int32): Step within episode (0-indexed)
episode_len (int32): Total length of the episode
policy (string): Policy type identifier
pixels (string): Relative path to observation image
goal (string, optional): Relative path to goal image
action (float array): Action taken at this step
reward (float): Reward received
Additional keys from environment infos
Note
Actions are shifted: action at step t leads to observation at step t+1
Last action in each episode is NaN (no action leads from final state)
Images are saved as JPEG for efficiency (may introduce compression artifacts)
Dataset is automatically sharded: 1 shard per 50 episodes
Only complete episodes are included in the final dataset
- Raises:
AssertionError – If required keys (‘pixels’, ‘episode_idx’, ‘step_idx’, ‘episode_len’) are missing from recorded data.
Example
Collect demonstration data with variations:
world.set_policy(expert_policy) world.record_dataset( dataset_name="cube_expert_1k", episodes=1000, seed=42, options={'variation': ['all']} )Collect data for specific tasks:
world.record_dataset( dataset_name="pusht_task2_data", episodes=500, options={'task_id': 2} )
- record_video(video_path, max_steps=500, fps=30, viewname='pixels', seed=None, options=None)[source]
Record rollout videos for each environment under the current policy.
Executes policy rollouts in all environments and saves them as MP4 videos, one per environment. Videos show stacked observation and goal images when goals are available.
- Parameters:
video_path (str or Path) – Directory path where videos will be saved. Created if it doesn’t exist. Videos are named ‘env_0.mp4’, ‘env_1.mp4’, etc.
max_steps (int, optional) – Maximum number of steps to record per episode. Recording stops earlier if any environment terminates. Defaults to 500.
fps (int, optional) – Frames per second for output videos. Higher values produce smoother playback but larger files. Defaults to 30.
seed (int, optional) – Random seed for reproducible rollouts. Defaults to None.
options (dict, optional) – Reset options passed to environments (e.g., variation settings). Defaults to None.
Note
Videos use libx264 codec for compatibility
If ‘goal’ is present in infos, frames show observation stacked above goal
All environments are recorded simultaneously
Recording stops when ANY environment terminates or truncates
Example
Record evaluation videos:
world.set_policy(trained_policy) world.record_video( video_path="./eval_videos", max_steps=200, fps=30, seed=42 )
- record_video_from_dataset(video_path, dataset_name, episode_idx, max_steps=500, fps=30, num_proc=4, viewname: str | list[str] = 'pixels', cache_dir=None)[source]
Replay stored dataset episodes and export them as MP4 videos.
Loads episodes from a previously recorded dataset and renders them as videos, useful for visualization, debugging, and qualitative evaluation of collected data.
- Parameters:
video_path (str or Path) – Directory where videos will be saved. Videos are named ‘episode_{idx}.mp4’. Directory is created if it doesn’t exist.
dataset_name (str) – Name of the dataset to load (must exist in cache_dir).
episode_idx (int or list of int) – Episode index or list of episode indices to render. Each episode is saved as a separate video file.
max_steps (int, optional) – Maximum number of steps to render per episode. Useful for limiting video length. Defaults to 500.
fps (int, optional) – Frames per second for output videos. Defaults to 30.
num_proc (int, optional) – Number of processes for parallel dataset filtering. Higher values speed up loading for large datasets. Defaults to 4.
cache_dir (str or Path, optional) – Root directory where dataset is stored. If None, uses swm.data.get_cache_dir(). Defaults to None.
- Raises:
AssertionError – If dataset doesn’t exist in cache_dir, or if episode length inconsistencies are detected in the data.
Note
Images are loaded from JPEG files stored in the dataset
If ‘goal’ column exists, observation and goal are stacked vertically
Videos use libx264 codec for broad compatibility
Episodes are validated for consistency (length matches metadata)
Example
Render specific episodes from a dataset:
world.record_video_from_dataset( video_path="./visualizations", dataset_name="cube_expert_1k", episode_idx=[0, 5, 10, 99], fps=30 )Render a single episode:
world.record_video_from_dataset( video_path="./debug", dataset_name="failed_episodes", episode_idx=42, max_steps=100 )
- reset(seed=None, options=None)[source]
Reset all environments to initial states.
- Parameters:
seed (int, optional) – Random seed for reproducible resets. If None, uses non-deterministic seeding. Defaults to None.
options (dict, optional) – Dictionary of reset options passed to environments. Common keys include ‘variation’ for domain randomization. Defaults to None.
- Updates:
self.states: Initial observations from all environments
self.infos: Initial auxiliary information
Example
Reset with domain randomization:
world.reset(seed=42, options={'variation': ['all']})Reset specific variations:
world.reset(options={'variation': ['cube.color', 'light.intensity']})
- set_policy(policy)[source]
Attach a policy to the world and configure it with environment context.
The policy will be used for action generation during step(), record_video(), record_dataset(), and evaluate() calls.
- Parameters:
policy (Policy) – Policy instance that implements get_action(infos) method. The policy receives environment context through set_env() and optional seeding through set_seed().
Note
If the policy has a ‘seed’ attribute, it will be applied via set_seed(). The policy’s set_env() method receives the wrapped vectorized environment.
- step()[source]
Advance all environments by one step using the current policy.
Queries the attached policy for actions based on current info, executes those actions in all environments, and updates internal state with the results.
- Updates:
self.states: New observations from all environments
self.rewards: Rewards received from the step
self.terminateds: Episode termination flags (task success)
self.truncateds: Episode truncation flags (timeout)
self.infos: Auxiliary information dictionaries
Note
Requires a policy to be attached via set_policy() before calling. The policy’s get_action() method receives the current infos dict.
- Raises:
AttributeError – If no policy has been set via set_policy().
- pretraining(script_path: str, dataset_name: str, output_model_name: str, dump_object: bool = True, args: str = '') int[source]
Run a pretraining script as a subprocess with optional command-line arguments.
This function checks if the specified script exists, constructs a command to run it with the provided arguments, and executes the command in a subprocess.
- Parameters:
script_path (str) – The path to the pretraining script to be executed.
dataset_name (str) – The name of the dataset to be used in pretraining.
output_model_name (str) – The name to save the output model.
dump_object (bool, optional) – Whether to dump the model object after training. Defaults to
args (str, optional) – A string of command-line arguments to pass to the script. Defaults to an empty string.
- Returns:
The return code of the subprocess. A return code of 0 indicates success.
- Return type:
int
- Raises:
ValueError – If the specified script does not exist.
SystemExit – If the subprocess exits with a non-zero return code.