World
- class World(env_name: str, num_envs: int, image_shape: tuple, goal_shape: tuple | None = None, goal_transform: Callable | None = None, image_transform: Callable | None = None, seed: int = 2349867, max_episode_steps: int = 100, verbose: int = 1, **kwargs)[source]
Bases:
objectHigh-level manager for vectorized Gymnasium environments with integrated data collection.
World orchestrates multiple parallel environments, providing a unified interface for policy execution, data collection, evaluation, and visualization. It automatically handles environment resets, batched action execution, and comprehensive episode tracking with support for visual domain randomization.
The World class is the central component of the stable-worldmodel library, designed to streamline the workflow from data collection to policy training and evaluation. It wraps Gymnasium’s vectorized environments with additional functionality for world model research.
- Variables:
envs (VectorEnv) – Vectorized Gymnasium environment wrapped with MegaWrapper and VariationWrapper for image processing and domain randomization.
seed (int) – Base random seed for reproducibility.
policy (Policy) – Currently attached policy that generates actions.
states (ndarray) – Current observation states for all environments.
rewards (ndarray) – Current rewards for all environments.
terminateds (ndarray) – Terminal flags indicating episode completion.
truncateds (ndarray) – Truncation flags indicating episode timeout.
infos (dict) – Dictionary containing major information from environments.
- Properties:
num_envs (int): Number of parallel environments. observation_space (Space): Batched observation space. action_space (Space): Batched action space. variation_space (Dict): Variation space for domain randomization (batched). single_variation_space (Dict): Variation space for a single environment. single_action_space (Space): Action space for a single environment. single_observation_space (Space): Observation space for a single environment.
Note
Environments use DISABLED autoreset mode to enable custom reset behavior with seed and options support, which is not provided by Gymnasium’s default autoreset mechanism.
- __init__(env_name: str, num_envs: int, image_shape: tuple, goal_shape: tuple | None = None, goal_transform: Callable | None = None, image_transform: Callable | None = None, seed: int = 2349867, max_episode_steps: int = 100, verbose: int = 1, **kwargs)[source]
Initialize the World with vectorized environments.
Creates and configures a vectorized environment with the specified number of parallel instances. Applies image and goal transformations, sets up variation support, and configures autoreset behavior.
- Parameters:
env_name (str) – Name of the Gymnasium environment to create. Must be a registered environment ID (e.g., ‘PushT-v0’, ‘CubeEnv-v0’).
num_envs (int) – Number of parallel environment instances to create. Higher values increase data collection throughput but require more memory.
image_shape (tuple) – Target shape for image observations as (height, width) or (height, width, channels). Images are resized to this shape.
goal_shape (tuple, optional) – Target shape for goal image observations. If None, goals are processed with the same shape as observations. Defaults to None.
goal_transform (Callable, optional) – Function to transform goal observations. Should accept and return numpy arrays. Applied after resizing. Defaults to None.
image_transform (Callable, optional) – Function to transform image observations. Should accept and return numpy arrays. Applied after resizing. Defaults to None.
seed (int, optional) – Base random seed for environment initialization. Each environment gets an offset seed. Defaults to 2349867.
max_episode_steps (int, optional) – Maximum number of steps per episode before truncation. Episodes terminate early on task success. Defaults to 100.
verbose (int, optional) – Verbosity level. 0 for silent, 1 for basic info, 2+ for detailed debugging information. Defaults to 1.
**kwargs – Additional keyword arguments passed to gym.make_vec() and subsequently to the underlying environment constructor.
Note
The MegaWrapper applies image transformations and resizing. The VariationWrapper enables domain randomization support. Autoreset is disabled to allow custom reset with seeds and options.
Example
Create a world with goal-conditioned observations:
world = World( env_name="PushT-v1", num_envs=8, image_shape=(96, 96), goal_shape=(64, 64), max_episode_steps=150, seed=42 )
- close(**kwargs)[source]
Close all environments and clean up resources.
- Parameters:
**kwargs – Additional keyword arguments passed to the underlying vectorized environment’s close method.
- Returns:
Return value from the underlying close method.
- Return type:
Any
- evaluate(episodes=10, eval_keys=None, seed=None, options=None)[source]
Evaluate the current policy over multiple episodes and return comprehensive metrics.
Runs the attached policy for a specified number of episodes, tracking success rates and optionally other metrics from environment info. Handles episode boundaries and ensures reproducibility through seeding.
- Parameters:
episodes (int, optional) – Total number of episodes to evaluate. More episodes provide more statistically reliable metrics. Defaults to 10.
eval_keys (list of str, optional) – Additional info keys to track across episodes. Must be keys present in self.infos (e.g., ‘reward_total’, ‘steps_to_success’). Defaults to None (track only success rate).
seed (int, optional) – Base random seed for reproducible evaluation. Each episode gets an incremental offset. Defaults to None (non-deterministic).
options (dict, optional) – Reset options passed to environments (e.g., task selection, variation settings). Defaults to None.
- Returns:
- Dictionary containing evaluation metrics:
’success_rate’ (float): Percentage of successful episodes (0-100)
’episode_successes’ (ndarray): Boolean array of episode outcomes
’seeds’ (ndarray): Random seeds used for each episode
Additional keys from eval_keys if specified (ndarray)
- Return type:
dict
- Raises:
AssertionError – If eval_key is not found in infos, if episode count mismatch occurs, or if duplicate seeds are detected.
Note
Success is determined by the ‘terminateds’ flag (True = success)
‘truncateds’ flag (timeout) is treated as failure
Seeds are validated for uniqueness to ensure independent episodes
Environments are manually reset with seeds (autoreset is bypassed)
All environments are evaluated in parallel for efficiency
Example
Basic evaluation:
metrics = world.evaluate(episodes=100, seed=42) print(f"Success: {metrics['success_rate']:.1f}%")Evaluate with additional metrics:
metrics = world.evaluate( episodes=100, eval_keys=['reward_total', 'episode_length'], seed=42, options={'task_id': 3} ) print(f"Success: {metrics['success_rate']:.1f}%") print(f"Avg reward: {metrics['reward_total'].mean():.2f}")Evaluate across different variations:
for var_type in ['none', 'color', 'light', 'all']: options = {'variation': [var_type]} if var_type != 'none' else None metrics = world.evaluate(episodes=50, seed=42, options=options) print(f"{var_type}: {metrics['success_rate']:.1f}%")
- record_dataset(dataset_name, episodes=10, seed=None, cache_dir=None, options=None)[source]
Collect episodes with the current policy and save as a HuggingFace Dataset.
Executes the attached policy to collect demonstration or rollout data, automatically managing episode boundaries and saving all observations, actions, rewards, and auxiliary information as a sharded Parquet dataset. Images are stored as JPEG files with paths in the dataset.
- The dataset is organized with the following structure:
- dataset_name/
- img/
- {episode_idx}/
{step_idx}_{column_name}.jpeg
data-{shard}.arrow (Parquet shards)
dataset_info.json
state.json
- Parameters:
dataset_name (str) – Name of the dataset. Used as subdirectory name in cache_dir. Should be descriptive (e.g., ‘pusht_expert_demos’).
episodes (int, optional) – Total number of complete episodes to collect. Incomplete episodes at the end are discarded. Defaults to 10.
seed (int, optional) – Base random seed for reproducibility. Each episode gets an incremental offset. Defaults to None (non-deterministic).
cache_dir (str or Path, optional) – Root directory for dataset storage. If None, uses swm.data.get_cache_dir(). Defaults to None.
options (dict, optional) – Reset options for environments. Use {‘variation’: [‘all’]} for full domain randomization. Defaults to None.
- Dataset Schema:
- Each row contains:
episode_idx (int32): Episode identifier
step_idx (int32): Step within episode (0-indexed)
episode_len (int32): Total length of the episode
policy (string): Policy type identifier
pixels (string): Relative path to observation image
goal (string, optional): Relative path to goal image
action (float array): Action taken at this step
reward (float): Reward received
Additional keys from environment infos
Note
Actions are shifted: action at step t leads to observation at step t+1
Last action in each episode is NaN (no action leads from final state)
Images are saved as JPEG for efficiency (may introduce compression artifacts)
Dataset is automatically sharded: 1 shard per 50 episodes
Only complete episodes are included in the final dataset
- Raises:
AssertionError – If required keys (‘pixels’, ‘episode_idx’, ‘step_idx’, ‘episode_len’) are missing from recorded data.
Example
Collect demonstration data with variations:
world.set_policy(expert_policy) world.record_dataset( dataset_name="cube_expert_1k", episodes=1000, seed=42, options={'variation': ['all']} )Collect data for specific tasks:
world.record_dataset( dataset_name="pusht_task2_data", episodes=500, options={'task_id': 2} )
- record_video(video_path, max_steps=500, fps=30, viewname='pixels', seed=None, options=None)[source]
Record rollout videos for each environment under the current policy.
Executes policy rollouts in all environments and saves them as MP4 videos, one per environment. Videos show stacked observation and goal images when goals are available.
- Parameters:
video_path (str or Path) – Directory path where videos will be saved. Created if it doesn’t exist. Videos are named ‘env_0.mp4’, ‘env_1.mp4’, etc.
max_steps (int, optional) – Maximum number of steps to record per episode. Recording stops earlier if any environment terminates. Defaults to 500.
fps (int, optional) – Frames per second for output videos. Higher values produce smoother playback but larger files. Defaults to 30.
seed (int, optional) – Random seed for reproducible rollouts. Defaults to None.
options (dict, optional) – Reset options passed to environments (e.g., variation settings). Defaults to None.
Note
Videos use libx264 codec for compatibility
If ‘goal’ is present in infos, frames show observation stacked above goal
All environments are recorded simultaneously
Recording stops when ANY environment terminates or truncates
Example
Record evaluation videos:
world.set_policy(trained_policy) world.record_video( video_path="./eval_videos", max_steps=200, fps=30, seed=42 )
- record_video_from_dataset(video_path, dataset_name, episode_idx, max_steps=500, fps=30, num_proc=4, viewname: str | list[str] = 'pixels', cache_dir=None)[source]
Replay stored dataset episodes and export them as MP4 videos.
Loads episodes from a previously recorded dataset and renders them as videos, useful for visualization, debugging, and qualitative evaluation of collected data.
- Parameters:
video_path (str or Path) – Directory where videos will be saved. Videos are named ‘episode_{idx}.mp4’. Directory is created if it doesn’t exist.
dataset_name (str) – Name of the dataset to load (must exist in cache_dir).
episode_idx (int or list of int) – Episode index or list of episode indices to render. Each episode is saved as a separate video file.
max_steps (int, optional) – Maximum number of steps to render per episode. Useful for limiting video length. Defaults to 500.
fps (int, optional) – Frames per second for output videos. Defaults to 30.
num_proc (int, optional) – Number of processes for parallel dataset filtering. Higher values speed up loading for large datasets. Defaults to 4.
cache_dir (str or Path, optional) – Root directory where dataset is stored. If None, uses swm.data.get_cache_dir(). Defaults to None.
- Raises:
AssertionError – If dataset doesn’t exist in cache_dir, or if episode length inconsistencies are detected in the data.
Note
Images are loaded from JPEG files stored in the dataset
If ‘goal’ column exists, observation and goal are stacked vertically
Videos use libx264 codec for broad compatibility
Episodes are validated for consistency (length matches metadata)
Example
Render specific episodes from a dataset:
world.record_video_from_dataset( video_path="./visualizations", dataset_name="cube_expert_1k", episode_idx=[0, 5, 10, 99], fps=30 )Render a single episode:
world.record_video_from_dataset( video_path="./debug", dataset_name="failed_episodes", episode_idx=42, max_steps=100 )
- reset(seed=None, options=None)[source]
Reset all environments to initial states.
- Parameters:
seed (int, optional) – Random seed for reproducible resets. If None, uses non-deterministic seeding. Defaults to None.
options (dict, optional) – Dictionary of reset options passed to environments. Common keys include ‘variation’ for domain randomization. Defaults to None.
- Updates:
self.states: Initial observations from all environments
self.infos: Initial auxiliary information
Example
Reset with domain randomization:
world.reset(seed=42, options={'variation': ['all']})Reset specific variations:
world.reset(options={'variation': ['cube.color', 'light.intensity']})
- set_policy(policy)[source]
Attach a policy to the world and configure it with environment context.
The policy will be used for action generation during step(), record_video(), record_dataset(), and evaluate() calls.
- Parameters:
policy (Policy) – Policy instance that implements get_action(infos) method. The policy receives environment context through set_env() and optional seeding through set_seed().
Note
If the policy has a ‘seed’ attribute, it will be applied via set_seed(). The policy’s set_env() method receives the wrapped vectorized environment.
- step()[source]
Advance all environments by one step using the current policy.
Queries the attached policy for actions based on current info, executes those actions in all environments, and updates internal state with the results.
- Updates:
self.states: New observations from all environments
self.rewards: Rewards received from the step
self.terminateds: Episode termination flags (task success)
self.truncateds: Episode truncation flags (timeout)
self.infos: Auxiliary information dictionaries
Note
Requires a policy to be attached via set_policy() before calling. The policy’s get_action() method receives the current infos dict.
- Raises:
AttributeError – If no policy has been set via set_policy().