Source code for stable_worldmodel.envs.voidrun

from __future__ import annotations

from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any

import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np
from gymnasium import spaces

import stable_worldmodel as swm


DEFAULT_VARIATIONS = ("board.prob_gravel", "agent.position", "goal.position")


[docs] @dataclass(frozen=True) class Action: LEFT: int = 0 RIGHT: int = 1 DOWN: int = 2 UP: int = 3
[docs] class VoidRunEnv(gym.Env): """Discrete grid environment with a 1x1 agent cell.""" metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 10} def __init__( self, seed: int | None = None, render_mode: str = "human", ) -> None: super().__init__() self.render_mode = render_mode self.max_size = 50 self.step_size = 1 # fixed for 1x1 agent self._rng = np.random.default_rng(seed) self._fig = None self._ax = None self._goal = None self.action_space = spaces.Discrete(4) self.observation_space = spaces.Dict( { "board": spaces.Box(low=0, high=3, shape=(self.max_size, self.max_size), dtype=np.int8), "perception": spaces.MultiDiscrete([self.max_size, self.max_size]), } ) # Variation space without radius; agent is always 1x1 self.variation_space = swm.spaces.Dict( { "agent": swm.spaces.Dict( { "color": swm.spaces.RGBBox( init_value=np.array([255, 0, 0], dtype=np.uint8), ), "position": swm.spaces.MultiDiscrete( [self.max_size, self.max_size], init_value=np.array([10, 10], dtype=np.int32), constrain_fn=self.check_location, ), "prob_break": swm.spaces.Box( low=np.array(0.5, dtype=np.float32), high=np.array(1.0, dtype=np.float32), init_value=np.array(1.0, dtype=np.float32), dtype=np.float32, ), }, sampling_order=["color", "position", "prob_break"], ), "goal": swm.spaces.Dict( { "color": swm.spaces.RGBBox( init_value=np.array([52, 235, 201], dtype=np.uint8), ), "position": swm.spaces.MultiDiscrete( [self.max_size, self.max_size], init_value=[5, 5], constrain_fn=self.check_location, ), }, sampling_order=["color", "position"], ), "board": swm.spaces.Dict( { "size": swm.spaces.Discrete(self.max_size - 10, start=10, init_value=20), "prob_gravel": swm.spaces.Box( low=np.array(0.0, dtype=np.float32), high=np.array(1.0, dtype=np.float32), init_value=np.array(0.45, dtype=np.float32), dtype=np.float32, ), "prob_break": swm.spaces.Box( low=np.array(0.5, dtype=np.float32), high=np.array(1.0, dtype=np.float32), init_value=np.array(1.0, dtype=np.float32), dtype=np.float32, ), "sand_color": swm.spaces.RGBBox( init_value=np.array([242, 218, 130], dtype=np.uint8), ), "gravel_color": swm.spaces.RGBBox( init_value=np.array([128, 128, 128], dtype=np.uint8), ), "void_color": swm.spaces.RGBBox( init_value=np.array([0, 0, 0], dtype=np.uint8), ), }, sampling_order=[ "size", "prob_gravel", "sand_color", "gravel_color", "void_color", ], ), }, sampling_order=["board", "agent", "goal"], ) # -------------------- Core API ----------
[docs] def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None): if seed is not None: self._rng = np.random.default_rng(seed) if hasattr(self, "variation_space"): self.variation_space.seed(seed) options = options or {} self.variation_space.reset() variations = options.get("variation", DEFAULT_VARIATIONS) if not isinstance(variations, Sequence): raise ValueError("variation option must be a Sequence containing variations names to sample") self.variation_space.update(variations) assert self.variation_space.check(debug=True), "Variation values must be within variation space!" self._reset_state() obs = self._get_obs() info = { "newly_voided": 0, "in_void": False, "goal": self._goal, "steps": self.steps, "goal_pos": self.goal_pos, } return obs, info
[docs] def step(self, action: int): if not self.action_space.contains(action): raise ValueError("Invalid action") old_r, old_c = self.player_row, self.player_col dr, dc = 0, 0 if action == Action.LEFT: dc = -self.step_size elif action == Action.RIGHT: dc = +self.step_size elif action == Action.DOWN: dr = +self.step_size elif action == Action.UP: dr = -self.step_size size = self.variation_space["board"]["size"].value # 1x1 agent can occupy any cell within [0, size-1] new_r = int(np.clip(old_r + dr, 0, size - 1)) new_c = int(np.clip(old_c + dc, 0, size - 1)) # Move the agent self.player_row, self.player_col = new_r, new_c self.player_y = self.player_row + 0.5 self.player_x = self.player_col + 0.5 # Void the cell the agent just left (like scraping behind) newly_voided = self._void_cell(old_r, old_c) # in_void = self.board[new_r, new_c] == 0 reward = float(newly_voided) self.steps += 1 terminated = self.check_termination() truncated = bool(in_void) obs = self._get_obs() info = { "newly_voided": newly_voided, "in_void": bool(in_void), "goal": self._goal, "steps": self.steps, "goal_pos": self.goal_pos, } if self.render_mode == "human": self.render() return obs, reward, bool(terminated), bool(truncated), info
[docs] def render(self, mode: str | None = None): mode = mode or self.render_mode or "human" size = self.variation_space["board"]["size"].value if self._fig is None or self._ax is None: self._fig, self._ax = plt.subplots(figsize=(size * 0.4, size * 0.4)) ax = self._ax ax.clear() self.render_board(ax=ax) # Draw 1x1 agent as a square if self.board[self.player_row, self.player_col] > 0: rect = plt.Rectangle( (self.player_x - 0.5, self.player_y - 0.5), 1.0, 1.0, fill=True, facecolor=self.variation_space["agent"]["color"].value / 255.0, edgecolor=None, zorder=3, antialiased=False, ) ax.add_patch(rect) ax.set_xlim(0, size) ax.set_ylim(0, size) ax.set_aspect("equal") ax.set_xticks([]) ax.set_yticks([]) ax.set_facecolor((1.0, 1.0, 1.0)) self._fig.tight_layout(pad=0) if mode == "human": plt.pause(0.001) plt.draw() return None if mode == "rgb_array": self._fig.canvas.draw() h, w = self._fig.canvas.get_width_height() buf = np.frombuffer(self._fig.canvas.buffer_rgba(), dtype=np.uint8) buf = buf.reshape(h, w, 4)[..., :3] return buf.copy() raise NotImplementedError(f"Render mode {mode} not supported.")
[docs] def close(self) -> None: if self._fig is not None: plt.close(self._fig) self._fig, self._ax = None, None
# -------------------- Helpers -------------------- def _reset_state(self) -> None: self.board = self.generate_board().astype(np.int8) self.player_row, self.player_col = self.variation_space["agent"]["position"].value self.player_y = self.player_row + 0.5 self.player_x = self.player_col + 0.5 self.steps = 0 self.generate_goal() def _get_obs(self) -> dict[str, Any]: return { "board": self.board.copy(), "perception": np.array([self.player_row, self.player_col], dtype=np.int32), }
[docs] def generate_board(self) -> np.ndarray: prob_gravel = self.variation_space["board"]["prob_gravel"].value.item() probs = [0.0, 1 - prob_gravel, prob_gravel] if not np.isclose(sum(probs), 1.0): raise ValueError("Probabilities must sum to 1") size = self.variation_space["board"]["size"].value board = np.zeros((self.max_size, self.max_size), dtype=np.int8) board[:size, :size] = self._rng.choice([0, 1, 2], size=(size, size), p=probs) return board
[docs] def render_board(self, ax: plt.Axes | None = None) -> None: void_color = self.variation_space["board"]["void_color"].value sand_color = self.variation_space["board"]["sand_color"].value gravel_color = self.variation_space["board"]["gravel_color"].value goal_color = self.variation_space["goal"]["color"].value lut = np.array([void_color, sand_color, gravel_color, goal_color], dtype=float) / 255.0 size = self.variation_space["board"]["size"].value board = self.board[:size, :size] board[self.goal_pos[0], self.goal_pos[1]] = 3 img = lut[board] if ax is None: _, ax = plt.subplots(figsize=(board.shape[1] * 0.2, board.shape[0] * 0.2)) ax.imshow(img, interpolation="nearest", origin="lower", extent=[0, size, 0, size]) ax.set_xticks([]) ax.set_yticks([])
def _void_cell(self, r: int, c: int) -> int: """Void the single cell at (r, c) and return 1 if it was newly voided, else 0.""" prob_break = self.variation_space["agent"]["prob_break"].value.item() should_void = self._rng.random() < prob_break if should_void and self.board[r, c] != 0: self.board[r, c] = 0 return 1 return 0
[docs] def check_termination(self) -> bool: """ Success = all blocks are void except under the agent, AND the agent is at the goal position. For 1x1 agent, 'under the agent' is just its current cell. """ size = self.variation_space["board"]["size"].value r, c = self.player_row, self.player_col board_copy = self.board[:size, :size].copy() board_copy[r, c] = 0 # ignore agent cell all_voided = np.count_nonzero(board_copy) == 0 at_goal = (r, c) == self.goal_pos return bool(all_voided and at_goal)
[docs] def set_state( self, board: np.ndarray, player_pos: tuple[int, int], *, validate: bool = True, render: bool = False, ) -> dict[str, Any]: if validate: size = self.variation_space["board"]["size"].value if board.shape != (size, size): raise ValueError("Invalid board shape") r, c = player_pos if not (0 <= r < size and 0 <= c < size): raise ValueError("player_pos out of bounds") self.board = board.astype(np.int8, copy=False) self.player_row, self.player_col = map(int, player_pos) self.player_y, self.player_x = self.player_row + 0.5, self.player_col + 0.5 self.steps = 0 if render: self.render() return self._get_obs()
[docs] def generate_goal(self, *, cell_value: int = 3) -> None: prev_board, prev_row, prev_col = ( self.board.copy(), self.player_row, self.player_col, ) size = self.variation_space["board"]["size"].value prev_y, prev_x, prev_steps = self.player_y, self.player_x, self.steps try: self.goal_pos = self.variation_space["goal"]["position"].value r, c = self.goal_pos[0], self.goal_pos[1] board = np.zeros((size, size), dtype=np.int8) board[r, c] = cell_value _ = self.set_state(board, (r, c), validate=True, render=False) self._goal = self.render(mode=self.render_mode) finally: self.board, self.player_row, self.player_col = ( prev_board, prev_row, prev_col, ) self.player_y, self.player_x, self.steps = prev_y, prev_x, prev_steps
[docs] def check_location(self, x): size = int(self.variation_space.value["board"]["size"]) # 1x1 agent can exist anywhere inside the board return (0 <= x[0] < size) and (0 <= x[1] < size)
def __del__(self): try: self.close() except Exception: pass