From a11d1e72997a00a9bcf7d370774729d68630de70 Mon Sep 17 00:00:00 2001 From: Daniel Cortez Stevenson Date: Fri, 8 Feb 2019 16:08:18 +0100 Subject: [PATCH] Replaced functional API in base.py with OO API Why? 1. Conform more with the estimator OO API 2. Standardize how datasets are returned 3. Make adding new datasets easy! 4. Add functionality to datasets easier - See issues #10972 and #11818 5. Integrate with other modules/packages well - create a proper `Dataset` class with methods Note: I do not think this implementation is complete or perfect, but I feel it's time to reach out to other developers about this project Work Done: 1. Introduced 3 abstract classes - `DatasetLoader` - single entrypoint method: `load` - abstract method: `_raw_data_to_bunch` - `LocalDatasetLoader` - abstract methods: `read_data` and `process` - `SimpleCSVDatasetLoader` 2. Deprecated `load_*` functions for datasets - replaced by `load` method of `DatasetLoader` - new API is executed in their docstrings 3. Made `load_data` more general and refactored - Now a method of `SimpleDatasetLoader` 4. Added tests for new API 5. API relies on passing `Bunch` between methods 6. Bunch objects have standardized attributes - there was significant variation before 7. Typing for a sklearn Dataset - `Union[Bunch, Tuple[Any, Any]]` TODO: 1. Extend `DatasetLoader` for Remote classes - `RemoteDatasetLoader`, for example 2. Add `as_frame` param to `DatsetLoader.load` - to address #10972 and #11818 3. Improve OO design (this is a first go) 4. Even more tests --- sklearn/datasets/__init__.py | 20 +- sklearn/datasets/base.py | 560 ++++++++++++++++------------ sklearn/datasets/tests/test_base.py | 105 ++++++ 3 files changed, 453 insertions(+), 232 deletions(-) diff --git a/sklearn/datasets/__init__.py b/sklearn/datasets/__init__.py index 77dac99c1d970..b9436b232b7ed 100644 --- a/sklearn/datasets/__init__.py +++ b/sklearn/datasets/__init__.py @@ -3,6 +3,14 @@ including methods to load and fetch popular reference datasets. It also features some artificial data generators. """ +from .base import Boston +from .base import BreastCancer +from .base import Diabetes +from .base import Digits +from .base import Iris +from .base import Linnerud +from .base import Wine +from .base import SampleImages from .base import load_breast_cancer from .base import load_boston from .base import load_diabetes @@ -52,7 +60,11 @@ from .rcv1 import fetch_rcv1 -__all__ = ['clear_data_home', +__all__ = ['Boston', + 'BreastCancer', + 'clear_data_home', + 'Diabetes', + 'Digits', 'dump_svmlight_file', 'fetch_20newsgroups', 'fetch_20newsgroups_vectorized', @@ -67,6 +79,8 @@ 'fetch_kddcup99', 'fetch_openml', 'get_data_home', + 'Iris', + 'Linnerud', 'load_boston', 'load_diabetes', 'load_digits', @@ -99,4 +113,6 @@ 'make_sparse_uncorrelated', 'make_spd_matrix', 'make_swiss_roll', - 'mldata_filename'] + 'mldata_filename', + 'SampleImages', + 'Wine'] diff --git a/sklearn/datasets/base.py b/sklearn/datasets/base.py index 34214bf3f58de..5448cfae8ee20 100644 --- a/sklearn/datasets/base.py +++ b/sklearn/datasets/base.py @@ -6,8 +6,6 @@ # 2010 Fabian Pedregosa # 2010 Olivier Grisel # License: BSD 3 clause -import os -import csv import sys import shutil from collections import namedtuple @@ -19,11 +17,17 @@ from ..utils import check_random_state import numpy as np +from numpy import ndarray from urllib.request import urlretrieve +from typing import Dict, Tuple, Union, Any +from abc import ABCMeta, abstractmethod +from ..utils import deprecated + RemoteFileMetadata = namedtuple('RemoteFileMetadata', ['filename', 'url', 'checksum']) +Dataset = Union[Bunch, Tuple[Any, Any]] def get_data_home(data_home=None): @@ -204,17 +208,16 @@ def load_files(container_path, description=None, categories=None, DESCR=description) -def load_data(module_path, data_file_name): +@deprecated("'load_data' was renamed to" + "'SimpleCSVLocalDatasetLoader.read_X_y_csv'" + "in version 0.21 and will be removed in 0.23.") +def load_data(path: str) -> Tuple[ndarray, ndarray, ndarray]: """Loads data from module_path/data/data_file_name. Parameters ---------- - module_path : string - The module path. - - data_file_name : string - Name of csv file to be loaded from - module_path/data/data_file_name. For example 'wine_data.csv'. + path : string + The data file path. Returns ------- @@ -230,22 +233,18 @@ def load_data(module_path, data_file_name): A 1D array containing the names of the classifications. For example target_names[0] is the name of the target[0] class. """ - with open(join(module_path, 'data', data_file_name)) as csv_file: - data_file = csv.reader(csv_file) - temp = next(data_file) - n_samples = int(temp[0]) - n_features = int(temp[1]) - target_names = np.array(temp[2:]) - data = np.empty((n_samples, n_features)) - target = np.empty((n_samples,), dtype=np.int) + return SimpleCSVLocalDatasetLoader.read_X_y_csv(path=path) - for i, ir in enumerate(data_file): - data[i] = np.asarray(ir[:-1], dtype=np.float64) - target[i] = np.asarray(ir[-1], dtype=np.int) - return data, target, target_names +def _attempt_cast_to_int(arr: ndarray) -> ndarray: + arri = arr.astype('int', casting='unsafe') + if (arr == arri).all(): + return arri + return arr +@deprecated("'load_wine' was renamed to 'Wine().load'" + "in version 0.21 and will be removed in 0.23.") def load_wine(return_X_y=False): """Load and return the wine dataset (classification). @@ -289,40 +288,18 @@ def load_wine(return_X_y=False): Let's say you are interested in the samples 10, 80, and 140, and want to know their class name. - >>> from sklearn.datasets import load_wine - >>> data = load_wine() + >>> from sklearn.datasets import Wine + >>> data = Wine().load() >>> data.target[[10, 80, 140]] array([0, 1, 2]) >>> list(data.target_names) ['class_0', 'class_1', 'class_2'] """ - module_path = dirname(__file__) - data, target, target_names = load_data(module_path, 'wine_data.csv') - - with open(join(module_path, 'descr', 'wine_data.rst')) as rst_file: - fdescr = rst_file.read() - - if return_X_y: - return data, target - - return Bunch(data=data, target=target, - target_names=target_names, - DESCR=fdescr, - feature_names=['alcohol', - 'malic_acid', - 'ash', - 'alcalinity_of_ash', - 'magnesium', - 'total_phenols', - 'flavanoids', - 'nonflavanoid_phenols', - 'proanthocyanins', - 'color_intensity', - 'hue', - 'od280/od315_of_diluted_wines', - 'proline']) + return Wine().load(return_X_y=return_X_y) +@deprecated("'load_iris' was renamed to 'Iris().load'" + "in version 0.21 and will be removed in 0.23.") def load_iris(return_X_y=False): """Load and return the iris dataset (classification). @@ -373,31 +350,18 @@ def load_iris(return_X_y=False): Let's say you are interested in the samples 10, 25, and 50, and want to know their class name. - >>> from sklearn.datasets import load_iris - >>> data = load_iris() + >>> from sklearn.datasets import Iris + >>> data = Iris().load() >>> data.target[[10, 25, 50]] array([0, 0, 1]) >>> list(data.target_names) ['setosa', 'versicolor', 'virginica'] """ - module_path = dirname(__file__) - data, target, target_names = load_data(module_path, 'iris.csv') - iris_csv_filename = join(module_path, 'data', 'iris.csv') - - with open(join(module_path, 'descr', 'iris.rst')) as rst_file: - fdescr = rst_file.read() - - if return_X_y: - return data, target - - return Bunch(data=data, target=target, - target_names=target_names, - DESCR=fdescr, - feature_names=['sepal length (cm)', 'sepal width (cm)', - 'petal length (cm)', 'petal width (cm)'], - filename=iris_csv_filename) + return Iris().load(return_X_y=return_X_y) +@deprecated("'load_breast_cancer' was renamed to 'BreastCancer().load'" + "in version 0.21 and will be removed in 0.23.") def load_breast_cancer(return_X_y=False): """Load and return the breast cancer wisconsin dataset (classification). @@ -445,46 +409,18 @@ def load_breast_cancer(return_X_y=False): Let's say you are interested in the samples 10, 50, and 85, and want to know their class name. - >>> from sklearn.datasets import load_breast_cancer - >>> data = load_breast_cancer() + >>> from sklearn.datasets import BreastCancer + >>> data = BreastCancer().load() >>> data.target[[10, 50, 85]] array([0, 1, 0]) >>> list(data.target_names) ['malignant', 'benign'] """ - module_path = dirname(__file__) - data, target, target_names = load_data(module_path, 'breast_cancer.csv') - csv_filename = join(module_path, 'data', 'breast_cancer.csv') - - with open(join(module_path, 'descr', 'breast_cancer.rst')) as rst_file: - fdescr = rst_file.read() - - feature_names = np.array(['mean radius', 'mean texture', - 'mean perimeter', 'mean area', - 'mean smoothness', 'mean compactness', - 'mean concavity', 'mean concave points', - 'mean symmetry', 'mean fractal dimension', - 'radius error', 'texture error', - 'perimeter error', 'area error', - 'smoothness error', 'compactness error', - 'concavity error', 'concave points error', - 'symmetry error', 'fractal dimension error', - 'worst radius', 'worst texture', - 'worst perimeter', 'worst area', - 'worst smoothness', 'worst compactness', - 'worst concavity', 'worst concave points', - 'worst symmetry', 'worst fractal dimension']) - - if return_X_y: - return data, target - - return Bunch(data=data, target=target, - target_names=target_names, - DESCR=fdescr, - feature_names=feature_names, - filename=csv_filename) + return BreastCancer().load(return_X_y=return_X_y) +@deprecated("'load_digits' was renamed to 'Digits().load'" + "in version 0.21 and will be removed in 0.23.") def load_digits(n_class=10, return_X_y=False): """Load and return the digits dataset (classification). @@ -531,8 +467,8 @@ def load_digits(n_class=10, return_X_y=False): -------- To load the data and visualize the images:: - >>> from sklearn.datasets import load_digits - >>> digits = load_digits() + >>> from sklearn.datasets import Digits + >>> digits = Digits().load() >>> print(digits.data.shape) (1797, 64) >>> import matplotlib.pyplot as plt #doctest: +SKIP @@ -540,31 +476,11 @@ def load_digits(n_class=10, return_X_y=False): >>> plt.matshow(digits.images[0]) #doctest: +SKIP >>> plt.show() #doctest: +SKIP """ - module_path = dirname(__file__) - data = np.loadtxt(join(module_path, 'data', 'digits.csv.gz'), - delimiter=',') - with open(join(module_path, 'descr', 'digits.rst')) as f: - descr = f.read() - target = data[:, -1].astype(np.int) - flat_data = data[:, :-1] - images = flat_data.view() - images.shape = (-1, 8, 8) - - if n_class < 10: - idx = target < n_class - flat_data, target = flat_data[idx], target[idx] - images = images[idx] - - if return_X_y: - return flat_data, target - - return Bunch(data=flat_data, - target=target, - target_names=np.arange(10), - images=images, - DESCR=descr) + return Digits(n_class=n_class).load(return_X_y=return_X_y) +@deprecated("'load_diabetes' was renamed to 'Diabetes().load'" + "in version 0.21 and will be removed in 0.23.") def load_diabetes(return_X_y=False): """Load and return the diabetes dataset (regression). @@ -598,26 +514,11 @@ def load_diabetes(return_X_y=False): .. versionadded:: 0.18 """ - module_path = dirname(__file__) - base_dir = join(module_path, 'data') - data_filename = join(base_dir, 'diabetes_data.csv.gz') - data = np.loadtxt(data_filename) - target_filename = join(base_dir, 'diabetes_target.csv.gz') - target = np.loadtxt(target_filename) - - with open(join(module_path, 'descr', 'diabetes.rst')) as rst_file: - fdescr = rst_file.read() - - if return_X_y: - return data, target - - return Bunch(data=data, target=target, DESCR=fdescr, - feature_names=['age', 'sex', 'bmi', 'bp', - 's1', 's2', 's3', 's4', 's5', 's6'], - data_filename=data_filename, - target_filename=target_filename) + return Diabetes().load(return_X_y=return_X_y) +@deprecated("'load_linnerud' was renamed to 'Linnerud().load'" + "in version 0.21 and will be removed in 0.23.") def load_linnerud(return_X_y=False): """Load and return the linnerud dataset (multivariate regression). @@ -654,34 +555,11 @@ def load_linnerud(return_X_y=False): .. versionadded:: 0.18 """ - base_dir = join(dirname(__file__), 'data/') - data_filename = join(base_dir, 'linnerud_exercise.csv') - target_filename = join(base_dir, 'linnerud_physiological.csv') - - # Read data - data_exercise = np.loadtxt(data_filename, skiprows=1) - data_physiological = np.loadtxt(target_filename, skiprows=1) - - # Read header - with open(data_filename) as f: - header_exercise = f.readline().split() - with open(target_filename) as f: - header_physiological = f.readline().split() - - with open(dirname(__file__) + '/descr/linnerud.rst') as f: - descr = f.read() - - if return_X_y: - return data_exercise, data_physiological - - return Bunch(data=data_exercise, feature_names=header_exercise, - target=data_physiological, - target_names=header_physiological, - DESCR=descr, - data_filename=data_filename, - target_filename=target_filename) + return Linnerud().load(return_X_y=return_X_y) +@deprecated("'load_boston' was renamed to 'Boston().load'" + "in version 0.21 and will be removed in 0.23.") def load_boston(return_X_y=False): """Load and return the boston house-prices dataset (regression). @@ -722,43 +600,16 @@ def load_boston(return_X_y=False): Examples -------- - >>> from sklearn.datasets import load_boston - >>> boston = load_boston() + >>> from sklearn.datasets import Boston + >>> boston = Boston().load() >>> print(boston.data.shape) (506, 13) """ - module_path = dirname(__file__) - - fdescr_name = join(module_path, 'descr', 'boston_house_prices.rst') - with open(fdescr_name) as f: - descr_text = f.read() - - data_file_name = join(module_path, 'data', 'boston_house_prices.csv') - with open(data_file_name) as f: - data_file = csv.reader(f) - temp = next(data_file) - n_samples = int(temp[0]) - n_features = int(temp[1]) - data = np.empty((n_samples, n_features)) - target = np.empty((n_samples,)) - temp = next(data_file) # names of features - feature_names = np.array(temp) - - for i, d in enumerate(data_file): - data[i] = np.asarray(d[:-1], dtype=np.float64) - target[i] = np.asarray(d[-1], dtype=np.float64) - - if return_X_y: - return data, target - - return Bunch(data=data, - target=target, - # last column is target value - feature_names=feature_names[:-1], - DESCR=descr_text, - filename=data_file_name) + return Boston().load(return_X_y=return_X_y) +@deprecated("'load_sample_images' was renamed to 'SampleImages().load'" + "in version 0.21 and will be removed in 0.23.") def load_sample_images(): """Load sample images for image manipulation. @@ -777,8 +628,8 @@ def load_sample_images(): -------- To load the data and visualize the images: - >>> from sklearn.datasets import load_sample_images - >>> dataset = load_sample_images() #doctest: +SKIP + >>> from sklearn.datasets import SampleImages + >>> dataset = SampleImages().load() #doctest: +SKIP >>> len(dataset.images) #doctest: +SKIP 2 >>> first_img_data = dataset.images[0] #doctest: +SKIP @@ -787,23 +638,12 @@ def load_sample_images(): >>> first_img_data.dtype #doctest: +SKIP dtype('uint8') """ - # import PIL only when needed - from ..externals._pilutil import imread - - module_path = join(dirname(__file__), "images") - with open(join(module_path, 'README.txt')) as f: - descr = f.read() - filenames = [join(module_path, filename) - for filename in os.listdir(module_path) - if filename.endswith(".jpg")] - # Load image data for each image in the source folder. - images = [imread(filename) for filename in filenames] - - return Bunch(images=images, - filenames=filenames, - DESCR=descr) + return SampleImages().load() +@deprecated("'load_sample_image' was refactored" + "in version 0.21 and will be removed in 0.23." + "Use 'SampleImages().load' instead") def load_sample_image(image_name): """Load the numpy array of a single sample image @@ -821,28 +661,288 @@ def load_sample_image(image_name): Examples --------- - - >>> from sklearn.datasets import load_sample_image - >>> china = load_sample_image('china.jpg') # doctest: +SKIP + >>> from sklearn.datasets import SampleImages + >>> china = SampleImages('china.jpg').load().images[0] # doctest: +SKIP >>> china.dtype # doctest: +SKIP dtype('uint8') >>> china.shape # doctest: +SKIP (427, 640, 3) - >>> flower = load_sample_image('flower.jpg') # doctest: +SKIP + >>> flower = SampleImages('flower.jpg').load().images[0] # doctest: +SKIP >>> flower.dtype # doctest: +SKIP dtype('uint8') >>> flower.shape # doctest: +SKIP (427, 640, 3) """ - images = load_sample_images() - index = None - for i, filename in enumerate(images.filenames): - if filename.endswith(image_name): - index = i - break - if index is None: - raise AttributeError("Cannot find sample image: %s" % image_name) - return images.images[index] + return SampleImages(image_name).load().images[0] + + +class DatasetLoader(object, metaclass=ABCMeta): + """Abstract class for all dataset loaders in scikit-learn.""" + + def load(self, return_X_y=False) -> Dataset: + bunch = self._raw_data_to_bunch() + if return_X_y: + return bunch.data, bunch.target + return bunch + + @abstractmethod + def _raw_data_to_bunch(self) -> Bunch: + raise NotImplementedError + + +class LocalDatasetLoader(DatasetLoader): + _module_path = dirname(__file__) + _data_dir = join(_module_path, 'data') + _descr_dir = join(_module_path, 'descr') + _images_dir = join(_module_path, 'images') + + @property + def X_file(self) -> Union[str, ndarray]: + raise NotImplementedError + + @property + def y_file(self) -> str: + return self.X_file + + @property + def descr_file(self): + return self.X_file.split('.', maxsplit=1)[0] + '.rst' + + @property + def local_data_paths(self) -> Dict[str, str]: + try: + return { + 'X': join(self._data_dir, self.X_file), + 'y': join(self._data_dir, self.y_file), + 'descr': join(self._descr_dir, self.descr_file), + } + except TypeError: + return { + 'X': np.array([join(self._data_dir, f) for f in self.X_file]), + 'y': np.array([join(self._data_dir, f) for f in self.y_file]), + 'descr': join(self._descr_dir, self.descr_file), + } + + @property + def feature_names(self) -> ndarray: + raise NotImplementedError + + @property + def target_names(self) -> ndarray: + raise NotImplementedError + + _attempt_cast_to_int = staticmethod(_attempt_cast_to_int) + + def _raw_data_to_bunch(self) -> Bunch: + bunch = self.read_data() + bunch.DESCR = self._read_description() + return self.process(bunch) + + @abstractmethod + def read_data(self) -> Bunch: + raise NotImplementedError + + @abstractmethod + def process(self, bunch: Bunch) -> Bunch: + bunch.target = self._attempt_cast_to_int(bunch.target) + return bunch + + def _read_description(self, path: str = None) -> str: + descr_path = path or self.local_data_paths['descr'] + with open(descr_path) as descr: + return descr.read() + + def _make_bunch(self, X, y, target_names, + description, images=None) -> Bunch: + return Bunch( + data=X, target=y, + feature_names=self.feature_names, + target_names=target_names, + DESCR=description, + images=images, + data_filename=self.local_data_paths['X'], + target_filename=self.local_data_paths['y'], + filename=self.local_data_paths['X']) # Backwards compatible (0.21) + + +class SimpleCSVLocalDatasetLoader(LocalDatasetLoader): + """Reads a .csv file with: + The first row containing: + - the sample size (int) + - the number of features (int) + - [Optional] the target names corresponding + to their 'int' representation in the y column + The second [Optional] row containing: + - feature variable names with type 'str' + X features in columns [:-1] with type 'int' or 'float' + y target in column [-1] with type 'int' or 'float' + + If you try to read a .csv file containing 'object' or + 'str' type variable values, it will fail! + """ + def read_data(self) -> Bunch: + X, y, target_names = self.read_X_y_csv(self.local_data_paths['X']) + if self.target_names.size > 0: # class property takes precedence + target_names = self.target_names + return self._make_bunch(X, y, target_names, None, None) + + def process(self, bunch: Bunch) -> Bunch: + return super().process(bunch) + + @staticmethod + def read_X_y_csv(path: str) -> Tuple[ndarray, ndarray, ndarray]: + with open(path) as f: + firstline = f.readline().rstrip().split(',') + n_features = int(firstline[1]) + col_ixs = tuple(range(n_features + 1)) + target_names = np.array(firstline[2:]) + + csv_arr = np.genfromtxt(path, delimiter=',', usecols=col_ixs, + skip_header=1, dtype=np.float) + data, target = csv_arr[:, :n_features], csv_arr[:, n_features] + + mask_row_all_nan = ~np.isnan(data).all(axis=1) + data, target = data[mask_row_all_nan], target[mask_row_all_nan] + return data, target, target_names + + +class Wine(SimpleCSVLocalDatasetLoader): + X_file = 'wine_data.csv' + feature_names = np.array([ + 'alcohol', 'malic_acid', + 'ash', 'alcalinity_of_ash', + 'magnesium', 'total_phenols', + 'flavanods', 'nonflavanoid_phenols', + 'proanthocyanins', 'color_intensity', + 'hue', 'od280/od315_of_diluted_wines', + 'proline']) + target_names = np.array([]) + + +class Iris(SimpleCSVLocalDatasetLoader): + X_file = 'iris.csv' + feature_names = np.array([ + 'sepal length (cm)', 'sepal width (cm)', + 'petal length (cm)', 'petal width (cm)']) + target_names = np.array([]) + + +class BreastCancer(SimpleCSVLocalDatasetLoader): + X_file = 'breast_cancer.csv' + feature_names = np.array([ + 'mean radius', 'mean texture', + 'mean perimeter', 'mean area', + 'mean smoothness', 'mean compactness', + 'mean concavity', 'mean concave points', + 'mean symmetry', 'mean fractal dimension', + 'radius error', 'texture error', + 'perimeter error', 'area error', + 'smoothness error', 'compactness error', + 'concavity error', 'concave points error', + 'symmetry error', 'fractal dimension error', + 'worst radius', 'worst texture', + 'worst perimeter', 'worst area', + 'worst smoothness', 'worst compactness', + 'worst concavity', 'worst concave points', + 'worst symmetry', 'worst frctal dimension']) + target_names = np.array([]) + + +class Digits(LocalDatasetLoader): + X_file = 'digits.csv.gz' + feature_names = np.array([list(map(lambda d: (d, x).__repr__(), + np.arange(1, 9, dtype='int'))) + for x in np.arange(1, 9, dtype='int')]) + target_names = np.arange(10) + + def __init__(self, n_class=10): + self.n_class = n_class + + def read_data(self): + X = np.loadtxt(self.local_data_paths['X'], delimiter=',') + y = X[:, -1].astype('int') + return self._make_bunch(X, y, self.target_names, None, None) + + def process(self, bunch: Bunch) -> Bunch: + X, y = bunch.data, bunch.target + flat_X = X[:, :-1] + images = flat_X.view() + images.shape = (-1, 8, 8) + if self.n_class < 10: + idx = y < self.n_class + flat_X, y = flat_X[idx], y[idx] + images = images[idx] + bunch.data, bunch.target, bunch.images = flat_X, y, images + return bunch + + +class Diabetes(LocalDatasetLoader): + X_file = 'diabetes_data.csv.gz' + y_file = 'diabetes_target.csv.gz' + descr_file = 'diabetes.rst' + feature_names = np.array([ + 'age', 'sex', 'bmi', 'bp', + 's1', 's2', 's3', 's4', 's5', 's6']) + target_names = np.array(['progression']) + + def read_data(self): + X = np.loadtxt(self.local_data_paths['X'], dtype='float') + y = np.loadtxt(self.local_data_paths['y'], dtype='float') + return self._make_bunch(X, y, self.target_names, None, None) + + def process(self, bunch: Bunch): + return super().process(bunch) + + +class Linnerud(LocalDatasetLoader): + X_file = 'linnerud_exercise.csv' + y_file = 'linnerud_physiological.csv' + descr_file = 'linnerud.rst' + feature_names = np.array(['Chins', 'Situps', 'Jumps']) + target_names = np.array(['Weight', 'Waist', 'Pulse']) + + def read_data(self) -> Bunch: + X = np.loadtxt(self.local_data_paths['X'], skiprows=1, dtype='float') + y = np.loadtxt(self.local_data_paths['y'], skiprows=1, dtype='float') + return self._make_bunch(X, y, self.target_names, None, None) + + def process(self, bunch: Bunch) -> Bunch: + return bunch + + +class Boston(SimpleCSVLocalDatasetLoader): + X_file = 'boston_house_prices.csv' + feature_names = np.array([ + 'CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', + 'DIS', 'RAD', 'TAX', 'PTRATIO', 'B', 'LSTAT']) + target_names = np.array(['MEDV']) + + +class SampleImages(LocalDatasetLoader): + X_file = np.array(['china.jpg', 'flower.jpg']) + feature_names = np.array([]) + target_names = np.array(['china', 'flower']) + _descr_dir = LocalDatasetLoader._images_dir + descr_file = 'README.txt' + + def __init__(self, image_name: str = None): + self.image_name = image_name + self._check_image_name() + + def _check_image_name(self): + if self.image_name and self.image_name not in self.X_file: + msg = 'Cannot find sample image: %s' % self.image_name + raise AttributeError(msg) + + def read_data(self) -> Bunch: + from ..externals._pilutil import imread # import PIL only when needed + image_files = [join(self._images_dir, file) for file in self.X_file] + images = [imread(img) for img in image_files] + return self._make_bunch(None, None, self.target_names, None, images) + + def process(self, bunch: Bunch): + bunch.filenames = bunch.filename # Backwards compatible (0.21) + return bunch def _pkl_filepath(*args, **kwargs): diff --git a/sklearn/datasets/tests/test_base.py b/sklearn/datasets/tests/test_base.py index 08a6ba29413cf..60741145d7e1e 100644 --- a/sklearn/datasets/tests/test_base.py +++ b/sklearn/datasets/tests/test_base.py @@ -21,6 +21,14 @@ from sklearn.datasets import load_breast_cancer from sklearn.datasets import load_boston from sklearn.datasets import load_wine +from sklearn.datasets import SampleImages # noqa: F401 +from sklearn.datasets import Digits # noqa: F401 +from sklearn.datasets import Diabetes # noqa: F401 +from sklearn.datasets import Linnerud # noqa: F401 +from sklearn.datasets import Iris # noqa: F401 +from sklearn.datasets import BreastCancer # noqa: F401 +from sklearn.datasets import Boston # noqa: F401 +from sklearn.datasets import Wine # noqa: F401 from sklearn.datasets.base import Bunch from sklearn.datasets.tests.test_common import check_return_X_y @@ -29,6 +37,8 @@ from sklearn.utils.testing import assert_equal from sklearn.utils.testing import assert_raises +import numpy.testing.utils as np_test_util + def _remove_dir(path): if os.path.isdir(path): @@ -109,6 +119,7 @@ def test_load_files_w_categories_desc_and_encoding( assert_equal(res.data, ["Hello World!\n"]) +@pytest.mark.filterwarnings('ignore::DeprecationWarning') def test_load_files_wo_load_content( test_category_dir_1, test_category_dir_2, load_files_root): res = load_files(load_files_root, load_content=False) @@ -118,6 +129,7 @@ def test_load_files_wo_load_content( assert_equal(res.get('data'), None) +@pytest.mark.filterwarnings('ignore::DeprecationWarning') def test_load_sample_images(): try: res = load_sample_images() @@ -128,6 +140,7 @@ def test_load_sample_images(): warnings.warn("Could not load sample images, PIL is not available.") +@pytest.mark.filterwarnings('ignore::DeprecationWarning') def test_load_digits(): digits = load_digits() assert_equal(digits.data.shape, (1797, 64)) @@ -137,12 +150,14 @@ def test_load_digits(): check_return_X_y(digits, partial(load_digits)) +@pytest.mark.filterwarnings('ignore::DeprecationWarning') def test_load_digits_n_class_lt_10(): digits = load_digits(9) assert_equal(digits.data.shape, (1617, 64)) assert_equal(numpy.unique(digits.target).size, 9) +@pytest.mark.filterwarnings('ignore::DeprecationWarning') def test_load_sample_image(): try: china = load_sample_image('china.jpg') @@ -152,6 +167,7 @@ def test_load_sample_image(): warnings.warn("Could not load sample images, PIL is not available.") +@pytest.mark.filterwarnings('ignore::DeprecationWarning') def test_load_missing_sample_image_error(): if pillow_installed: assert_raises(AttributeError, load_sample_image, @@ -160,6 +176,7 @@ def test_load_missing_sample_image_error(): warnings.warn("Could not load sample images, PIL is not available.") +@pytest.mark.filterwarnings('ignore::DeprecationWarning') def test_load_diabetes(): res = load_diabetes() assert_equal(res.data.shape, (442, 10)) @@ -171,6 +188,7 @@ def test_load_diabetes(): check_return_X_y(res, partial(load_diabetes)) +@pytest.mark.filterwarnings('ignore::DeprecationWarning') def test_load_linnerud(): res = load_linnerud() assert_equal(res.data.shape, (20, 3)) @@ -184,6 +202,7 @@ def test_load_linnerud(): check_return_X_y(res, partial(load_linnerud)) +@pytest.mark.filterwarnings('ignore::DeprecationWarning') def test_load_iris(): res = load_iris() assert_equal(res.data.shape, (150, 4)) @@ -196,6 +215,7 @@ def test_load_iris(): check_return_X_y(res, partial(load_iris)) +@pytest.mark.filterwarnings('ignore::DeprecationWarning') def test_load_wine(): res = load_wine() assert_equal(res.data.shape, (178, 13)) @@ -207,6 +227,7 @@ def test_load_wine(): check_return_X_y(res, partial(load_wine)) +@pytest.mark.filterwarnings('ignore::DeprecationWarning') def test_load_breast_cancer(): res = load_breast_cancer() assert_equal(res.data.shape, (569, 30)) @@ -219,6 +240,7 @@ def test_load_breast_cancer(): check_return_X_y(res, partial(load_breast_cancer)) +@pytest.mark.filterwarnings('ignore::DeprecationWarning') def test_load_boston(): res = load_boston() assert_equal(res.data.shape, (506, 13)) @@ -259,7 +281,90 @@ def test_bunch_pickle_generated_with_0_16_and_read_with_0_17(): assert_equal(bunch_from_pkl['key'], 'changed') +@pytest.mark.filterwarnings('ignore::DeprecationWarning') def test_bunch_dir(): # check that dir (important for autocomplete) shows attributes data = load_iris() assert "data" in dir(data) + + +@pytest.mark.parametrize('test_loader,expected_target_dtype', [ + ('Iris', 'int'), + ('Boston', 'float'), + ('BreastCancer', 'int'), + ('Digits', 'int'), + ('Diabetes', 'int'), + ('Linnerud', 'float'), + ('Wine', 'int') +]) +def test_dataset_loader_dtype(test_loader, expected_target_dtype): + assert eval(test_loader)().load().target.dtype == expected_target_dtype + + +@pytest.mark.parametrize('test_loader', [ + 'Iris', + 'Boston', + 'BreastCancer', + 'Digits', + 'Diabetes', + 'Linnerud', + 'Wine' +]) +def test_dataset_loader_bunch_paths(test_loader): + bunch = eval(test_loader)().load() + paths = bunch.filename, bunch.data_filename, bunch.target_filename + assert all(list(map(os.path.exists, paths))) is True + + +@pytest.mark.parametrize('test_loader,exp_features,exp_targets,exp_n', [ + ('Iris', 4, 3, 150), + ('Boston', 13, 1, 506), + ('BreastCancer', 30, 2, 569), + ('Digits', 64, 10, 1797), + ('Diabetes', 10, 1, 442), + ('Linnerud', 3, 3, 20), + ('Wine', 13, 3, 178)]) +def test_dataset_loader_shape(test_loader, exp_features, + exp_targets, exp_n): + bunch = eval(test_loader)().load() + n_features, m_features = bunch.data.shape[:2] + n_targets, m_targets = bunch.target.shape[0], bunch.target_names.size + assert (m_features == exp_features) and \ + (m_targets == exp_targets) and \ + (n_features == n_targets == exp_n) + + +@pytest.mark.parametrize('test_loader', [ + 'Iris', + 'Boston', + 'BreastCancer', + 'Digits', + 'Diabetes', + 'Linnerud', + 'Wine' +]) +def test_dataset_loader_check_nan(test_loader): + bunch = eval(test_loader)().load() + data, target = bunch.data, bunch.target + np_test_util.assert_equal(numpy.isnan(data).any(), False) + np_test_util.assert_equal(numpy.isnan(target).any(), False) + + +def test_load_data_deprecated(): + from ..base import load_data + iris_path = Iris().local_data_paths['X'] + pytest.deprecated_call(load_data, iris_path) + + +@pytest.mark.parametrize('test_deprecated_fun', [ + 'load_wine', + 'load_iris', + 'load_digits', + 'load_diabetes', + 'load_boston', + 'load_breast_cancer', + 'load_linnerud', + 'load_sample_images' +]) +def test_functional_load_deprecated(test_deprecated_fun): + pytest.deprecated_call(eval(test_deprecated_fun))