import os
from dataclasses import dataclass
import logging
import numpy as np
import hydra
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset
from torchvision.transforms.functional import to_pil_image
from .augmentations import TransformsConfig
@dataclass
class DatasetConfig:
"""Configuration for the data used for training the model.
Parameters
----------
dir : str, optional
Path to the directory containing the training data.
Default is "data".
name : str, optional
Name of the dataset to use (e.g., "CIFAR10", "CIFAR100").
Default is "CIFAR10".
split : str, optional
Name of the dataset split to use (e.g., "train", "test").
Default is "train".
num_workers : int, optional
Number of workers to use for data loading.
Default is -1 (use all available CPUs).
batch_size : int, optional
Batch size for training. Default is 256.
transforms : dict, optional
Dictionary of transformations to apply to the data. Default is None.
drop_last : bool, optional
Whether to drop the last incomplete batch. Default is False.
shuffle : bool, optional
Whether to shuffle the data. Default is False.
"""
dir: str = "data"
name: str = "CIFAR10"
split: str = "train"
num_workers: int = -1
batch_size: int = 256
transforms: list[TransformsConfig] = None
drop_last: bool = False
shuffle: bool = False
def __post_init__(self):
"""Initialize transforms if not provided."""
if self.transforms is None:
self.transforms = [TransformsConfig("None")]
else:
self.transforms = [
TransformsConfig(name, t) for name, t in self.transforms.items()
]
@property
def num_classes(self):
"""Return the number of classes in the dataset."""
if self.name == "CIFAR10":
return 10
elif self.name == "CIFAR100":
return 100
@property
def resolution(self):
"""Return the resolution of the images in the dataset."""
if self.name in ["CIFAR10", "CIFAR100"]:
return 32
@property
def data_path(self):
"""Return the path to the dataset."""
return os.path.join(hydra.utils.get_original_cwd(), self.dir, self.name)
def get_dataset(self):
"""Load a dataset from torchvision.datasets.
Raises
------
ValueError
If the dataset is not found in torchvision.datasets.
"""
if not hasattr(torchvision.datasets, self.name):
raise ValueError(f"Dataset {self.name} not found in torchvision.datasets.")
torchvision_dataset = getattr(torchvision.datasets, self.name)
return torchvision_dataset(
root=self.data_path,
train=self.split == "train",
download=True,
transform=Sampler(self.transforms),
)
def get_dataloader(self):
"""Return a DataLoader for the dataset.
Returns
-------
torch.utils.data.DataLoader
DataLoader object for the dataset.
"""
dataset = self.get_dataset()
# FIXME: handle those cases
# if self.config.hardware.world_size > 1:
# sampler = torch.utils.data.distributed.DistributedSampler(
# dataset, shuffle=not train, drop_last=train
# )
# assert self.config.optim.batch_size % self.config.hardware.world_size == 0
# else:
# sampler = None
# per_device_batch_size = (
# self.config.optim.batch_size // self.config.hardware.world_size
# )
if self.num_workers == -1:
if os.environ.get("SLURM_JOB_ID"):
num_workers = os.environ.get("SLURM_JOB_CPUS_PER_NODE", 1)
else:
num_workers = os.cpu_count()
logging.info(
f"Using {num_workers} workers (maximum available) for data loading."
)
else:
num_workers = self.num_workers
loader = torch.utils.data.DataLoader(
dataset,
batch_size=self.batch_size,
num_workers=num_workers,
pin_memory=True,
sampler=None,
shuffle=self.shuffle,
drop_last=self.drop_last,
)
return loader
[docs]
@dataclass
class DataConfig:
"""Configuration for multiple datasets used for training the model.
Parameters
----------
train_on : str
The dataset to train on.
datasets : dict[str, DatasetConfig]
A dictionary of dataset configurations.
"""
train_on: str
datasets: dict[str, DatasetConfig]
def __init__(self, train_on, *args, **datasets):
"""Initialize DataConfig.
Parameters
----------
train_on : str
The dataset to train on.
datasets : dict
A dictionary of dataset configurations.
"""
assert len(args) == 0
self.train_on = train_on
self.datasets = {name: DatasetConfig(**d) for name, d in datasets.items()}
[docs]
def get_datasets(self):
"""Get datasets for training and validation.
Returns
-------
dict
A dictionary containing datasets.
"""
return {name: d.get_dataset() for name, d in self.datasets.items()}
[docs]
def get_dataloaders(self):
"""Get dataloaders for the datasets.
Returns
-------
dict
A dictionary containing dataloaders.
"""
return {name: d.get_dataloader() for name, d in self.datasets.items()}
class Sampler:
"""Apply a list of transforms to an input and return all outputs."""
def __init__(self, transforms: list):
self.transforms = transforms
def __call__(self, x):
views = []
for t in self.transforms:
views.append(t(x))
if len(self.transforms) == 1:
return views[0]
return views
# def load_dataset(dataset_name, data_path, train=True):
# """
# Load a dataset from torchvision.datasets.
# Uses PositivePairSampler for training and ValSampler for validation.
# If coeff_imbalance is not None, create an imbalanced version of the dataset with
# the specified coefficient (exponential imbalance).
# """
# if not hasattr(torchvision.datasets, dataset_name):
# raise ValueError(f"Dataset {dataset_name} not found in torchvision.datasets.")
# torchvision_dataset = getattr(torchvision.datasets, dataset_name)
# if train:
# return torchvision_dataset(
# root=data_path,
# train=True,
# download=True,
# transform=Sampler(dataset=dataset_name),
# )
# return torchvision_dataset(
# root=data_path,
# train=False,
# download=True,
# transform=ValSampler(dataset=dataset_name),
# )
# def imbalance_torchvision_dataset(
# data_path, dataset, dataset_name, coeff_imbalance=2.0
# ):
# save_path = os.path.join(data_path, f"imbalanced_coeff_{coeff_imbalance}.pt")
# if not os.path.exists(save_path):
# data, labels = from_torchvision(data_path=data_path, dataset=dataset)
# imbalanced_data, imbalanced_labels = resample_classes(
# data, labels, coeff_imbalance=coeff_imbalance
# )
# imbalanced_dataset = {"features": imbalanced_data, "labels": imbalanced_labels}
# save_path = os.path.join(data_path, f"imbalanced_coeff_{coeff_imbalance}.pt")
# torch.save(imbalanced_dataset, save_path)
# print(f"[stable-SSL] Subsampling : imbalanced dataset saved to {save_path}.")
# return CustomTorchvisionDataset(
# root=save_path, transform=PositivePairSampler(dataset=dataset_name)
# )
def from_torchvision(data_path, dataset):
"""Load dataset features and labels from torchvision.
Parameters
----------
data_path : str
Path to the dataset.
dataset : torch.utils.data.Dataset
The dataset class from torchvision.
Returns
-------
tuple
Tuple of features and labels.
"""
dataset = dataset(
root=data_path, train=True, download=True, transform=transforms.ToTensor()
)
features = torch.stack([dataset[i][0] for i in range(len(dataset))])
labels = torch.tensor([dataset[i][1] for i in range(len(dataset))])
return features, labels
def resample_classes(dataset, samples_or_freq, random_seed=None):
"""Create an exponential class imbalance.
Parameters
----------
dataset : torch.utils.data.Dataset
The input dataset.
samples_or_freq : iterable
Number of samples or frequency for each class in the new dataset.
random_seed : int, optional
The random seed for reproducibility. Default is None.
Returns
-------
torch.utils.data.Subset
Subset of the dataset with the resampled classes.
Raises
------
ValueError
If the dataset does not have 'labels' or 'targets' attributes.
"""
if hasattr(dataset, "labels"):
labels = dataset.labels
elif hasattr(dataset, "targets"):
labels = dataset.targets
else:
raise ValueError("dataset does not have `labels`")
classes, class_inverse, class_counts = np.unique(
labels, return_counts=True, return_inverse=True
)
logging.info(f"Subsampling : original class counts: {list(class_counts)}")
if np.min(samples_or_freq) < 0:
raise ValueError(
"You can't have negative values in `samples_or_freq`, "
f"got {samples_or_freq}."
)
elif np.sum(samples_or_freq) <= 1:
target_class_counts = np.array(samples_or_freq) * len(dataset)
elif np.sum(samples_or_freq) == len(dataset):
freq = np.array(samples_or_freq) / np.sum(samples_or_freq)
target_class_counts = freq * len(dataset)
if (target_class_counts / class_counts).max() > 1:
raise ValueError("specified more samples per class than available")
else:
raise ValueError(
f"samples_or_freq needs to sum to <= 1 or len(dataset) ({len(dataset)}), "
f"got {np.sum(samples_or_freq)}."
)
target_class_counts = (
target_class_counts / (target_class_counts / class_counts).max()
).astype(int)
logging.info(f"Subsampling : target class counts: {list(target_class_counts)}")
keep_indices = []
generator = np.random.Generator(np.random.PCG64(seed=random_seed))
for cl, count in zip(classes, target_class_counts):
cl_indices = np.flatnonzero(class_inverse == cl)
cl_indices = generator.choice(cl_indices, size=count, replace=False)
keep_indices.extend(cl_indices)
return torch.utils.data.Subset(dataset, indices=keep_indices)
class CustomTorchvisionDataset(Dataset):
"""A custom dataset class for loading torchvision datasets.
Parameters
----------
root : str
Path to the dataset.
transform : callable, optional
Transformation function to apply to the data. Default is None.
"""
def __init__(self, root, transform=None):
"""Initialize the dataset with the given root path and transform."""
self.transform = transform
# Load the dataset from the .pt file
data = torch.load(root)
self.features = data["features"]
self.labels = data["labels"]
def __len__(self):
"""Return the length of the dataset."""
return len(self.features)
def __getitem__(self, idx):
"""Get a sample from the dataset.
Parameters
----------
idx : int
Index of the sample to retrieve.
Returns
-------
tuple
The feature and label of the sample.
"""
feature = self.features[idx]
feature = to_pil_image(feature)
label = self.labels[idx]
if self.transform:
feature = self.transform(feature)
return feature, label