Skip to content

[MRG+1] fetch_openml should support return_X_y #11840

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Aug 19, 2018
11 changes: 10 additions & 1 deletion sklearn/datasets/openml.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def _verify_target_data_type(features_dict, target_columns):


def fetch_openml(name=None, version='active', data_id=None, data_home=None,
target_column='default-target', cache=True):
target_column='default-target', cache=True, return_X_y=False):
"""Fetch dataset from openml by name or dataset id.

Datasets are uniquely identified by either an integer ID or by a
Expand Down Expand Up @@ -395,6 +395,10 @@ def fetch_openml(name=None, version='active', data_id=None, data_home=None,
cache : boolean, default=True
Whether to cache downloaded datasets using joblib.

return_X_y : boolean, default=False.
If True, returns ``(data, target)`` instead of a Bunch object. See
below for more information about the `data` and `target` objects.

Returns
-------

Expand All @@ -416,6 +420,8 @@ def fetch_openml(name=None, version='active', data_id=None, data_home=None,
details : dict
More metadata from OpenML

(data, target) : tuple if ``return_X_y`` is True

.. note:: EXPERIMENTAL

This interface is **experimental** as at version 0.20 and
Expand Down Expand Up @@ -557,6 +563,9 @@ def fetch_openml(name=None, version='active', data_id=None, data_home=None,
elif y.shape[1] == 0:
y = None

if return_X_y:
return X, y

bunch = Bunch(
data=X, target=y, feature_names=data_columns,
DESCR=description, details=data_description,
Expand Down
7 changes: 7 additions & 0 deletions sklearn/datasets/tests/test_openml.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
assert_raise_message)
from sklearn.externals.six import string_types
from sklearn.externals.six.moves.urllib.error import HTTPError
from sklearn.datasets.tests.test_common import check_return_X_y
from functools import partial


currdir = os.path.dirname(os.path.abspath(__file__))
Expand Down Expand Up @@ -124,6 +126,11 @@ def _fetch_dataset_from_openml(data_id, data_name, data_version,
# np.isnan doesn't work on CSR matrix
assert (np.count_nonzero(np.isnan(data_by_id.data)) ==
expected_missing)

# test return_X_y option
fetch_func = partial(fetch_openml, data_id=data_id, cache=False,
target_column=target_column)
check_return_X_y(data_by_id, fetch_func)
return data_by_id


Expand Down