Source code for stable_datasets.images.e_mnist

import os

import datasets
import numpy as np
import scipy.io as sio


[docs] class EMNISTConfig(datasets.BuilderConfig): def __init__(self, variant, **kwargs): super().__init__(version=datasets.Version("1.0.0", ""), **kwargs) self.variant = variant
[docs] class EMNIST(datasets.GeneratorBasedBuilder): BUILDER_CONFIGS = [ EMNISTConfig(name="byclass", variant="byclass"), EMNISTConfig(name="bymerge", variant="bymerge"), EMNISTConfig(name="balanced", variant="balanced"), EMNISTConfig(name="letters", variant="letters"), EMNISTConfig(name="digits", variant="digits"), EMNISTConfig(name="mnist", variant="mnist"), ] def _info(self): variant = self.config.variant if variant == "byclass": num_classes = 62 elif variant == "bymerge": num_classes = 47 elif variant == "balanced": num_classes = 47 elif variant == "letters": num_classes = 26 elif variant == "digits": num_classes = 10 elif variant == "mnist": num_classes = 10 return datasets.DatasetInfo( description="EMNIST dataset", features=datasets.Features( {"image": datasets.Image(), "label": datasets.ClassLabel(num_classes=num_classes)} ), supervised_keys=("image", "label"), homepage="https://www.nist.gov/itl/iad/image-group/emnist-dataset", citation="""@misc{cohen2017emnistextensionmnisthandwritten, title={EMNIST: an extension of MNIST to handwritten letters}, author={Gregory Cohen and Saeed Afshar and Jonathan Tapson and André van Schaik}, year={2017}, eprint={1702.05373}, archivePrefix={arXiv}, primaryClass={cs.CV}, url={https://arxiv.org/abs/1702.05373}, }""", ) def _split_generators(self, dl_manager): variant = self.config.variant # Download and extract the matlab.zip file extracted_path = dl_manager.download_and_extract("https://biometrics.nist.gov/cs_links/EMNIST/matlab.zip") # The extracted_path now points to the directory containing "matlab" folder # The .mat files are likely in extracted_path/matlab/ mat_dir = os.path.join(extracted_path, "matlab") mat_file = f"emnist-{variant}.mat" mat_path = os.path.join(mat_dir, mat_file) return [ datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"mat_path": mat_path, "split": "train"}), datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"mat_path": mat_path, "split": "test"}), ] def _generate_examples(self, mat_path, split): data = sio.loadmat(mat_path) dataset = data["dataset"][0, 0] subset = dataset[split][0, 0] images = subset["images"] labels = subset["labels"] images = np.array(images, dtype=np.uint8).reshape(-1, 28, 28) labels = np.array(labels, dtype=np.int64).flatten() for idx, (img, lbl) in enumerate(zip(images, labels)): yield idx, {"image": img, "label": int(lbl)}