Skip to content

MAINT Parameters validation for sklearn.model_selection.cross_validate #26129

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 34 additions & 6 deletions sklearn/model_selection/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import numbers
import time
from functools import partial
from numbers import Real
from traceback import format_exc
from contextlib import suppress
from collections import Counter
Expand All @@ -29,7 +30,14 @@
from ..utils.validation import _num_samples
from ..utils.parallel import delayed, Parallel
from ..utils.metaestimators import _safe_split
from ..utils._param_validation import (
HasMethods,
Integral,
StrOptions,
validate_params,
)
from ..metrics import check_scoring
from ..metrics import get_scorer_names
from ..metrics._scorer import _check_multimetric_scoring, _MultimetricScorer
from ..exceptions import FitFailedWarning
from ._split import check_cv
Expand All @@ -46,6 +54,31 @@
]


@validate_params(
{
"estimator": [HasMethods("fit")],
"X": ["array-like", "sparse matrix"],
"y": ["array-like", None],
"groups": ["array-like", None],
"scoring": [
StrOptions(set(get_scorer_names())),
callable,
list,
tuple,
dict,
None,
],
"cv": ["cv_object"],
"n_jobs": [Integral, None],
"verbose": ["verbose"],
"fit_params": [dict, None],
"pre_dispatch": [Integral, str],
"return_train_score": ["boolean"],
"return_estimator": ["boolean"],
"return_indices": ["boolean"],
"error_score": [StrOptions({"raise"}), Real],
}
)
def cross_validate(
estimator,
X,
Expand All @@ -72,7 +105,7 @@ def cross_validate(
estimator : estimator object implementing 'fit'
The object to use to fit the data.

X : array-like of shape (n_samples, n_features)
X : {array-like, sparse matrix} of shape (n_samples, n_features)
The data to fit. Can be for example a list, or an array.

y : array-like of shape (n_samples,) or (n_samples, n_outputs), default=None
Expand Down Expand Up @@ -141,11 +174,6 @@ def cross_validate(
explosion of memory consumption when more jobs get dispatched
than CPUs can process. This parameter can be:

- None, in which case all the jobs are immediately
created and spawned. Use this for lightweight and
fast-running jobs, to avoid delays due to on-demand
spawning of the jobs

- An int, giving the exact number of total jobs that are
spawned

Expand Down
37 changes: 21 additions & 16 deletions sklearn/model_selection/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pytest
import numpy as np
from scipy.sparse import coo_matrix, csr_matrix
from scipy.sparse import issparse
from sklearn.exceptions import FitFailedWarning

from sklearn.model_selection.tests.test_search import FailingClassifier
Expand Down Expand Up @@ -354,18 +355,10 @@ def test_cross_validate_invalid_scoring_param():
with pytest.raises(ValueError, match=error_message_regexp):
cross_validate(estimator, X, y, scoring=[[make_scorer(precision_score)]])

error_message_regexp = (
".*scoring is invalid.*Refer to the scoring glossary for details:.*"
)

# Empty dict should raise invalid scoring error
with pytest.raises(ValueError, match="An empty dict"):
cross_validate(estimator, X, y, scoring=(dict()))

# And so should any other invalid entry
with pytest.raises(ValueError, match=error_message_regexp):
cross_validate(estimator, X, y, scoring=5)

multiclass_scorer = make_scorer(precision_recall_fscore_support)

# Multiclass Scorers that return multiple values are not supported yet
Expand All @@ -382,9 +375,6 @@ def test_cross_validate_invalid_scoring_param():
with pytest.warns(UserWarning, match=warning_message):
cross_validate(estimator, X, y, scoring={"foo": multiclass_scorer})

with pytest.raises(ValueError, match="'mse' is not a valid scoring value."):
cross_validate(SVC(), X, y, scoring="mse")


def test_cross_validate_nested_estimator():
# Non-regression test to ensure that nested
Expand All @@ -405,7 +395,8 @@ def test_cross_validate_nested_estimator():
assert all(isinstance(estimator, Pipeline) for estimator in estimators)


def test_cross_validate():
@pytest.mark.parametrize("use_sparse", [False, True])
def test_cross_validate(use_sparse: bool):
# Compute train and test mse/r2 scores
cv = KFold()

Expand All @@ -417,6 +408,10 @@ def test_cross_validate():
X_clf, y_clf = make_classification(n_samples=30, random_state=0)
clf = SVC(kernel="linear", random_state=0)

if use_sparse:
X_reg = csr_matrix(X_reg)
X_clf = csr_matrix(X_clf)

for X, y, est in ((X_reg, y_reg, reg), (X_clf, y_clf, clf)):
# It's okay to evaluate regression metrics on classification too
mse_scorer = check_scoring(est, scoring="neg_mean_squared_error")
Expand Down Expand Up @@ -510,7 +505,15 @@ def check_cross_validate_single_metric(clf, X, y, scores, cv):
clf, X, y, scoring="neg_mean_squared_error", return_estimator=True, cv=cv
)
for k, est in enumerate(mse_scores_dict["estimator"]):
assert_almost_equal(est.coef_, fitted_estimators[k].coef_)
est_coef = est.coef_.copy()
if issparse(est_coef):
est_coef = est_coef.toarray()

fitted_est_coef = fitted_estimators[k].coef_.copy()
if issparse(fitted_est_coef):
fitted_est_coef = fitted_est_coef.toarray()

assert_almost_equal(est_coef, fitted_est_coef)
assert_almost_equal(est.intercept_, fitted_estimators[k].intercept_)


Expand Down Expand Up @@ -2104,10 +2107,12 @@ def test_fit_and_score_failing():
"error_score must be the string 'raise' or a numeric value. (Hint: if "
"using 'raise', please make sure that it has been spelled correctly.)"
)
with pytest.raises(ValueError, match=error_message):
cross_validate(failing_clf, X, cv=3, error_score="unvalid-string")

with pytest.raises(ValueError, match=error_message):
error_message_cross_validate = (
"The 'error_score' parameter of cross_validate must be .*. Got .* instead."
)

with pytest.raises(ValueError, match=error_message_cross_validate):
cross_val_score(failing_clf, X, cv=3, error_score="unvalid-string")

with pytest.raises(ValueError, match=error_message):
Expand Down
1 change: 1 addition & 0 deletions sklearn/tests/test_public_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def _check_function_param_validation(
"sklearn.metrics.roc_curve",
"sklearn.metrics.top_k_accuracy_score",
"sklearn.metrics.zero_one_loss",
"sklearn.model_selection.cross_validate",
"sklearn.model_selection.train_test_split",
"sklearn.neighbors.sort_graph_by_row_values",
"sklearn.preprocessing.add_dummy_feature",
Expand Down