import gymnasium as gym
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
import pygame
from PIL import Image, ImageOps
# Get the default color cycle from Matplotlib's rcParams
prop_cycle = plt.rcParams["axes.prop_cycle"]
colors_hex = prop_cycle.by_key()["color"]
# Convert hex colors to RGBA tuples
COLORS = np.asarray([mcolors.to_rgba(hex_color) for hex_color in colors_hex])
COLORS = (COLORS * 255).astype(int)
COLORS = [tuple(u) for u in COLORS]
[docs]
class ImagePositioning(gym.Env):
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}
def __init__(
self,
resolution: int,
images: list[Image],
render_mode: str | None = None,
background_power_decay: float | None = 1.0,
):
self.resolution = resolution
self.background_power_decay = background_power_decay
# Define what the agent can observe
# Dict space gives us structured, human-readable observations
self.observation_space = gym.spaces.Dict(
{
"current_background": gym.spaces.Box(0, 0.9, shape=(2, 1), dtype=float),
"current_locations": gym.spaces.Box(0, 0.9, shape=(len(images), 2), dtype=float),
"current_rotations": gym.spaces.Box(0, 1, shape=(len(images), 1), dtype=float),
"target_background": gym.spaces.Box(0, 1, shape=(2, 1), dtype=float),
"target_locations": gym.spaces.Box(0, 0.9, shape=(len(images), 2), dtype=float),
"target_rotations": gym.spaces.Box(0, 1, shape=(len(images), 1), dtype=float),
}
)
# Initialize positions - will be set randomly in reset()
# Using -1,-1 as "uninitialized" state
self._current_locations = np.empty(self.observation_space["current_locations"].shape, dtype=float)
self._target_locations = np.array(self.observation_space["target_locations"].shape, dtype=float)
self._current_rotations = np.empty(self.observation_space["current_rotations"].shape, dtype=float)
self._target_rotations = np.array(self.observation_space["target_rotations"].shape, dtype=float)
self._current_background = np.empty(self.observation_space["current_background"].shape, dtype=float)
self._target_background = np.array(self.observation_space["target_background"].shape, dtype=float)
# Define what actions are available (4 directions)
self.action_space = gym.spaces.Dict(
{
"delta_background": gym.spaces.Box(low=-0.1, high=0.1, shape=(2, 1)),
"delta_locations": gym.spaces.Box(low=-0.1, high=0.1, shape=(len(images), 2)),
"delta_rotations": gym.spaces.Box(low=-0.1, high=0.1, shape=(len(images), 1)),
}
)
self.images = [ImageOps.expand(img, border=5, fill=c).convert("RGBA") for img, c in zip(images, COLORS)]
assert render_mode is None or render_mode in self.metadata["render_modes"]
self.render_mode = render_mode
"""
If human-rendering is used, `self.window` will be a reference
to the window that we draw to. `self.clock` will be a clock that is used
to ensure that the environment is rendered at the correct framerate in
human-mode. They will remain `None` until human-mode is used for the
first time.
"""
self.window = None
self.clock = None
def _get_obs(self):
"""Convert internal state to observation format.
Returns:
dict: Observation with agent and target positions
"""
return {
"current_background": self._current_background,
"current_locations": self._current_locations,
"current_rotations": self._current_rotations,
"target_background": self._target_background,
"target_locations": self._target_locations,
"target_rotations": self._target_rotations,
}
def _get_info(self):
"""Compute auxiliary information for debugging.
Returns:
dict: Info with distance between agent and target
"""
return {
"location_distance": np.linalg.norm(self._current_locations - self._target_locations, ord=1),
"rotation_distance": np.linalg.norm(self._current_rotations - self._target_rotations, ord=1),
"background_distance": np.linalg.norm(self._current_background - self._target_background, ord=1),
}
[docs]
def reset(self, seed: int | None = None, options: dict | None = None):
"""Start a new episode.
Args:
seed: Random seed for reproducible episodes
options: Additional configuration (unused in this example)
Returns:
tuple: (observation, info) for the initial state
"""
# IMPORTANT: Must call this first to seed the random number generator
super().reset(seed=seed)
# Randomly place the agent anywhere on the grid
self._current_background = self.np_random.random(size=(2, 1), dtype=float)
self._current_locations = self.np_random.random(size=(len(self.images), 2), dtype=float)
self._current_rotations = self.np_random.random(size=(len(self.images), 2), dtype=float)
self._target_background = self.np_random.random(size=(2, 1), dtype=float)
self._target_locations = self.np_random.random(size=(len(self.images), 2), dtype=float)
self._target_rotations = self.np_random.random(size=(len(self.images), 1), dtype=float)
white_noise = np.random.randn(self.resolution * 2, self.resolution * 2)
rows, cols = white_noise.shape
fft_white_noise = np.fft.fft2(white_noise)
# Create frequency coordinates
fy = np.fft.fftfreq(rows)
fx = np.fft.fftfreq(cols)
# Create 2D frequency grid
fx_grid, fy_grid = np.meshgrid(fx, fy)
# Calculate radial frequency magnitude
f_magnitude = np.sqrt(fx_grid**2 + fy_grid**2)
# Avoid division by zero at the DC component (f=0)
f_magnitude[0, 0] = 1 # Or a small epsilon to prevent singularity
# Apply the 1/f filter to the frequency magnitudes
# For power spectral density 1/f, amplitude is 1/sqrt(f)
pink_filter = (1 / f_magnitude) ** self.background_power_decay
fft_pink_noise = fft_white_noise * pink_filter
pink_noise = np.fft.ifft2(fft_pink_noise).real
pink_noise -= pink_noise.min()
pink_noise /= pink_noise.max()
pink_noise = (pink_noise * 255).astype(np.uint8)
self.pink_noise = np.tile(np.expand_dims(pink_noise, 2), (1, 1, 3))
observation = self._get_obs()
info = self._get_info()
return observation, info
[docs]
def step(self, action):
"""Execute one timestep within the environment.
Args:
action: The action to take (0-3 for directions)
Returns:
tuple: (observation, reward, terminated, truncated, info)
"""
action["delta_background"] = np.clip(
action["delta_background"],
self.action_space["delta_background"].low,
self.action_space["delta_background"].high,
)
action["delta_locations"] = np.clip(
action["delta_locations"],
self.action_space["delta_locations"].low,
self.action_space["delta_locations"].high,
)
action["delta_rotations"] = np.clip(
action["delta_rotations"],
self.action_space["delta_rotations"].low,
self.action_space["delta_rotations"].high,
)
self._current_background = np.clip(
self._current_background + action["delta_background"],
self.observation_space["current_background"].low,
self.observation_space["current_background"].high,
)
self._current_locations = np.clip(
self._current_locations + action["delta_locations"],
self.observation_space["current_locations"].low,
self.observation_space["current_locations"].high,
)
self._current_rotations = np.clip(
self._current_rotations + action["delta_rotations"],
self.observation_space["current_rotations"].low,
self.observation_space["current_rotations"].high,
)
observation = self._get_obs()
info = self._get_info()
# Check if agent reached the target
terminated = (
info["location_distance"] < 1e-2
and info["rotation_distance"] < 1e-2
and info["background_distance"] < 1e-2
)
# We don't use truncation in this simple environment
# (could add a step limit here if desired)
truncated = False
# Simple reward structure: +1 for reaching target, 0 otherwise
# Alternative: could give small negative rewards for each step to encourage efficiency
reward = 1 if terminated else 0
return observation, reward, terminated, truncated, info
def _get_optimal_action(self):
rotations = self._current_rotations - self._target_rotations
locations = self._current_locations - self._target_locations
background = self._current_background - self._target_background
return {
"delta_background": -background,
"delta_locations": -locations,
"delta_rotations": -rotations,
}
[docs]
def render(self, mode="current"):
if self.render_mode == "rgb_array":
return self._render_frame(mode=mode)
def _render_frame(self, mode):
if self.window is None and self.render_mode in ["human", "rgb_array"]:
pygame.init()
pygame.display.init()
self.window = pygame.display.set_mode((self.resolution, self.resolution))
if self.clock is None and self.render_mode in ["human", "rgb_array"]:
self.clock = pygame.time.Clock()
# canvas = pygame.Surface((self.resolution, self.resolution))
# canvas.fill((255, 255, 255))
# get the image
if mode == "current":
x = int(self.resolution * self._current_background[0, 0])
y = int(self.resolution * self._current_background[1, 0])
new_background = Image.fromarray(self.pink_noise[x : x + self.resolution, y : y + self.resolution :])
else:
x = int(self.resolution * self._target_background[0, 0])
y = int(self.resolution * self._target_background[1, 0])
new_background = Image.fromarray(self.pink_noise[x : x + self.resolution, y : y + self.resolution :])
for i, img in enumerate(self.images):
if mode == "current":
box = [
int(self._current_locations[i, 0] * self.resolution),
int(self._current_locations[i, 1] * self.resolution),
int(self._current_locations[i, 0] * self.resolution + img.height),
int(self._current_locations[i, 1] * self.resolution + img.width),
]
new_background.paste(img.rotate(self._current_rotations[i, 0] * 360), box)
else:
box = [
int(self._target_locations[i, 0] * self.resolution),
int(self._target_locations[i, 1] * self.resolution),
int(self._target_locations[i, 0] * self.resolution + img.height),
int(self._target_locations[i, 1] * self.resolution + img.width),
]
new_background.paste(img.rotate(self._target_rotations[i, 0] * 360), box)
# get the surface
# Get image data, size, and mode from PIL Image
image_bytes = new_background.tobytes()
image_size = new_background.size
image_mode = new_background.mode
# Create a Pygame Surface from the PIL image data
pygame_surface = pygame.image.frombytes(image_bytes, image_size, image_mode)
self.window.blit(pygame_surface, (0, 0)) # Blit at position (0,0)
# Update the display
pygame.display.flip()
if self.render_mode == "human":
# The following line copies our drawings from `canvas` to the visible window
# self.window.blit(canvas, canvas.get_rect())
pygame.event.pump()
pygame.display.update()
# We need to ensure that human-rendering occurs at the predefined framerate.
# The following line will automatically add a delay to keep the framerate stable.
self.clock.tick(self.metadata["render_fps"])
else: # rgb_array
return np.transpose(np.array(pygame.surfarray.pixels3d(pygame_surface)), axes=(1, 0, 2))
[docs]
def close(self):
if self.window is not None:
pygame.display.quit()
pygame.quit()
if __name__ == "__main__":
import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np
from gymnasium.wrappers import RecordVideo
import stable_worldmodel as swm
# 1. Setup Environment
# Create a CartPole environment with "rgb_array" render mode to get image data
images = [
swm.utils.create_pil_image_from_url(
"https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcQK5OnlnP3_GHXI2y1LoIHbMROdN8_DYyLEGg&s"
).resize((64, 64)),
swm.utils.create_pil_image_from_url(
"https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcQjrFGrhOLwgYP0cdjTIBEWMpy9MHBcya4c5Q&s"
).resize((32, 32)),
]
env = gym.make(
"swm/ImagePositioning",
render_mode="rgb_array",
resolution=224,
images=images,
background_power_decay=2,
max_episode_steps=20,
) #
env = gym.wrappers.AddRenderObservation(env, render_only=False)
swm.collect.random_action(env, num_episodes=1)
env = RecordVideo(
env,
video_folder="cartpole-agent", # Folder to save videos
name_prefix="eval", # Prefix for video filenames
episode_trigger=lambda x: True, # Record every episode
)
# 2. Reset the environment to get an initial observation
observation, info = env.reset() #
print(observation)
print(info)
# 3. Render the environment to get the image array
# The render method returns an RGB array when render_mode is "rgb_array"
# 4. Save the figure
# Use Matplotlib to display and save the image
fig, axs = plt.subplots(1, 2)
rgb_array = env.unwrapped.render() #
axs[0].imshow(rgb_array)
axs[0].set_xticks([])
axs[0].set_yticks([])
axs[0].set_title("Init.")
rgb_array = env.unwrapped.render(mode="target")
axs[1].imshow(rgb_array)
axs[1].set_xticks([])
axs[1].set_yticks([])
axs[1].set_title("Target")
plt.savefig("cartpole_observation.png")
plt.close() # Close the plot to free up memory
for i in range(5):
action = env.unwrapped._get_optimal_action()
env.step(action)
print("Saved CartPole observation as cartpole_observation.png")
# 5. Close the environment
env.close()