Skip to content

More sensitive sample weight estimator check #30143

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 117 additions & 48 deletions sklearn/utils/_test_common/instance_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from functools import partial
from inspect import isfunction

import numpy as np

from sklearn import clone, config_context
from sklearn.calibration import CalibratedClassifierCV
from sklearn.cluster import (
Expand Down Expand Up @@ -111,6 +113,7 @@
RANSACRegressor,
Ridge,
RidgeClassifier,
RidgeClassifierCV,
RidgeCV,
SGDClassifier,
SGDOneClassSVM,
Expand Down Expand Up @@ -537,6 +540,10 @@
max_iter=20, n_components=1, transform_algorithm="lasso_lars"
)
},
ElasticNetCV: {
"check_sample_weight_equivalence_on_dense_data": dict(max_iter=100, tol=1e-2),
"check_sample_weight_equivalence_on_sparse_data": dict(max_iter=100, tol=1e-2),
},
FactorAnalysis: {"check_dict_unchanged": dict(max_iter=5, n_components=1)},
FastICA: {"check_dict_unchanged": dict(max_iter=5, n_components=1)},
FeatureAgglomeration: {"check_dict_unchanged": dict(n_clusters=1)},
Expand All @@ -554,38 +561,88 @@
},
GammaRegressor: {
"check_sample_weight_equivalence_on_dense_data": [
dict(solver="newton-cholesky"),
dict(solver="lbfgs"),
dict(solver="newton-cholesky", max_iter=1_000, tol=1e-12),
dict(solver="lbfgs", max_iter=1_000, tol=1e-12),
],
"check_sample_weight_equivalence_on_sparse_data": [
dict(solver="newton-cholesky", max_iter=1_000, tol=1e-12),
dict(solver="lbfgs", max_iter=1_000, tol=1e-12),
],
},
GaussianMixture: {"check_dict_unchanged": dict(max_iter=5, n_init=2)},
GaussianRandomProjection: {"check_dict_unchanged": dict(n_components=1)},
HuberRegressor: {
"check_sample_weight_equivalence_on_dense_data": dict(
tol=1e-12, max_iter=1_000
),
"check_sample_weight_equivalence_on_sparse_data": dict(
tol=1e-12, max_iter=1_000
),
},
IncrementalPCA: {"check_dict_unchanged": dict(batch_size=10, n_components=1)},
Isomap: {"check_dict_unchanged": dict(n_components=1)},
KMeans: {"check_dict_unchanged": dict(max_iter=5, n_clusters=1, n_init=2)},
KernelPCA: {"check_dict_unchanged": dict(n_components=1)},
LassoLars: {"check_non_transformer_estimators_n_iter": dict(alpha=0.0)},
LassoCV: {
"check_sample_weight_equivalence_on_dense_data": dict(max_iter=100, tol=1e-2),
"check_sample_weight_equivalence_on_sparse_data": dict(max_iter=100, tol=1e-2),
},
LatentDirichletAllocation: {
"check_dict_unchanged": dict(batch_size=10, max_iter=5, n_components=1)
},
LinearDiscriminantAnalysis: {"check_dict_unchanged": dict(n_components=1)},
LinearRegression: {
"check_estimator_sparse_tag": [dict(positive=False), dict(positive=True)],
LinearSVC: {
"check_sample_weight_equivalence_on_dense_data": [
dict(positive=False),
dict(positive=True),
dict(dual=False, max_iter=1_000, tol=1e-12),
# XXX: the dual solver has trouble converging on the repeated test
# data with a lower tolerance. Futhermore, the solver is not
# deterministic with dual=True. We would need a statistical test
# to check weight/repetition equivalence instead.
# dict(dual=True, max_iter=1_000, tol=1e-3),
],
"check_sample_weight_equivalence_on_sparse_data": [
dict(dual=False, max_iter=1_000, tol=1e-12),
],
},
LinearSVR: {
"check_sample_weight_equivalence_on_dense_data": [
dict(max_iter=1_000, tol=1e-8),
],
"check_sample_weight_equivalence_on_sparse_data": [
dict(max_iter=1_000, tol=1e-8),
],
},
LocallyLinearEmbedding: {"check_dict_unchanged": dict(max_iter=5, n_components=1)},
LogisticRegression: {
"check_sample_weight_equivalence_on_dense_data": [
dict(solver="lbfgs"),
dict(solver="liblinear"),
dict(solver="newton-cg"),
dict(solver="newton-cholesky"),
dict(solver="lbfgs", max_iter=1_000, tol=1e-12),
dict(solver="newton-cg", max_iter=1_000, tol=1e-12),
dict(solver="newton-cholesky", max_iter=1_000, tol=1e-12),
# liblinear has more problems with higher regularization apparently...
dict(solver="liblinear", C=0.01, max_iter=1_000, tol=1e-12),
],
"check_sample_weight_equivalence_on_sparse_data": [
dict(solver="liblinear"),
# liblinear has more problems with higher regularization apparently...
dict(solver="liblinear", C=0.01, max_iter=1_000, tol=1e-12),
],
},
LogisticRegressionCV: {
"check_sample_weight_equivalence_on_dense_data": [
dict(
solver="newton-cholesky",
Cs=np.logspace(-3, 3, 5),
max_iter=1_000,
tol=1e-12,
),
],
"check_sample_weight_equivalence_on_sparse_data": [
dict(
solver="newton-cholesky",
Cs=np.logspace(-3, 3, 5),
max_iter=1_000,
tol=1e-12,
),
],
},
MDS: {"check_dict_unchanged": dict(max_iter=5, n_components=1, n_init=2)},
Expand Down Expand Up @@ -614,8 +671,12 @@
PLSSVD: {"check_dict_unchanged": dict(n_components=1)},
PoissonRegressor: {
"check_sample_weight_equivalence_on_dense_data": [
dict(solver="newton-cholesky"),
dict(solver="lbfgs"),
dict(solver="newton-cholesky", max_iter=100),
dict(solver="lbfgs", max_iter=100),
],
"check_sample_weight_equivalence_on_sparse_data": [
dict(solver="newton-cholesky", max_iter=100),
dict(solver="lbfgs", max_iter=100),
],
},
PolynomialCountSketch: {"check_dict_unchanged": dict(n_components=1)},
Expand All @@ -632,27 +693,40 @@
"check_sample_weight_equivalence_on_dense_data": [
dict(solver="svd"),
dict(solver="cholesky"),
dict(solver="sparse_cg"),
dict(solver="lsqr"),
dict(solver="sparse_cg", tol=1e-12),
dict(solver="lsqr", tol=1e-12),
dict(solver="lbfgs", positive=True),
],
"check_sample_weight_equivalence_on_sparse_data": [
dict(solver="sparse_cg"),
dict(solver="lsqr"),
dict(solver="sparse_cg", tol=1e-12),
dict(solver="lsqr", tol=1e-12),
],
},
RidgeClassifier: {
"check_sample_weight_equivalence_on_dense_data": [
dict(solver="svd"),
dict(solver="cholesky"),
dict(solver="sparse_cg"),
dict(solver="lsqr"),
dict(solver="sparse_cg", tol=1e-12),
dict(solver="lsqr", tol=1e-12),
dict(solver="lbfgs", positive=True),
],
"check_sample_weight_equivalence_on_sparse_data": [
dict(solver="sparse_cg"),
dict(solver="lsqr"),
dict(solver="sparse_cg", tol=1e-12),
dict(solver="lsqr", tol=1e-12),
],
},
RidgeCV: {
# XXX: the default grid (0.1, 1, 10.) is not wide and fine enough to
# detect discrepancies that impact the choice of the best alpha.
"check_sample_weight_equivalence_on_dense_data": dict(
alphas=np.logspace(-3, 3, 5)
),
},
RidgeClassifierCV: {
"check_sample_weight_equivalence_on_dense_data": dict(
alphas=np.logspace(-3, 3, 5)
),
},
SkewedChi2Sampler: {"check_dict_unchanged": dict(n_components=1)},
SparsePCA: {"check_dict_unchanged": dict(max_iter=5, n_components=1)},
SparseRandomProjection: {"check_dict_unchanged": dict(n_components=1)},
Expand All @@ -677,8 +751,12 @@
TruncatedSVD: {"check_dict_unchanged": dict(n_components=1)},
TweedieRegressor: {
"check_sample_weight_equivalence_on_dense_data": [
dict(solver="newton-cholesky"),
dict(solver="lbfgs"),
dict(solver="newton-cholesky", max_iter=1_000, tol=1e-12),
dict(solver="lbfgs", max_iter=1_000, tol=1e-12),
],
"check_sample_weight_equivalence_on_sparse_data": [
dict(solver="newton-cholesky", max_iter=1_000, tol=1e-12),
dict(solver="lbfgs", max_iter=1_000, tol=1e-12),
],
},
}
Expand Down Expand Up @@ -830,9 +908,9 @@ def _yield_instances_for_check(check, estimator_orig):
"check_sample_weight_equivalence_on_dense_data": (
"sample_weight is not equivalent to removing/repeating samples."
),
"check_sample_weight_equivalence_on_sparse_data": (
"sample_weight is not equivalent to removing/repeating samples."
),
# "check_sample_weight_equivalence_on_sparse_data": (
# "sample_weight is not equivalent to removing/repeating samples."
# ),
},
BernoulliRBM: {
"check_methods_subset_invariance": ("fails for the decision_function method"),
Expand Down Expand Up @@ -996,34 +1074,25 @@ def _yield_instances_for_check(check, estimator_orig):
),
},
LinearSVC: {
# TODO: replace by a statistical test when _dual=True, see meta-issue #16298
"check_sample_weight_equivalence_on_dense_data": (
"sample_weight is not equivalent to removing/repeating samples."
),
"check_sample_weight_equivalence_on_sparse_data": (
"sample_weight is not equivalent to removing/repeating samples."
),
# TODO: replace by a statistical test when dual=True, see meta-issue #16298
# "check_sample_weight_equivalence_on_dense_data": (
# "sample_weight is not equivalent to removing/repeating samples."
# ),
# "check_sample_weight_equivalence_on_sparse_data": (
# "sample_weight is not equivalent to removing/repeating samples."
# ),
"check_non_transformer_estimators_n_iter": (
"n_iter_ cannot be easily accessed."
),
},
LinearSVR: {
# TODO: replace by a statistical test, see meta-issue #16298
"check_sample_weight_equivalence_on_dense_data": (
"sample_weight is not equivalent to removing/repeating samples."
),
"check_sample_weight_equivalence_on_sparse_data": (
"sample_weight is not equivalent to removing/repeating samples."
),
},
LogisticRegression: {
# TODO: fix sample_weight handling of this estimator, see meta-issue #16298
"check_sample_weight_equivalence_on_dense_data": (
"sample_weight is not equivalent to removing/repeating samples."
),
"check_sample_weight_equivalence_on_sparse_data": (
"sample_weight is not equivalent to removing/repeating samples."
),
# "check_sample_weight_equivalence_on_dense_data": (
# "sample_weight is not equivalent to removing/repeating samples."
# ),
# "check_sample_weight_equivalence_on_sparse_data": (
# "sample_weight is not equivalent to removing/repeating samples."
# ),
},
MiniBatchKMeans: {
# TODO: replace by a statistical test, see meta-issue #16298
Expand Down
63 changes: 54 additions & 9 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
make_regression,
)
from ..exceptions import (
ConvergenceWarning,
DataConversionWarning,
EstimatorCheckFailedWarning,
NotFittedError,
Expand Down Expand Up @@ -1512,11 +1513,38 @@ def _check_sample_weight_equivalence(name, estimator_orig, sparse_container):
set_random_state(estimator_repeated, random_state=0)

rng = np.random.RandomState(42)
n_samples = 15
X = rng.rand(n_samples, n_samples * 2)
y = rng.randint(0, 3, size=n_samples)

# Generate some random data with 3 classes that could either be used for
# classification or regression. We use a large number of features to give
# more freedom to the estimator when fitting and be more sensitive to
# train data weighting/resampling as a result.
n_samples_with_small_weights = 16
n_features = n_samples_with_small_weights * 2
X, y = make_classification(
n_samples=n_samples_with_small_weights,
n_classes=2,
n_features=n_features,
n_informative=3 * n_features // 4,
random_state=rng,
)
# Use random integers (including zero) as weights.
sw = rng.randint(0, 5, size=n_samples)
sw = rng.randint(0, 3, size=n_samples_with_small_weights)

# Add a third class with a few data points but with heavier weights right
# in the middle of the rest of the data.
n_samples_with_large_weights = 4
X_with_large_weights = rng.normal(
loc=X[y == 0].mean(axis=0),
scale=0.01,
size=(n_samples_with_large_weights, n_features),
)
X = np.vstack([X, X_with_large_weights])
y = np.hstack([y, [2] * n_samples_with_large_weights])
sw = np.hstack([sw, [100] * n_samples_with_large_weights])

tags = get_tags(estimator_orig)
if tags.input_tags.positive_only:
X -= X.min(axis=0)

X_weighted = X
y_weighted = y
Expand Down Expand Up @@ -1558,19 +1586,36 @@ def _check_sample_weight_equivalence(name, estimator_orig, sparse_container):
X_weighted = sparse_container(X_weighted)
X_repeated = sparse_container(X_repeated)

estimator_repeated.fit(X_repeated, y=y_repeated, sample_weight=None)
estimator_weighted.fit(X_weighted, y=y_weighted, sample_weight=sw)
with warnings.catch_warnings(record=True):
# Ensure we converge, otherwise debugging sample_weight equivalence
# failures can be very misleading.
warnings.simplefilter("error", category=ConvergenceWarning)

estimator_repeated.fit(X_repeated, y=y_repeated, sample_weight=None)
estimator_weighted.fit(X_weighted, y=y_weighted, sample_weight=sw)

X_test = rng.uniform(low=X.min(), high=X.max(), size=(300, n_features))
if sparse_container is not None:
X_test = sparse_container(X_test)

for method in ["predict_proba", "decision_function", "predict", "transform"]:
if hasattr(estimator_orig, method):
X_pred1 = getattr(estimator_repeated, method)(X)
X_pred2 = getattr(estimator_weighted, method)(X)
X_pred1 = getattr(estimator_repeated, method)(X_test)
X_pred2 = getattr(estimator_weighted, method)(X_test)
err_msg = (
f"Comparing the output of {name}.{method} revealed that fitting "
"with `sample_weight` is not equivalent to fitting with removed "
"or repeated data points."
)
assert_allclose_dense_sparse(X_pred1, X_pred2, err_msg=err_msg)

# We use a large tolerance than usual because this check is pushing
# the solvers to their limits and it is acceptable to tolerate some
# cumulative rounding errors after many iterations. But if the
# `sample_weight` is not equivalent to removing or repeating data
# points, the error will be large and the test will fail.
assert_allclose_dense_sparse(
X_pred1, X_pred2, err_msg=err_msg, rtol=1e-5, atol=1e-6
)


def check_sample_weight_equivalence_on_dense_data(name, estimator_orig):
Expand Down
Loading