Source code for stable_datasets.cassava

#!/usr/bin/env python
import io
import os
import time
import urllib
import zipfile

import matplotlib.image as mpimg
import numpy as np


__author__ = "Randall Balestriero"


[docs] class cassava: """Plant images classification. The data consists of two folders, a training folder that contains 5 subfolders that contain the respective images for the different 5 classes and a test folder containing test images. Participants are to train their models using the images in the training folder and provide a submission file like the sample provided which contains the image name exactly matching the image name in the test folder and the corresponding class prediction with labels corresponding to the disease categories, cmd, healthy, cgm, cbsd, cbb. Please cite this paper if you use the dataset for your project: https://arxiv.org/pdf/1908.02900.pdf """ classes = ["cbb", "cmd", "cbsd", "cgm", "healthy"]
[docs] @staticmethod def download(path): """ Download the cassava dataset and store the result into the given path Parameters ---------- path: str the path where the downloaded files will be stored. If the directory does not exist, it is created. """ # Check if directory exists if not os.path.isdir(path + "cassava"): print("Creating cassava Directory") os.mkdir(path + "cassava") # Check if file exists if not os.path.exists(path + "cassava/cassavaleafdata.zip"): url = "https://storage.googleapis.com/emcassavadata/" + "cassavaleafdata.zip" urllib.request.urlretrieve(url, path + "cassava/cassavaleafdata.zip")
[docs] @staticmethod def load(path=None): """ Parameters ---------- path: str (optional) default ($DATASET_PATH), the path to look for the data and where the data will be downloaded if not present Returns ------- train_images: array train_labels: array valid_images: array valid_labels: array test_images: array test_labels: array """ if path is None: path = os.environ["DATASET_PATH"] cassava.download(path) t0 = time.time() # Loading the file data = {"train": [[], []], "test": [[], []], "validation": [[], []]} f = zipfile.ZipFile(path + "cassava/cassavaleafdata.zip") for filename in f.namelist(): if ".jpg" not in filename: continue setname, foldername = filename.split("/")[1:3] img = mpimg.imread(io.BytesIO(f.read(filename)), "jpg") data[setname][0].append(img) data[setname][1].append(cassava.classes.index(foldername)) train_images = np.array(data["train"][0]) test_images = np.array(data["test"][0]) valid_images = np.array(data["validation"][0]) train_labels = np.array(data["train"][1]) test_labels = np.array(data["test"][1]) valid_labels = np.array(data["validation"][1]) print(f"Dataset cassava loaded in {time.time() - t0:.2f}s.") return ( train_images, train_labels, valid_images, valid_labels, test_images, test_labels, )