Source code for stable_datasets.images.med_mnist

import datasets
import numpy as np

from stable_datasets.utils import BaseDatasetBuilder


MEDMNIST_VERSION = datasets.Version("1.0.0")


[docs] class MedMNISTConfig(datasets.BuilderConfig): """BuilderConfig with per-variant metadata used by MedMNIST._info().""" def __init__(self, *, num_classes: int, is_3d: bool = False, multi_label: bool = False, **kwargs): super().__init__(version=MEDMNIST_VERSION, **kwargs) self.num_classes = num_classes self.is_3d = is_3d self.multi_label = multi_label
[docs] class MedMNIST(BaseDatasetBuilder): """MedMNIST, a large-scale MNIST-like collection of standardized biomedical images, including 12 datasets for 2D and 6 datasets for 3D.""" VERSION = MEDMNIST_VERSION BUILDER_CONFIGS = [ # 2D Datasets MedMNISTConfig(name="pathmnist", description="MedMNIST PathMNIST (2D)", num_classes=9), MedMNISTConfig( name="chestmnist", description="MedMNIST ChestMNIST (2D, multi-label)", num_classes=14, multi_label=True, ), MedMNISTConfig(name="dermamnist", description="MedMNIST DermaMNIST (2D)", num_classes=7), MedMNISTConfig(name="octmnist", description="MedMNIST OCTMNIST (2D)", num_classes=4), MedMNISTConfig(name="pneumoniamnist", description="MedMNIST PneumoniaMNIST (2D)", num_classes=2), MedMNISTConfig(name="retinamnist", description="MedMNIST RetinaMNIST (2D)", num_classes=5), MedMNISTConfig(name="breastmnist", description="MedMNIST BreastMNIST (2D)", num_classes=2), MedMNISTConfig(name="bloodmnist", description="MedMNIST BloodMNIST (2D)", num_classes=8), MedMNISTConfig(name="tissuemnist", description="MedMNIST TissueMNIST (2D)", num_classes=8), MedMNISTConfig(name="organamnist", description="MedMNIST OrganAMNIST (2D)", num_classes=11), MedMNISTConfig(name="organcmnist", description="MedMNIST OrganCMNIST (2D)", num_classes=11), MedMNISTConfig(name="organsmnist", description="MedMNIST OrganSMNIST (2D)", num_classes=11), # 3D Datasets MedMNISTConfig(name="organmnist3d", description="MedMNIST OrganMNIST3D (3D)", num_classes=11, is_3d=True), MedMNISTConfig(name="nodulemnist3d", description="MedMNIST NoduleMNIST3D (3D)", num_classes=2, is_3d=True), MedMNISTConfig(name="adrenalmnist3d", description="MedMNIST AdrenalMNIST3D (3D)", num_classes=2, is_3d=True), MedMNISTConfig(name="fracturemnist3d", description="MedMNIST FractureMNIST3D (3D)", num_classes=3, is_3d=True), MedMNISTConfig(name="vesselmnist3d", description="MedMNIST VesselMNIST3D (3D)", num_classes=2, is_3d=True), MedMNISTConfig(name="synapsemnist3d", description="MedMNIST SynapseMNIST3D (3D)", num_classes=2, is_3d=True), ] def _source(self) -> dict: """Variant-aware source definition (computed from self.config at runtime).""" variant = self.config.name url = f"https://zenodo.org/records/10519652/files/{variant}.npz?download=1" # Single NPZ contains all splits; we map each split name to the same URL. return { "homepage": "https://medmnist.com/", "assets": {"train": url, "test": url, "val": url}, "citation": """@article{medmnistv2, title={MedMNIST v2-A large-scale lightweight benchmark for 2D and 3D biomedical image classification}, author={Yang, Jiancheng and Shi, Rui and Wei, Donglai and Liu, Zequan and Zhao, Lin and Ke, Bilian and Pfister, Hanspeter and Ni, Bingbing}, journal={Scientific Data}, volume={10}, number={1}, pages={41}, year={2023}, publisher={Nature Publishing Group UK London} }""", } def _info(self): variant = self.config.name source = self._source() if getattr(self.config, "multi_label", False): # multi-label instead of multi-class label_feature = datasets.Sequence(datasets.Value("int8")) else: label_feature = datasets.ClassLabel(num_classes=self.config.num_classes) return datasets.DatasetInfo( description=f"MedMNIST variant: {variant} dataset.", features=datasets.Features( { "image": ( datasets.Array3D(shape=(28, 28, 28), dtype="uint8") if getattr(self.config, "is_3d", False) else datasets.Image() ), "label": label_feature, } ), supervised_keys=("image", "label"), homepage=source["homepage"], license="CC BY 4.0", citation=source["citation"], ) def _generate_examples(self, data_path, split): data = np.load(data_path) images = data[f"{split}_images"] labels = data[f"{split}_labels"].squeeze() for idx, (image, label) in enumerate(zip(images, labels)): yield idx, {"image": image, "label": label}