Source code for stable_ssl.data.module

import copy
import inspect
from typing import Iterable, Optional, Union

import hydra
import lightning as pl
import torch
from loguru import logger as logging
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader, Dataset

from .sampler import RepeatedRandomSampler
from .utils import HFDataset


class DictFormat(Dataset):
    """Format dataset with named columns for dictionary-style access.

    Args:
        dataset (Iterable): Dataset to be wrapped.
        names (Iterable): Column names for the dataset.
    """

    def __init__(self, dataset: Iterable, names: Iterable):
        self.dataset = dataset
        self.names = names

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        values = self.dataset[idx]
        sample = {k: v for k, v in zip(self.names, values)}
        return sample


[docs] class DataModule(pl.LightningDataModule): """PyTorch Lightning DataModule for handling train/val/test/predict dataloaders.""" def __init__( self, train: Optional[Union[dict, DictConfig, DataLoader]] = None, test: Optional[Union[dict, DictConfig, DataLoader]] = None, val: Optional[Union[dict, DictConfig, DataLoader]] = None, predict: Optional[Union[dict, DictConfig, DataLoader]] = None, **kwargs, ): super().__init__() if train is None and test is None and val is None and predict is None: raise ValueError("They can't all be none") logging.info("Setting up DataModule") if train is None: logging.warning( "⚠️⚠️⚠️ train was not passed to DataModule, it is required" "unless you only validate ⚠️⚠️⚠️" ) self.train = self._format_data_conf(train, "train") self.test = self._format_data_conf(test, "test") if val is None: logging.warning( "⚠️⚠️⚠️ val was not passed to DataModule, it is required" "unless you set `num_sanity_val_steps=0` and `val_check_interval=0` ⚠️⚠️⚠️" ) self.val = self._format_data_conf(val, "val") self.predict = self._format_data_conf(predict, "predict") for k, v in kwargs.items(): setattr(self, k, v) self._trainer = None @staticmethod def _format_data_conf(conf: Union[dict, DictConfig], stage: str): if conf is None: return None if isinstance(conf, DataLoader): return conf elif isinstance(conf["dataset"], HFDataset): logging.info(f"\t{stage} already has an instantiated dataset! ✅") elif type(conf) is dict: logging.info(f"\t{stage} has `dict` type and no instantiated dataset!") conf = OmegaConf.create(conf) logging.info(f"\t{stage} created `DictConfig`! ✅") elif type(conf) is not DictConfig: raise ValueError(f"`{conf}` must be a dict of DictConfig") sign = inspect.signature(DataLoader) # we check that user gives the required parameters for k, param in sign.parameters.items(): if param.default is param.empty and k not in conf: raise ValueError(f"conf must specify a value for {k}") # we check that user doesn't give extra parameters for k in conf.keys(): if k not in sign.parameters: raise ValueError(f"{k} given in conf is not a DataLoader kwarg") conf = copy.deepcopy(conf) logging.info(f"\t{stage} conf is valid and saved! ✅") return conf
[docs] def setup(self, stage): # TODO: should we move some to prepare_data? if stage not in ["fit", "validate", "test", "predict"]: raise ValueError(f"Invalid stage {stage}") d = None if stage == "fit" and not isinstance(self.train, DataLoader): self.train_dataset = d = hydra.utils.instantiate( self.train.dataset, _convert_="object" ) self.val_dataset = hydra.utils.instantiate( self.val.dataset, _convert_="object" ) elif stage == "test" and not isinstance(self.test, DataLoader): self.test_dataset = d = hydra.utils.instantiate( self.test.dataset, _convert_="object" ) elif stage == "validate" and not isinstance(self.val, DataLoader): self.val_dataset = d = hydra.utils.instantiate( self.val.dataset, _convert_="object" ) elif stage == "predict" and not isinstance(self.predict, DataLoader): self.predict_dataset = d = hydra.utils.instantiate( self.predict.dataset, _convert_="object" ) logging.info(f"dataset for {stage} loaded! ✅") if d is not None: logging.info(f"\t● length: {len(d)}") logging.info(f"\t● columns: {d.column_names}") else: logging.info("\t● setup was done by user")
def _get_loader_kwargs(self, config, dataset): kwargs = dict() for k, v in config.items(): if k == "dataset": continue if type(v) in [dict, DictConfig] and "_target_" in v: kwargs[k] = hydra.utils.instantiate(v, _convert_="object") if "_partial_" in v: kwargs[k] = kwargs[k](dataset) else: kwargs[k] = v return kwargs
[docs] def train_dataloader(self): if isinstance(self.train, DataLoader): loader = self.train else: kwargs = self._get_loader_kwargs(self.train, self.train_dataset) loader = DataLoader(dataset=self.train_dataset, **kwargs) loader.dataset.set_pl_trainer(self._trainer) if ( self.trainer is not None and self.trainer.world_size > 1 and isinstance(loader.sampler, RepeatedRandomSampler) ): sampler = RepeatedRandomSampler( loader.sampler._data_source_len, loader.sampler.n_views, loader.sampler.replacement, loader.sampler.seed, ) loader = DataLoader( dataset=loader.dataset, sampler=sampler, batch_size=loader.batch_size, shuffle=False, num_workers=loader.num_workers, collate_fn=loader.collate_fn, pin_memory=loader.pin_memory, drop_last=loader.drop_last, timeout=loader.timeout, worker_init_fn=loader.worker_init_fn, prefetch_factor=loader.prefetch_factor, persistent_workers=loader.persistent_workers, pin_memory_device=loader.pin_memory_device, in_order=loader.in_order, ) return loader
[docs] def val_dataloader(self): if isinstance(self.val, DataLoader): self.val.dataset.set_pl_trainer(self._trainer) return self.val kwargs = self._get_loader_kwargs(self.val, self.val_dataset) self.val_dataset.set_pl_trainer(self._trainer) return DataLoader(dataset=self.val_dataset, **kwargs)
[docs] def test_dataloader(self): if isinstance(self.test, DataLoader): self.test.dataset.set_pl_trainer(self._trainer) return self.test kwargs = self._get_loader_kwargs(self.test, self.test_dataset) self.test_dataset.set_pl_trainer(self._trainer) return DataLoader(dataset=self.test_dataset, **kwargs)
[docs] def predict_dataloader(self): if isinstance(self.predict, DataLoader): self.predict.dataset.set_pl_trainer(self._trainer) return self.predict kwargs = self._get_loader_kwargs(self.predict, self.predict_dataset) self.predict_dataset.set_pl_trainer(self._trainer) return DataLoader(dataset=self.predict_dataset, **kwargs)
[docs] def teardown(self, stage: str): pass
[docs] def state_dict(self): return {}
[docs] def load_state_dict(self, state_dict): return
def set_pl_trainer(self, trainer: pl.Trainer): print(trainer) self._trainer = trainer