from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any
import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np
from gymnasium import spaces
import stable_worldmodel as swm
DEFAULT_VARIATIONS = ("board.prob_gravel", "agent.position", "goal.position")
[docs]
@dataclass(frozen=True)
class Action:
    LEFT: int = 0
    RIGHT: int = 1
    DOWN: int = 2
    UP: int = 3 
[docs]
class VoidRunEnv(gym.Env):
    """Discrete grid environment with a 1x1 agent cell."""
    metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 10}
    def __init__(
        self,
        seed: int | None = None,
        render_mode: str = "human",
    ) -> None:
        super().__init__()
        self.render_mode = render_mode
        self.max_size = 50
        self.step_size = 1  # fixed for 1x1 agent
        self._rng = np.random.default_rng(seed)
        self._fig = None
        self._ax = None
        self._goal = None
        self.action_space = spaces.Discrete(4)
        self.observation_space = spaces.Dict(
            {
                "board": spaces.Box(low=0, high=3, shape=(self.max_size, self.max_size), dtype=np.int8),
                "perception": spaces.MultiDiscrete([self.max_size, self.max_size]),
            }
        )
        # Variation space without radius; agent is always 1x1
        self.variation_space = swm.spaces.Dict(
            {
                "agent": swm.spaces.Dict(
                    {
                        "color": swm.spaces.RGBBox(
                            init_value=np.array([255, 0, 0], dtype=np.uint8),
                        ),
                        "position": swm.spaces.MultiDiscrete(
                            [self.max_size, self.max_size],
                            init_value=np.array([10, 10], dtype=np.int32),
                            constrain_fn=self.check_location,
                        ),
                        "prob_break": swm.spaces.Box(
                            low=np.array(0.5, dtype=np.float32),
                            high=np.array(1.0, dtype=np.float32),
                            init_value=np.array(1.0, dtype=np.float32),
                            dtype=np.float32,
                        ),
                    },
                    sampling_order=["color", "position", "prob_break"],
                ),
                "goal": swm.spaces.Dict(
                    {
                        "color": swm.spaces.RGBBox(
                            init_value=np.array([52, 235, 201], dtype=np.uint8),
                        ),
                        "position": swm.spaces.MultiDiscrete(
                            [self.max_size, self.max_size],
                            init_value=[5, 5],
                            constrain_fn=self.check_location,
                        ),
                    },
                    sampling_order=["color", "position"],
                ),
                "board": swm.spaces.Dict(
                    {
                        "size": swm.spaces.Discrete(self.max_size - 10, start=10, init_value=20),
                        "prob_gravel": swm.spaces.Box(
                            low=np.array(0.0, dtype=np.float32),
                            high=np.array(1.0, dtype=np.float32),
                            init_value=np.array(0.45, dtype=np.float32),
                            dtype=np.float32,
                        ),
                        "prob_break": swm.spaces.Box(
                            low=np.array(0.5, dtype=np.float32),
                            high=np.array(1.0, dtype=np.float32),
                            init_value=np.array(1.0, dtype=np.float32),
                            dtype=np.float32,
                        ),
                        "sand_color": swm.spaces.RGBBox(
                            init_value=np.array([242, 218, 130], dtype=np.uint8),
                        ),
                        "gravel_color": swm.spaces.RGBBox(
                            init_value=np.array([128, 128, 128], dtype=np.uint8),
                        ),
                        "void_color": swm.spaces.RGBBox(
                            init_value=np.array([0, 0, 0], dtype=np.uint8),
                        ),
                    },
                    sampling_order=[
                        "size",
                        "prob_gravel",
                        "sand_color",
                        "gravel_color",
                        "void_color",
                    ],
                ),
            },
            sampling_order=["board", "agent", "goal"],
        )
    # -------------------- Core API ----------
[docs]
    def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None):
        if seed is not None:
            self._rng = np.random.default_rng(seed)
        if hasattr(self, "variation_space"):
            self.variation_space.seed(seed)
        options = options or {}
        self.variation_space.reset()
        variations = options.get("variation", DEFAULT_VARIATIONS)
        if not isinstance(variations, Sequence):
            raise ValueError("variation option must be a Sequence containing variations names to sample")
        self.variation_space.update(variations)
        assert self.variation_space.check(debug=True), "Variation values must be within variation space!"
        self._reset_state()
        obs = self._get_obs()
        info = {
            "newly_voided": 0,
            "in_void": False,
            "goal": self._goal,
            "steps": self.steps,
            "goal_pos": self.goal_pos,
        }
        return obs, info 
[docs]
    def step(self, action: int):
        if not self.action_space.contains(action):
            raise ValueError("Invalid action")
        old_r, old_c = self.player_row, self.player_col
        dr, dc = 0, 0
        if action == Action.LEFT:
            dc = -self.step_size
        elif action == Action.RIGHT:
            dc = +self.step_size
        elif action == Action.DOWN:
            dr = +self.step_size
        elif action == Action.UP:
            dr = -self.step_size
        size = self.variation_space["board"]["size"].value
        # 1x1 agent can occupy any cell within [0, size-1]
        new_r = int(np.clip(old_r + dr, 0, size - 1))
        new_c = int(np.clip(old_c + dc, 0, size - 1))
        # Move the agent
        self.player_row, self.player_col = new_r, new_c
        self.player_y = self.player_row + 0.5
        self.player_x = self.player_col + 0.5
        # Void the cell the agent just left (like scraping behind)
        newly_voided = self._void_cell(old_r, old_c)
        #
        in_void = self.board[new_r, new_c] == 0
        reward = float(newly_voided)
        self.steps += 1
        terminated = self.check_termination()
        truncated = bool(in_void)
        obs = self._get_obs()
        info = {
            "newly_voided": newly_voided,
            "in_void": bool(in_void),
            "goal": self._goal,
            "steps": self.steps,
            "goal_pos": self.goal_pos,
        }
        if self.render_mode == "human":
            self.render()
        return obs, reward, bool(terminated), bool(truncated), info 
[docs]
    def render(self, mode: str | None = None):
        mode = mode or self.render_mode or "human"
        size = self.variation_space["board"]["size"].value
        if self._fig is None or self._ax is None:
            self._fig, self._ax = plt.subplots(figsize=(size * 0.4, size * 0.4))
        ax = self._ax
        ax.clear()
        self.render_board(ax=ax)
        # Draw 1x1 agent as a square
        if self.board[self.player_row, self.player_col] > 0:
            rect = plt.Rectangle(
                (self.player_x - 0.5, self.player_y - 0.5),
                1.0,
                1.0,
                fill=True,
                facecolor=self.variation_space["agent"]["color"].value / 255.0,
                edgecolor=None,
                zorder=3,
                antialiased=False,
            )
            ax.add_patch(rect)
        ax.set_xlim(0, size)
        ax.set_ylim(0, size)
        ax.set_aspect("equal")
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_facecolor((1.0, 1.0, 1.0))
        self._fig.tight_layout(pad=0)
        if mode == "human":
            plt.pause(0.001)
            plt.draw()
            return None
        if mode == "rgb_array":
            self._fig.canvas.draw()
            h, w = self._fig.canvas.get_width_height()
            buf = np.frombuffer(self._fig.canvas.buffer_rgba(), dtype=np.uint8)
            buf = buf.reshape(h, w, 4)[..., :3]
            return buf.copy()
        raise NotImplementedError(f"Render mode {mode} not supported.") 
[docs]
    def close(self) -> None:
        if self._fig is not None:
            plt.close(self._fig)
            self._fig, self._ax = None, None 
    # -------------------- Helpers --------------------
    def _reset_state(self) -> None:
        self.board = self.generate_board().astype(np.int8)
        self.player_row, self.player_col = self.variation_space["agent"]["position"].value
        self.player_y = self.player_row + 0.5
        self.player_x = self.player_col + 0.5
        self.steps = 0
        self.generate_goal()
    def _get_obs(self) -> dict[str, Any]:
        return {
            "board": self.board.copy(),
            "perception": np.array([self.player_row, self.player_col], dtype=np.int32),
        }
[docs]
    def generate_board(self) -> np.ndarray:
        prob_gravel = self.variation_space["board"]["prob_gravel"].value.item()
        probs = [0.0, 1 - prob_gravel, prob_gravel]
        if not np.isclose(sum(probs), 1.0):
            raise ValueError("Probabilities must sum to 1")
        size = self.variation_space["board"]["size"].value
        board = np.zeros((self.max_size, self.max_size), dtype=np.int8)
        board[:size, :size] = self._rng.choice([0, 1, 2], size=(size, size), p=probs)
        return board 
[docs]
    def render_board(self, ax: plt.Axes | None = None) -> None:
        void_color = self.variation_space["board"]["void_color"].value
        sand_color = self.variation_space["board"]["sand_color"].value
        gravel_color = self.variation_space["board"]["gravel_color"].value
        goal_color = self.variation_space["goal"]["color"].value
        lut = np.array([void_color, sand_color, gravel_color, goal_color], dtype=float) / 255.0
        size = self.variation_space["board"]["size"].value
        board = self.board[:size, :size]
        board[self.goal_pos[0], self.goal_pos[1]] = 3
        img = lut[board]
        if ax is None:
            _, ax = plt.subplots(figsize=(board.shape[1] * 0.2, board.shape[0] * 0.2))
        ax.imshow(img, interpolation="nearest", origin="lower", extent=[0, size, 0, size])
        ax.set_xticks([])
        ax.set_yticks([]) 
    def _void_cell(self, r: int, c: int) -> int:
        """Void the single cell at (r, c) and return 1 if it was newly voided, else 0."""
        prob_break = self.variation_space["agent"]["prob_break"].value.item()
        should_void = self._rng.random() < prob_break
        if should_void and self.board[r, c] != 0:
            self.board[r, c] = 0
            return 1
        return 0
[docs]
    def check_termination(self) -> bool:
        """
        Success = all blocks are void except under the agent,
        AND the agent is at the goal position.
        For 1x1 agent, 'under the agent' is just its current cell.
        """
        size = self.variation_space["board"]["size"].value
        r, c = self.player_row, self.player_col
        board_copy = self.board[:size, :size].copy()
        board_copy[r, c] = 0  # ignore agent cell
        all_voided = np.count_nonzero(board_copy) == 0
        at_goal = (r, c) == self.goal_pos
        return bool(all_voided and at_goal) 
[docs]
    def set_state(
        self,
        board: np.ndarray,
        player_pos: tuple[int, int],
        *,
        validate: bool = True,
        render: bool = False,
    ) -> dict[str, Any]:
        if validate:
            size = self.variation_space["board"]["size"].value
            if board.shape != (size, size):
                raise ValueError("Invalid board shape")
            r, c = player_pos
            if not (0 <= r < size and 0 <= c < size):
                raise ValueError("player_pos out of bounds")
        self.board = board.astype(np.int8, copy=False)
        self.player_row, self.player_col = map(int, player_pos)
        self.player_y, self.player_x = self.player_row + 0.5, self.player_col + 0.5
        self.steps = 0
        if render:
            self.render()
        return self._get_obs() 
[docs]
    def generate_goal(self, *, cell_value: int = 3) -> None:
        prev_board, prev_row, prev_col = (
            self.board.copy(),
            self.player_row,
            self.player_col,
        )
        size = self.variation_space["board"]["size"].value
        prev_y, prev_x, prev_steps = self.player_y, self.player_x, self.steps
        try:
            self.goal_pos = self.variation_space["goal"]["position"].value
            r, c = self.goal_pos[0], self.goal_pos[1]
            board = np.zeros((size, size), dtype=np.int8)
            board[r, c] = cell_value
            _ = self.set_state(board, (r, c), validate=True, render=False)
            self._goal = self.render(mode=self.render_mode)
        finally:
            self.board, self.player_row, self.player_col = (
                prev_board,
                prev_row,
                prev_col,
            )
            self.player_y, self.player_x, self.steps = prev_y, prev_x, prev_steps 
[docs]
    def check_location(self, x):
        size = int(self.variation_space.value["board"]["size"])
        # 1x1 agent can exist anywhere inside the board
        return (0 <= x[0] < size) and (0 <= x[1] < size) 
    def __del__(self):
        try:
            self.close()
        except Exception:
            pass