From 52a063f8febeb10aee266bceeecaa8fd1f412814 Mon Sep 17 00:00:00 2001 From: Nick Sorros Date: Fri, 13 Apr 2018 09:25:18 +0100 Subject: [PATCH 1/4] Add as_frame flag in iris dataset to return data as dataframes --- sklearn/datasets/base.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/sklearn/datasets/base.py b/sklearn/datasets/base.py index 2305f60ebbb54..d8061910b6bae 100644 --- a/sklearn/datasets/base.py +++ b/sklearn/datasets/base.py @@ -21,6 +21,7 @@ from ..utils import check_random_state import numpy as np +import pandas as pd from sklearn.externals.six.moves.urllib.request import urlretrieve @@ -322,7 +323,7 @@ def load_wine(return_X_y=False): 'proline']) -def load_iris(return_X_y=False): +def load_iris(return_X_y=False, as_frame=False): """Load and return the iris dataset (classification). The iris dataset is a classic and very easy multi-class classification @@ -382,11 +383,17 @@ def load_iris(return_X_y=False): if return_X_y: return data, target + feature_names = ['sepal length (cm)', 'sepal width (cm)', + 'petal length (cm)', 'petal width (cm)'] + if as_frame: + data_frame = pd.DataFrame(data, columns=feature_names) + target_series = pd.Series(target, name="class") + return data_frame, target_series + return Bunch(data=data, target=target, target_names=target_names, DESCR=fdescr, - feature_names=['sepal length (cm)', 'sepal width (cm)', - 'petal length (cm)', 'petal width (cm)'], + feature_names=feature_names, filename=iris_csv_filename) From 241be2169c979e144e523ed5408385016c6571f3 Mon Sep 17 00:00:00 2001 From: Nick Sorros Date: Tue, 17 Apr 2018 09:07:56 +0100 Subject: [PATCH 2/4] partial import of pandas when used --- sklearn/datasets/base.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/sklearn/datasets/base.py b/sklearn/datasets/base.py index d8061910b6bae..ef8dcb121d834 100644 --- a/sklearn/datasets/base.py +++ b/sklearn/datasets/base.py @@ -21,7 +21,6 @@ from ..utils import check_random_state import numpy as np -import pandas as pd from sklearn.externals.six.moves.urllib.request import urlretrieve @@ -385,10 +384,15 @@ def load_iris(return_X_y=False, as_frame=False): feature_names = ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)'] + if as_frame: - data_frame = pd.DataFrame(data, columns=feature_names) - target_series = pd.Series(target, name="class") - return data_frame, target_series + try: + from pandas import Series, DataFrame + data_frame = DataFrame(data, columns=feature_names) + target_series = Series(target, name="class") + return data_frame, target_series + except ImportError: + pass return Bunch(data=data, target=target, target_names=target_names, From 9dfe2b9c2bfc78fef90b4b338f87cfb55fdad046 Mon Sep 17 00:00:00 2001 From: Nick Sorros Date: Tue, 17 Apr 2018 09:20:07 +0100 Subject: [PATCH 3/4] Add test --- sklearn/datasets/tests/test_base.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/sklearn/datasets/tests/test_base.py b/sklearn/datasets/tests/test_base.py index 091735b986a3f..d1b298b122fac 100644 --- a/sklearn/datasets/tests/test_base.py +++ b/sklearn/datasets/tests/test_base.py @@ -29,6 +29,7 @@ from sklearn.utils.testing import assert_true from sklearn.utils.testing import assert_equal from sklearn.utils.testing import assert_raises +from sklearn.utils.testing import SkipTest DATA_HOME = tempfile.mkdtemp(prefix="scikit_learn_data_home_test_") @@ -202,6 +203,17 @@ def test_load_iris(): check_return_X_y(res, partial(load_iris)) +def test_load_iris_as_frame(): + try: + data_frame, target_series = load_iris(as_frame=True) + assert_equal(data_frame.shape, (150, 4)) + assert_equal(target_series.shape[0], 150) + except SkipTest as message: + # the only SkipTest thrown currently results from not + # being able to import pandas. + warnings.warn(message, SkipTestWarning) + + def test_load_wine(): res = load_wine() assert_equal(res.data.shape, (178, 13)) From e089ae3f5614ec3693f02f66caadcb13d6b16342 Mon Sep 17 00:00:00 2001 From: Nick Sorros Date: Tue, 17 Apr 2018 09:53:32 +0100 Subject: [PATCH 4/4] import pandas or skip test --- sklearn/datasets/base.py | 11 ++++------- sklearn/datasets/tests/test_base.py | 6 ++---- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/sklearn/datasets/base.py b/sklearn/datasets/base.py index ef8dcb121d834..d37b308a6d5d8 100644 --- a/sklearn/datasets/base.py +++ b/sklearn/datasets/base.py @@ -386,13 +386,10 @@ def load_iris(return_X_y=False, as_frame=False): 'petal length (cm)', 'petal width (cm)'] if as_frame: - try: - from pandas import Series, DataFrame - data_frame = DataFrame(data, columns=feature_names) - target_series = Series(target, name="class") - return data_frame, target_series - except ImportError: - pass + from pandas import Series, DataFrame + data_frame = DataFrame(data, columns=feature_names) + target_series = pd.Series(target, name="class") + return data_frame, target_series return Bunch(data=data, target=target, target_names=target_names, diff --git a/sklearn/datasets/tests/test_base.py b/sklearn/datasets/tests/test_base.py index d1b298b122fac..61c495c01f1de 100644 --- a/sklearn/datasets/tests/test_base.py +++ b/sklearn/datasets/tests/test_base.py @@ -208,10 +208,8 @@ def test_load_iris_as_frame(): data_frame, target_series = load_iris(as_frame=True) assert_equal(data_frame.shape, (150, 4)) assert_equal(target_series.shape[0], 150) - except SkipTest as message: - # the only SkipTest thrown currently results from not - # being able to import pandas. - warnings.warn(message, SkipTestWarning) + except IOError as : + SkipTest("Pandas is needed to run the test") def test_load_wine():