Source code for stable_worldmodel.world

"""World environment manager.

This module provides the World class, a high-level manager for vectorized Gymnasium
environments with integrated support for data collection, policy evaluation, video
recording, and dataset management. It serves as the central orchestration layer for
training and evaluating world models with domain randomization support.

The World class handles:
    - Vectorized environment creation and management
    - Policy attachment and execution
    - Episode data collection with automatic sharding
    - Video recording from live rollouts or stored datasets
    - Policy evaluation with comprehensive metrics
    - Visual domain randomization through variation spaces

Example:
    Basic usage for policy evaluation::

        from stable_worldmodel import World
        from stable_worldmodel.policy import RandomPolicy

        # Create a world with 4 parallel environments
        world = World(
            env_name="PushT-v1",
            num_envs=4,
            image_shape=(96, 96),
            max_episode_steps=200
        )

        # Attach a policy
        policy = RandomPolicy()
        world.set_policy(policy)

        # Evaluate the policy
        metrics = world.evaluate(episodes=100, seed=42)
        print(f"Success rate: {metrics['success_rate']:.2f}%")

    Data collection example::

        # Collect demonstration data
        world.record_dataset(
            dataset_name="pusht_demos",
            episodes=1000,
            seed=42,
            options={'variation': ['all']}
        )

Todo:
    * Add real-time metric visualization during evaluation

.. _Gymnasium:
   https://gymnasium.farama.org/
"""

from collections.abc import Callable
from pathlib import Path

import datasets
import gymnasium as gym
import imageio.v3 as iio
import numpy as np
from datasets import Dataset, Features, Value, load_from_disk
from loguru import logger as logging
from PIL import Image
from rich import print

import stable_worldmodel as swm
from stable_worldmodel.data import is_image

from .wrappers import MegaWrapper, VariationWrapper


[docs] class World: """High-level manager for vectorized Gymnasium environments with integrated data collection. World orchestrates multiple parallel environments, providing a unified interface for policy execution, data collection, evaluation, and visualization. It automatically handles environment resets, batched action execution, and comprehensive episode tracking with support for visual domain randomization. The World class is the central component of the stable-worldmodel library, designed to streamline the workflow from data collection to policy training and evaluation. It wraps Gymnasium's vectorized environments with additional functionality for world model research. Attributes: envs (VectorEnv): Vectorized Gymnasium environment wrapped with MegaWrapper and VariationWrapper for image processing and domain randomization. seed (int): Base random seed for reproducibility. policy (Policy): Currently attached policy that generates actions. states (ndarray): Current observation states for all environments. rewards (ndarray): Current rewards for all environments. terminateds (ndarray): Terminal flags indicating episode completion. truncateds (ndarray): Truncation flags indicating episode timeout. infos (dict): Dictionary containing major information from environments. Properties: num_envs (int): Number of parallel environments. observation_space (Space): Batched observation space. action_space (Space): Batched action space. variation_space (Dict): Variation space for domain randomization (batched). single_variation_space (Dict): Variation space for a single environment. single_action_space (Space): Action space for a single environment. single_observation_space (Space): Observation space for a single environment. Note: Environments use DISABLED autoreset mode to enable custom reset behavior with seed and options support, which is not provided by Gymnasium's default autoreset mechanism. """
[docs] def __init__( self, env_name: str, num_envs: int, image_shape: tuple, goal_shape: tuple | None = None, goal_transform: Callable | None = None, image_transform: Callable | None = None, seed: int = 2349867, max_episode_steps: int = 100, verbose: int = 1, **kwargs, ): """Initialize the World with vectorized environments. Creates and configures a vectorized environment with the specified number of parallel instances. Applies image and goal transformations, sets up variation support, and configures autoreset behavior. Args: env_name (str): Name of the Gymnasium environment to create. Must be a registered environment ID (e.g., 'PushT-v0', 'CubeEnv-v0'). num_envs (int): Number of parallel environment instances to create. Higher values increase data collection throughput but require more memory. image_shape (tuple): Target shape for image observations as (height, width) or (height, width, channels). Images are resized to this shape. goal_shape (tuple, optional): Target shape for goal image observations. If None, goals are processed with the same shape as observations. Defaults to None. goal_transform (Callable, optional): Function to transform goal observations. Should accept and return numpy arrays. Applied after resizing. Defaults to None. image_transform (Callable, optional): Function to transform image observations. Should accept and return numpy arrays. Applied after resizing. Defaults to None. seed (int, optional): Base random seed for environment initialization. Each environment gets an offset seed. Defaults to 2349867. max_episode_steps (int, optional): Maximum number of steps per episode before truncation. Episodes terminate early on task success. Defaults to 100. verbose (int, optional): Verbosity level. 0 for silent, 1 for basic info, 2+ for detailed debugging information. Defaults to 1. **kwargs: Additional keyword arguments passed to gym.make_vec() and subsequently to the underlying environment constructor. Note: The MegaWrapper applies image transformations and resizing. The VariationWrapper enables domain randomization support. Autoreset is disabled to allow custom reset with seeds and options. Example: Create a world with goal-conditioned observations:: world = World( env_name="PushT-v1", num_envs=8, image_shape=(96, 96), goal_shape=(64, 64), max_episode_steps=150, seed=42 ) """ self.envs = gym.make_vec( env_name, num_envs=num_envs, vectorization_mode="sync", wrappers=[lambda x: MegaWrapper(x, image_shape, image_transform, goal_shape, goal_transform)], max_episode_steps=max_episode_steps, **kwargs, ) self.envs = VariationWrapper(self.envs) self.envs.unwrapped.autoreset_mode = gym.vector.AutoresetMode.DISABLED if verbose > 0: logging.info(f"🌍🌍🌍 World {env_name} initialized 🌍🌍🌍") logging.info("đŸ•šī¸ đŸ•šī¸ đŸ•šī¸ Action space đŸ•šī¸ đŸ•šī¸ đŸ•šī¸") logging.info(f"{self.envs.action_space}") logging.info("đŸ‘ī¸ đŸ‘ī¸ đŸ‘ī¸ Observation space đŸ‘ī¸ đŸ‘ī¸ đŸ‘ī¸") logging.info(f"{self.envs.observation_space}") if self.envs.variation_space is not None: logging.info("âš—ī¸ âš—ī¸ âš—ī¸ Variation space âš—ī¸ âš—ī¸ âš—ī¸") print(self.envs.variation_space) else: logging.warning("No variation space provided!") self.seed = seed
@property def num_envs(self): """int: Number of parallel environment instances.""" return self.envs.num_envs @property def observation_space(self): """Space: Batched observation space for all environments.""" return self.envs.observation_space @property def action_space(self): """Space: Batched action space for all environments.""" return self.envs.action_space @property def variation_space(self): """Dict: Batched variation space for domain randomization across all environments.""" return self.envs.variation_space @property def single_variation_space(self): """Dict: Variation space for a single environment instance.""" return self.envs.single_variation_space @property def single_action_space(self): """Space: Action space for a single environment instance.""" return self.envs.single_action_space @property def single_observation_space(self): """Space: Observation space for a single environment instance.""" return self.envs.single_observation_space
[docs] def close(self, **kwargs): """Close all environments and clean up resources. Args: **kwargs: Additional keyword arguments passed to the underlying vectorized environment's close method. Returns: Any: Return value from the underlying close method. """ return self.envs.close(**kwargs)
[docs] def step(self): """Advance all environments by one step using the current policy. Queries the attached policy for actions based on current info, executes those actions in all environments, and updates internal state with the results. Updates: - self.states: New observations from all environments - self.rewards: Rewards received from the step - self.terminateds: Episode termination flags (task success) - self.truncateds: Episode truncation flags (timeout) - self.infos: Auxiliary information dictionaries Note: Requires a policy to be attached via set_policy() before calling. The policy's get_action() method receives the current infos dict. Raises: AttributeError: If no policy has been set via set_policy(). """ # note: reset happens before because of auto-reset, should fix that actions = self.policy.get_action(self.infos) self.states, self.rewards, self.terminateds, self.truncateds, self.infos = self.envs.step(actions)
[docs] def reset(self, seed=None, options=None): """Reset all environments to initial states. Args: seed (int, optional): Random seed for reproducible resets. If None, uses non-deterministic seeding. Defaults to None. options (dict, optional): Dictionary of reset options passed to environments. Common keys include 'variation' for domain randomization. Defaults to None. Updates: - self.states: Initial observations from all environments - self.infos: Initial auxiliary information Example: Reset with domain randomization:: world.reset(seed=42, options={'variation': ['all']}) Reset specific variations:: world.reset(options={'variation': ['cube.color', 'light.intensity']}) """ self.states, self.infos = self.envs.reset(seed=seed, options=options)
[docs] def set_policy(self, policy): """Attach a policy to the world and configure it with environment context. The policy will be used for action generation during step(), record_video(), record_dataset(), and evaluate() calls. Args: policy (Policy): Policy instance that implements get_action(infos) method. The policy receives environment context through set_env() and optional seeding through set_seed(). Note: If the policy has a 'seed' attribute, it will be applied via set_seed(). The policy's set_env() method receives the wrapped vectorized environment. """ self.policy = policy self.policy.set_env(self.envs) if hasattr(self.policy, "seed") and self.policy.seed is not None: self.policy.set_seed(self.policy.seed)
[docs] def record_video( self, video_path, max_steps=500, fps=30, viewname="pixels", seed=None, options=None, ): """Record rollout videos for each environment under the current policy. Executes policy rollouts in all environments and saves them as MP4 videos, one per environment. Videos show stacked observation and goal images when goals are available. Args: video_path (str or Path): Directory path where videos will be saved. Created if it doesn't exist. Videos are named 'env_0.mp4', 'env_1.mp4', etc. max_steps (int, optional): Maximum number of steps to record per episode. Recording stops earlier if any environment terminates. Defaults to 500. fps (int, optional): Frames per second for output videos. Higher values produce smoother playback but larger files. Defaults to 30. seed (int, optional): Random seed for reproducible rollouts. Defaults to None. options (dict, optional): Reset options passed to environments (e.g., variation settings). Defaults to None. Note: - Videos use libx264 codec for compatibility - If 'goal' is present in infos, frames show observation stacked above goal - All environments are recorded simultaneously - Recording stops when ANY environment terminates or truncates Example: Record evaluation videos:: world.set_policy(trained_policy) world.record_video( video_path="./eval_videos", max_steps=200, fps=30, seed=42 ) """ import imageio viewname = [viewname] if isinstance(viewname, str) else viewname out = [ imageio.get_writer( Path(video_path) / f"env_{i}.mp4", "output.mp4", fps=fps, codec="libx264", ) for i in range(self.num_envs) ] self.reset(seed, options) for i, o in enumerate(out): frame = np.vstack([self.infos[v_name][i] for v_name in viewname]) if "goal" in self.infos: frame = np.vstack([frame, self.infos["goal"][i]]) o.append_data(frame) for _ in range(max_steps): self.step() if np.any(self.terminateds) or np.any(self.truncateds): break for i, o in enumerate(out): frame = np.vstack([self.infos[v_name][i] for v_name in viewname]) if "goal" in self.infos: frame = np.vstack([frame, self.infos["goal"][i]]) o.append_data(frame) [o.close() for o in out] print(f"Video saved to {video_path}")
[docs] def record_dataset(self, dataset_name, episodes=10, seed=None, cache_dir=None, options=None): """Collect episodes with the current policy and save as a HuggingFace Dataset. Executes the attached policy to collect demonstration or rollout data, automatically managing episode boundaries and saving all observations, actions, rewards, and auxiliary information as a sharded Parquet dataset. Images are stored as JPEG files with paths in the dataset. The dataset is organized with the following structure: - dataset_name/ - img/ - {episode_idx}/ - {step_idx}_{column_name}.jpeg - data-{shard}.arrow (Parquet shards) - dataset_info.json - state.json Args: dataset_name (str): Name of the dataset. Used as subdirectory name in cache_dir. Should be descriptive (e.g., 'pusht_expert_demos'). episodes (int, optional): Total number of complete episodes to collect. Incomplete episodes at the end are discarded. Defaults to 10. seed (int, optional): Base random seed for reproducibility. Each episode gets an incremental offset. Defaults to None (non-deterministic). cache_dir (str or Path, optional): Root directory for dataset storage. If None, uses swm.data.get_cache_dir(). Defaults to None. options (dict, optional): Reset options for environments. Use {'variation': ['all']} for full domain randomization. Defaults to None. Dataset Schema: Each row contains: - episode_idx (int32): Episode identifier - step_idx (int32): Step within episode (0-indexed) - episode_len (int32): Total length of the episode - policy (string): Policy type identifier - pixels (string): Relative path to observation image - goal (string, optional): Relative path to goal image - action (float array): Action taken at this step - reward (float): Reward received - Additional keys from environment infos Note: - Actions are shifted: action at step t leads to observation at step t+1 - Last action in each episode is NaN (no action leads from final state) - Images are saved as JPEG for efficiency (may introduce compression artifacts) - Dataset is automatically sharded: 1 shard per 50 episodes - Only complete episodes are included in the final dataset Raises: AssertionError: If required keys ('pixels', 'episode_idx', 'step_idx', 'episode_len') are missing from recorded data. Example: Collect demonstration data with variations:: world.set_policy(expert_policy) world.record_dataset( dataset_name="cube_expert_1k", episodes=1000, seed=42, options={'variation': ['all']} ) Collect data for specific tasks:: world.record_dataset( dataset_name="pusht_task2_data", episodes=500, options={'task_id': 2} ) """ cache_dir = cache_dir or swm.data.get_cache_dir() dataset_path = Path(cache_dir, dataset_name) dataset_path.mkdir(parents=True, exist_ok=True) recorded_episodes = 0 self.terminateds = np.zeros(self.num_envs) self.truncateds = np.zeros(self.num_envs) episode_idx = np.arange(self.num_envs) self.reset(seed, options) # <- incr global seed by num_envs root_seed = seed + self.num_envs if seed is not None else None records = {key: list(value) for key, value in self.infos.items() if key[0] != "_"} records["episode_idx"] = list(episode_idx) records["policy"] = [self.policy.type] * self.num_envs while True: self.step() # start new episode for done envs for i in range(self.num_envs): if self.terminateds[i] or self.truncateds[i]: # re-reset env with seed and options (no supported by auto-reset) new_seed = root_seed + recorded_episodes if seed is not None else None # determine new episode idx next_ep_idx = episode_idx.max() + 1 episode_idx[i] = next_ep_idx recorded_episodes += 1 self.envs.unwrapped._autoreset_envs = np.zeros((self.num_envs,)) _, infos = self.envs.envs[i].reset(seed=new_seed, options=options) for k, v in infos.items(): self.infos[k][i] = np.asarray(v) if recorded_episodes >= episodes: break for key in self.infos: if key[0] == "_": continue # shift actions if key == "action": n_action = len(self.infos[key]) last_episode = records["episode_idx"][-n_action:] action_mask = (last_episode == episode_idx)[:, None] # override last actions of continuing episodes records[key][-n_action:] = np.where( action_mask, self.infos[key], np.nan, ) # add new dummy action action_shape = np.shape(self.infos[key][0]) action_dtype = self.single_action_space.dtype dummy_block = [np.full(action_shape, np.nan, dtype=action_dtype) for _ in range(self.num_envs)] records[key].extend(dummy_block) else: records[key].extend(list(self.infos[key])) records["episode_idx"].extend(list(episode_idx)) records["policy"].extend([self.policy.type] * self.num_envs) # add the episode length counts = np.bincount(np.array(records["episode_idx"]), minlength=max(records["episode_idx"]) + 1) records["episode_len"] = [int(counts[ep]) for ep in records["episode_idx"]] ######################## # Save dataset to disk # ######################## assert "pixels" in records, "pixels key is required in records" assert "episode_idx" in records, "episode_idx key is required in records" assert "step_idx" in records, "step_idx key is required in records" assert "episode_len" in records, "episode_len key is required in records" # Create the dataset directory structure dataset_path.mkdir(parents=True, exist_ok=True) # save all jpeg images image_cols = {col for col in records if is_image(records[col][0])} # pre-create all directories for ep_idx in set(records["episode_idx"]): img_folder = dataset_path / "img" / f"{ep_idx}" img_folder.mkdir(parents=True, exist_ok=True) # dump all data for i in range(len(records["episode_idx"])): ep_idx = records["episode_idx"][i] step_idx = records["step_idx"][i] for img_col in image_cols: img = records[img_col][i] img_folder = dataset_path / "img" / f"{ep_idx}" img_path = img_folder / f"{step_idx}_{img_col.replace('.', '_')}.jpeg" iio.imwrite(img_path, img) # replace image in records with relative path records[img_col][i] = str(img_path.relative_to(dataset_path)) def determine_features(records): features = { "episode_idx": Value("int32"), "step_idx": Value("int32"), "episode_len": Value("int32"), } for col_name in records: if col_name in features: continue first_elem = records[col_name][0] if type(first_elem) is str: features[col_name] = Value("string") elif isinstance(first_elem, np.ndarray): if first_elem.ndim == 1: state_feature = datasets.Sequence( feature=Value(dtype=first_elem.dtype.name), length=len(first_elem), ) elif 2 <= first_elem.ndim <= 6: feature_cls = getattr(datasets, f"Array{first_elem.ndim}D") state_feature = feature_cls(shape=first_elem.shape, dtype=first_elem.dtype.name) else: state_feature = Value(first_elem.dtype.name) features[col_name] = state_feature elif isinstance(first_elem, (np.generic)): features[col_name] = Value(first_elem.dtype.name) else: features[col_name] = Value(type(first_elem).__name__) return Features(features) records_feat = determine_features(records) records_ds = Dataset.from_dict(records, features=records_feat) # flush incomplete episodes # get episodes that are currently running (not done) incomplete_episodes = episode_idx[~(self.terminateds | self.truncateds)] # keep only episodes that are NOT in the incomplete list keep_mask = ~np.isin(records_ds["episode_idx"], incomplete_episodes) records_ds = records_ds.select(np.nonzero(keep_mask)[0]) # flush all extra episodes saved (keep only first N episodes) episodes_to_keep = np.unique(records_ds["episode_idx"])[:episodes] keep_mask = np.isin(records_ds["episode_idx"], episodes_to_keep) records_ds = records_ds.select(np.nonzero(keep_mask)[0]) # save dataset records_path = dataset_path # / "records" num_chunks = episodes // 50 records_path.mkdir(parents=True, exist_ok=True) records_ds.save_to_disk(records_path, num_shards=num_chunks or 1) print(f"Dataset saved to {dataset_path} with {episodes} episodes!")
[docs] def record_video_from_dataset( self, video_path, dataset_name, episode_idx, max_steps=500, fps=30, num_proc=4, viewname: str | list[str] = "pixels", cache_dir=None, ): """Replay stored dataset episodes and export them as MP4 videos. Loads episodes from a previously recorded dataset and renders them as videos, useful for visualization, debugging, and qualitative evaluation of collected data. Args: video_path (str or Path): Directory where videos will be saved. Videos are named 'episode_{idx}.mp4'. Directory is created if it doesn't exist. dataset_name (str): Name of the dataset to load (must exist in cache_dir). episode_idx (int or list of int): Episode index or list of episode indices to render. Each episode is saved as a separate video file. max_steps (int, optional): Maximum number of steps to render per episode. Useful for limiting video length. Defaults to 500. fps (int, optional): Frames per second for output videos. Defaults to 30. num_proc (int, optional): Number of processes for parallel dataset filtering. Higher values speed up loading for large datasets. Defaults to 4. cache_dir (str or Path, optional): Root directory where dataset is stored. If None, uses swm.data.get_cache_dir(). Defaults to None. Raises: AssertionError: If dataset doesn't exist in cache_dir, or if episode length inconsistencies are detected in the data. Note: - Images are loaded from JPEG files stored in the dataset - If 'goal' column exists, observation and goal are stacked vertically - Videos use libx264 codec for broad compatibility - Episodes are validated for consistency (length matches metadata) Example: Render specific episodes from a dataset:: world.record_video_from_dataset( video_path="./visualizations", dataset_name="cube_expert_1k", episode_idx=[0, 5, 10, 99], fps=30 ) Render a single episode:: world.record_video_from_dataset( video_path="./debug", dataset_name="failed_episodes", episode_idx=42, max_steps=100 ) """ import imageio cache_dir = cache_dir or swm.data.get_cache_dir() dataset_path = Path(cache_dir, dataset_name) assert dataset_path.is_dir(), f"Dataset {dataset_name} not found in cache dir {swm.data.get_cache_dir()}" episode_idx = [episode_idx] if isinstance(episode_idx, int) else episode_idx viewname = [viewname] if isinstance(viewname, str) else viewname out = [ imageio.get_writer( Path(video_path) / f"episode_{i}.mp4", "output.mp4", fps=fps, codec="libx264", ) for i in episode_idx ] dataset = load_from_disk(dataset_path).with_format("numpy") for i, o in zip(episode_idx, out): episode = dataset.filter(lambda ex: ex["episode_idx"] == i, num_proc=num_proc) episode = episode.sort("step_idx") episode_len = len(episode) assert len(set(episode["episode_len"])) == 1, ( "'episode_len' contains different values for the same episode" ) assert len(episode) == episode["episode_len"][0], ( f"Episode {i} has {len(episode)} steps, but 'episode_len' is {episode['episode_len'][0]}" ) for step_idx in range(min(episode_len, max_steps)): frame = [] for view in viewname: img_path = Path(dataset_path, episode[step_idx][view]) frame.append(np.array(Image.open(img_path).convert("RGB"), dtype=np.uint8)) frame = np.vstack(frame) # should try hstack? if "goal" in episode.column_names: goal_path = Path(dataset_path, episode[step_idx]["goal"]) goal = Image.open(goal_path) goal = np.array(goal.convert("RGB"), dtype=np.uint8) frame = np.vstack([frame, goal]) o.append_data(frame) [o.close() for o in out] print(f"Video saved to {video_path}")
[docs] def evaluate(self, episodes=10, eval_keys=None, seed=None, options=None): """Evaluate the current policy over multiple episodes and return comprehensive metrics. Runs the attached policy for a specified number of episodes, tracking success rates and optionally other metrics from environment info. Handles episode boundaries and ensures reproducibility through seeding. Args: episodes (int, optional): Total number of episodes to evaluate. More episodes provide more statistically reliable metrics. Defaults to 10. eval_keys (list of str, optional): Additional info keys to track across episodes. Must be keys present in self.infos (e.g., 'reward_total', 'steps_to_success'). Defaults to None (track only success rate). seed (int, optional): Base random seed for reproducible evaluation. Each episode gets an incremental offset. Defaults to None (non-deterministic). options (dict, optional): Reset options passed to environments (e.g., task selection, variation settings). Defaults to None. Returns: dict: Dictionary containing evaluation metrics: - 'success_rate' (float): Percentage of successful episodes (0-100) - 'episode_successes' (ndarray): Boolean array of episode outcomes - 'seeds' (ndarray): Random seeds used for each episode - Additional keys from eval_keys if specified (ndarray) Raises: AssertionError: If eval_key is not found in infos, if episode count mismatch occurs, or if duplicate seeds are detected. Note: - Success is determined by the 'terminateds' flag (True = success) - 'truncateds' flag (timeout) is treated as failure - Seeds are validated for uniqueness to ensure independent episodes - Environments are manually reset with seeds (autoreset is bypassed) - All environments are evaluated in parallel for efficiency Example: Basic evaluation:: metrics = world.evaluate(episodes=100, seed=42) print(f"Success: {metrics['success_rate']:.1f}%") Evaluate with additional metrics:: metrics = world.evaluate( episodes=100, eval_keys=['reward_total', 'episode_length'], seed=42, options={'task_id': 3} ) print(f"Success: {metrics['success_rate']:.1f}%") print(f"Avg reward: {metrics['reward_total'].mean():.2f}") Evaluate across different variations:: for var_type in ['none', 'color', 'light', 'all']: options = {'variation': [var_type]} if var_type != 'none' else None metrics = world.evaluate(episodes=50, seed=42, options=options) print(f"{var_type}: {metrics['success_rate']:.1f}%") """ options = options or {} metrics = { "success_rate": 0, "episode_successes": np.zeros(episodes), "seeds": np.empty(episodes, dtype=int), } if eval_keys: for key in eval_keys: metrics[key] = np.zeros(episodes) self.terminateds = np.zeros(self.num_envs) self.truncateds = np.zeros(self.num_envs) episode_idx = np.arange(self.num_envs) self.reset(seed, options) root_seed = seed + self.num_envs if seed is not None else None eval_ep_count = 0 eval_done = False while True: self.step() # start new episode for done envs for i in range(self.num_envs): if self.terminateds[i] or self.truncateds[i]: # record eval info ep_idx = episode_idx[i] metrics["episode_successes"][ep_idx] = self.terminateds[i] metrics["seeds"][ep_idx] = self.envs.envs[i].unwrapped.np_random_seed if eval_keys: for key in eval_keys: assert key in self.infos, f"key {key} not found in infos" metrics[key][ep_idx] = self.infos[key][i] # determine new episode idx # re-reset env with seed and options (no supported by auto-reset) new_seed = root_seed + eval_ep_count if seed is not None else None next_ep_idx = episode_idx.max() + 1 episode_idx[i] = next_ep_idx eval_ep_count += 1 # break if enough episodes evaluated if eval_ep_count >= episodes: eval_done = True break self.envs.unwrapped._autoreset_envs = np.zeros((self.num_envs,)) _, infos = self.envs.envs[i].reset(seed=new_seed, options=options) for k, v in infos.items(): if k not in self.infos: continue # Convert to array and extract scalar to preserve dtype self.infos[k][i] = np.asarray(v) if eval_done: break # compute success rate metrics["success_rate"] = float(np.sum(metrics["episode_successes"])) / episodes * 100.0 assert eval_ep_count == episodes, f"eval_ep_count {eval_ep_count} != episodes {episodes}" assert np.unique(metrics["seeds"]).shape[0] == episodes, "Some episode seeds are identical!" return metrics