StepsDataset
- class StepsDataset(path, *args, num_steps=2, frameskip=1, cache_dir=None, **kwargs)[source]
Bases:
HFDatasetDataset for loading multi-step trajectory sequences.
This dataset loads sequences of consecutive steps from episodic data, supporting frame skipping and automatic image loading from file paths. Inherits from stable_pretraining’s HFDataset.
- Variables:
data_dir (Path) – Directory containing the dataset.
num_steps (int) – Number of steps per sequence.
frameskip (int) – Number of steps to skip between frames.
episodes (np.ndarray) – Array of unique episode indices.
episode_slices (dict) – Mapping from episode index to dataset indices.
cum_slices (np.ndarray) – Cumulative sum of valid samples per episode.
idx_to_ep (np.ndarray) – Mapping from sample index to episode.
img_cols (set) – Set of column names containing image paths.
- __init__(path, *args, num_steps=2, frameskip=1, cache_dir=None, **kwargs)[source]
Initialize the StepsDataset.
- Parameters:
path (str) – Name or path of the dataset within cache directory.
*args – Additional arguments passed to parent class.
num_steps (int, optional) – Number of consecutive steps per sample. Defaults to 2.
frameskip (int, optional) – Number of steps between sampled frames. Defaults to 1.
cache_dir (str, optional) – Cache directory path. Defaults to None (uses default).
**kwargs – Additional keyword arguments passed to parent class.
- Raises:
AssertionError – If required columns are missing from dataset.
ValueError – If episodes are too short for the requested num_steps and frameskip.
- get_episode_slice(episode_idx, episode_indices)[source]
Get dataset indices for a specific episode.
- Parameters:
episode_idx (int) – Episode index to retrieve.
episode_indices (array-like) – Array of episode indices for all steps.
- Returns:
Indices of steps belonging to the specified episode.
- Return type:
np.ndarray
- Raises:
ValueError – If episode is too short for num_steps and frameskip.
- infer_img_path_columns()[source]
Infer which dataset columns contain image file paths.
Checks the first dataset element to identify string columns with common image file extensions.
- Returns:
Set of column names containing image file paths.
- Return type:
set