diff --git a/sklearn/datasets/_base.py b/sklearn/datasets/_base.py index bcd163caf738d..9ec4fd2172c17 100644 --- a/sklearn/datasets/_base.py +++ b/sklearn/datasets/_base.py @@ -68,12 +68,17 @@ def get_data_home(data_home=None) -> str: return data_home +@validate_params( + { + "data_home": [str, os.PathLike, None], + } +) def clear_data_home(data_home=None): """Delete all the content of the data home cache. Parameters ---------- - data_home : str, default=None + data_home : str or path-like, default=None The path to scikit-learn data directory. If `None`, the default path is `~/sklearn_learn_data`. """ diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 2bb6846dc4cbf..619180ea454c3 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -119,6 +119,7 @@ def _check_function_param_validation( "sklearn.covariance.empirical_covariance", "sklearn.covariance.ledoit_wolf_shrinkage", "sklearn.covariance.shrunk_covariance", + "sklearn.datasets.clear_data_home", "sklearn.datasets.dump_svmlight_file", "sklearn.datasets.fetch_20newsgroups", "sklearn.datasets.fetch_20newsgroups_vectorized",