from contextlib import contextmanager
from itertools import islice
from random import getstate, setstate
from random import seed as rseed
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
import PIL.Image
import torch
import torchvision
from PIL import ImageFilter
from torchvision import tv_tensors
from torchvision.transforms import v2
from torchvision.transforms.functional import InterpolationMode
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2._utils import query_chw
from PIL import Image
from stable_pretraining.data.masking import multi_block_mask
[docs]
@torch.jit.unused
def to_image(
input: Union[torch.Tensor, PIL.Image.Image, np.ndarray],
) -> tv_tensors.Image:
"""See :class:`~torchvision.transforms.v2.ToImage` for details."""
if isinstance(input, np.ndarray):
output = torch.from_numpy(np.atleast_3d(input)).transpose(-3, -1).contiguous()
elif isinstance(input, PIL.Image.Image):
output = torchvision.transforms.functional.pil_to_tensor(input)
elif isinstance(input, torch.Tensor):
output = input
else:
raise TypeError(
f"Input can either be a pure Tensor, a numpy array, or a PIL image, but got {type(input)} instead."
)
return tv_tensors.Image(output)
[docs]
class ToImage(Transform):
"""Convert input to image tensor with optional normalization."""
def __init__(
self,
dtype=torch.float32,
scale=True,
mean=None,
std=None,
source: str = "image",
target: str = "image",
):
super().__init__()
t = [to_image, v2.ToDtype(dtype, scale=scale)]
if mean is not None and std is not None:
t.append(v2.Normalize(mean=mean, std=std))
self.t = v2.Compose(t)
self.source = source
self.target = target
def __call__(self, x):
self.nested_set(x, self.t(self.nested_get(x, self.source)), self.target)
return x
[docs]
class RandomGrayscale(Transform, v2.RandomGrayscale):
"""Randomly convert image to grayscale with given probability."""
def __init__(self, p=0.1, source: str = "image", target: str = "image"):
super().__init__(p)
self.source = source
self.target = target
def _get_params(self, inp: List[Any]) -> Dict[str, Any]:
num_input_channels, *_ = query_chw([inp])
return dict(num_input_channels=num_input_channels)
def __call__(self, x) -> Any:
if self.p < 1 and torch.rand(1) >= self.p:
x[self.get_name(x)] = False
self.nested_set(x, self.nested_get(x, self.source), self.target)
return x
channels, *_ = query_chw([self.nested_get(x, self.source)])
self.nested_set(
x,
F.rgb_to_grayscale(
self.nested_get(x, self.source), num_output_channels=channels
),
self.target,
)
x[self.get_name(x)] = True
return x
[docs]
class Lambda(Transform):
"""Applies a lambda callable to target key and store it in source."""
def __init__(self, lambd, source: str = "image", target: str = "image"):
super().__init__()
self.source = source
self.target = target
self.lambd = lambd
def __call__(self, x) -> Any:
self.nested_set(x, self.lambd(x), self.target)
return x
[docs]
class RandomSolarize(Transform, v2.RandomSolarize):
"""Randomly solarize image by inverting pixel values above threshold."""
def __init__(self, threshold, p=0.5, source: str = "image", target: str = "image"):
super().__init__(threshold, p)
self.source = source
self.target = target
def __call__(self, x) -> Any:
if self.p < 1 and torch.rand(1) >= self.p:
x[self.get_name(x)] = False
return x
self.nested_set(
x, F.solarize(self.nested_get(x, self.source), self.threshold), self.target
)
x[self.get_name(x)] = True
return x
[docs]
class GaussianBlur(Transform, v2.GaussianBlur):
"""Apply Gaussian blur to image with random sigma values."""
_NAMES = ["sigma_x", "sigma_y"]
def __init__(
self,
kernel_size,
sigma=(0.1, 2.0),
p=1,
source: str = "image",
target: str = "image",
):
super().__init__(kernel_size, sigma)
self.p = p
self.source = source
self.target = target
def __call__(self, x) -> Any:
if self.p < 1 and torch.rand(1) >= self.p:
x[self.get_name(x)] = torch.zeros((2,))
return x
params = self.make_params([])
self.nested_set(
x, self.transform(self.nested_get(x, self.source), params), self.target
)
x[self.get_name(x)] = torch.Tensor(params["sigma"])
return x
[docs]
class PILGaussianBlur(Transform):
"""PIL-based Gaussian blur transform with random sigma sampling."""
_NAMES = ["sigma_x", "sigma_y"]
def __init__(self, sigma=None, p=1, source: str = "image", target: str = "image"):
"""Gaussian blur as a callable object.
Args:
sigma (Sequence[float]): range to sample the radius of the gaussian blur filter.
Defaults to [0.1, 2.0].
p (float): probability of applying the transform.
source (str): source key in the data dictionary.
target (str): target key in the data dictionary.
"""
if sigma is None:
sigma = [0.1, 2.0]
self.sigma = sigma
self.p = p
self.source = source
self.target = target
def __call__(self, x):
"""Applies gaussian blur to an input image.
Args:
x (dict): Data dictionary containing the image to transform.
Returns:
dict: Data dictionary with blurred image.
"""
if self.p < 1 and torch.rand(1) >= self.p:
x[self.get_name(x)] = torch.zeros((1,))
return x
sigma = torch.rand((1,)) * (self.sigma[1] - self.sigma[0]) + self.sigma[0]
x[self.get_name(x)] = sigma
self.nested_set(
x,
self.nested_get(x, self.source).filter(
ImageFilter.GaussianBlur(radius=sigma.item())
),
self.target,
)
return x
[docs]
class RandomContiguousTemporalSampler(Transform):
"""Randomly sample contiguous frames from a video sequence."""
def __init__(self, source, target, num_frames, frame_subsampling: int = 1):
self.source = source
self.target = target
self.num_frames = num_frames
self.frame_subsampling = frame_subsampling
def __call__(self, x):
metadata = self.nested_get(x, self.source).get_metadata()
T = int(metadata["video"]["duration"][0] * metadata["video"]["fps"][0])
covering = self.num_frames * self.frame_subsampling
start = torch.randint(low=0, high=T - covering, size=(1,)).item()
video_frames = [] # video frame buffer
# Seek and return frames
count = 0
for frame in islice(
self.nested_get(x, self.source).seek(start / metadata["video"]["fps"][0]),
covering,
):
if count % self.frame_subsampling == 0:
video_frames.append(frame["data"])
count += 1
# Stack it into a tensor
self.nested_set(x, torch.stack(video_frames, 0), self.target)
x[self.get_name(x)] = start
return x
[docs]
class RGB(Transform, v2.RGB):
"""Convert image to RGB format."""
def __init__(self, source: str = "image", target: str = "image"):
super().__init__()
self.source = source
self.target = target
def __call__(self, x):
self.nested_set(
x, F.grayscale_to_rgb(self.nested_get(x, self.source)), self.target
)
return x
[docs]
class Resize(Transform, v2.Resize):
"""Resize image to specified size."""
def __init__(
self,
size,
interpolation=2,
max_size=None,
antialias=True,
source="image",
target="image",
) -> None:
super().__init__(size, interpolation, max_size, antialias)
self.source = source
self.target = target
def __call__(self, x):
self.nested_set(
x, self.transform(self.nested_get(x, self.source), []), self.target
)
return x
[docs]
class ColorJitter(Transform, v2.ColorJitter):
"""Randomly change brightness, contrast, saturation, and hue of an image."""
def __init__(
self,
brightness=None,
contrast=None,
saturation=None,
hue=None,
p=1,
source: str = "image",
target: str = "image",
):
super().__init__(brightness, contrast, saturation, hue)
self.p = p
self.source = source
self.target = target
def __call__(self, x) -> Any:
if self.p < 1 and torch.rand(1) > self.p:
self.nested_set(x, self.nested_get(x, self.source), self.target)
x[self.get_name(x)] = torch.zeros(8)
return x
params = self.make_params([])
self.nested_set(
x, self.transform(self.nested_get(x, self.source), params), self.target
)
brightness_factor = params["brightness_factor"]
contrast_factor = params["contrast_factor"]
saturation_factor = params["saturation_factor"]
hue_factor = params["hue_factor"]
perm = params["fn_idx"].tolist()
x[self.get_name(x)] = torch.Tensor(
[brightness_factor, contrast_factor, saturation_factor, hue_factor] + perm
)
return x
[docs]
class RandomRotation(Transform, v2.RandomRotation):
"""Rotate image by random angle within specified degrees range."""
def __init__(
self,
degrees,
interpolation=InterpolationMode.NEAREST,
expand=False,
center=None,
fill=0,
source: str = "image",
target: str = "image",
):
super().__init__(degrees, interpolation, expand, center, fill)
self.source = source
self.target = target
def __call__(self, x):
angle = self.make_params([])
self.nested_set(
x, self.transform(self.nested_get(x, self.source), angle), self.target
)
x[self.get_name(x)] = angle
return x
[docs]
class RandomChannelPermutation(Transform, v2.RandomChannelPermutation):
"""Randomly permute the channels of an image."""
def __init__(self, source: str = "image", target: str = "image"):
super().__init__()
self.source = source
self.target = target
def __call__(self, x) -> Any:
num_channels, *_ = query_chw([self.nested_get(x, self.source)])
perm = torch.randperm(num_channels)
self.nested_set(
x, F.permute_channels(self.nested_get(x, self.source), perm), self.target
)
x[self.get_name(x)] = perm
return x
[docs]
class RandomCrop(Transform, v2.RandomCrop):
"""Crop a random portion of image and resize it to given size."""
_NAMES = ["needs_crop", "top", "left", "height", "width", "needs_pad", "padding"]
def __init__(
self,
size,
padding=None,
pad_if_needed=False,
fill=0,
padding_mode="constant",
source: str = "image",
target: str = "image",
):
super().__init__(size, padding, pad_if_needed, fill, padding_mode)
self.source = source
self.target = target
def __call__(self, x):
params = self.make_params([self.nested_get(x, self.source)])
self.nested_set(
x, self.transform(self.nested_get(x, self.source), params), self.target
)
values = []
values.append(params["needs_crop"])
values.append(params["top"])
values.append(params["left"])
values.append(params["height"])
values.append(params["width"])
values.append(params["needs_pad"])
values.extend(params["padding"])
x[self.get_name(x)] = torch.Tensor(values)
return x
[docs]
class RandomHorizontalFlip(Transform, v2.RandomHorizontalFlip):
"""Horizontally flip the given image randomly with a given probability."""
def __init__(self, p=0.5, source: str = "image", target: str = "image"):
super().__init__(p)
self.source = source
self.target = target
def __call__(self, x) -> Any:
if self.p > 0 and torch.rand(1) < self.p:
candidates = self.nested_get(x, self.source)
if type(candidates) in [tuple, list]:
out = [F.horizontal_flip(c) for c in candidates]
self.nested_set(x, out, self.target)
else:
self.nested_set(x, F.horizontal_flip(candidates), self.target)
x[self.get_name(x)] = True
else:
self.nested_set(x, self.nested_get(x, self.source), self.target)
x[self.get_name(x)] = False
return x
[docs]
class RandomResizedCrop(Transform, v2.RandomResizedCrop):
"""Crop a random portion of image and resize it to given size."""
_NAMES = ["top", "left", "height", "width"]
def __init__(
self,
size: Union[int, Sequence[int]],
scale: Tuple[float, float] = (0.08, 1.0),
ratio: Tuple[float, float] = (3.0 / 4.0, 4.0 / 3.0),
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[bool] = True,
source: str = "image",
target: str = "image",
):
super().__init__(size, scale, ratio, interpolation, antialias)
self.source = source
self.target = target
def __call__(self, x):
params = self.make_params([self.nested_get(x, self.source)])
candidates = self.nested_get(x, self.source)
if type(candidates) in [tuple, list]:
out = [self.transform(c, params) for c in candidates]
self.nested_set(x, out, self.target)
else:
self.nested_set(
x, self.transform(self.nested_get(x, self.source), params), self.target
)
values = []
values.append(params["top"])
values.append(params["left"])
values.append(params["height"])
values.append(params["width"])
x[self.get_name(x)] = torch.Tensor(values)
return x
[docs]
class PatchMasking(Transform):
"""Randomly masks square patches in an image, similar to patch masking used in Masked Signal Encoding (MSE) tasks.
This transform operates on a dictionary input, applies patch masking to the image found at the specified `source` key,
and writes the masked image to the `target` key. It also saves a boolean mask matrix (one entry per patch) to the
`mask_key` in the dictionary, indicating which patches were masked (False) or kept (True).
The output image remains in the same format as the input (PIL Image or Tensor), and the masking is performed efficiently
for both input types.
Args:
patch_size (int): The size (in pixels) of each square patch to be masked.
drop_ratio (float): The exact fraction of patches to randomly mask (set to the mask value).
source (str): The key in the input dictionary from which to read the image.
target (str): The key in the output dictionary to which the masked image will be written.
mask_key (str): The key in the output dictionary to which the boolean patch mask will be written.
fill_value (float, optional): The value to use for masked patches. If None, defaults to 0.0 for float tensors,
and 128/255.0 for PIL images (mid-gray). Can be set to any float in [0,1] for normalized images.
"""
def __init__(
self,
patch_size: int = 16,
drop_ratio: float = 0.5,
source: str = "image",
target: str = "image",
fill_value: float = None,
mask_key: str = "patch_mask",
):
super().__init__()
if not 0.0 <= drop_ratio <= 1.0:
raise ValueError(f"drop_ratio must be in [0, 1], got {drop_ratio}")
if patch_size <= 0:
raise ValueError(f"patch_size must be positive, got {patch_size}")
self.mask_key = mask_key
self.patch_size = patch_size
self.drop_ratio = drop_ratio
self.source = source
self.target = target
self.fill_value = fill_value
def __call__(self, x):
img = self.nested_get(x, self.source)
img_tensor = self._to_tensor(img)
_, H, W = img_tensor.shape
# Compute number of patches
n_patches_h = H // self.patch_size
n_patches_w = W // self.patch_size
total_patches = n_patches_h * n_patches_w
# Generate mask with EXACT drop ratio (not probabilistic)
n_masked = int(total_patches * self.drop_ratio)
perm = torch.randperm(total_patches)
mask_flat = torch.ones(total_patches, dtype=torch.bool)
mask_flat[perm[:n_masked]] = False # False = masked
mask = mask_flat.reshape(n_patches_h, n_patches_w)
# Determine mask value
if self.fill_value is not None:
fill_value = self.fill_value
else:
fill_value = 0.0
fill_value = torch.tensor(
fill_value, dtype=img_tensor.dtype, device=img_tensor.device
)
# Vectorized masking: upsample patch mask to full resolution
# Create full-size mask initialized to True (keep remainder pixels)
full_mask = torch.ones(H, W, dtype=torch.bool, device=img_tensor.device)
# Upsample patch mask and copy to full mask
upsampled_mask = mask.repeat_interleave(
self.patch_size, dim=0
).repeat_interleave(self.patch_size, dim=1)
full_mask[: upsampled_mask.shape[0], : upsampled_mask.shape[1]] = upsampled_mask
masked_img_out = torch.where(full_mask, img_tensor, fill_value)
self.nested_set(x, masked_img_out, self.target)
self.nested_set(x, mask, self.mask_key)
return x
@staticmethod
def _to_tensor(img):
if isinstance(img, torch.Tensor):
if img.dtype == torch.uint8:
img = img.float() / 255.0
if img.ndim == 3:
return img
elif img.ndim == 2:
return img.unsqueeze(0)
elif isinstance(img, Image.Image):
return F.pil_to_tensor(img).float() / 255.0
else:
raise TypeError("Unsupported image type")
[docs]
class CenterCrop(Transform, v2.CenterCrop):
"""Crop the center of an image to the given size."""
_NAMES = []
def __init__(self, size, source: str = "image", target: str = "image"):
super().__init__(size)
self.source = source
self.target = target
def __call__(self, x):
self.nested_set(
x, self.transform(self.nested_get(x, self.source), []), self.target
)
return x
[docs]
def set_seed(seeds):
if hasattr(seeds[0], "__len__"):
version, state, gauss = seeds[0]
setstate((version, tuple(state), gauss))
else:
rseed(seeds[0])
if hasattr(seeds[1], "__len__"):
np.random.set_state(seeds[1])
else:
np.random.seed(seeds[1])
if hasattr(seeds[2], "__len__"):
torch.set_rng_state(seeds[2])
else:
torch.manual_seed(seeds[2])
if len(seeds) == 4:
if hasattr(seeds[3], "__len__"):
torch.cuda.set_rng_state_all(seeds[3])
else:
torch.cuda.manual_seed(seeds[3])
[docs]
@contextmanager
def random_seed(seed):
seeds = [getstate(), np.random.get_state(), torch.get_rng_state()]
if False: # torch.cuda.is_available():
seeds.append(torch.cuda.get_rng_state_all())
new_seeds = [int(seed)] * len(seeds)
set_seed(new_seeds)
yield
set_seed(seeds)
[docs]
class Conditional(Transform):
"""Apply transform conditionally based on a data dictionary key."""
def __init__(self, transform, condition_key, apply_on_true=True):
super().__init__()
self._transform = transform
self.condition_key = condition_key
self.apply_on_true = apply_on_true
def __call__(self, x):
if x[self.condition_key] and self.apply_on_true:
return self._transform(x)
elif not x[self.condition_key] and not self.apply_on_true:
return self._transform(x)
# if the transform is not applied we still inform the user
# otherwise collate_fn will complain
x[self._transform.get_name(x)] = self._transform.BYPASS_VALUE
return x
[docs]
class AdditiveGaussian(Transform):
"""Add Gaussian noise to input data."""
BYPASS_VALUE = False
def __init__(self, sigma, p=1):
super().__init__()
if not torch.is_tensor(sigma):
sigma = torch.Tensor([sigma])[0]
self.sigma = sigma
self.p = p
def __call__(self, x):
if self.p == 0 or self.p < torch.rand(1):
x[self.get_name(x)] = self.BYPASS_VALUE
return x
x[self.get_name(x)] = True
out = torch.randn_like(x["image"]).mul_(self.sigma)
x["image"] = x["image"].add_(out)
return x
[docs]
class Compose(v2.Transform):
"""Compose multiple transforms together in sequence."""
def __init__(self, *args):
super().__init__()
self.args = args
def __call__(self, sample):
for a in self.args:
sample = a(sample)
return sample
[docs]
class ContextTargetsMultiBlockMask(Transform):
"""Transform that adds multi-block masks to batch, with multiple target blocks and one disjoint context block.
Args:
patch_size: Size of the patch in patches
num_blocks: Number of blocks to sample
context_scale: Scale of the context block
aspect_ratio: Aspect ratio of the blocks
min_keep: Minimum number of patches that must be in the block
"""
def __init__(
self,
patch_size=16,
context_scale=(0.85, 1.0),
context_aspect_ratio=(1.0, 1.0),
target_scales=((0.15, 0.2),) * 4,
target_aspect_ratios=((0.75, 1.5),) * 4,
min_keep=10,
source: str = "image",
target_context: str = "mask_context",
target_targets: str = "masks_target",
):
super().__init__()
self.patch_size = patch_size
self.context_scale = context_scale
self.context_aspect_ratio = context_aspect_ratio
self.target_scales = target_scales
self.target_aspect_ratios = target_aspect_ratios
self.source = source
self.target_context = target_context
self.target_targets = target_targets
if len(target_scales) != len(target_aspect_ratios):
raise ValueError(
"Each scale must have its associated aspect ratio and vice versa.",
"Received {len(target_scales)=} {len(target_aspect_ratios)=}",
)
self.min_keep = min_keep
def __call__(self, x):
source = self.nested_get(x, self.source)
if isinstance(source, PIL.Image.Image):
W, H = source.size # PIL is W,H
elif isinstance(source, torch.Tensor):
# assumes H W
H, W = source.shape[-2:]
else:
raise ValueError(
f"Source must be a PIL.Image.Image or a torch.Tensor, but got {type(source)} instead."
)
scales = [self.context_scale, *self.target_scales]
aspect_ratios = [self.context_aspect_ratio, *self.target_aspect_ratios]
context_mask, *target_masks = multi_block_mask(
H // self.patch_size,
W // self.patch_size,
block_scales=scales,
aspect_ratios=aspect_ratios,
min_keep=self.min_keep,
)
# makes targets disjoint with context
for mask in target_masks:
context_mask &= ~mask
x[self.target_context] = torch.nonzero(context_mask.flatten()).squeeze()
x[self.target_targets] = [
torch.nonzero(mask.flatten()).squeeze() for mask in target_masks
]
x[self.get_name(x)] = torch.tensor([scales, aspect_ratios])
return x
[docs]
class RandomMask(Transform):
r"""Creates a random MAE-style mask for an image.
This transform generates a random permutation of all patch indices for an
input image. It then splits these indices into two disjoint sets:
'visible' and 'masked', according to the specified `mask_ratio`.
It also provides an `ids_restore` tensor, which can un-shuffle a sequence
of patches back to its original 2D grid order. All outputs are added as
new keys to the sample dictionary.
Example:
>>> # xdoctest: +SKIP
>>> transform = RandomMask(patch_size=16, mask_ratio=0.75)
>>> sample = {"image": torch.randn(3, 224, 224)}
>>> result = transform(sample)
>>> sorted(result.keys())
['image', 'ids_restore', 'len_keep', 'mask_masked', 'mask_visible']
>>> result["len_keep"]
49
>>> result["mask_visible"].shape
torch.Size([49])
Args:
patch_size (int): The height and width of each square patch.
mask_ratio (float): The fraction of patches to be masked (e.g., 0.75).
source (str): The key in the sample dict for the source image tensor.
target_visible (str): The key to use when storing visible patch indices.
target_masked (str): The key to use when storing masked patch indices.
target_ids_restore (str): The key to use for the restoration indices.
target_len_keep (str): The key to use for the count of visible patches.
"""
def __init__(
self,
patch_size=16,
mask_ratio=0.75,
source: str = "image",
target_visible: str = "mask_visible",
target_masked: str = "mask_masked",
target_ids_restore: str = "ids_restore",
target_len_keep: str = "len_keep",
):
super().__init__()
self.patch_size = patch_size
self.mask_ratio = mask_ratio
self.source = source
self.target_visible = target_visible
self.target_masked = target_masked
self.target_ids_restore = target_ids_restore
self.target_len_keep = target_len_keep
def __call__(self, x):
source = self.nested_get(x, self.source)
if isinstance(source, PIL.Image.Image):
W, H = source.size # PIL is W,H
elif isinstance(source, torch.Tensor):
# NOTE assumes _HW
H, W = source.shape[-2:]
else:
raise ValueError(
f"Source must be a PIL.Image.Image or a torch.Tensor, but got {type(source)} instead."
)
num_patches = (H // self.patch_size) * (W // self.patch_size)
len_keep = int(num_patches * (1 - self.mask_ratio))
# Generate random noise and shuffle indices (like MAE)
noise = torch.rand(num_patches)
ids_shuffle = torch.argsort(noise)
ids_restore = torch.argsort(ids_shuffle) # inverse permutation
# Split into visible and masked
mask_visible = ids_shuffle[:len_keep] # first len_keep are visible
mask_masked = ids_shuffle[len_keep:] # rest are masked
# Add to sample
x[self.target_visible] = mask_visible
x[self.target_masked] = mask_masked
x[self.target_ids_restore] = (
ids_restore # NEW: for reconstructing full sequence
)
x[self.target_len_keep] = len_keep
return x
# class RandomClassSwitch(v2.Transform):
# def __init__(
# self,
# label_key: str,
# new_key: str,
# p: float,
# low: int = -2147483648,
# high: int = 0,
# ):
# super().__init__()
# self.p = p
# self.label_key = label_key
# self.new_key = new_key
# self.low = low
# self.high = high
# def __call__(self, sample: dict):
# assert type(sample) is dict
# assert self.label_key in sample
# assert self.new_key not in sample
# if self.p > 0 and torch.rand(1) < self.p:
# if torch.is_tensor(sample[self.label_key]):
# sample[self.new_key] = torch.randint(
# low=self.low, high=self.high, size=()
# )
# else:
# sample[self.new_key] = np.random.randint(low=self.low, high=self.high)
# else:
# sample[self.new_key] = sample[self.label_key]
# return sample