diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 249cb022f8e87..1044eb666e39e 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -78,6 +78,7 @@ def _yield_checks(name, estimator): yield check_sample_weights_pandas_series yield check_sample_weights_list yield check_sample_weights_invariance + yield check_sample_weights_equivalence_sampling yield check_estimators_fit_returns_self yield partial(check_estimators_fit_returns_self, readonly_memmap=True) @@ -631,6 +632,43 @@ def check_sample_weights_invariance(name, estimator_orig): % name) +@ignore_warnings(category=(DeprecationWarning, FutureWarning)) +def check_sample_weights_equivalence_sampling(name, estimator_orig): + # check that the estimators yield same results for + # over-sample dataset by indice filtering and using sample_weight + if (has_fit_parameter(estimator_orig, "sample_weight") and + not (hasattr(estimator_orig, "_pairwise") + and estimator_orig._pairwise)): + # We skip pairwise because the data is not pairwise + + estimator1 = clone(estimator_orig) + estimator2 = clone(estimator_orig) + set_random_state(estimator1, random_state=0) + set_random_state(estimator2, random_state=0) + + if is_classifier(estimator1): + X, y = load_iris(return_X_y=True) + else: + X, y = load_boston(return_X_y=True) + y = enforce_estimator_tags_y(estimator1, y) + + step = 2 + indices = np.arange(start=0, stop=y.size, step=step) + sample_weight = np.zeros((y.size,)) + sample_weight[::step] = 1. + + estimator1.fit(X, y=y, sample_weight=sample_weight) + estimator2.fit(X[indices], y[indices]) + + err_msg = ("For {} does not yield to the same results when given " + "sample_weight and an up-sampled dataset") + for method in ["predict", "transform"]: + if hasattr(estimator_orig, method): + X_pred1 = getattr(estimator1, method)(X) + X_pred2 = getattr(estimator2, method)(X) + assert_allclose_dense_sparse(X_pred1, X_pred2, err_msg=err_msg) + + @ignore_warnings(category=(DeprecationWarning, FutureWarning, UserWarning)) def check_dtype_object(name, estimator_orig): # check that estimators treat dtype object as numeric if possible