Source code for stable_datasets.images.cassava
#!/usr/bin/env python
"""Legacy Cassava loader (to be refactored into a BaseDatasetBuilder).
This module was moved under `stable_datasets.images` to align the repository layout.
It still exposes the original imperative `cassava.load(...)` API for now.
"""
import io
import os
import time
import urllib
import zipfile
import matplotlib.image as mpimg
import numpy as np
__author__ = "Randall Balestriero"
[docs]
class cassava:
"""Plant images classification.
The data consists of two folders, a training folder that contains 5
subfolders that contain the respective images for the different 5 classes
and a test folder containing test images.
"""
classes = ["cbb", "cmd", "cbsd", "cgm", "healthy"]
[docs]
@staticmethod
def download(path):
# Check if directory exists
if not os.path.isdir(path + "cassava"):
print("Creating cassava Directory")
os.mkdir(path + "cassava")
# Check if file exists
if not os.path.exists(path + "cassava/cassavaleafdata.zip"):
url = "https://storage.googleapis.com/emcassavadata/" + "cassavaleafdata.zip"
urllib.request.urlretrieve(url, path + "cassava/cassavaleafdata.zip")
[docs]
@staticmethod
def load(path=None):
if path is None:
path = os.environ["DATASET_PATH"]
cassava.download(path)
t0 = time.time()
# Loading the file
data = {"train": [[], []], "test": [[], []], "validation": [[], []]}
f = zipfile.ZipFile(path + "cassava/cassavaleafdata.zip")
for filename in f.namelist():
if ".jpg" not in filename:
continue
setname, foldername = filename.split("/")[1:3]
img = mpimg.imread(io.BytesIO(f.read(filename)), "jpg")
data[setname][0].append(img)
data[setname][1].append(cassava.classes.index(foldername))
train_images = np.array(data["train"][0])
test_images = np.array(data["test"][0])
valid_images = np.array(data["validation"][0])
train_labels = np.array(data["train"][1])
test_labels = np.array(data["test"][1])
valid_labels = np.array(data["validation"][1])
print(f"Dataset cassava loaded in {time.time() - t0:.2f}s.")
return (
train_images,
train_labels,
valid_images,
valid_labels,
test_images,
test_labels,
)