from collections import deque
from dataclasses import dataclass
from pathlib import Path
from typing import Protocol
import numpy as np
import torch
import stable_worldmodel as swm
from stable_worldmodel.solver import Solver
[docs]
@dataclass(frozen=True)
class PlanConfig:
"""Configuration for the planning process."""
horizon: int
receding_horizon: int
history_len: int = 1
action_block: int = 1 # frameskip
warm_start: bool = True # use previous plan to warm start
@property
def plan_len(self):
return self.horizon * self.action_block
[docs]
class BasePolicy:
"""Base class for agent policies."""
# a policy takes in an environment and a planner
def __init__(self, **kwargs):
self.env = None
self.type = "base"
for arg, value in kwargs.items():
setattr(self, arg, value)
[docs]
def get_action(self, obs, **kwargs):
"""Get action from the policy given the observation."""
raise NotImplementedError
[docs]
def set_env(self, env):
self.env = env
[docs]
class RandomPolicy(BasePolicy):
"""Random Policy."""
def __init__(self, seed=None, **kwargs):
super().__init__(**kwargs)
self.type = "random"
self.seed = seed
[docs]
def get_action(self, obs, **kwargs):
return self.env.action_space.sample()
[docs]
def set_seed(self, seed):
if self.env is not None:
self.env.action_space.seed(seed)
[docs]
class ExpertPolicy(BasePolicy):
"""Expert Policy."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.type = "expert"
[docs]
def get_action(self, obs, goal_obs, **kwargs):
# Implement expert policy logic here
pass
[docs]
class WorldModelPolicy(BasePolicy):
"""World Model Policy using a planning solver."""
def __init__(
self,
solver: Solver,
config: PlanConfig,
process: dict[str, Transformable] | None = None,
transform: dict[str, callable] | None = None,
**kwargs,
):
super().__init__(**kwargs)
self.type = "world_model"
self.cfg = config
self.solver = solver
self.action_buffer = deque(maxlen=self.flatten_receding_horizon)
self.process = process or {}
self.transform = transform or {}
self._action_buffer = None
self._next_init = None
@property
def flatten_receding_horizon(self):
return self.cfg.receding_horizon * self.cfg.action_block
[docs]
def set_env(self, env):
self.env = env
n_envs = getattr(env, "num_envs", 1)
self.solver.configure(action_space=env.action_space, n_envs=n_envs, config=self.cfg)
self._action_buffer = deque(maxlen=self.flatten_receding_horizon)
assert isinstance(self.solver, Solver), "Solver must implement the Solver protocol"
[docs]
def get_action(self, info_dict, **kwargs):
assert hasattr(self, "env"), "Environment not set for the policy"
assert "pixels" in info_dict, "'pixels' must be provided in info_dict"
assert "goal" in info_dict, "'goal' must be provided in info_dict"
# pre-process and transform observations
for k, v in info_dict.items():
v = self.process[k].transform(v) if k in self.process else v
v = torch.stack([self.transform[k](x) for x in v]) if k in self.transform else v
info_dict[k] = torch.from_numpy(v) if isinstance(v, (np.ndarray | np.generic)) else v
# need to replan if action buffer is empty
if len(self._action_buffer) == 0:
outputs = self.solver(info_dict, init_action=self._next_init)
actions = outputs["actions"] # (num_envs, horizon, action_dim)
keep_horizon = self.cfg.receding_horizon
plan = actions[:, :keep_horizon]
rest = actions[:, keep_horizon:]
self._next_init = rest if self.cfg.warm_start else None
# frameskip back to timestep
plan = plan.reshape(self.env.num_envs, self.flatten_receding_horizon, -1)
self._action_buffer.extend(plan.transpose(0, 1))
action = self._action_buffer.popleft()
action = action.reshape(*self.env.action_space.shape)
action = action.numpy()
# post-process action
if "action" in self.process:
action = self.process["action"].inverse_transform(action)
return action # (num_envs, action_dim)
[docs]
def AutoCostModel(model_name, cache_dir=None):
cache_dir = Path(cache_dir or swm.data.get_cache_dir())
path = cache_dir / f"{model_name}_object.ckpt"
assert path.exists(), f"World model named {model_name} not found. Should launch pretraining first."
print(path)
spt_module = torch.load(path, weights_only=False)
def scan_module(module):
if hasattr(module, "get_cost"):
return module
for child in module.children():
result = scan_module(child)
if result is not None:
return result
return None
result = scan_module(spt_module)
if result is not None:
return result
raise RuntimeError("No cost model found in the loaded world model.")