Skip to content

[MRG+2] Add a test for sample weights for estimators #11558

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
35 changes: 35 additions & 0 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def _yield_non_meta_checks(name, estimator):
yield check_dtype_object
yield check_sample_weights_pandas_series
yield check_sample_weights_list
yield check_sample_weights_invariance
yield check_estimators_fit_returns_self
yield partial(check_estimators_fit_returns_self, readonly_memmap=True)
yield check_complex_data
Expand Down Expand Up @@ -554,6 +555,40 @@ def check_sample_weights_list(name, estimator_orig):
estimator.fit(X, y, sample_weight=sample_weight)


@ignore_warnings(category=(DeprecationWarning, FutureWarning))
def check_sample_weights_invariance(name, estimator_orig):
# check that the estimators yield same results for
# unit weights and no weights
if (has_fit_parameter(estimator_orig, "sample_weight") and
not (hasattr(estimator_orig, "_pairwise")
and estimator_orig._pairwise)):
# We skip pairwise because the data is not pairwise

estimator1 = clone(estimator_orig)
estimator2 = clone(estimator_orig)
set_random_state(estimator1, random_state=0)
set_random_state(estimator2, random_state=0)

X = np.array([[1, 3], [1, 3], [1, 3], [1, 3],
[2, 1], [2, 1], [2, 1], [2, 1],
[3, 3], [3, 3], [3, 3], [3, 3],
[4, 1], [4, 1], [4, 1], [4, 1]], dtype=np.dtype('float'))
y = np.array([1, 1, 1, 1, 2, 2, 2, 2,
1, 1, 1, 1, 2, 2, 2, 2], dtype=np.dtype('int'))

estimator1.fit(X, y=y, sample_weight=np.ones(shape=len(y)))
estimator2.fit(X, y=y, sample_weight=None)

for method in ["predict", "transform"]:
if hasattr(estimator_orig, method):
X_pred1 = getattr(estimator1, method)(X)
X_pred2 = getattr(estimator2, method)(X)
assert_allclose(X_pred1, X_pred2, rtol=0.5,
err_msg="For %s sample_weight=None is not"
" equivalent to sample_weight=ones"
% name)


@ignore_warnings(category=(DeprecationWarning, FutureWarning, UserWarning))
def check_dtype_object(name, estimator_orig):
# check that estimators treat dtype object as numeric if possible
Expand Down