Skip to content

TST check sample_weight shape added to common tests #11598

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
4 changes: 0 additions & 4 deletions sklearn/ensemble/tests/test_gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,10 +345,6 @@ def test_check_inputs():
clf = GradientBoostingClassifier(n_estimators=100, random_state=1)
assert_raises(ValueError, clf.fit, X, y + [0, 1])

clf = GradientBoostingClassifier(n_estimators=100, random_state=1)
assert_raises(ValueError, clf.fit, X, y,
sample_weight=([1] * len(y)) + [0, 1])

weight = [0, 0, 0, 1, 1, 1]
clf = GradientBoostingClassifier(n_estimators=100, random_state=1)
msg = ("y contains 1 class after sample_weight trimmed classes with "
Expand Down
4 changes: 0 additions & 4 deletions sklearn/neighbors/tests/test_kde.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,6 @@ def test_sample_weight_invalid():
kde = KernelDensity()
data = np.reshape([1., 2., 3.], (-1, 1))

sample_weight = [0.1, 0.2]
with pytest.raises(ValueError):
kde.fit(data, sample_weight=sample_weight)

sample_weight = [0.1, -0.2, 0.3]
expected_err = "sample_weight must have positive values"
with pytest.raises(ValueError, match=expected_err):
Expand Down
5 changes: 0 additions & 5 deletions sklearn/svm/tests/test_svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,11 +639,6 @@ def test_bad_input():
with pytest.raises(ValueError):
clf.fit(X, Y)

# sample_weight bad dimensions
clf = svm.SVC()
with pytest.raises(ValueError):
clf.fit(X, Y, sample_weight=range(len(X) - 1))

# predict with sparse input when trained with dense
clf = svm.SVC().fit(X, Y)
with pytest.raises(ValueError):
Expand Down
8 changes: 0 additions & 8 deletions sklearn/tree/tests/test_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,14 +1125,6 @@ def test_sample_weight_invalid():
with pytest.raises(TypeError, match=expected_err):
clf.fit(X, y, sample_weight=sample_weight)

sample_weight = np.ones(101)
with pytest.raises(ValueError):
clf.fit(X, y, sample_weight=sample_weight)

sample_weight = np.ones(99)
with pytest.raises(ValueError):
clf.fit(X, y, sample_weight=sample_weight)


def check_class_weights(name):
"""Check class_weights resemble sample_weights behavior."""
Expand Down
26 changes: 26 additions & 0 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def _yield_checks(name, estimator):
yield check_sample_weights_pandas_series
yield check_sample_weights_not_an_array
yield check_sample_weights_list
yield check_sample_weights_shape
yield check_sample_weights_invariance
yield check_estimators_fit_returns_self
yield partial(check_estimators_fit_returns_self, readonly_memmap=True)
Expand Down Expand Up @@ -764,6 +765,31 @@ def check_sample_weights_list(name, estimator_orig):
estimator.fit(X, y, sample_weight=sample_weight)


@ignore_warnings(category=FutureWarning)
def check_sample_weights_shape(name, estimator_orig):
# check that estimators raise an error if sample_weight
# shape mismatches the input
if (has_fit_parameter(estimator_orig, "sample_weight") and
not (hasattr(estimator_orig, "_pairwise")
and estimator_orig._pairwise)):
estimator = clone(estimator_orig)
X = np.array([[1, 3], [1, 3], [1, 3], [1, 3],
[2, 1], [2, 1], [2, 1], [2, 1],
[3, 3], [3, 3], [3, 3], [3, 3],
[4, 1], [4, 1], [4, 1], [4, 1]])
y = np.array([1, 1, 1, 1, 2, 2, 2, 2,
1, 1, 1, 1, 2, 2, 2, 2])
y = _enforce_estimator_tags_y(estimator, y)

estimator.fit(X, y, sample_weight=np.ones(len(y)))

assert_raises(ValueError, estimator.fit, X, y,
sample_weight=np.ones(2*len(y)))

assert_raises(ValueError, estimator.fit, X, y,
sample_weight=np.ones((len(y), 2)))


@ignore_warnings(category=FutureWarning)
def check_sample_weights_invariance(name, estimator_orig):
# check that the estimators yield same results for
Expand Down