StepsDataset

class StepsDataset(path, *args, num_steps=2, frameskip=1, cache_dir=None, **kwargs)[source]

Bases: HFDataset

Dataset for loading multi-step trajectory sequences.

This dataset loads sequences of consecutive steps from episodic data, supporting frame skipping and automatic image loading from file paths. Inherits from stable_pretraining’s HFDataset.

Variables:
  • data_dir (Path) – Directory containing the dataset.

  • num_steps (int) – Number of steps per sequence.

  • frameskip (int) – Number of steps to skip between frames.

  • episodes (np.ndarray) – Array of unique episode indices.

  • episode_slices (dict) – Mapping from episode index to dataset indices.

  • cum_slices (np.ndarray) – Cumulative sum of valid samples per episode.

  • idx_to_ep (np.ndarray) – Mapping from sample index to episode.

  • img_cols (set) – Set of column names containing image paths.

__init__(path, *args, num_steps=2, frameskip=1, cache_dir=None, **kwargs)[source]

Initialize the StepsDataset.

Parameters:
  • path (str) – Name or path of the dataset within cache directory.

  • *args – Additional arguments passed to parent class.

  • num_steps (int, optional) – Number of consecutive steps per sample. Defaults to 2.

  • frameskip (int, optional) – Number of steps between sampled frames. Defaults to 1.

  • cache_dir (str, optional) – Cache directory path. Defaults to None (uses default).

  • **kwargs – Additional keyword arguments passed to parent class.

Raises:
  • AssertionError – If required columns are missing from dataset.

  • ValueError – If episodes are too short for the requested num_steps and frameskip.

get_episode_slice(episode_idx, episode_indices)[source]

Get dataset indices for a specific episode.

Parameters:
  • episode_idx (int) – Episode index to retrieve.

  • episode_indices (array-like) – Array of episode indices for all steps.

Returns:

Indices of steps belonging to the specified episode.

Return type:

np.ndarray

Raises:

ValueError – If episode is too short for num_steps and frameskip.

infer_img_path_columns()[source]

Infer which dataset columns contain image file paths.

Checks the first dataset element to identify string columns with common image file extensions.

Returns:

Set of column names containing image file paths.

Return type:

set