import gymnasium as gym
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
import pygame
from PIL import Image, ImageOps
# Get the default color cycle from Matplotlib's rcParams
prop_cycle = plt.rcParams["axes.prop_cycle"]
colors_hex = prop_cycle.by_key()["color"]
# Convert hex colors to RGBA tuples
COLORS = np.asarray([mcolors.to_rgba(hex_color) for hex_color in colors_hex])
COLORS = (COLORS * 255).astype(int)
COLORS = [tuple(u) for u in COLORS]
[docs]
class ImagePositioning(gym.Env):
    metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}
    def __init__(
        self,
        resolution: int,
        images: list[Image],
        render_mode: str | None = None,
        background_power_decay: float | None = 1.0,
    ):
        self.resolution = resolution
        self.background_power_decay = background_power_decay
        # Define what the agent can observe
        # Dict space gives us structured, human-readable observations
        self.observation_space = gym.spaces.Dict(
            {
                "current_background": gym.spaces.Box(0, 0.9, shape=(2, 1), dtype=float),
                "current_locations": gym.spaces.Box(0, 0.9, shape=(len(images), 2), dtype=float),
                "current_rotations": gym.spaces.Box(0, 1, shape=(len(images), 1), dtype=float),
                "target_background": gym.spaces.Box(0, 1, shape=(2, 1), dtype=float),
                "target_locations": gym.spaces.Box(0, 0.9, shape=(len(images), 2), dtype=float),
                "target_rotations": gym.spaces.Box(0, 1, shape=(len(images), 1), dtype=float),
            }
        )
        # Initialize positions - will be set randomly in reset()
        # Using -1,-1 as "uninitialized" state
        self._current_locations = np.empty(self.observation_space["current_locations"].shape, dtype=float)
        self._target_locations = np.array(self.observation_space["target_locations"].shape, dtype=float)
        self._current_rotations = np.empty(self.observation_space["current_rotations"].shape, dtype=float)
        self._target_rotations = np.array(self.observation_space["target_rotations"].shape, dtype=float)
        self._current_background = np.empty(self.observation_space["current_background"].shape, dtype=float)
        self._target_background = np.array(self.observation_space["target_background"].shape, dtype=float)
        # Define what actions are available (4 directions)
        self.action_space = gym.spaces.Dict(
            {
                "delta_background": gym.spaces.Box(low=-0.1, high=0.1, shape=(2, 1)),
                "delta_locations": gym.spaces.Box(low=-0.1, high=0.1, shape=(len(images), 2)),
                "delta_rotations": gym.spaces.Box(low=-0.1, high=0.1, shape=(len(images), 1)),
            }
        )
        self.images = [ImageOps.expand(img, border=5, fill=c).convert("RGBA") for img, c in zip(images, COLORS)]
        assert render_mode is None or render_mode in self.metadata["render_modes"]
        self.render_mode = render_mode
        """
        If human-rendering is used, `self.window` will be a reference
        to the window that we draw to. `self.clock` will be a clock that is used
        to ensure that the environment is rendered at the correct framerate in
        human-mode. They will remain `None` until human-mode is used for the
        first time.
        """
        self.window = None
        self.clock = None
    def _get_obs(self):
        """Convert internal state to observation format.
        Returns:
            dict: Observation with agent and target positions
        """
        return {
            "current_background": self._current_background,
            "current_locations": self._current_locations,
            "current_rotations": self._current_rotations,
            "target_background": self._target_background,
            "target_locations": self._target_locations,
            "target_rotations": self._target_rotations,
        }
    def _get_info(self):
        """Compute auxiliary information for debugging.
        Returns:
            dict: Info with distance between agent and target
        """
        return {
            "location_distance": np.linalg.norm(self._current_locations - self._target_locations, ord=1),
            "rotation_distance": np.linalg.norm(self._current_rotations - self._target_rotations, ord=1),
            "background_distance": np.linalg.norm(self._current_background - self._target_background, ord=1),
        }
[docs]
    def reset(self, seed: int | None = None, options: dict | None = None):
        """Start a new episode.
        Args:
            seed: Random seed for reproducible episodes
            options: Additional configuration (unused in this example)
        Returns:
            tuple: (observation, info) for the initial state
        """
        # IMPORTANT: Must call this first to seed the random number generator
        super().reset(seed=seed)
        # Randomly place the agent anywhere on the grid
        self._current_background = self.np_random.random(size=(2, 1), dtype=float)
        self._current_locations = self.np_random.random(size=(len(self.images), 2), dtype=float)
        self._current_rotations = self.np_random.random(size=(len(self.images), 2), dtype=float)
        self._target_background = self.np_random.random(size=(2, 1), dtype=float)
        self._target_locations = self.np_random.random(size=(len(self.images), 2), dtype=float)
        self._target_rotations = self.np_random.random(size=(len(self.images), 1), dtype=float)
        white_noise = np.random.randn(self.resolution * 2, self.resolution * 2)
        rows, cols = white_noise.shape
        fft_white_noise = np.fft.fft2(white_noise)
        # Create frequency coordinates
        fy = np.fft.fftfreq(rows)
        fx = np.fft.fftfreq(cols)
        # Create 2D frequency grid
        fx_grid, fy_grid = np.meshgrid(fx, fy)
        # Calculate radial frequency magnitude
        f_magnitude = np.sqrt(fx_grid**2 + fy_grid**2)
        # Avoid division by zero at the DC component (f=0)
        f_magnitude[0, 0] = 1  # Or a small epsilon to prevent singularity
        # Apply the 1/f filter to the frequency magnitudes
        # For power spectral density 1/f, amplitude is 1/sqrt(f)
        pink_filter = (1 / f_magnitude) ** self.background_power_decay
        fft_pink_noise = fft_white_noise * pink_filter
        pink_noise = np.fft.ifft2(fft_pink_noise).real
        pink_noise -= pink_noise.min()
        pink_noise /= pink_noise.max()
        pink_noise = (pink_noise * 255).astype(np.uint8)
        self.pink_noise = np.tile(np.expand_dims(pink_noise, 2), (1, 1, 3))
        observation = self._get_obs()
        info = self._get_info()
        return observation, info 
[docs]
    def step(self, action):
        """Execute one timestep within the environment.
        Args:
            action: The action to take (0-3 for directions)
        Returns:
            tuple: (observation, reward, terminated, truncated, info)
        """
        action["delta_background"] = np.clip(
            action["delta_background"],
            self.action_space["delta_background"].low,
            self.action_space["delta_background"].high,
        )
        action["delta_locations"] = np.clip(
            action["delta_locations"],
            self.action_space["delta_locations"].low,
            self.action_space["delta_locations"].high,
        )
        action["delta_rotations"] = np.clip(
            action["delta_rotations"],
            self.action_space["delta_rotations"].low,
            self.action_space["delta_rotations"].high,
        )
        self._current_background = np.clip(
            self._current_background + action["delta_background"],
            self.observation_space["current_background"].low,
            self.observation_space["current_background"].high,
        )
        self._current_locations = np.clip(
            self._current_locations + action["delta_locations"],
            self.observation_space["current_locations"].low,
            self.observation_space["current_locations"].high,
        )
        self._current_rotations = np.clip(
            self._current_rotations + action["delta_rotations"],
            self.observation_space["current_rotations"].low,
            self.observation_space["current_rotations"].high,
        )
        observation = self._get_obs()
        info = self._get_info()
        # Check if agent reached the target
        terminated = (
            info["location_distance"] < 1e-2
            and info["rotation_distance"] < 1e-2
            and info["background_distance"] < 1e-2
        )
        # We don't use truncation in this simple environment
        # (could add a step limit here if desired)
        truncated = False
        # Simple reward structure: +1 for reaching target, 0 otherwise
        # Alternative: could give small negative rewards for each step to encourage efficiency
        reward = 1 if terminated else 0
        return observation, reward, terminated, truncated, info 
    def _get_optimal_action(self):
        rotations = self._current_rotations - self._target_rotations
        locations = self._current_locations - self._target_locations
        background = self._current_background - self._target_background
        return {
            "delta_background": -background,
            "delta_locations": -locations,
            "delta_rotations": -rotations,
        }
[docs]
    def render(self, mode="current"):
        if self.render_mode == "rgb_array":
            return self._render_frame(mode=mode) 
    def _render_frame(self, mode):
        if self.window is None and self.render_mode in ["human", "rgb_array"]:
            pygame.init()
            pygame.display.init()
            self.window = pygame.display.set_mode((self.resolution, self.resolution))
        if self.clock is None and self.render_mode in ["human", "rgb_array"]:
            self.clock = pygame.time.Clock()
        # canvas = pygame.Surface((self.resolution, self.resolution))
        # canvas.fill((255, 255, 255))
        # get the image
        if mode == "current":
            x = int(self.resolution * self._current_background[0, 0])
            y = int(self.resolution * self._current_background[1, 0])
            new_background = Image.fromarray(self.pink_noise[x : x + self.resolution, y : y + self.resolution :])
        else:
            x = int(self.resolution * self._target_background[0, 0])
            y = int(self.resolution * self._target_background[1, 0])
            new_background = Image.fromarray(self.pink_noise[x : x + self.resolution, y : y + self.resolution :])
        for i, img in enumerate(self.images):
            if mode == "current":
                box = [
                    int(self._current_locations[i, 0] * self.resolution),
                    int(self._current_locations[i, 1] * self.resolution),
                    int(self._current_locations[i, 0] * self.resolution + img.height),
                    int(self._current_locations[i, 1] * self.resolution + img.width),
                ]
                new_background.paste(img.rotate(self._current_rotations[i, 0] * 360), box)
            else:
                box = [
                    int(self._target_locations[i, 0] * self.resolution),
                    int(self._target_locations[i, 1] * self.resolution),
                    int(self._target_locations[i, 0] * self.resolution + img.height),
                    int(self._target_locations[i, 1] * self.resolution + img.width),
                ]
                new_background.paste(img.rotate(self._target_rotations[i, 0] * 360), box)
        # get the surface
        # Get image data, size, and mode from PIL Image
        image_bytes = new_background.tobytes()
        image_size = new_background.size
        image_mode = new_background.mode
        # Create a Pygame Surface from the PIL image data
        pygame_surface = pygame.image.frombytes(image_bytes, image_size, image_mode)
        self.window.blit(pygame_surface, (0, 0))  # Blit at position (0,0)
        # Update the display
        pygame.display.flip()
        if self.render_mode == "human":
            # The following line copies our drawings from `canvas` to the visible window
            # self.window.blit(canvas, canvas.get_rect())
            pygame.event.pump()
            pygame.display.update()
            # We need to ensure that human-rendering occurs at the predefined framerate.
            # The following line will automatically add a delay to keep the framerate stable.
            self.clock.tick(self.metadata["render_fps"])
        else:  # rgb_array
            return np.transpose(np.array(pygame.surfarray.pixels3d(pygame_surface)), axes=(1, 0, 2))
[docs]
    def close(self):
        if self.window is not None:
            pygame.display.quit()
            pygame.quit() 
 
if __name__ == "__main__":
    import gymnasium as gym
    import matplotlib.pyplot as plt
    import numpy as np
    from gymnasium.wrappers import RecordVideo
    import stable_worldmodel as swm
    # 1. Setup Environment
    # Create a CartPole environment with "rgb_array" render mode to get image data
    images = [
        swm.utils.create_pil_image_from_url(
            "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcQK5OnlnP3_GHXI2y1LoIHbMROdN8_DYyLEGg&s"
        ).resize((64, 64)),
        swm.utils.create_pil_image_from_url(
            "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcQjrFGrhOLwgYP0cdjTIBEWMpy9MHBcya4c5Q&s"
        ).resize((32, 32)),
    ]
    env = gym.make(
        "swm/ImagePositioning",
        render_mode="rgb_array",
        resolution=224,
        images=images,
        background_power_decay=2,
        max_episode_steps=20,
    )  #
    env = gym.wrappers.AddRenderObservation(env, render_only=False)
    swm.collect.random_action(env, num_episodes=1)
    env = RecordVideo(
        env,
        video_folder="cartpole-agent",  # Folder to save videos
        name_prefix="eval",  # Prefix for video filenames
        episode_trigger=lambda x: True,  # Record every episode
    )
    # 2. Reset the environment to get an initial observation
    observation, info = env.reset()  #
    print(observation)
    print(info)
    # 3. Render the environment to get the image array
    # The render method returns an RGB array when render_mode is "rgb_array"
    # 4. Save the figure
    # Use Matplotlib to display and save the image
    fig, axs = plt.subplots(1, 2)
    rgb_array = env.unwrapped.render()  #
    axs[0].imshow(rgb_array)
    axs[0].set_xticks([])
    axs[0].set_yticks([])
    axs[0].set_title("Init.")
    rgb_array = env.unwrapped.render(mode="target")
    axs[1].imshow(rgb_array)
    axs[1].set_xticks([])
    axs[1].set_yticks([])
    axs[1].set_title("Target")
    plt.savefig("cartpole_observation.png")
    plt.close()  # Close the plot to free up memory
    for i in range(5):
        action = env.unwrapped._get_optimal_action()
        env.step(action)
    print("Saved CartPole observation as cartpole_observation.png")
    # 5. Close the environment
    env.close()