diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 3633479672cde..1db6031e8d702 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -88,6 +88,7 @@ def _yield_non_meta_checks(name, estimator): yield check_dtype_object yield check_sample_weights_pandas_series yield check_sample_weights_list + yield check_sample_weights_invariance yield check_estimators_fit_returns_self yield partial(check_estimators_fit_returns_self, readonly_memmap=True) yield check_complex_data @@ -554,6 +555,40 @@ def check_sample_weights_list(name, estimator_orig): estimator.fit(X, y, sample_weight=sample_weight) +@ignore_warnings(category=(DeprecationWarning, FutureWarning)) +def check_sample_weights_invariance(name, estimator_orig): + # check that the estimators yield same results for + # unit weights and no weights + if (has_fit_parameter(estimator_orig, "sample_weight") and + not (hasattr(estimator_orig, "_pairwise") + and estimator_orig._pairwise)): + # We skip pairwise because the data is not pairwise + + estimator1 = clone(estimator_orig) + estimator2 = clone(estimator_orig) + set_random_state(estimator1, random_state=0) + set_random_state(estimator2, random_state=0) + + X = np.array([[1, 3], [1, 3], [1, 3], [1, 3], + [2, 1], [2, 1], [2, 1], [2, 1], + [3, 3], [3, 3], [3, 3], [3, 3], + [4, 1], [4, 1], [4, 1], [4, 1]], dtype=np.dtype('float')) + y = np.array([1, 1, 1, 1, 2, 2, 2, 2, + 1, 1, 1, 1, 2, 2, 2, 2], dtype=np.dtype('int')) + + estimator1.fit(X, y=y, sample_weight=np.ones(shape=len(y))) + estimator2.fit(X, y=y, sample_weight=None) + + for method in ["predict", "transform"]: + if hasattr(estimator_orig, method): + X_pred1 = getattr(estimator1, method)(X) + X_pred2 = getattr(estimator2, method)(X) + assert_allclose(X_pred1, X_pred2, rtol=0.5, + err_msg="For %s sample_weight=None is not" + " equivalent to sample_weight=ones" + % name) + + @ignore_warnings(category=(DeprecationWarning, FutureWarning, UserWarning)) def check_dtype_object(name, estimator_orig): # check that estimators treat dtype object as numeric if possible