Skip to content

ENH allows to overwrite read_csv parameter in fetch_openml #25488

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,14 @@ Changelog
is now deprecated and will be removed in v1.5.
:pr:`25251` by :user:`Gleb Levitski <glevv>`.

:mod:`sklearn.datasets`
.......................

- |Enhancement| Allows to overwrite the parameters used to open the ARFF file using
the parameter `read_csv_kwargs` in :func:`datasets.fetch_openml` when using the
pandas parser.
:pr:`25488` by :user:`Guillaume Lemaitre <glemaitre>`.

:mod:`sklearn.decomposition`
............................

Expand Down
37 changes: 27 additions & 10 deletions sklearn/datasets/_arff_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def _pandas_arff_parser(
openml_columns_info,
feature_names_to_select,
target_names_to_select,
read_csv_kwargs=None,
):
"""ARFF parser using `pandas.read_csv`.

Expand Down Expand Up @@ -331,6 +332,12 @@ def _pandas_arff_parser(
target_names_to_select : list of str
A list of the target names to be selected to build `y`.

read_csv_kwargs : dict, default=None
Keyword arguments to pass to `pandas.read_csv`. It allows to overwrite
the default options.

.. versionadded:: 1.3

Returns
-------
X : {ndarray, sparse matrix, dataframe}
Expand Down Expand Up @@ -364,17 +371,19 @@ def _pandas_arff_parser(
elif column_dtype.lower() == "nominal":
dtypes[name] = "category"

default_read_csv_kwargs = {
"header": None,
"na_values": ["?"], # missing values are represented by `?`
"comment": "%", # skip line starting by `%` since they are comments
"quotechar": '"', # delimiter to use for quoted strings
"names": [name for name in openml_columns_info],
"dtype": dtypes,
"skipinitialspace": True, # skip spaces after delimiter to follow ARFF specs
}
read_csv_kwargs = {**default_read_csv_kwargs, **(read_csv_kwargs or {})}

# ARFF represents missing values with "?"
frame = pd.read_csv(
gzip_file,
header=None,
na_values=["?"], # missing values are represented by `?`
comment="%", # skip line starting by `%` since they are comments
quotechar='"', # delimiter to use for quoted strings
names=[name for name in openml_columns_info],
dtype=dtypes,
skipinitialspace=True, # skip spaces after delimiter to follow ARFF specs
)
frame = pd.read_csv(gzip_file, **read_csv_kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, if there is an exception while reading the data, one would need to enter a debugger to find out where the file is and what the read_csv_kwargs are. I think it would be helpful reraise an exception that outputs the read_csv_kwargs and gzip_file to help with debugging the issue.


columns_to_select = feature_names_to_select + target_names_to_select
columns_to_keep = [col for col in frame.columns if col in columns_to_select]
Expand Down Expand Up @@ -431,6 +440,7 @@ def load_arff_from_gzip_file(
feature_names_to_select,
target_names_to_select,
shape=None,
read_csv_kwargs=None,
):
"""Load a compressed ARFF file using a given parser.

Expand Down Expand Up @@ -461,6 +471,12 @@ def load_arff_from_gzip_file(
target_names_to_select : list of str
A list of the target names to be selected.

read_csv_kwargs : dict, default=None
Keyword arguments to pass to `pandas.read_csv`. It allows to overwrite
the default options.

.. versionadded:: 1.3

Returns
-------
X : {ndarray, sparse matrix, dataframe}
Expand Down Expand Up @@ -493,6 +509,7 @@ def load_arff_from_gzip_file(
openml_columns_info,
feature_names_to_select,
target_names_to_select,
read_csv_kwargs,
)
else:
raise ValueError(
Expand Down
33 changes: 33 additions & 0 deletions sklearn/datasets/_openml.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ def _load_arff_response(
md5_checksum: str,
n_retries: int = 3,
delay: float = 1.0,
read_csv_kwargs: Optional[Dict] = None,
):
"""Load the ARFF data associated with the OpenML URL.

Expand Down Expand Up @@ -461,6 +462,12 @@ def _load_arff_response(
md5_checksum : str
The MD5 checksum provided by OpenML to check the data integrity.

read_csv_kwargs : dict, default=None
Keyword arguments to pass to `pandas.read_csv`. It allows to overwrite
the default options.

.. versionadded:: 1.3

Returns
-------
X : {ndarray, sparse matrix, dataframe}
Expand Down Expand Up @@ -503,6 +510,7 @@ def _load_arff_response(
feature_names_to_select=feature_names_to_select,
target_names_to_select=target_names_to_select,
shape=shape,
read_csv_kwargs=read_csv_kwargs,
)

return X, y, frame, categories
Expand All @@ -522,6 +530,7 @@ def _download_data_to_bunch(
n_retries: int = 3,
delay: float = 1.0,
parser: str,
read_csv_kwargs: Optional[Dict] = None,
):
"""Download ARFF data, load it to a specific container and create to Bunch.

Expand Down Expand Up @@ -568,6 +577,21 @@ def _download_data_to_bunch(
parser : {"liac-arff", "pandas"}
The parser used to parse the ARFF file.

read_csv_kwargs : dict, default=None
Keyword arguments to pass to `pandas.read_csv`. It allows to overwrite
the default options. Internally, we used the default parameters of
:func:`pandas.read_csv` except for the following parameters:

- `header`: set to `None`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be in fetch_openml as part of the public API?

- `na_values`: set to `["?"]`
- `comment`: set to `#`
- `quotechar`: set to `"` (double quote)
- `names`: set to the list of column names given by OpenML
- `dtypes`: set to the list provided by OpenML
- `skipinitialspace`: set to `True`

.. versionadded:: 1.3

Returns
-------
data : :class:`~sklearn.utils.Bunch`
Expand Down Expand Up @@ -619,6 +643,7 @@ def _download_data_to_bunch(
md5_checksum=md5_checksum,
n_retries=n_retries,
delay=delay,
read_csv_kwargs=read_csv_kwargs,
)

return Bunch(
Expand Down Expand Up @@ -687,6 +712,7 @@ def fetch_openml(
n_retries: int = 3,
delay: float = 1.0,
parser: Optional[str] = "warn",
read_csv_kwargs: Optional[Dict] = None,
):
"""Fetch dataset from openml by name or dataset id.

Expand Down Expand Up @@ -791,6 +817,12 @@ def fetch_openml(
warning. Therefore, an `ImportError` will be raised from 1.4 if
the dataset is dense and pandas is not installed.

read_csv_kwargs : dict, default=None
Keyword arguments passed to `pandas.read_csv` when loading the data
from a ARFF file. It can allows to overwrite some default parameters.

.. versionadded:: 1.3

Returns
-------
data : :class:`~sklearn.utils.Bunch`
Expand Down Expand Up @@ -1054,6 +1086,7 @@ def fetch_openml(
n_retries=n_retries,
delay=delay,
parser=parser_,
read_csv_kwargs=read_csv_kwargs,
)

if return_X_y:
Expand Down
28 changes: 28 additions & 0 deletions sklearn/datasets/tests/test_openml.py
Original file line number Diff line number Diff line change
Expand Up @@ -1348,6 +1348,34 @@ def test_dataset_with_openml_warning(monkeypatch, gzip_response):
fetch_openml(data_id=data_id, cache=False, as_frame=False, parser="liac-arff")


def test_fetch_openml_overwrite_default_params_read_csv(monkeypatch):
"""Check that we can overwrite the default parameters of `read_csv`."""
pytest.importorskip("pandas")
data_id = 1590
_monkey_patch_webbased_functions(monkeypatch, data_id=data_id, gzip_response=False)

common_params = {
"data_id": data_id,
"as_frame": True,
"cache": False,
"parser": "pandas",
}

# By default, the initial spaces are skipped. We checked that setting the parameter
# `skipinitialspace` to False will have an effect.
adult_without_spaces = fetch_openml(**common_params)
adult_with_spaces = fetch_openml(
**common_params, read_csv_kwargs={"skipinitialspace": False}
)
assert all(
cat.startswith(" ") for cat in adult_with_spaces.frame["class"].cat.categories
)
assert not any(
cat.startswith(" ")
for cat in adult_without_spaces.frame["class"].cat.categories
)


###############################################################################
# Test cache, retry mechanisms, checksum, etc.

Expand Down