Skip to content

Refactor check_sample_weights_invariance into a more general repetition/reweighting equivalence check #29818

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 42 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
1e8958d
split test in two
antoinebaker Sep 9, 2024
e0c2238
remove deprecated comment
antoinebaker Sep 9, 2024
f29a5a9
Merge remote-tracking branch 'upstream/main' into split_check_sample_…
antoinebaker Sep 9, 2024
a6b13e2
more difficult X, y
antoinebaker Sep 10, 2024
ab888b6
Merge branch 'main' into split_check_sample_weight
antoinebaker Sep 11, 2024
d6449e3
edit xfail_checks tags
antoinebaker Sep 11, 2024
7731542
fix test decision tree
antoinebaker Sep 11, 2024
acedf82
Apply suggestions from code review
antoinebaker Sep 16, 2024
7740f11
fix dtype X
antoinebaker Sep 16, 2024
bfc51dc
add xfails tags
antoinebaker Sep 16, 2024
24fb0b7
Merge branch 'main' into split_check_sample_weight
antoinebaker Sep 16, 2024
7f02c5e
use PER_ESTIMATOR_CHECK_PARAMS
antoinebaker Sep 16, 2024
63ebfd7
changelog
antoinebaker Sep 16, 2024
d8a92c5
better error message
antoinebaker Sep 17, 2024
0f000b2
remove deprecated test
antoinebaker Sep 17, 2024
6b72397
add more tests
antoinebaker Sep 18, 2024
3fa7a9f
change xfail message
antoinebaker Sep 18, 2024
d8a6a74
edit xfail tags
antoinebaker Sep 18, 2024
962ef5f
remove check_unit_sample_weights
antoinebaker Sep 18, 2024
ce872a9
fewer xpassing tests
antoinebaker Sep 18, 2024
e91eb05
fix typo changelog
antoinebaker Sep 18, 2024
e535df2
Merge remote-tracking branch 'upstream/main' into split_check_sample_…
antoinebaker Sep 18, 2024
94df7e1
remove deprecated QuantileRegressor
antoinebaker Sep 19, 2024
e9e4ce6
remove internal checks
antoinebaker Sep 19, 2024
3178074
Improve sensitivity of check_sample_weights_invariance by increasing …
ogrisel Sep 20, 2024
d587015
Update estimator tags to xfail the newly discovered bugs
ogrisel Sep 20, 2024
5caf1f0
multiclass y
antoinebaker Sep 20, 2024
6e3d414
rename test
antoinebaker Sep 20, 2024
daa812e
Merge branch 'main' into split_check_sample_weight
antoinebaker Sep 20, 2024
d525b68
skip xfails in check_estimator
antoinebaker Sep 23, 2024
a428271
fix GLM newton-cholesky solver
antoinebaker Sep 24, 2024
dafe9b1
Merge branch 'main' into split_check_sample_weight
antoinebaker Sep 24, 2024
5e23c83
make hessian_warning sample weight aware
antoinebaker Sep 27, 2024
23e1d68
fix multinomial loss
antoinebaker Sep 27, 2024
984cf8c
use n_classes
antoinebaker Sep 27, 2024
e46ab2e
Update sklearn/linear_model/_glm/glm.py
antoinebaker Sep 27, 2024
1578e07
changelog
antoinebaker Sep 27, 2024
6bac927
typos changelog
antoinebaker Sep 27, 2024
4535e50
changelog
antoinebaker Sep 30, 2024
b3aefc0
Update sklearn/linear_model/tests/test_ridge.py
antoinebaker Oct 11, 2024
0b3f051
Apply suggestions from code review
antoinebaker Oct 11, 2024
49f991b
Apply suggestions from code review
antoinebaker Oct 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion doc/whats_new/v1.6.rst
Original file line number Diff line number Diff line change
Expand Up @@ -277,9 +277,15 @@ Changelog

- |Fix| :class:`linear_model.LassoCV` and :class:`linear_model.ElasticNetCV` now
take sample weights into accounts to define the search grid for the internally tuned
`alpha` hyper-parameter. :pr:`29442` by :user:`John Hopfensperger <s-banach> and
`alpha` hyper-parameter. :pr:`29442` by :user:`John Hopfensperger <s-banach>` and
:user:`Shruti Nath <snath-xoc>`.

- |Fix| :class:`linear_model.LogisticRegression`, :class:`linear_model.PoissonRegressor`,
:class:`linear_model.GammaRegressor`, :class:`linear_model.TweedieRegressor`
now take sample weights into account to decide when to fall back to `solver='lbfgs'`
whenever `solver='newton-cholesky'` becomes numerically unstable.
:pr:`29818` by :user:`Antoine Baker <antoinebaker>`.

:mod:`sklearn.manifold`
.......................

Expand Down Expand Up @@ -382,6 +388,14 @@ Changelog
`ensure_all_finite`. `force_all_finite` will be removed in 1.8.
:pr:`29404` by :user:`Jérémie du Boisberranger <jeremiedb>`.

:mod:`sklearn.utils.check_estimators`
.....................................

- |API| :func:`check_estimators.check_sample_weights_invariance` replaced by
:func:`check_estimators.check_sample_weight_equivalence` which uses
integer (including zero) weights.
:pr:`29818` by :user:`Antoine Baker <antoinebaker>`.

.. rubric:: Code and documentation contributors

Thanks to everyone who has contributed to the maintenance and improvement of
Expand Down
5 changes: 3 additions & 2 deletions sklearn/cluster/_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -1179,9 +1179,10 @@ def score(self, X, y=None, sample_weight=None):

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
# TODO: replace by a statistical test, see meta-issue #162298
tags._xfail_checks = {
"check_sample_weights_invariance": (
"zero sample_weight is not equivalent to removing samples"
"check_sample_weight_equivalence": (
"sample_weight is not equivalent to removing/repeating samples."
),
}
return tags
Expand Down
6 changes: 6 additions & 0 deletions sklearn/ensemble/_bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,12 @@ def _get_estimator(self):
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.allow_nan = get_tags(self._get_estimator()).input_tags.allow_nan
# TODO: replace by a statistical test, see meta-issue #162298
tags._xfail_checks = {
"check_sample_weight_equivalence": (
"sample_weight is not equivalent to removing/repeating samples."
),
}
return tags


Expand Down
30 changes: 30 additions & 0 deletions sklearn/ensemble/_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1556,6 +1556,16 @@ def __init__(
self.monotonic_cst = monotonic_cst
self.ccp_alpha = ccp_alpha

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
# TODO: replace by a statistical test, see meta-issue #162298
tags._xfail_checks = {
"check_sample_weight_equivalence": (
"sample_weight is not equivalent to removing/repeating samples."
),
}
return tags


class RandomForestRegressor(ForestRegressor):
"""
Expand Down Expand Up @@ -1915,6 +1925,16 @@ def __init__(
self.ccp_alpha = ccp_alpha
self.monotonic_cst = monotonic_cst

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
# TODO: replace by a statistical test, see meta-issue #162298
tags._xfail_checks = {
"check_sample_weight_equivalence": (
"sample_weight is not equivalent to removing/repeating samples."
),
}
return tags


class ExtraTreesClassifier(ForestClassifier):
"""
Expand Down Expand Up @@ -2985,3 +3005,13 @@ def transform(self, X):
"""
check_is_fitted(self)
return self.one_hot_encoder_.transform(self.apply(X))

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
# TODO: replace by a statistical test, see meta-issue #162298
tags._xfail_checks = {
"check_sample_weight_equivalence": (
"sample_weight is not equivalent to removing/repeating samples."
),
}
return tags
20 changes: 20 additions & 0 deletions sklearn/ensemble/_gb.py
Original file line number Diff line number Diff line change
Expand Up @@ -1723,6 +1723,16 @@ def staged_predict_proba(self, X):
"loss=%r does not support predict_proba" % self.loss
) from e

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
# TODO: investigate failure see meta-issue #162298
tags._xfail_checks = {
"check_sample_weight_equivalence": (
"sample_weight is not equivalent to removing/repeating samples."
),
}
return tags


class GradientBoostingRegressor(RegressorMixin, BaseGradientBoosting):
"""Gradient Boosting for regression.
Expand Down Expand Up @@ -2177,3 +2187,13 @@ def apply(self, X):
leaves = super().apply(X)
leaves = leaves.reshape(X.shape[0], self.estimators_.shape[0])
return leaves

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
# TODO: investigate failure see meta-issue #162298
tags._xfail_checks = {
"check_sample_weight_equivalence": (
"sample_weight is not equivalent to removing/repeating samples."
),
}
return tags
6 changes: 6 additions & 0 deletions sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -1414,6 +1414,12 @@ def _compute_partial_dependence_recursion(self, grid, target_features):
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.allow_nan = True
# TODO: replace by a statistical test, see meta-issue #162298
tags._xfail_checks = {
"check_sample_weight_equivalence": (
"sample_weight is not equivalent to removing/repeating samples."
),
}
return tags

@abstractmethod
Expand Down
5 changes: 3 additions & 2 deletions sklearn/ensemble/_iforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,9 +633,10 @@ def _compute_score_samples(self, X, subsample_features):

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
# TODO: replace by a statistical test, see meta-issue #162298
tags._xfail_checks = {
"check_sample_weights_invariance": (
"zero sample_weight is not equivalent to removing samples"
"check_sample_weight_equivalence": (
"sample_weight is not equivalent to removing/repeating samples."
),
}
tags.input_tags.allow_nan = True
Expand Down
21 changes: 20 additions & 1 deletion sklearn/ensemble/_weight_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,6 +966,16 @@ def predict_log_proba(self, X):
"""
return np.log(self.predict_proba(X))

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
# TODO: replace by a statistical test, see meta-issue #162298
tags._xfail_checks = {
"check_sample_weight_equivalence": (
"sample_weight is not equivalent to removing/repeating samples."
),
}
return tags


class AdaBoostRegressor(_RoutingNotSupportedMixin, RegressorMixin, BaseWeightBoosting):
"""An AdaBoost regressor.
Expand Down Expand Up @@ -1139,7 +1149,6 @@ def _boost(self, iboost, X, y, sample_weight, random_state):
`random_state` attribute.
Controls also the bootstrap of the weights used to train the weak
learner.
replacement.

Returns
-------
Expand Down Expand Up @@ -1275,3 +1284,13 @@ def staged_predict(self, X):

for i, _ in enumerate(self.estimators_, 1):
yield self._get_median_predict(X, limit=i)

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
# TODO: replace by a statistical test, see meta-issue #162298
tags._xfail_checks = {
"check_sample_weight_equivalence": (
"sample_weight is not equivalent to removing/repeating samples."
),
}
return tags
16 changes: 16 additions & 0 deletions sklearn/linear_model/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,22 @@ def rmatvec(b):
self._set_intercept(X_offset, y_offset, X_scale)
return self

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
# TODO: investigate failure see meta-issue #162298
#
# Note: this model should converge to the minimum norm solution of the
# least squares problem and as result be numerically stable enough when
# running the equivalence check even if n_features > n_samples. Maybe
# this is is not the case and a different choice of solver could fix
# this problem.
tags._xfail_checks = {
"check_sample_weight_equivalence": (
"sample_weight is not equivalent to removing/repeating samples."
),
}
return tags


def _check_precomputed_gram_matrix(
X, precompute, X_offset, X_scale, rtol=None, atol=1e-5
Expand Down
10 changes: 10 additions & 0 deletions sklearn/linear_model/_bayes.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,16 @@ def _log_marginal_likelihood(

return score

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
# TODO: fix sample_weight handling of this estimator, see meta-issue #162298
tags._xfail_checks = {
"check_sample_weight_equivalence": (
"sample_weight is not equivalent to removing/repeating samples."
),
}
return tags


###############################################################################
# ARD (Automatic Relevance Determination) regression
Expand Down
11 changes: 0 additions & 11 deletions sklearn/linear_model/_coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1834,17 +1834,6 @@ def fit(self, X, y, sample_weight=None, **params):
self.n_iter_ = model.n_iter_
return self

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
# Note: check_sample_weights_invariance(kind='ones') should work, but
# currently we can only mark a whole test as xfail.
tags._xfail_checks = {
"check_sample_weights_invariance": (
"zero sample_weight is not equivalent to removing samples"
),
}
return tags

def get_metadata_routing(self):
"""Get metadata routing of this object.

Expand Down
12 changes: 11 additions & 1 deletion sklearn/linear_model/_linear_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,17 @@ def gradient_hessian(
# For non-canonical link functions and far away from the optimum, the pointwise
# hessian can be negative. We take care that 75% of the hessian entries are
# positive.
hessian_warning = np.mean(hess_pointwise <= 0) > 0.25
sw = np.ones(n_samples) if sample_weight is None else sample_weight
n_classes = 1
# For multi_class loss, hess_pointwise.shape = (n_samples, n_classes).
# We need to reshape sample_weight for broadcasting.
if self.base_loss.is_multiclass:
n_classes = self.base_loss.n_classes
sw = sw[:, np.newaxis]
negative_hessian_proportion = np.sum(sw * (hess_pointwise < 0)) / (
sw.sum() * n_classes
)
hessian_warning = negative_hessian_proportion > 0.25
hess_pointwise = np.abs(hess_pointwise)

if not self.base_loss.is_multiclass:
Expand Down
10 changes: 10 additions & 0 deletions sklearn/linear_model/_perceptron.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,13 @@ def __init__(
class_weight=class_weight,
n_jobs=n_jobs,
)

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
# TODO: replace by a statistical test, see meta-issue #162298
tags._xfail_checks = {
"check_sample_weight_equivalence": (
"sample_weight is not equivalent to removing/repeating samples."
),
}
return tags
5 changes: 3 additions & 2 deletions sklearn/linear_model/_ransac.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,9 +723,10 @@ def get_metadata_routing(self):

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
# TODO: replace by a statistical test, see meta-issue #162298
tags._xfail_checks = {
"check_sample_weights_invariance": (
"zero sample_weight is not equivalent to removing samples"
"check_sample_weight_equivalence": (
"sample_weight is not equivalent to removing/repeating samples."
),
}
return tags
2 changes: 1 addition & 1 deletion sklearn/linear_model/_ridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2729,7 +2729,7 @@ def fit(self, X, y, sample_weight=None, **params):
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags._xfail_checks = {
"check_sample_weights_invariance": (
"check_sample_weight_equivalence": (
"GridSearchCV does not forward the weights to the scorer by default."
),
}
Expand Down
15 changes: 9 additions & 6 deletions sklearn/linear_model/_stochastic_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -1376,9 +1376,10 @@ def predict_log_proba(self, X):

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
# TODO: replace by a statistical test, see meta-issue #162298
tags._xfail_checks = {
"check_sample_weights_invariance": (
"zero sample_weight is not equivalent to removing samples"
"check_sample_weight_equivalence": (
"sample_weight is not equivalent to removing/repeating samples."
),
}
return tags
Expand Down Expand Up @@ -2060,9 +2061,10 @@ def __init__(

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
# TODO: replace by a statistical test, see meta-issue #162298
tags._xfail_checks = {
"check_sample_weights_invariance": (
"zero sample_weight is not equivalent to removing samples"
"check_sample_weight_equivalence": (
"sample_weight is not equivalent to removing/repeating samples."
),
}
return tags
Expand Down Expand Up @@ -2640,9 +2642,10 @@ def predict(self, X):

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
# TODO: replace by a statistical test, see meta-issue #162298
tags._xfail_checks = {
"check_sample_weights_invariance": (
"zero sample_weight is not equivalent to removing samples"
"check_sample_weight_equivalence": (
"sample_weight is not equivalent to removing/repeating samples."
),
}
return tags
6 changes: 3 additions & 3 deletions sklearn/linear_model/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ def test_linear_regression_sample_weight_consistency(
"""Test that the impact of sample_weight is consistent.

Note that this test is stricter than the common test
check_sample_weights_invariance alone and also tests sparse X.
check_sample_weight_equivalence alone and also tests sparse X.
It is very similar to test_enet_sample_weight_consistency.
"""
rng = np.random.RandomState(global_random_seed)
Expand All @@ -717,8 +717,8 @@ def test_linear_regression_sample_weight_consistency(
if fit_intercept:
intercept = reg.intercept_

# 1) sample_weight=np.ones(..) must be equivalent to sample_weight=None
# same check as check_sample_weights_invariance(name, reg, kind="ones"), but we also
# 1) sample_weight=np.ones(..) must be equivalent to sample_weight=None,
# a special case of check_sample_weight_equivalence(name, reg), but we also
# test with sparse input.
sample_weight = np.ones_like(y)
reg.fit(X, y, sample_weight=sample_weight)
Expand Down
2 changes: 1 addition & 1 deletion sklearn/linear_model/tests/test_coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1224,7 +1224,7 @@ def test_enet_sample_weight_consistency(
"""Test that the impact of sample_weight is consistent.

Note that this test is stricter than the common test
check_sample_weights_invariance alone and also tests sparse X.
check_sample_weight_equivalence alone and also tests sparse X.
"""
rng = np.random.RandomState(global_random_seed)
n_samples, n_features = 10, 5
Expand Down
Loading