diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index e4e965aca02d8..8c0d709aee8cc 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -104,6 +104,14 @@ Changelog is now deprecated and will be removed in v1.5. :pr:`25251` by :user:`Gleb Levitski `. +:mod:`sklearn.datasets` +....................... + +- |Enhancement| Allows to overwrite the parameters used to open the ARFF file using + the parameter `read_csv_kwargs` in :func:`datasets.fetch_openml` when using the + pandas parser. + :pr:`25488` by :user:`Guillaume Lemaitre `. + :mod:`sklearn.decomposition` ............................ diff --git a/sklearn/datasets/_arff_parser.py b/sklearn/datasets/_arff_parser.py index a624cbe367a97..b58f2e8ab64fa 100644 --- a/sklearn/datasets/_arff_parser.py +++ b/sklearn/datasets/_arff_parser.py @@ -302,6 +302,7 @@ def _pandas_arff_parser( openml_columns_info, feature_names_to_select, target_names_to_select, + read_csv_kwargs=None, ): """ARFF parser using `pandas.read_csv`. @@ -331,6 +332,12 @@ def _pandas_arff_parser( target_names_to_select : list of str A list of the target names to be selected to build `y`. + read_csv_kwargs : dict, default=None + Keyword arguments to pass to `pandas.read_csv`. It allows to overwrite + the default options. + + .. versionadded:: 1.3 + Returns ------- X : {ndarray, sparse matrix, dataframe} @@ -364,17 +371,19 @@ def _pandas_arff_parser( elif column_dtype.lower() == "nominal": dtypes[name] = "category" + default_read_csv_kwargs = { + "header": None, + "na_values": ["?"], # missing values are represented by `?` + "comment": "%", # skip line starting by `%` since they are comments + "quotechar": '"', # delimiter to use for quoted strings + "names": [name for name in openml_columns_info], + "dtype": dtypes, + "skipinitialspace": True, # skip spaces after delimiter to follow ARFF specs + } + read_csv_kwargs = {**default_read_csv_kwargs, **(read_csv_kwargs or {})} + # ARFF represents missing values with "?" - frame = pd.read_csv( - gzip_file, - header=None, - na_values=["?"], # missing values are represented by `?` - comment="%", # skip line starting by `%` since they are comments - quotechar='"', # delimiter to use for quoted strings - names=[name for name in openml_columns_info], - dtype=dtypes, - skipinitialspace=True, # skip spaces after delimiter to follow ARFF specs - ) + frame = pd.read_csv(gzip_file, **read_csv_kwargs) columns_to_select = feature_names_to_select + target_names_to_select columns_to_keep = [col for col in frame.columns if col in columns_to_select] @@ -431,6 +440,7 @@ def load_arff_from_gzip_file( feature_names_to_select, target_names_to_select, shape=None, + read_csv_kwargs=None, ): """Load a compressed ARFF file using a given parser. @@ -461,6 +471,12 @@ def load_arff_from_gzip_file( target_names_to_select : list of str A list of the target names to be selected. + read_csv_kwargs : dict, default=None + Keyword arguments to pass to `pandas.read_csv`. It allows to overwrite + the default options. + + .. versionadded:: 1.3 + Returns ------- X : {ndarray, sparse matrix, dataframe} @@ -493,6 +509,7 @@ def load_arff_from_gzip_file( openml_columns_info, feature_names_to_select, target_names_to_select, + read_csv_kwargs, ) else: raise ValueError( diff --git a/sklearn/datasets/_openml.py b/sklearn/datasets/_openml.py index be85e72d822b0..61907a910d5d9 100644 --- a/sklearn/datasets/_openml.py +++ b/sklearn/datasets/_openml.py @@ -419,6 +419,7 @@ def _load_arff_response( md5_checksum: str, n_retries: int = 3, delay: float = 1.0, + read_csv_kwargs: Optional[Dict] = None, ): """Load the ARFF data associated with the OpenML URL. @@ -461,6 +462,12 @@ def _load_arff_response( md5_checksum : str The MD5 checksum provided by OpenML to check the data integrity. + read_csv_kwargs : dict, default=None + Keyword arguments to pass to `pandas.read_csv`. It allows to overwrite + the default options. + + .. versionadded:: 1.3 + Returns ------- X : {ndarray, sparse matrix, dataframe} @@ -503,6 +510,7 @@ def _load_arff_response( feature_names_to_select=feature_names_to_select, target_names_to_select=target_names_to_select, shape=shape, + read_csv_kwargs=read_csv_kwargs, ) return X, y, frame, categories @@ -522,6 +530,7 @@ def _download_data_to_bunch( n_retries: int = 3, delay: float = 1.0, parser: str, + read_csv_kwargs: Optional[Dict] = None, ): """Download ARFF data, load it to a specific container and create to Bunch. @@ -568,6 +577,21 @@ def _download_data_to_bunch( parser : {"liac-arff", "pandas"} The parser used to parse the ARFF file. + read_csv_kwargs : dict, default=None + Keyword arguments to pass to `pandas.read_csv`. It allows to overwrite + the default options. Internally, we used the default parameters of + :func:`pandas.read_csv` except for the following parameters: + + - `header`: set to `None` + - `na_values`: set to `["?"]` + - `comment`: set to `#` + - `quotechar`: set to `"` (double quote) + - `names`: set to the list of column names given by OpenML + - `dtypes`: set to the list provided by OpenML + - `skipinitialspace`: set to `True` + + .. versionadded:: 1.3 + Returns ------- data : :class:`~sklearn.utils.Bunch` @@ -619,6 +643,7 @@ def _download_data_to_bunch( md5_checksum=md5_checksum, n_retries=n_retries, delay=delay, + read_csv_kwargs=read_csv_kwargs, ) return Bunch( @@ -687,6 +712,7 @@ def fetch_openml( n_retries: int = 3, delay: float = 1.0, parser: Optional[str] = "warn", + read_csv_kwargs: Optional[Dict] = None, ): """Fetch dataset from openml by name or dataset id. @@ -791,6 +817,12 @@ def fetch_openml( warning. Therefore, an `ImportError` will be raised from 1.4 if the dataset is dense and pandas is not installed. + read_csv_kwargs : dict, default=None + Keyword arguments passed to `pandas.read_csv` when loading the data + from a ARFF file. It can allows to overwrite some default parameters. + + .. versionadded:: 1.3 + Returns ------- data : :class:`~sklearn.utils.Bunch` @@ -1054,6 +1086,7 @@ def fetch_openml( n_retries=n_retries, delay=delay, parser=parser_, + read_csv_kwargs=read_csv_kwargs, ) if return_X_y: diff --git a/sklearn/datasets/tests/test_openml.py b/sklearn/datasets/tests/test_openml.py index 73a94da6603c6..acf57aec88fd5 100644 --- a/sklearn/datasets/tests/test_openml.py +++ b/sklearn/datasets/tests/test_openml.py @@ -1348,6 +1348,34 @@ def test_dataset_with_openml_warning(monkeypatch, gzip_response): fetch_openml(data_id=data_id, cache=False, as_frame=False, parser="liac-arff") +def test_fetch_openml_overwrite_default_params_read_csv(monkeypatch): + """Check that we can overwrite the default parameters of `read_csv`.""" + pytest.importorskip("pandas") + data_id = 1590 + _monkey_patch_webbased_functions(monkeypatch, data_id=data_id, gzip_response=False) + + common_params = { + "data_id": data_id, + "as_frame": True, + "cache": False, + "parser": "pandas", + } + + # By default, the initial spaces are skipped. We checked that setting the parameter + # `skipinitialspace` to False will have an effect. + adult_without_spaces = fetch_openml(**common_params) + adult_with_spaces = fetch_openml( + **common_params, read_csv_kwargs={"skipinitialspace": False} + ) + assert all( + cat.startswith(" ") for cat in adult_with_spaces.frame["class"].cat.categories + ) + assert not any( + cat.startswith(" ") + for cat in adult_without_spaces.frame["class"].cat.categories + ) + + ############################################################################### # Test cache, retry mechanisms, checksum, etc.