Source code for stable_worldmodel.solver.cem

import time

import numpy as np
import torch
from gymnasium.spaces import Box
from loguru import logger as logging

from .solver import Costable


[docs] class CEMSolver: """Cross Entropy Method Solver. adapted from https://github.com/gaoyuezhou/dino_wm/blob/main/planning/cem.py """ def __init__( self, model: Costable, batch_size: int = 1, num_samples: int = 300, var_scale: float = 1, n_steps: int = 30, topk: int = 30, device="cpu", seed: int = 1234, ): self.model = model self.batch_size = batch_size self.var_scale = var_scale self.num_samples = num_samples self.n_steps = n_steps self.topk = topk self.device = device self.torch_gen = torch.Generator(device=device).manual_seed(seed)
[docs] def configure(self, *, action_space, n_envs: int, config) -> None: self._action_space = action_space self._n_envs = n_envs self._config = config self._action_dim = int(np.prod(action_space.shape[1:])) self._configured = True # warning if action space is discrete if not isinstance(action_space, Box): logging.warning(f"Action space is discrete, got {type(action_space)}. GDSolver may not work as expected.")
@property def n_envs(self) -> int: return self._n_envs @property def action_dim(self) -> int: return self._action_dim * self._config.action_block @property def horizon(self) -> int: return self._config.horizon def __call__(self, *args, **kwargs) -> torch.Tensor: return self.solve(*args, **kwargs)
[docs] def init_action_distrib(self, actions=None): """Initialize the action distribution params (mu, sigma) given the initial condition. Args: actions (n_envs, T, action_dim): initial actions, T <= horizon """ var = self.var_scale * torch.ones([self.n_envs, self.horizon, self.action_dim]) mean = torch.zeros([self.n_envs, 0, self.action_dim]) if actions is None else actions # -- fill remaining actions with random sample remaining = self.horizon - mean.shape[1] if remaining > 0: device = mean.device new_mean = torch.zeros([self.n_envs, remaining, self.action_dim]) mean = torch.cat([mean, new_mean], dim=1).to(device) return mean, var
[docs] @torch.inference_mode() def solve(self, info_dict, init_action=None) -> dict: start_time = time.time() outputs = { "costs": [], "mean": [], # History of means "var": [], # History of vars } # -- initialize the action distribution globally mean, var = self.init_action_distrib(init_action) mean = mean.to(self.device) var = var.to(self.device) total_envs = self.n_envs # --- Iterate over batches --- for start_idx in range(0, total_envs, self.batch_size): end_idx = min(start_idx + self.batch_size, total_envs) current_bs = end_idx - start_idx # Slice Distribution Parameters for current batch batch_mean = mean[start_idx:end_idx] batch_var = var[start_idx:end_idx] # Expand Info Dict for current batch expanded_infos = {} for k, v in info_dict.items(): # v is shape (n_envs, ...) # Slice batch v_batch = v[start_idx:end_idx] if torch.is_tensor(v): # Add sample dim: (batch, 1, ...) v_batch = v_batch.unsqueeze(1) # Expand: (batch, num_samples, ...) v_batch = v_batch.expand(current_bs, self.num_samples, *v_batch.shape[2:]) elif isinstance(v, np.ndarray): v_batch = np.repeat(v_batch[:, None, ...], self.num_samples, axis=1) expanded_infos[k] = v_batch # Optimization Loop final_batch_cost = None for step in range(self.n_steps): # Sample action sequences: (Batch, Num_Samples, Horizon, Dim) candidates = torch.randn( current_bs, self.num_samples, self.horizon, self.action_dim, generator=self.torch_gen, device=self.device, ) # Scale and shift: (Batch, N, H, D) * (Batch, 1, H, D) + (Batch, 1, H, D) candidates = candidates * batch_var.unsqueeze(1) + batch_mean.unsqueeze(1) # Force the first sample to be the current mean candidates[:, 0] = batch_mean current_info = expanded_infos.copy() # Evaluate candidates costs = self.model.get_cost(current_info, candidates) assert isinstance(costs, torch.Tensor), f"Expected cost to be a torch.Tensor, got {type(costs)}" assert costs.ndim == 2 and costs.shape[0] == current_bs and costs.shape[1] == self.num_samples, ( f"Expected cost to be of shape ({current_bs}, {self.num_samples}), got {costs.shape}" ) # Select Top-K # topk_vals: (Batch, K), topk_inds: (Batch, K) topk_vals, topk_inds = torch.topk(costs, k=self.topk, dim=1, largest=False) # Gather Top-K Candidates # We need to select the specific candidates corresponding to topk_inds batch_indices = torch.arange(current_bs, device=self.device).unsqueeze(1).expand(-1, self.topk) # Indexing: candidates[batch_idx, sample_idx] # Result shape: (Batch, K, Horizon, Dim) topk_candidates = candidates[batch_indices, topk_inds] # Update Mean and Variance based on Top-K batch_mean = topk_candidates.mean(dim=1) batch_var = topk_candidates.std(dim=1) # Update final cost for logging # We average the cost of the top elites final_batch_cost = topk_vals.mean(dim=1).cpu().tolist() # Write results back to global storage mean[start_idx:end_idx] = batch_mean var[start_idx:end_idx] = batch_var # Store history/metadata outputs["costs"].extend(final_batch_cost) outputs["actions"] = mean.detach().cpu() outputs["mean"] = [mean.detach().cpu()] outputs["var"] = [var.detach().cpu()] print(f"CEM solve time: {time.time() - start_time:.4f} seconds") return outputs