Source code for stable_worldmodel.solver.gd
import numpy as np
import torch
from gymnasium.spaces import Box
from loguru import logger as logging
from .solver import Costable
[docs]
class GDSolver(torch.nn.Module):
    """Gradient Descent Solver."""
    def __init__(
        self,
        model: Costable,
        n_steps: int,
        action_noise=0.0,
        device="cpu",
    ):
        super().__init__()
        self.model = model
        self.n_steps = n_steps
        self.action_noise = action_noise
        self.device = device
        self._configured = False
        self._n_envs = None
        self._action_dim = None
        self._config = None
[docs]
    def set_seed(self, seed: int) -> None:
        """Set random seed for deterministic behavior.
        Args:
            seed: Random seed to use for numpy and torch
        """
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed) 
    @property
    def n_envs(self) -> int:
        return self._n_envs
    @property
    def action_dim(self) -> int:
        return self._action_dim * self._config.action_block
    @property
    def horizon(self) -> int:
        return self._config.horizon
    def __call__(self, *args, **kwargs) -> torch.Tensor:
        return self.solve(*args, **kwargs)
[docs]
    def init_action(self, actions=None):
        """Initialize the action tensor for the solver.
        set self.init - initial action sequences (n_envs, horizon, action_dim)
        """
        if actions is None:
            actions = torch.zeros((self._n_envs, 0, self.action_dim))
        # fill remaining action
        remaining = self.horizon - actions.shape[1]
        if remaining > 0:
            new_actions = torch.zeros(self._n_envs, remaining, self.action_dim)
            actions = torch.cat([actions, new_actions], dim=1)
        actions = actions.to(self.device)
        # reset actions
        if hasattr(self, "init"):
            self.init.copy_(actions)
        else:
            self.register_parameter("init", torch.nn.Parameter(actions)) 
[docs]
    def solve(self, info_dict, init_action=None) -> torch.Tensor:
        """Solve the planning optimization problem using gradient descent."""
        outputs = {
            "cost": [],
            "trajectory": [],
        }
        # Set model to eval mode to ensure deterministic behavior
        self.model.eval()
        with torch.no_grad():
            self.init_action(init_action)
        optim = torch.optim.SGD([self.init], lr=1.0)
        # perform gradient descent
        for _ in range(self.n_steps):
            # copy info dict to avoid in-place modification
            cost = self.model.get_cost(dict(info_dict), self.init)
            assert type(cost) is torch.Tensor, f"Got {type(cost)} cost, expect torch.Tensor"
            assert cost.ndim == 1 and len(cost) == self.n_envs, f"Cost should be of shape (n_envs,), got {cost.shape}"
            assert cost.requires_grad, "Cost must requires_grad for GD solver."
            cost = cost.sum()  # independent cost for each env
            cost.backward()
            optim.step()
            optim.zero_grad(set_to_none=True)
            if self.action_noise > 0:
                self.init.data += torch.randn_like(self.init) * self.action_noise
            outputs["cost"].append(cost.item())
            outputs["trajectory"].extend([self.init.detach().cpu().clone()])
            print(f" GD step {_ + 1}/{self.n_steps}, cost: {outputs['cost'][-1]:.4f}")
        # TODO break solving if finished self.eval? done break
        # get the actions to return
        outputs["actions"] = self.init.detach().cpu()
        return outputs