stable_worldmodel.solver package

Submodules

stable_worldmodel.solver.cem module

class CEMSolver(model: Costable, num_samples, var_scale, n_steps, topk, device='cpu')[source]

Bases: object

Cross Entropy Method Solver.

adapted from https://github.com/gaoyuezhou/dino_wm/blob/main/planning/cem.py

property action_dim: int
configure(*, action_space, n_envs: int, config) None[source]
property horizon: int
init_action_distrib(actions=None)[source]

Initialize the action distribution params (mu, sigma) given the initial condition.

Parameters:

actions (n_envs, T, action_dim) – initial actions, T <= horizon

property n_envs: int
solve(info_dict, init_action=None)[source]

stable_worldmodel.solver.gd module

class GDSolver(model: Costable, n_steps: int, action_noise=0.0, device='cpu')[source]

Bases: Module

Gradient Descent Solver.

property action_dim: int
configure(*, action_space, n_envs: int, config) None[source]
property horizon: int
init_action(actions=None)[source]

Initialize the action tensor for the solver.

set self.init - initial action sequences (n_envs, horizon, action_dim)

property n_envs: int
set_seed(seed: int) None[source]

Set random seed for deterministic behavior.

Parameters:

seed – Random seed to use for numpy and torch

solve(info_dict, init_action=None) Tensor[source]

Solve the planning optimization problem using gradient descent.

stable_worldmodel.solver.mppi module

class MPPISolver(model: Costable, num_samples, num_elites, var_scale, n_steps, use_elites=True, temperature=0.5, device='cpu')[source]

Bases: object

Model Predictive Path Integral Solver.

proposed in https://arxiv.org/abs/1509.01149 algorithm from: https://acdslab.github.io/mppi-generic-website/docs/mppi.html

Note

The original MPPI compute the cost as a summation of costs along the trajectory. Here, we use the final cost only, which should be updated in future updates.

property action_dim: int
compute_trajectory_weights(costs: Tensor) Tensor[source]

Compute trajectory weights from costs using softmin with temperature.

Parameters:

costs (num_samples,) – Tensor of trajectory costs.

Returns:

Tensor of trajectory weights.

configure(*, action_space, n_envs: int, config) None[source]
property horizon: int
init_action_distrib(actions=None)[source]

Initialize the action distribution params (mu, sigma) given the initial condition.

Parameters:

actions (n_envs, T, action_dim) – initial actions, T <= horizon

property n_envs: int
solve(info_dict, init_action=None)[source]

stable_worldmodel.solver.nevergrad module

class NevergradSolver(model: Costable, optimizer, n_steps: int, device='cpu')[source]

Bases: object

Nevergrad Solver.

supporting https://github.com/facebookresearch/nevergrad

Attention

  • CPU based optimizer (no GPU support)

  • It’s your duty to ensure num_workers == n_envs for parallelization

property action_dim: int
configure(*, action_space, n_envs: int, config) None[source]
property horizon: int
property n_envs: int
solve(info_dict, init_action=None)[source]

stable_worldmodel.solver.old_mppi module

stable_worldmodel.solver.random module

Random action sampling solver for planning problems.

This module provides a baseline solver that samples random actions from the action space without any optimization. It serves as a simple baseline for comparison with more sophisticated planning algorithms like CEM, gradient descent, or MPPI.

The RandomSolver is useful for:
  • Establishing performance baselines in model-based planning experiments

  • Testing environment and policy infrastructure without complex optimization

  • Quick debugging of planning pipelines

  • Ablation studies comparing random vs. optimized action selection

Classes:

RandomSolver: Samples random actions uniformly from the action space.

Typical usage example:

Basic usage with stable-worldmodel:

import stable_worldmodel as swm

# Create world and solver
world = swm.World("swm/SimplePointMaze-v0", num_envs=4)
solver = swm.solver.RandomSolver()

# Configure planning
config = swm.PlanConfig(horizon=10, receding_horizon=5, action_block=1)

# Create policy and evaluate
policy = swm.policy.WorldModelPolicy(solver=solver, config=config)
world.set_policy(policy)
results = world.evaluate(episodes=5)

Direct solver usage:

from stable_worldmodel.solver import RandomSolver
import gymnasium as gym

# Setup
env = gym.make("Pendulum-v1")
solver = RandomSolver()

config = swm.PlanConfig(horizon=10, receding_horizon=5, action_block=1)
solver.configure(action_space=env.action_space, n_envs=1, config=config)

# Generate random actions
result = solver.solve({})
actions = result["actions"]  # Shape: (1, 10, action_dim)
class RandomSolver[source]

Bases: object

Random action sampling solver for model-based planning.

This solver generates action sequences by uniformly sampling from the action space without any optimization or cost evaluation. Unlike optimization-based solvers (CEM, GD, MPPI), it does not require a world model or cost function, making it extremely fast and simple to use.

The solver is primarily intended as a baseline for evaluating the performance gains of model-based planning. Random action selection typically performs poorly on complex tasks but can be surprisingly effective in simple or stochastic environments.

Key features:
  • Zero computation cost: No forward passes through world models

  • Parallel sampling: Generates actions for multiple environments simultaneously

  • Action blocking: Supports repeating actions for temporal abstraction

  • Warm-starting: Can extend partial action sequences

  • API compatible: Works with WorldModelPolicy and other solver-based policies

Variables:
  • n_envs (int) – Number of parallel environments being planned for.

  • action_dim (int) – Total action dimensionality (base_dim × action_block).

  • horizon (int) – Number of planning steps in the action sequence.

Example

Using with stable-worldmodel’s World and Policy classes:

import stable_worldmodel as swm

# Create environment
world = swm.World("swm/SimplePointMaze-v0", num_envs=8)

# Setup random solver policy
config = swm.PlanConfig(
    horizon=15,  # Plan 15 steps ahead
    receding_horizon=5,  # Replan every 5 steps
    action_block=1,  # No action repetition
)
solver = swm.solver.RandomSolver()
policy = swm.policy.WorldModelPolicy(solver=solver, config=config)

# Evaluate
world.set_policy(policy)
results = world.evaluate(episodes=10, seed=42)
print(f"Baseline reward: {results['mean_reward']:.2f}")

Standalone usage for custom planning loops:

from stable_worldmodel.solver import RandomSolver
import gymnasium as gym
import torch

env = gym.make("Hopper-v4", render_mode="rgb_array")
solver = RandomSolver()

# Configure
config = swm.PlanConfig(horizon=20, receding_horizon=10, action_block=2)
solver.configure(action_space=env.action_space, n_envs=1, config=config)

# Generate and execute actions
obs, info = env.reset()
result = solver.solve(info_dict={})
actions = result["actions"][0]  # Get first env's actions

for i in range(config.receding_horizon):
    action = actions[i].numpy()
    obs, reward, done, truncated, info = env.step(action)
    if done or truncated:
        break

Note

This solver ignores the info_dict parameter in solve() since it doesn’t use state information or world models. The parameter is kept for API consistency with optimization-based solvers.

See also

__init__()[source]

Initialize an unconfigured RandomSolver.

Creates a solver instance that must be configured via configure() before calling solve(). This two-step initialization allows the policy framework to instantiate solvers before environment details are available.

Example

Typical initialization pattern:

solver = RandomSolver()  # Create
solver.configure(...)  # Configure with env specs
result = solver.solve({})  # Use
property action_dim: int

Total action dimensionality including action blocking.

Equals base_action_dim x action_block. For example, if the environment has 3D continuous actions and action_block=5, this returns 15.

Type:

int

configure(*, action_space, n_envs: int, config) None[source]

Configure the solver with environment and planning specifications.

Must be called before solve() to set up the action space dimensions, number of parallel environments, and planning configuration (horizon, action blocking, etc.).

Parameters:
  • action_space – Gymnasium action space defining valid actions. Must have a sample() method and shape attribute. Typically env.action_space or env.single_action_space.

  • n_envs (int) – Number of parallel environments to plan for. Action sequences will be generated for each environment independently. Must be ≥ 1.

  • config

    Planning configuration object (typically swm.PlanConfig) with required attributes:

    • horizon (int): Number of planning steps. Each step corresponds to one action selection point.

    • action_block (int): Number of environment steps per planning step. Actions are repeated this many times.

Raises:
  • AttributeError – If config is missing required attributes.

  • ValueError – If n_envs < 1 or horizon < 1.

Example

Configure for vectorized environment:

import stable_worldmodel as swm

world = swm.World("swm/SimplePointMaze-v0", num_envs=8)
solver = swm.solver.RandomSolver()

config = swm.PlanConfig(horizon=10, receding_horizon=5, action_block=1)
solver.configure(
    action_space=world.envs.single_action_space,
    n_envs=world.num_envs,
    config=config,
)

Note

The solver extracts action_space.shape[1:] as the base action dimensionality, assuming the first dimension is the batch/environment dimension in vectorized action spaces.

property horizon: int

Planning horizon in steps (number of action selections).

Type:

int

property n_envs: int

Number of parallel environments the solver plans for.

Type:

int

solve(info_dict, init_action=None) dict[source]

Generate random action sequences for the planning horizon.

Samples random actions uniformly from the action space to create action sequences for each environment. If partial action sequences are provided via init_action, only the remaining steps are sampled and concatenated.

This method does not use info_dict since random sampling doesn’t require state information, but the parameter is kept for API consistency with optimization-based solvers that do use environment state.

Parameters:
  • info_dict (dict) – Environment state information dictionary. Not used by RandomSolver but required for solver API consistency. Other solvers may use fields like ‘state’, ‘observation’, ‘latent’, etc.

  • init_action (torch.Tensor, optional) – Partial action sequence to warm-start planning. Shape: (n_envs, k, action_dim) where k < horizon. The solver samples actions for the remaining (horizon - k) steps and concatenates them. Useful for receding horizon planning where previous plans are reused. Defaults to None (sample full horizon).

Returns:

Dictionary with a single key:
  • 'actions' (torch.Tensor): Random action sequences with shape (n_envs, horizon, action_dim). Values are sampled uniformly from the action space bounds.

Return type:

dict

Example

Generate full random action sequence:

solver.configure(action_space=env.action_space, n_envs=4, config=config)
result = solver.solve(info_dict={})
actions = result["actions"]  # Shape: (4, horizon, action_dim)

Warm-start with partial sequence (receding horizon planning):

# First planning step: full horizon
result1 = solver.solve({})
actions1 = result1["actions"]  # (4, 10, action_dim)

# Execute first 5 actions, then replan
executed = actions1[:, :5, :]
remaining = actions1[:, 5:, :]  # Use as warm-start

# Second planning step: extend remaining actions
result2 = solver.solve({}, init_action=remaining)
actions2 = result2["actions"]  # (4, 10, action_dim) - new last 5 steps

Note

The sampling uses action_space.sample() which respects the space’s bounds (e.g., Box low/high limits). For continuous spaces, this typically produces uniform distributions. For discrete spaces, it samples uniformly over valid discrete values.

stable_worldmodel.solver.solver module

class Costable(*args, **kwargs)[source]

Bases: Protocol

Protocol for world model cost functions.

This protocol defines the interface for models that can compute costs for planning and optimization. Models implementing this protocol can evaluate the quality of action sequences in a given environment state.

Example

>>> class MyWorldModel(Costable):
...     def get_cost(self, info_dict, action_candidates):
...         # Compute cost based on predicted trajectories from action candidates
...         return costs
get_cost(action_candidates: Tensor) Tensor[source]

Compute cost for given action candidates based on info dictionary.

Parameters:
  • info_dict – Dictionary containing environment state information. Typically includes keys like ‘pixels’, ‘goal’, ‘proprio’, ‘predicted_states’, etc.

  • action_candidates – Tensor of shape (B, horizon, action_dim) containing action sequences to evaluate.

Returns:

Tensor of shape (n_envs,) containing the cost of each environment action sequence .A lower costs indicate better action sequences.

Note

The cost computation should be differentiable (requires_grad=True) with respect to action_candidates to enable gradient-based planning methods.

class Solver(*args, **kwargs)[source]

Bases: Protocol

Protocol for model-based planning solvers.

This protocol defines the interface for optimization algorithms that plan action sequences by minimizing a cost function computed by a world model. Solvers receive the current environment state (observations, goals, proprioception) and output optimal action sequences that achieve desired behaviors.

Planning Process:
  1. Receive current state via info_dict (pixels, goal, proprio, etc.)

  2. Initialize or warm-start action sequences

  3. Optimize actions using the world model’s get_cost function

  4. Return optimized action sequences for execution

The protocol supports various optimization methods including:
  • Gradient-based: GDSolver (gradient descent)

  • Sampling-based: CEMSolver (cross-entropy method), MPPISolver

  • Random: RandomSolver (baseline)

Key Concepts:
  • Horizon: Number of timesteps to plan ahead

  • Action Block: Number of actions grouped together due to frame skip.

  • Receding Horizon: Number of actions actually executed before replanning

  • Warm Start: Using previous solutions leftover to initialize new optimization

Variables:
  • action_dim (int) – Flattened action dimension including action_block grouping. Formula: base_action_dim * action_block

  • n_envs (int) – Number of parallel environments being optimized simultaneously.

  • horizon (int) – Planning horizon length in timesteps.

Example

Basic usage with a world model:

>>> # Setup world model and planning config
>>> world_model = DINOWM(encoder, predictor, ...)
>>> plan_config = PlanConfig(horizon=10, receding_horizon=5, action_block=2)
>>>
>>> # Create and configure solver
>>> solver = GDSolver(world_model, n_steps=10, device="cuda")
>>> solver.configure(
...     action_space=env.action_space,
...     n_envs=4,
...     config=plan_config
... )
>>>
>>> # Solve for optimal actions given current state
>>> info_dict = {'pixels': pixels, 'goal': goal, 'proprio': proprio}
>>> outputs = solver.solve(info_dict, init_action=None)
>>> actions = outputs["actions"]  # Shape: (4, 10, action_dim)
>>>
>>> # Warm-start next optimization with remaining actions
>>> next_outputs = solver.solve(info_dict, init_action=outputs["actions"][:, 5:])

See also

  • Costable: Protocol defining the world model cost interface

  • PlanConfig: Configuration dataclass for planning parameters

  • GDSolver, CEMSolver, MPPISolver: Concrete solver implementations

property action_dim: int

Flattened action dimension including action_block grouping.

This is the total size of actions per timestep, computed as: base_action_dim * action_block

The action_block groups multiple actions together for frame skipping. For example, if the environment has 2D actions and action_block=5, then action_dim=10 (the 2 action dimensions grouped 5 times).

Returns:

Total flattened action dimension used in optimization.

Return type:

int

Type:

int

configure(*, action_space: Space, n_envs: int, config) None[source]

Configure the solver with environment and planning specifications.

This method initializes the solver’s internal state based on the environment’s action space and planning configuration. Must be called once after solver creation and before any solve() calls.

Parameters:
  • action_space (gym.Space) – Environment’s action space. For continuous control, this should be a Box space. The shape is typically (n_envs, action_dim) for vectorized environments.

  • n_envs (int) – Number of parallel environments to optimize for. The solver will produce n_envs independent action sequences.

  • config (PlanConfig) – Planning configuration containing: - horizon: Number of future timesteps to plan - receding_horizon: Number of planned actions to execute - action_block: Number of actions grouped together due to frame skip

Note

This method should only be called once during initialization. The solver caches the configuration internally for use in solve().

Raises:

Warning – If action_space is not a Box (some solvers only support continuous actions).

property horizon: int

Planning horizon length in timesteps.

This is the number of future timesteps the solver plans ahead. Note that this may differ from receding_horizon (the number of actions actually executed before replanning).

Returns:

Number of timesteps in the planning horizon.

Return type:

int

Type:

int

property n_envs: int

Number of parallel environments being planned for.

Returns:

Number of independent action sequences the solver optimizes.

Return type:

int

Type:

int

solve(info_dict, init_action=None) dict[source]

Solve the planning optimization problem to find optimal actions.

This is the main method that performs trajectory optimization. It uses the world model to evaluate action sequences and finds actions that minimize the cost function. The optimization strategy is solver-specific (gradient descent, sampling, etc.).

Typical workflow:
  1. Initialize action sequences (from init_action or zeros)

  2. Iteratively evaluate cost and update actions

  3. Return optimized actions and optimization statistics

Parameters:
  • info_dict (dict) – Current environment state containing: - ‘pixels’ (np.ndarray): Current observation images, shape (n_envs, H, W, 3) - ‘goal’ (np.ndarray): Goal observation images, shape (n_envs, H, W, 3) - ‘proprio’ (np.ndarray, optional): Proprioceptive state, shape (n_envs, proprio_dim) - ‘action’ (np.ndarray, optional): Previous actions for history - Additional task-specific keys as needed

  • init_action (torch.Tensor, optional) – Warm-start action sequences with shape (n_envs, init_horizon, action_dim). Common use cases: - None: Initialize all actions to zero (cold start) - Partial sequence: Pad remaining horizon with zeros - Previous solution shifted: Warm-start from last optimization

Returns:

Optimization results containing:
  • ’actions’ (torch.Tensor): Optimized action sequences with shape (n_envs, horizon, action_dim). These are the planned actions.

  • ’cost’ (list[float]): Cost values during optimization. Format and length depend on the solver implementation.

  • ’trajectory’ (list[torch.Tensor]): Intermediate action sequences during optimization (solver-dependent).

  • Additional solver-specific keys (e.g., ‘elite_actions’ for CEM)

Return type:

dict

Note

The returned actions are typically in the solver’s internal representation and may require denormalization or reshaping before execution in the environment. The WorldModelPolicy handles this transformation.

Example

Cold start (zero initialization): >>> outputs = solver.solve(info_dict)

Warm start with previous solution: >>> outputs1 = solver.solve(info_dict) >>> # Execute first 5 actions, keep rest for warm start >>> outputs2 = solver.solve(new_info_dict, init_action=outputs1[“actions”][:, 5:])

Module contents

class CEMSolver(model: Costable, num_samples, var_scale, n_steps, topk, device='cpu')[source]

Bases: object

Cross Entropy Method Solver.

adapted from https://github.com/gaoyuezhou/dino_wm/blob/main/planning/cem.py

property action_dim: int
configure(*, action_space, n_envs: int, config) None[source]
property horizon: int
init_action_distrib(actions=None)[source]

Initialize the action distribution params (mu, sigma) given the initial condition.

Parameters:

actions (n_envs, T, action_dim) – initial actions, T <= horizon

property n_envs: int
solve(info_dict, init_action=None)[source]
class GDSolver(model: Costable, n_steps: int, action_noise=0.0, device='cpu')[source]

Bases: Module

Gradient Descent Solver.

property action_dim: int
configure(*, action_space, n_envs: int, config) None[source]
property horizon: int
init_action(actions=None)[source]

Initialize the action tensor for the solver.

set self.init - initial action sequences (n_envs, horizon, action_dim)

property n_envs: int
set_seed(seed: int) None[source]

Set random seed for deterministic behavior.

Parameters:

seed – Random seed to use for numpy and torch

solve(info_dict, init_action=None) Tensor[source]

Solve the planning optimization problem using gradient descent.

class MPPISolver(model: Costable, num_samples, num_elites, var_scale, n_steps, use_elites=True, temperature=0.5, device='cpu')[source]

Bases: object

Model Predictive Path Integral Solver.

proposed in https://arxiv.org/abs/1509.01149 algorithm from: https://acdslab.github.io/mppi-generic-website/docs/mppi.html

Note

The original MPPI compute the cost as a summation of costs along the trajectory. Here, we use the final cost only, which should be updated in future updates.

property action_dim: int
compute_trajectory_weights(costs: Tensor) Tensor[source]

Compute trajectory weights from costs using softmin with temperature.

Parameters:

costs (num_samples,) – Tensor of trajectory costs.

Returns:

Tensor of trajectory weights.

configure(*, action_space, n_envs: int, config) None[source]
property horizon: int
init_action_distrib(actions=None)[source]

Initialize the action distribution params (mu, sigma) given the initial condition.

Parameters:

actions (n_envs, T, action_dim) – initial actions, T <= horizon

property n_envs: int
solve(info_dict, init_action=None)[source]
class NevergradSolver(model: Costable, optimizer, n_steps: int, device='cpu')[source]

Bases: object

Nevergrad Solver.

supporting https://github.com/facebookresearch/nevergrad

Attention

  • CPU based optimizer (no GPU support)

  • It’s your duty to ensure num_workers == n_envs for parallelization

property action_dim: int
configure(*, action_space, n_envs: int, config) None[source]
property horizon: int
property n_envs: int
solve(info_dict, init_action=None)[source]
class RandomSolver[source]

Bases: object

Random action sampling solver for model-based planning.

This solver generates action sequences by uniformly sampling from the action space without any optimization or cost evaluation. Unlike optimization-based solvers (CEM, GD, MPPI), it does not require a world model or cost function, making it extremely fast and simple to use.

The solver is primarily intended as a baseline for evaluating the performance gains of model-based planning. Random action selection typically performs poorly on complex tasks but can be surprisingly effective in simple or stochastic environments.

Key features:
  • Zero computation cost: No forward passes through world models

  • Parallel sampling: Generates actions for multiple environments simultaneously

  • Action blocking: Supports repeating actions for temporal abstraction

  • Warm-starting: Can extend partial action sequences

  • API compatible: Works with WorldModelPolicy and other solver-based policies

Variables:
  • n_envs (int) – Number of parallel environments being planned for.

  • action_dim (int) – Total action dimensionality (base_dim × action_block).

  • horizon (int) – Number of planning steps in the action sequence.

Example

Using with stable-worldmodel’s World and Policy classes:

import stable_worldmodel as swm

# Create environment
world = swm.World("swm/SimplePointMaze-v0", num_envs=8)

# Setup random solver policy
config = swm.PlanConfig(
    horizon=15,  # Plan 15 steps ahead
    receding_horizon=5,  # Replan every 5 steps
    action_block=1,  # No action repetition
)
solver = swm.solver.RandomSolver()
policy = swm.policy.WorldModelPolicy(solver=solver, config=config)

# Evaluate
world.set_policy(policy)
results = world.evaluate(episodes=10, seed=42)
print(f"Baseline reward: {results['mean_reward']:.2f}")

Standalone usage for custom planning loops:

from stable_worldmodel.solver import RandomSolver
import gymnasium as gym
import torch

env = gym.make("Hopper-v4", render_mode="rgb_array")
solver = RandomSolver()

# Configure
config = swm.PlanConfig(horizon=20, receding_horizon=10, action_block=2)
solver.configure(action_space=env.action_space, n_envs=1, config=config)

# Generate and execute actions
obs, info = env.reset()
result = solver.solve(info_dict={})
actions = result["actions"][0]  # Get first env's actions

for i in range(config.receding_horizon):
    action = actions[i].numpy()
    obs, reward, done, truncated, info = env.step(action)
    if done or truncated:
        break

Note

This solver ignores the info_dict parameter in solve() since it doesn’t use state information or world models. The parameter is kept for API consistency with optimization-based solvers.

See also

__init__()[source]

Initialize an unconfigured RandomSolver.

Creates a solver instance that must be configured via configure() before calling solve(). This two-step initialization allows the policy framework to instantiate solvers before environment details are available.

Example

Typical initialization pattern:

solver = RandomSolver()  # Create
solver.configure(...)  # Configure with env specs
result = solver.solve({})  # Use
property action_dim: int

Total action dimensionality including action blocking.

Equals base_action_dim x action_block. For example, if the environment has 3D continuous actions and action_block=5, this returns 15.

Type:

int

configure(*, action_space, n_envs: int, config) None[source]

Configure the solver with environment and planning specifications.

Must be called before solve() to set up the action space dimensions, number of parallel environments, and planning configuration (horizon, action blocking, etc.).

Parameters:
  • action_space – Gymnasium action space defining valid actions. Must have a sample() method and shape attribute. Typically env.action_space or env.single_action_space.

  • n_envs (int) – Number of parallel environments to plan for. Action sequences will be generated for each environment independently. Must be ≥ 1.

  • config

    Planning configuration object (typically swm.PlanConfig) with required attributes:

    • horizon (int): Number of planning steps. Each step corresponds to one action selection point.

    • action_block (int): Number of environment steps per planning step. Actions are repeated this many times.

Raises:
  • AttributeError – If config is missing required attributes.

  • ValueError – If n_envs < 1 or horizon < 1.

Example

Configure for vectorized environment:

import stable_worldmodel as swm

world = swm.World("swm/SimplePointMaze-v0", num_envs=8)
solver = swm.solver.RandomSolver()

config = swm.PlanConfig(horizon=10, receding_horizon=5, action_block=1)
solver.configure(
    action_space=world.envs.single_action_space,
    n_envs=world.num_envs,
    config=config,
)

Note

The solver extracts action_space.shape[1:] as the base action dimensionality, assuming the first dimension is the batch/environment dimension in vectorized action spaces.

property horizon: int

Planning horizon in steps (number of action selections).

Type:

int

property n_envs: int

Number of parallel environments the solver plans for.

Type:

int

solve(info_dict, init_action=None) dict[source]

Generate random action sequences for the planning horizon.

Samples random actions uniformly from the action space to create action sequences for each environment. If partial action sequences are provided via init_action, only the remaining steps are sampled and concatenated.

This method does not use info_dict since random sampling doesn’t require state information, but the parameter is kept for API consistency with optimization-based solvers that do use environment state.

Parameters:
  • info_dict (dict) – Environment state information dictionary. Not used by RandomSolver but required for solver API consistency. Other solvers may use fields like ‘state’, ‘observation’, ‘latent’, etc.

  • init_action (torch.Tensor, optional) – Partial action sequence to warm-start planning. Shape: (n_envs, k, action_dim) where k < horizon. The solver samples actions for the remaining (horizon - k) steps and concatenates them. Useful for receding horizon planning where previous plans are reused. Defaults to None (sample full horizon).

Returns:

Dictionary with a single key:
  • 'actions' (torch.Tensor): Random action sequences with shape (n_envs, horizon, action_dim). Values are sampled uniformly from the action space bounds.

Return type:

dict

Example

Generate full random action sequence:

solver.configure(action_space=env.action_space, n_envs=4, config=config)
result = solver.solve(info_dict={})
actions = result["actions"]  # Shape: (4, horizon, action_dim)

Warm-start with partial sequence (receding horizon planning):

# First planning step: full horizon
result1 = solver.solve({})
actions1 = result1["actions"]  # (4, 10, action_dim)

# Execute first 5 actions, then replan
executed = actions1[:, :5, :]
remaining = actions1[:, 5:, :]  # Use as warm-start

# Second planning step: extend remaining actions
result2 = solver.solve({}, init_action=remaining)
actions2 = result2["actions"]  # (4, 10, action_dim) - new last 5 steps

Note

The sampling uses action_space.sample() which respects the space’s bounds (e.g., Box low/high limits). For continuous spaces, this typically produces uniform distributions. For discrete spaces, it samples uniformly over valid discrete values.

class Solver(*args, **kwargs)[source]

Bases: Protocol

Protocol for model-based planning solvers.

This protocol defines the interface for optimization algorithms that plan action sequences by minimizing a cost function computed by a world model. Solvers receive the current environment state (observations, goals, proprioception) and output optimal action sequences that achieve desired behaviors.

Planning Process:
  1. Receive current state via info_dict (pixels, goal, proprio, etc.)

  2. Initialize or warm-start action sequences

  3. Optimize actions using the world model’s get_cost function

  4. Return optimized action sequences for execution

The protocol supports various optimization methods including:
  • Gradient-based: GDSolver (gradient descent)

  • Sampling-based: CEMSolver (cross-entropy method), MPPISolver

  • Random: RandomSolver (baseline)

Key Concepts:
  • Horizon: Number of timesteps to plan ahead

  • Action Block: Number of actions grouped together due to frame skip.

  • Receding Horizon: Number of actions actually executed before replanning

  • Warm Start: Using previous solutions leftover to initialize new optimization

Variables:
  • action_dim (int) – Flattened action dimension including action_block grouping. Formula: base_action_dim * action_block

  • n_envs (int) – Number of parallel environments being optimized simultaneously.

  • horizon (int) – Planning horizon length in timesteps.

Example

Basic usage with a world model:

>>> # Setup world model and planning config
>>> world_model = DINOWM(encoder, predictor, ...)
>>> plan_config = PlanConfig(horizon=10, receding_horizon=5, action_block=2)
>>>
>>> # Create and configure solver
>>> solver = GDSolver(world_model, n_steps=10, device="cuda")
>>> solver.configure(
...     action_space=env.action_space,
...     n_envs=4,
...     config=plan_config
... )
>>>
>>> # Solve for optimal actions given current state
>>> info_dict = {'pixels': pixels, 'goal': goal, 'proprio': proprio}
>>> outputs = solver.solve(info_dict, init_action=None)
>>> actions = outputs["actions"]  # Shape: (4, 10, action_dim)
>>>
>>> # Warm-start next optimization with remaining actions
>>> next_outputs = solver.solve(info_dict, init_action=outputs["actions"][:, 5:])

See also

  • Costable: Protocol defining the world model cost interface

  • PlanConfig: Configuration dataclass for planning parameters

  • GDSolver, CEMSolver, MPPISolver: Concrete solver implementations

property action_dim: int

Flattened action dimension including action_block grouping.

This is the total size of actions per timestep, computed as: base_action_dim * action_block

The action_block groups multiple actions together for frame skipping. For example, if the environment has 2D actions and action_block=5, then action_dim=10 (the 2 action dimensions grouped 5 times).

Returns:

Total flattened action dimension used in optimization.

Return type:

int

Type:

int

configure(*, action_space: Space, n_envs: int, config) None[source]

Configure the solver with environment and planning specifications.

This method initializes the solver’s internal state based on the environment’s action space and planning configuration. Must be called once after solver creation and before any solve() calls.

Parameters:
  • action_space (gym.Space) – Environment’s action space. For continuous control, this should be a Box space. The shape is typically (n_envs, action_dim) for vectorized environments.

  • n_envs (int) – Number of parallel environments to optimize for. The solver will produce n_envs independent action sequences.

  • config (PlanConfig) – Planning configuration containing: - horizon: Number of future timesteps to plan - receding_horizon: Number of planned actions to execute - action_block: Number of actions grouped together due to frame skip

Note

This method should only be called once during initialization. The solver caches the configuration internally for use in solve().

Raises:

Warning – If action_space is not a Box (some solvers only support continuous actions).

property horizon: int

Planning horizon length in timesteps.

This is the number of future timesteps the solver plans ahead. Note that this may differ from receding_horizon (the number of actions actually executed before replanning).

Returns:

Number of timesteps in the planning horizon.

Return type:

int

Type:

int

property n_envs: int

Number of parallel environments being planned for.

Returns:

Number of independent action sequences the solver optimizes.

Return type:

int

Type:

int

solve(info_dict, init_action=None) dict[source]

Solve the planning optimization problem to find optimal actions.

This is the main method that performs trajectory optimization. It uses the world model to evaluate action sequences and finds actions that minimize the cost function. The optimization strategy is solver-specific (gradient descent, sampling, etc.).

Typical workflow:
  1. Initialize action sequences (from init_action or zeros)

  2. Iteratively evaluate cost and update actions

  3. Return optimized actions and optimization statistics

Parameters:
  • info_dict (dict) – Current environment state containing: - ‘pixels’ (np.ndarray): Current observation images, shape (n_envs, H, W, 3) - ‘goal’ (np.ndarray): Goal observation images, shape (n_envs, H, W, 3) - ‘proprio’ (np.ndarray, optional): Proprioceptive state, shape (n_envs, proprio_dim) - ‘action’ (np.ndarray, optional): Previous actions for history - Additional task-specific keys as needed

  • init_action (torch.Tensor, optional) – Warm-start action sequences with shape (n_envs, init_horizon, action_dim). Common use cases: - None: Initialize all actions to zero (cold start) - Partial sequence: Pad remaining horizon with zeros - Previous solution shifted: Warm-start from last optimization

Returns:

Optimization results containing:
  • ’actions’ (torch.Tensor): Optimized action sequences with shape (n_envs, horizon, action_dim). These are the planned actions.

  • ’cost’ (list[float]): Cost values during optimization. Format and length depend on the solver implementation.

  • ’trajectory’ (list[torch.Tensor]): Intermediate action sequences during optimization (solver-dependent).

  • Additional solver-specific keys (e.g., ‘elite_actions’ for CEM)

Return type:

dict

Note

The returned actions are typically in the solver’s internal representation and may require denormalization or reshaping before execution in the environment. The WorldModelPolicy handles this transformation.

Example

Cold start (zero initialization): >>> outputs = solver.solve(info_dict)

Warm start with previous solution: >>> outputs1 = solver.solve(info_dict) >>> # Execute first 5 actions, keep rest for warm start >>> outputs2 = solver.solve(new_info_dict, init_action=outputs1[“actions”][:, 5:])