Skip to content

Commit bc3a19d

Browse files
MAINT Parameters validation for sklearn.datasets.load_files (#26203)
Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>
1 parent 18a4576 commit bc3a19d

File tree

3 files changed

+18
-3
lines changed

3 files changed

+18
-3
lines changed

sklearn/datasets/_base.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from ..utils import check_random_state
2323
from ..utils import check_pandas_support
2424
from ..utils.fixes import _open_binary, _open_text, _read_text, _contents
25-
from ..utils._param_validation import validate_params, Interval
25+
from ..utils._param_validation import validate_params, Interval, StrOptions
2626

2727
import numpy as np
2828

@@ -104,6 +104,19 @@ def _convert_data_dataframe(
104104
return combined_df, X, y
105105

106106

107+
@validate_params(
108+
{
109+
"container_path": [str, os.PathLike],
110+
"description": [str, None],
111+
"categories": [list, None],
112+
"load_content": ["boolean"],
113+
"shuffle": ["boolean"],
114+
"encoding": [str, None],
115+
"decode_error": [StrOptions({"strict", "ignore", "replace"})],
116+
"random_state": ["random_state"],
117+
"allowed_extensions": [list, None],
118+
}
119+
)
107120
def load_files(
108121
container_path,
109122
*,

sklearn/datasets/tests/test_base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,11 @@ def test_default_load_files(test_category_dir_1, test_category_dir_2, load_files
9898
def test_load_files_w_categories_desc_and_encoding(
9999
test_category_dir_1, test_category_dir_2, load_files_root
100100
):
101-
category = os.path.abspath(test_category_dir_1).split("/").pop()
101+
category = os.path.abspath(test_category_dir_1).split(os.sep).pop()
102102
res = load_files(
103-
load_files_root, description="test", categories=category, encoding="utf-8"
103+
load_files_root, description="test", categories=[category], encoding="utf-8"
104104
)
105+
105106
assert len(res.filenames) == 1
106107
assert len(res.target_names) == 1
107108
assert res.DESCR == "test"

sklearn/tests/test_public_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def _check_function_param_validation(
134134
"sklearn.datasets.load_breast_cancer",
135135
"sklearn.datasets.load_diabetes",
136136
"sklearn.datasets.load_digits",
137+
"sklearn.datasets.load_files",
137138
"sklearn.datasets.load_iris",
138139
"sklearn.datasets.load_linnerud",
139140
"sklearn.datasets.load_svmlight_file",

0 commit comments

Comments
 (0)