diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index a28c69d0f7cc5..0c7f07929e370 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -345,10 +345,6 @@ def test_check_inputs(): clf = GradientBoostingClassifier(n_estimators=100, random_state=1) assert_raises(ValueError, clf.fit, X, y + [0, 1]) - clf = GradientBoostingClassifier(n_estimators=100, random_state=1) - assert_raises(ValueError, clf.fit, X, y, - sample_weight=([1] * len(y)) + [0, 1]) - weight = [0, 0, 0, 1, 1, 1] clf = GradientBoostingClassifier(n_estimators=100, random_state=1) msg = ("y contains 1 class after sample_weight trimmed classes with " diff --git a/sklearn/neighbors/tests/test_kde.py b/sklearn/neighbors/tests/test_kde.py index 69aca8e8f75b8..6687cfa475ce8 100644 --- a/sklearn/neighbors/tests/test_kde.py +++ b/sklearn/neighbors/tests/test_kde.py @@ -210,10 +210,6 @@ def test_sample_weight_invalid(): kde = KernelDensity() data = np.reshape([1., 2., 3.], (-1, 1)) - sample_weight = [0.1, 0.2] - with pytest.raises(ValueError): - kde.fit(data, sample_weight=sample_weight) - sample_weight = [0.1, -0.2, 0.3] expected_err = "sample_weight must have positive values" with pytest.raises(ValueError, match=expected_err): diff --git a/sklearn/svm/tests/test_svm.py b/sklearn/svm/tests/test_svm.py index 191420d1d7147..fb811940c2971 100644 --- a/sklearn/svm/tests/test_svm.py +++ b/sklearn/svm/tests/test_svm.py @@ -639,11 +639,6 @@ def test_bad_input(): with pytest.raises(ValueError): clf.fit(X, Y) - # sample_weight bad dimensions - clf = svm.SVC() - with pytest.raises(ValueError): - clf.fit(X, Y, sample_weight=range(len(X) - 1)) - # predict with sparse input when trained with dense clf = svm.SVC().fit(X, Y) with pytest.raises(ValueError): diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 572bf8d01d57c..1149ceb8678d9 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -1125,14 +1125,6 @@ def test_sample_weight_invalid(): with pytest.raises(TypeError, match=expected_err): clf.fit(X, y, sample_weight=sample_weight) - sample_weight = np.ones(101) - with pytest.raises(ValueError): - clf.fit(X, y, sample_weight=sample_weight) - - sample_weight = np.ones(99) - with pytest.raises(ValueError): - clf.fit(X, y, sample_weight=sample_weight) - def check_class_weights(name): """Check class_weights resemble sample_weights behavior.""" diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index ef9a4b1ca17f6..1e86f68d4ca3c 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -80,6 +80,7 @@ def _yield_checks(name, estimator): yield check_sample_weights_pandas_series yield check_sample_weights_not_an_array yield check_sample_weights_list + yield check_sample_weights_shape yield check_sample_weights_invariance yield check_estimators_fit_returns_self yield partial(check_estimators_fit_returns_self, readonly_memmap=True) @@ -764,6 +765,31 @@ def check_sample_weights_list(name, estimator_orig): estimator.fit(X, y, sample_weight=sample_weight) +@ignore_warnings(category=FutureWarning) +def check_sample_weights_shape(name, estimator_orig): + # check that estimators raise an error if sample_weight + # shape mismatches the input + if (has_fit_parameter(estimator_orig, "sample_weight") and + not (hasattr(estimator_orig, "_pairwise") + and estimator_orig._pairwise)): + estimator = clone(estimator_orig) + X = np.array([[1, 3], [1, 3], [1, 3], [1, 3], + [2, 1], [2, 1], [2, 1], [2, 1], + [3, 3], [3, 3], [3, 3], [3, 3], + [4, 1], [4, 1], [4, 1], [4, 1]]) + y = np.array([1, 1, 1, 1, 2, 2, 2, 2, + 1, 1, 1, 1, 2, 2, 2, 2]) + y = _enforce_estimator_tags_y(estimator, y) + + estimator.fit(X, y, sample_weight=np.ones(len(y))) + + assert_raises(ValueError, estimator.fit, X, y, + sample_weight=np.ones(2*len(y))) + + assert_raises(ValueError, estimator.fit, X, y, + sample_weight=np.ones((len(y), 2))) + + @ignore_warnings(category=FutureWarning) def check_sample_weights_invariance(name, estimator_orig): # check that the estimators yield same results for