import time
import numpy as np
import torch
from gymnasium.spaces import Box
from loguru import logger as logging
from .solver import Costable
[docs]
class CEMSolver:
"""Cross Entropy Method Solver.
adapted from https://github.com/gaoyuezhou/dino_wm/blob/main/planning/cem.py
"""
def __init__(
self,
model: Costable,
batch_size: int = 1,
num_samples: int = 300,
var_scale: float = 1,
n_steps: int = 30,
topk: int = 30,
device="cpu",
seed: int = 1234,
):
self.model = model
self.batch_size = batch_size
self.var_scale = var_scale
self.num_samples = num_samples
self.n_steps = n_steps
self.topk = topk
self.device = device
self.torch_gen = torch.Generator(device=device).manual_seed(seed)
@property
def n_envs(self) -> int:
return self._n_envs
@property
def action_dim(self) -> int:
return self._action_dim * self._config.action_block
@property
def horizon(self) -> int:
return self._config.horizon
def __call__(self, *args, **kwargs) -> torch.Tensor:
return self.solve(*args, **kwargs)
[docs]
def init_action_distrib(self, actions=None):
"""Initialize the action distribution params (mu, sigma) given the initial condition.
Args:
actions (n_envs, T, action_dim): initial actions, T <= horizon
"""
var = self.var_scale * torch.ones([self.n_envs, self.horizon, self.action_dim])
mean = torch.zeros([self.n_envs, 0, self.action_dim]) if actions is None else actions
# -- fill remaining actions with random sample
remaining = self.horizon - mean.shape[1]
if remaining > 0:
device = mean.device
new_mean = torch.zeros([self.n_envs, remaining, self.action_dim])
mean = torch.cat([mean, new_mean], dim=1).to(device)
return mean, var
[docs]
@torch.inference_mode()
def solve(self, info_dict, init_action=None) -> dict:
start_time = time.time()
outputs = {
"costs": [],
"mean": [], # History of means
"var": [], # History of vars
}
# -- initialize the action distribution globally
mean, var = self.init_action_distrib(init_action)
mean = mean.to(self.device)
var = var.to(self.device)
total_envs = self.n_envs
# --- Iterate over batches ---
for start_idx in range(0, total_envs, self.batch_size):
end_idx = min(start_idx + self.batch_size, total_envs)
current_bs = end_idx - start_idx
# Slice Distribution Parameters for current batch
batch_mean = mean[start_idx:end_idx]
batch_var = var[start_idx:end_idx]
# Expand Info Dict for current batch
expanded_infos = {}
for k, v in info_dict.items():
# v is shape (n_envs, ...)
# Slice batch
v_batch = v[start_idx:end_idx]
if torch.is_tensor(v):
# Add sample dim: (batch, 1, ...)
v_batch = v_batch.unsqueeze(1)
# Expand: (batch, num_samples, ...)
v_batch = v_batch.expand(current_bs, self.num_samples, *v_batch.shape[2:])
elif isinstance(v, np.ndarray):
v_batch = np.repeat(v_batch[:, None, ...], self.num_samples, axis=1)
expanded_infos[k] = v_batch
# Optimization Loop
final_batch_cost = None
for step in range(self.n_steps):
# Sample action sequences: (Batch, Num_Samples, Horizon, Dim)
candidates = torch.randn(
current_bs,
self.num_samples,
self.horizon,
self.action_dim,
generator=self.torch_gen,
device=self.device,
)
# Scale and shift: (Batch, N, H, D) * (Batch, 1, H, D) + (Batch, 1, H, D)
candidates = candidates * batch_var.unsqueeze(1) + batch_mean.unsqueeze(1)
# Force the first sample to be the current mean
candidates[:, 0] = batch_mean
current_info = expanded_infos.copy()
# Evaluate candidates
costs = self.model.get_cost(current_info, candidates)
assert isinstance(costs, torch.Tensor), f"Expected cost to be a torch.Tensor, got {type(costs)}"
assert costs.ndim == 2 and costs.shape[0] == current_bs and costs.shape[1] == self.num_samples, (
f"Expected cost to be of shape ({current_bs}, {self.num_samples}), got {costs.shape}"
)
# Select Top-K
# topk_vals: (Batch, K), topk_inds: (Batch, K)
topk_vals, topk_inds = torch.topk(costs, k=self.topk, dim=1, largest=False)
# Gather Top-K Candidates
# We need to select the specific candidates corresponding to topk_inds
batch_indices = torch.arange(current_bs, device=self.device).unsqueeze(1).expand(-1, self.topk)
# Indexing: candidates[batch_idx, sample_idx]
# Result shape: (Batch, K, Horizon, Dim)
topk_candidates = candidates[batch_indices, topk_inds]
# Update Mean and Variance based on Top-K
batch_mean = topk_candidates.mean(dim=1)
batch_var = topk_candidates.std(dim=1)
# Update final cost for logging
# We average the cost of the top elites
final_batch_cost = topk_vals.mean(dim=1).cpu().tolist()
# Write results back to global storage
mean[start_idx:end_idx] = batch_mean
var[start_idx:end_idx] = batch_var
# Store history/metadata
outputs["costs"].extend(final_batch_cost)
outputs["actions"] = mean.detach().cpu()
outputs["mean"] = [mean.detach().cpu()]
outputs["var"] = [var.detach().cpu()]
print(f"CEM solve time: {time.time() - start_time:.4f} seconds")
return outputs