import numpy as np
import torch
from gymnasium.spaces import Box
from loguru import logger as logging
from .solver import Costable
[docs]
class MPPISolver:
"""Model Predictive Path Integral Solver.
proposed in https://arxiv.org/abs/1509.01149
algorithm from: https://acdslab.github.io/mppi-generic-website/docs/mppi.html
Note:
The original MPPI compute the cost as a summation of costs along the trajectory.
Here, we use the final cost only, which should be updated in future updates.
"""
def __init__(
self,
model: Costable,
num_samples,
num_elites,
var_scale,
n_steps,
use_elites=True,
temperature=0.5,
device="cpu",
):
self.model = model
self.var_scale = var_scale
self.use_elites = use_elites
self.num_samples = num_samples
self.num_elites = num_elites
self.temperature = temperature
self.n_steps = n_steps
self.device = device
@property
def n_envs(self) -> int:
return self._n_envs
@property
def action_dim(self) -> int:
return self._action_dim * self._config.action_block
@property
def horizon(self) -> int:
return self._config.horizon
def __call__(self, *args, **kwargs) -> torch.Tensor:
return self.solve(*args, **kwargs)
[docs]
def init_action_distrib(self, actions=None):
"""Initialize the action distribution params (mu, sigma) given the initial condition.
Args:
actions (n_envs, T, action_dim): initial actions, T <= horizon
"""
var = self.var_scale * torch.ones([self.n_envs, self.horizon, self.action_dim])
mean = torch.zeros([self.n_envs, 0, self.action_dim]) if actions is None else actions
# -- fill remaining actions with random sample
remaining = self.horizon - mean.shape[1]
if remaining > 0:
device = mean.device
new_mean = torch.zeros([self.n_envs, remaining, self.action_dim])
mean = torch.cat([mean, new_mean], dim=1).to(device)
return mean, var
[docs]
def compute_trajectory_weights(self, costs: torch.Tensor) -> torch.Tensor:
"""Compute trajectory weights from costs using softmin with temperature.
Args:
costs (num_samples,): Tensor of trajectory costs.
Returns:
Tensor of trajectory weights.
"""
stable_costs = costs - costs.min()
weights = torch.softmax(-stable_costs / self.temperature, dim=0)
return weights
[docs]
@torch.inference_mode()
def solve(self, info_dict, init_action=None):
outputs = {
"costs": [],
"mean": [],
"var": [],
}
# -- initialize the action distribution
mean, var = self.init_action_distrib(init_action)
mean = mean.to(self.device)
var = var.to(self.device)
n_envs = mean.shape[0]
# -- optimization loop
for step in range(self.n_steps):
costs = []
# TODO: could flatten the batch dimension and process all samples together
# rem: need many memory and split before computing top k
for traj in range(n_envs):
expanded_infos = {}
for k, v in info_dict.items():
v_traj = v[traj]
if torch.is_tensor(v):
v_traj = v_traj.unsqueeze(0).repeat_interleave(self.num_samples, dim=0)
elif isinstance(v, np.ndarray):
v_traj = np.repeat(v_traj[None, ...], self.num_samples, axis=0)
expanded_infos[k] = v_traj
# sample noisy actions sequences
candidates = torch.randn(self.num_samples, self.horizon, self.action_dim, device=self.device)
candidates = candidates * var[traj] + mean[traj]
# make the first action seq being mean
candidates[0] = mean[traj]
# get the costs
cost = self.model.get_cost(expanded_infos, candidates)
assert type(cost) is torch.Tensor, f"Expected cost to be a torch.Tensor, got {type(cost)}"
assert cost.ndim == 1 and len(cost) == self.num_samples, (
f"Expected cost to be of shape num_samples ({self.num_samples},), got {cost.shape}"
)
if self.use_elites:
elite_idx = torch.argsort(cost)[: self.num_elites]
candidates = candidates[elite_idx]
cost = cost[elite_idx]
costs.append(cost.max().item())
# update nominal sequence with weighted candidates
weights = self.compute_trajectory_weights(cost)
mean[traj] = (weights * candidates).sum(dim=0)
outputs["costs"].append(np.mean(costs))
outputs["nominal_action"].append(mean.detach().cpu().clone())
print(f"MMPI step {step + 1}/{self.n_steps}, cost: {outputs['costs'][-1]:.4f}")
outputs["actions"] = mean.detach().cpu()
return outputs