Source code for stable_worldmodel.data

"""Data management and dataset utilities for Stable World Model.

This module provides utilities for managing datasets, models, and world
information. It includes functionality for loading multi-step trajectories,
querying cached data, and retrieving metadata about environments.

The module supports:
    - Multi-step trajectory datasets with frame skipping
    - Dataset and model cache management
    - World environment introspection
    - Gymnasium space metadata extraction
"""

import os
import re
import shutil
from functools import lru_cache
from pathlib import Path
from typing import Any, TypedDict

import gymnasium as gym
import numpy as np
import PIL
import stable_pretraining as spt
import torch
from datasets import load_from_disk
from rich import print

import stable_worldmodel as swm


[docs] class StepsDataset(spt.data.HFDataset): """Dataset for loading multi-step trajectory sequences. This dataset loads sequences of consecutive steps from episodic data, supporting frame skipping and automatic image loading from file paths. Inherits from stable_pretraining's HFDataset. Attributes: data_dir (Path): Directory containing the dataset. num_steps (int): Number of steps per sequence. frameskip (int): Number of steps to skip between frames. episodes (np.ndarray): Array of unique episode indices. episode_slices (dict): Mapping from episode index to dataset indices. cum_slices (np.ndarray): Cumulative sum of valid samples per episode. idx_to_ep (np.ndarray): Mapping from sample index to episode. img_cols (set): Set of column names containing image paths. """
[docs] def __init__( self, path, *args, num_steps=2, frameskip=1, cache_dir=None, **kwargs, ): """Initialize the StepsDataset. Args: path (str): Name or path of the dataset within cache directory. *args: Additional arguments passed to parent class. num_steps (int, optional): Number of consecutive steps per sample. Defaults to 2. frameskip (int, optional): Number of steps between sampled frames. Defaults to 1. cache_dir (str, optional): Cache directory path. Defaults to None (uses default). **kwargs: Additional keyword arguments passed to parent class. Raises: AssertionError: If required columns are missing from dataset. ValueError: If episodes are too short for the requested num_steps and frameskip. """ data_dir = Path(cache_dir or swm.data.get_cache_dir(), path) super().__init__(str(data_dir), *args, **kwargs) self.data_dir = data_dir self.num_steps = num_steps self.frameskip = frameskip assert "episode_idx" in self.dataset.column_names, "Dataset must have 'episode_idx' column" assert "step_idx" in self.dataset.column_names, "Dataset must have 'step_idx' column" assert "action" in self.dataset.column_names, "Dataset must have 'action' column" self.dataset.set_format("torch") # get number of episodes ep_indices = self.dataset["episode_idx"][:] self.episodes = np.unique(ep_indices) # get dataset indices of each episode self.episode_slices = {e: self.get_episode_slice(e, ep_indices) for e in self.episodes} # start index for each episode valid_samples_per_ep = [ max(0, len(ep_slice) - self.num_steps * self.frameskip + 1) for ep_slice in self.episode_slices.values() ] self.cum_slices = np.cumsum([0] + valid_samples_per_ep) # map from sample to their episode self.idx_to_ep = np.searchsorted(self.cum_slices, torch.arange(len(self)), side="right") - 1 self.img_cols = self.infer_img_path_columns()
[docs] def get_episode_slice(self, episode_idx, episode_indices): """Get dataset indices for a specific episode. Args: episode_idx (int): Episode index to retrieve. episode_indices (array-like): Array of episode indices for all steps. Returns: np.ndarray: Indices of steps belonging to the specified episode. Raises: ValueError: If episode is too short for num_steps and frameskip. """ indices = np.flatnonzero(episode_indices == episode_idx) if len(indices) <= (self.num_steps * self.frameskip): raise ValueError( f"Episode {episode_idx} is too short ({len(indices)} steps) for {self.num_steps} steps with {self.frameskip} frameskip" ) return indices
def __len__(self): """Get total number of valid step sequences in the dataset. Returns: int: Number of valid sequences across all episodes. """ return int(self.cum_slices[-1]) def __getitem__(self, idx): """Get a multi-step sequence at the given index. Args: idx (int): Index of the sequence to retrieve. Returns: dict: Dictionary containing the sequence data with keys for observations, actions, and other dataset fields. Images are loaded as PIL Images and actions are reshaped to (num_steps, action_dim). """ ep = self.idx_to_ep[idx] episode_indices = self.episode_slices[ep] offset = idx - self.cum_slices[ep] start = offset stop = start + self.num_steps * self.frameskip idx_slice = episode_indices[start:stop] steps = self.dataset[idx_slice] for k, v in steps.items(): if k == "action": continue v = v[:: self.frameskip] steps[k] = v if k in self.img_cols: steps[k] = [PIL.Image.open(self.data_dir / img_path) for img_path in v] if self.transform: steps = self.transform(steps) # stack images into a single tensor for k in self.img_cols: steps[k] = torch.stack(steps[k]) # reshape action steps["action"] = steps["action"].reshape(self.num_steps, -1) return steps
[docs] def infer_img_path_columns(self): """Infer which dataset columns contain image file paths. Checks the first dataset element to identify string columns with common image file extensions. Returns: set: Set of column names containing image file paths. """ IMG_EXTENSIONS = (".jpeg", ".png", ".jpg") img_cols = set() first_elem = self.dataset[0] for col in self.dataset.column_names: if isinstance(first_elem[col], str) and first_elem[col].endswith(IMG_EXTENSIONS): img_cols.add(col) return img_cols
##################### ### utils ### #####################
[docs] def is_image(x): """Check if input is a valid image array. Args: x: Input to check. Returns: bool: True if x is a uint8 numpy array with shape (H, W, C) where C is 1 (grayscale), 3 (RGB), or 4 (RGBA). """ return type(x) is np.ndarray and x.ndim == 3 and x.shape[2] in [1, 3, 4] and x.dtype == np.uint8
##################### ### CLI Info #### #####################
[docs] class SpaceInfo(TypedDict, total=False): """Type specification for Gymnasium space metadata. Attributes: shape: Dimensions of the space. type: Class name of the space (e.g., 'Box', 'Discrete'). dtype: Data type of the space elements. low: Lower bounds for Box spaces. high: Upper bounds for Box spaces. n: Number of discrete values for Discrete spaces. """ shape: tuple[int, ...] type: str dtype: str low: Any high: Any n: int
[docs] class VariationInfo(TypedDict): """Type specification for environment variation metadata. Attributes: has_variation: Whether the environment supports variations. type: Class name of the variation space if it exists. names: List of variation parameter names. """ has_variation: bool type: str | None names: list[str] | None
[docs] class WorldInfo(TypedDict): """Type specification for world environment information. Attributes: name: Name/ID of the world environment. observation_space: Metadata about the observation space. action_space: Metadata about the action space. variation: Information about environment variations. config: Additional configuration parameters. """ name: str observation_space: SpaceInfo action_space: SpaceInfo variation: VariationInfo config: dict[str, Any]
[docs] def get_cache_dir() -> Path: """Get the cache directory for stable_worldmodel data. The cache directory can be customized via the STABLEWM_HOME environment variable. If not set, defaults to ~/.stable_worldmodel. Returns: Path: Path to the cache directory. Directory is created if it doesn't exist. """ cache_dir = os.getenv("STABLEWM_HOME", os.path.expanduser("~/.stable_worldmodel")) os.makedirs(cache_dir, exist_ok=True) return Path(cache_dir)
[docs] def list_datasets(): """List all cached datasets. Returns: list[str]: Names of all dataset directories in the cache. """ with os.scandir(get_cache_dir()) as entries: return [e.name for e in entries if e.is_dir()]
[docs] def list_models(): """List all cached model checkpoints. Searches for files matching the pattern `<name>_weights*.ckpt` or `<name>_object.ckpt` and returns unique model names. Returns: list[str]: Sorted list of model names found in cache. """ pattern = re.compile(r"^(.*?)(?=_(?:weights(?:-[^.]*)?|object)\.ckpt$)", re.IGNORECASE) cache_dir = get_cache_dir() models = set() for fname in os.listdir(cache_dir): m = pattern.match(fname) if m: models.add(m.group(1)) return sorted(models)
[docs] def dataset_info(name): """Get metadata about a cached dataset. Args: name (str): Name of the dataset. Returns: dict: Dictionary containing dataset metadata including: - name: Dataset name - num_episodes: Number of unique episodes - num_steps: Total number of steps - columns: List of column names - obs_shape: Shape of observation images - action_shape: Shape of action vectors - goal_shape: Shape of goal images - variation: Dict with variation information Raises: ValueError: If dataset is not found in cache. AssertionError: If required columns are missing. """ # check name exists if name not in list_datasets(): raise ValueError(f"Dataset '{name}' not found. Available: {list_datasets()}") dataset = load_from_disk(str(Path(get_cache_dir(), name, "records"))) dataset.set_format("numpy") def assert_msg(col): return f"Dataset must have '{col}' column" # type: ignore assert "episode_idx" in dataset.column_names, assert_msg("episode_idx") assert "step_idx" in dataset.column_names, assert_msg("step_idx") assert "episode_len" in dataset.column_names, assert_msg("episode_len") assert "pixels" in dataset.column_names, assert_msg("pixels") assert "action" in dataset.column_names, assert_msg("action") assert "goal" in dataset.column_names, assert_msg("goal") info = { "name": name, "num_episodes": len(np.unique(dataset["episode_idx"])), "num_steps": len(dataset), "columns": dataset.column_names, "obs_shape": dataset["pixels"][0].shape, "action_shape": dataset["action"][0].shape, "goal_shape": dataset["goal"][0].shape, "variation": { "has_variation": any(col.startswith("variation.") for col in dataset.column_names), "names": [col.removeprefix("variation.") for col in dataset.column_names if col.startswith("variation.")], }, } return info
[docs] def list_worlds(): """List all registered world environments. Returns: list[str]: Sorted list of world environment IDs. """ return sorted(swm.envs.WORLDS)
def _space_meta(space) -> SpaceInfo | dict[str, SpaceInfo] | list[SpaceInfo]: """Extract metadata from a Gymnasium space. Recursively processes Dict, Tuple, and Sequence spaces. Args: space: A Gymnasium space object. Returns: SpaceInfo | dict | list: Space metadata. Returns a dict for Dict spaces, a list for Tuple/Sequence spaces, or a SpaceInfo dict for simple spaces. """ if isinstance(space, gym.spaces.Dict): return {k: _space_meta(v) for k, v in space.spaces.items()} if isinstance(space, gym.spaces.Sequence) or isinstance(space, gym.spaces.Tuple): return [_space_meta(s) for s in space.spaces] info: SpaceInfo = { "shape": getattr(space, "shape", None), "type": type(space).__name__, } if hasattr(space, "dtype") and getattr(space, "dtype") is not None: info["dtype"] = str(space.dtype) if hasattr(space, "low"): info["low"] = getattr(space, "low", None) if hasattr(space, "high"): info["high"] = getattr(space, "high", None) if hasattr(space, "n"): info["n"] = getattr(space, "n") return info
[docs] @lru_cache(maxsize=128) def world_info( name: str, *, image_shape: tuple[int, int] = (224, 224), render_mode: str = "rgb_array", ) -> WorldInfo: """Get metadata about a world environment. Creates a temporary world instance to extract observation space, action space, and variation information. Results are cached for efficiency. Args: name: ID of the world environment. image_shape: Desired image shape for rendering. Defaults to (224, 224). render_mode: Rendering mode for the environment. Defaults to "rgb_array". Returns: WorldInfo: Dictionary containing world metadata including spaces and variations. Raises: ValueError: If world name is not registered. """ if name not in swm.envs.WORLDS: raise ValueError(f"World '{name}' not found. Available: {', '.join(list_worlds())}") world = None try: world = swm.World( name, num_envs=1, image_shape=image_shape, render_mode=render_mode, verbose=0, ) obs_space = getattr(world, "single_observation_space", None) act_space = getattr(world, "single_action_space", None) var_space = getattr(world, "single_variation_space", None) variation: VariationInfo = { "has_variation": var_space is not None, "type": type(var_space).__name__ if var_space is not None else None, "names": var_space.names() if hasattr(var_space, "names") else None, } return { "name": name, "observation_space": _space_meta(obs_space) if obs_space else {}, "action_space": _space_meta(act_space) if act_space else {}, "variation": variation, } finally: if world is not None and hasattr(world, "close"): try: world.close() except Exception: pass
[docs] def delete_dataset(name): """Delete a cached dataset and its associated files. Args: name (str): Name of the dataset to delete. Note: Prints success or error messages to console. """ from datasets import logging as ds_logging ds_logging.set_verbosity_error() try: dataset_path = Path(get_cache_dir(), name) if not dataset_path.exists(): raise ValueError(f"Dataset {name} does not exist at {dataset_path}") dataset = load_from_disk(str(Path(dataset_path, "records"))) # remove cache files dataset.cleanup_cache_files() # delete dataset directory shutil.rmtree(dataset_path, ignore_errors=False) print(f"🗑️ Dataset {dataset_path} deleted!") except Exception as e: print(f"[red]Error cleaning up dataset [cyan]{name}[/cyan]: {e}[/red]")
[docs] def delete_model(name): """Delete cached model checkpoint files. Removes all checkpoint files (weights and object files) matching the given model name. Args: name (str): Name of the model to delete. Note: Prints success or error messages to console for each file deleted. """ pattern = re.compile(rf"^{re.escape(name)}(?:_[^-].*)?\.ckpt$") cache_dir = get_cache_dir() for fname in os.listdir(cache_dir): if pattern.match(fname): filepath = os.path.join(cache_dir, fname) try: os.remove(filepath) print(f"🔮 Model {fname} deleted") except Exception as e: print(f"[red]Error occurred while deleting model [cyan]{name}[/cyan]: {e}[/red]")