Skip to content

Fix sample weight passing in KBinsDiscretizer #29907

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 92 commits into from
Feb 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
8bc7bfd
add sample weights to resampling
snath-xoc Sep 22, 2024
09d82b2
fixed sample_weight handling in _indexing/resampling
snath-xoc Sep 22, 2024
f25e535
fixed sample_weight handling in _indexing/resampling
snath-xoc Sep 22, 2024
e1c3eae
fixed sample_weight handling in _indexing/resampling
snath-xoc Sep 22, 2024
81c20aa
fix _indexing
snath-xoc Sep 30, 2024
2a27aaa
added both weighted resampling and passing weights to binestimator op…
snath-xoc Oct 10, 2024
12f3ac7
Apply suggestions from code review
snath-xoc Oct 25, 2024
eb7248c
use only weighted resampling
snath-xoc Oct 25, 2024
33b3b39
Update sklearn/preprocessing/_discretization.py
snath-xoc Oct 25, 2024
d8a1dc2
Update sklearn/utils/_indexing.py
snath-xoc Oct 25, 2024
e494a53
added sample weights to validate params and weighted_percentiles
snath-xoc Oct 25, 2024
71ba4c0
add test_resample_weighted
snath-xoc Oct 25, 2024
2e6daba
add tests for resampling
snath-xoc Oct 25, 2024
092119f
add stricter tests for resampling
snath-xoc Oct 25, 2024
03a5275
fixed n_samples in test_resample_weighted
snath-xoc Oct 26, 2024
249e80d
Update sklearn/utils/tests/test_indexing.py
snath-xoc Oct 28, 2024
6658562
Update sklearn/utils/tests/test_indexing.py
snath-xoc Oct 28, 2024
1883cef
Update sklearn/utils/_indexing.py
snath-xoc Oct 28, 2024
214be7d
add check sample weight
snath-xoc Oct 31, 2024
3d95892
fix assertion error
snath-xoc Oct 31, 2024
25804c4
added further suggestions
snath-xoc Nov 6, 2024
8803914
Trigger Build
snath-xoc Oct 31, 2024
0f8eb3c
fix changelog
snath-xoc Nov 6, 2024
9417cfe
edited changelog
snath-xoc Nov 1, 2024
3c09a50
rebase
snath-xoc Nov 6, 2024
4f1e101
added changes to upcoming changes
snath-xoc Nov 6, 2024
fc59d67
fix sample weight check
snath-xoc Nov 6, 2024
d2f0930
Apply suggestions from code review
snath-xoc Nov 7, 2024
db00df1
Add support for sample_weight for the uniform strategy
ogrisel Nov 7, 2024
ae76737
New changelog entry for the uniform strategy
ogrisel Nov 7, 2024
c42b972
Update sklearn/preprocessing/_discretization.py
ogrisel Nov 7, 2024
d7ad9ec
updated tests for notimplementederror
snath-xoc Nov 7, 2024
288bd12
moved tests for -ve mean
snath-xoc Nov 7, 2024
1ca1bc1
Update sklearn/utils/_indexing.py
snath-xoc Nov 7, 2024
cc6aebf
set method inverted_cdf np.percentile
snath-xoc Nov 7, 2024
447ee16
updated doc string
snath-xoc Nov 7, 2024
3379d8c
removed unrelated formatting change
snath-xoc Nov 7, 2024
6babc16
Merge branch 'main' into fix_kbinsdiscretizer
ogrisel Nov 14, 2024
7aa8052
Update sklearn/preprocessing/_discretization.py
snath-xoc Nov 14, 2024
46953f8
Merge branch 'main' into fix_kbinsdiscretizer
snath-xoc Nov 15, 2024
eb1eff6
passed quantile method on to np percentile, fixed instance generator
snath-xoc Nov 15, 2024
f9316e4
Merge branch 'main' into fix_kbinsdiscretizer
snath-xoc Nov 15, 2024
79f862a
minor formatting
snath-xoc Nov 15, 2024
ce95a4f
added futurewarning and valueerror
snath-xoc Nov 15, 2024
5268419
added warn as default quantile method
snath-xoc Nov 18, 2024
0fc5fb2
update per_estimater_check_params
snath-xoc Nov 18, 2024
6dc0eae
Merge branch 'main' into fix_kbinsdiscretizer
snath-xoc Nov 18, 2024
0eea4f8
Apply suggestions from code review
snath-xoc Nov 22, 2024
27dbbbe
Apply suggestions from code review
snath-xoc Nov 25, 2024
6fe58f7
moved quantile_method to local variable
snath-xoc Nov 25, 2024
f1b313f
checked sample_weight tests for quantile with averaged_inverted_cdf
snath-xoc Nov 25, 2024
d31dd88
modified to use averaged inverted cdf
snath-xoc Nov 25, 2024
6446ffc
Merge branch 'main' into fix_kbinsdiscretizer
snath-xoc Dec 2, 2024
0fbf8d8
Apply suggestions from code review
snath-xoc Dec 2, 2024
621f445
Apply suggestions from code review
snath-xoc Dec 5, 2024
7fc7280
added avergaed_weighted tests
snath-xoc Dec 5, 2024
5dadf43
corrected tests for quantile_method
snath-xoc Dec 5, 2024
3ed3233
modify test_permutation_importance
snath-xoc Dec 5, 2024
d941c85
modify _averaged_Weighted_quantile tests
snath-xoc Dec 5, 2024
a2ab009
modify test_perm
snath-xoc Dec 6, 2024
2bbd6e0
modify minmax test
snath-xoc Dec 6, 2024
903dd42
further modified tests
snath-xoc Dec 6, 2024
a2bcd8e
fix_tests
snath-xoc Dec 6, 2024
a34756a
fixed instance generator
snath-xoc Dec 7, 2024
7f41bbd
Merge branch 'main' into fix_kbinsdiscretizer
snath-xoc Dec 7, 2024
9e5c0a1
further modify tests for polynomial and target encoder
snath-xoc Dec 7, 2024
71e182f
Merge branch 'main' into fix_kbinsdiscretizer
snath-xoc Dec 19, 2024
bf4479a
add changelog
snath-xoc Dec 7, 2024
4e7478b
added handing of np version < 1.22
snath-xoc Dec 19, 2024
6eb06ce
added handing of np version < 1.22 dtype
snath-xoc Dec 19, 2024
a3e47d4
modify test_averaged_weighted_percentile for np version handling
snath-xoc Dec 19, 2024
8746a25
TST use @pytest.mark.skipif
ogrisel Dec 27, 2024
6d8de61
Backward compat with older numpy versions
ogrisel Dec 27, 2024
94aa394
TST cover ValueError
ogrisel Dec 27, 2024
d74456d
TST fix common tests and remove XFAIL mark
ogrisel Dec 27, 2024
d6b548f
FIX typo in inline comment
ogrisel Dec 27, 2024
676abb4
STYLE improve formatting
ogrisel Dec 27, 2024
5e9bb4b
DOC add comment about quantile_method='warn' in tests
ogrisel Dec 27, 2024
e6e79fd
MAINT add missing TODO comment
ogrisel Dec 27, 2024
43365b5
Apply suggestions from code review
snath-xoc Jan 2, 2025
1b40d29
Merge branch 'main' into fix_kbinsdiscretizer
snath-xoc Jan 6, 2025
5e837ed
fix linting issues
snath-xoc Jan 6, 2025
5bb2886
fix_linting
snath-xoc Jan 6, 2025
cf3c720
Merge branch 'main' into fix_kbinsdiscretizer
snath-xoc Jan 7, 2025
03f4c95
Apply suggestions from code review
snath-xoc Jan 13, 2025
8e567de
move sample weight check before subsample check
snath-xoc Jan 13, 2025
2d3822c
Apply suggestions from code review
snath-xoc Jan 24, 2025
4027771
Merge branch 'main' into fix_kbinsdiscretizer
snath-xoc Feb 5, 2025
7aaed10
Update sklearn/preprocessing/tests/test_discretization.py
snath-xoc Feb 6, 2025
626f5e3
Merge branch 'main' into fix_kbinsdiscretizer
snath-xoc Feb 6, 2025
c6fcad3
update changelog
snath-xoc Feb 6, 2025
16c2dbd
Update sklearn/preprocessing/tests/test_discretization.py
jeremiedbb Feb 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
- :class:`preprocessing.KBinsDiscretizer` with `strategy="uniform"` now
accepts `sample_weight`. Additionally with `strategy="quantile"` the
`quantile_method` can now be specified (in the future
`quantile_method="averaged_inverted_cdf"` will become the default)
:pr:`29907` by :user:`Shruti Nath <snath-xoc>` and :user:`Olivier Grisel
<ogrisel>`
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
- :class:`preprocessing.KBinsDiscretizer` now uses weighted resampling when
sample weights are given and subsampling is used. This may change results
even when not using sample weights, although in absolute and not in terms
of statistical properties.
:pr:`29907` by :user:`Shruti Nath <snath-xoc>` and :user:`Jérémie du Boisberranger
<jeremiedbb>`
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

- :func: `resample` now handles sample weights which allows
weighted resampling.
- :func: `_averaged_weighted_percentile` now added which implements
an averaged inverted cdf calculation of percentiles.
Comment on lines +4 to +5
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@snath-xoc I missed that in the review, but we should not mention private API related names in the changelog. It's confusing: the primary users of the changelog are library users, and they should not be incentivized to use private API in their code otherwise, since it can break without deprecation cycles when upgrading scikit-learn.

Can you please open a new PR to delete those lines?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

O.K. will do, thanks for letting me know!

:pr:`29907` by :user:`Shruti Nath <snath-xoc>` and :user:`Olivier Grisel
<ogrisel>`
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,9 @@ def make_missing_value_data(n_samples=int(1e4), seed=0):
# Pre-bin the data to ensure a deterministic handling by the 2
# strategies and also make it easier to insert np.nan in a structured
# way:
X = KBinsDiscretizer(n_bins=42, encode="ordinal").fit_transform(X)
X = KBinsDiscretizer(
n_bins=42, encode="ordinal", quantile_method="averaged_inverted_cdf"
).fit_transform(X)

# First feature has missing values completely at random:
rnd_mask = rng.rand(X.shape[0]) > 0.9
Expand Down
6 changes: 5 additions & 1 deletion sklearn/inspection/tests/test_permutation_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,11 @@ def test_permutation_importance_equivalence_array_dataframe(n_jobs, max_samples)
X_df = pd.DataFrame(X)

# Add a categorical feature that is statistically linked to y:
binner = KBinsDiscretizer(n_bins=3, encode="ordinal")
binner = KBinsDiscretizer(
n_bins=3,
encode="ordinal",
quantile_method="averaged_inverted_cdf",
)
cat_column = binner.fit_transform(y.reshape(-1, 1))

# Concatenate the extra column to the numpy array: integers will be
Expand Down
141 changes: 124 additions & 17 deletions sklearn/preprocessing/_discretization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from ..utils import resample
from ..utils._param_validation import Interval, Options, StrOptions
from ..utils.deprecation import _deprecate_Xt_in_inverse_transform
from ..utils.stats import _weighted_percentile
from ..utils.fixes import np_version, parse_version
from ..utils.stats import _averaged_weighted_percentile, _weighted_percentile
from ..utils.validation import (
_check_feature_names_in,
_check_sample_weight,
Expand Down Expand Up @@ -57,6 +58,17 @@ class KBinsDiscretizer(TransformerMixin, BaseEstimator):
For an example of the different strategies see:
:ref:`sphx_glr_auto_examples_preprocessing_plot_discretization_strategies.py`.

quantile_method : {"inverted_cdf", "averaged_inverted_cdf",
"closest_observation", "interpolated_inverted_cdf", "hazen",
"weibull", "linear", "median_unbiased", "normal_unbiased"},
default="linear"
Method to pass on to np.percentile calculation when using
strategy="quantile". Only `averaged_inverted_cdf` and `inverted_cdf`
support the use of `sample_weight != None` when subsampling is not
active.

.. versionadded:: 1.7

dtype : {np.float32, np.float64}, default=None
The desired data-type for the output. If None, output dtype is
consistent with input dtype. Only np.float32 and np.float64 are
Expand Down Expand Up @@ -175,6 +187,22 @@ class KBinsDiscretizer(TransformerMixin, BaseEstimator):
"n_bins": [Interval(Integral, 2, None, closed="left"), "array-like"],
"encode": [StrOptions({"onehot", "onehot-dense", "ordinal"})],
"strategy": [StrOptions({"uniform", "quantile", "kmeans"})],
"quantile_method": [
StrOptions(
{
"warn",
"inverted_cdf",
"averaged_inverted_cdf",
"closest_observation",
"interpolated_inverted_cdf",
"hazen",
"weibull",
"linear",
"median_unbiased",
"normal_unbiased",
}
)
],
"dtype": [Options(type, {np.float64, np.float32}), None],
"subsample": [Interval(Integral, 1, None, closed="left"), None],
"random_state": ["random_state"],
Expand All @@ -186,13 +214,15 @@ def __init__(
*,
encode="onehot",
strategy="quantile",
quantile_method="warn",
dtype=None,
subsample=200_000,
random_state=None,
):
self.n_bins = n_bins
self.encode = encode
self.strategy = strategy
self.quantile_method = quantile_method
self.dtype = dtype
self.subsample = subsample
self.random_state = random_state
Expand All @@ -213,10 +243,12 @@ def fit(self, X, y=None, sample_weight=None):

sample_weight : ndarray of shape (n_samples,)
Contains weight values to be associated with each sample.
Cannot be used when `strategy` is set to `"uniform"`.

.. versionadded:: 1.3

.. versionchanged:: 1.7
Added support for strategy="uniform".

Returns
-------
self : object
Expand All @@ -231,32 +263,74 @@ def fit(self, X, y=None, sample_weight=None):

n_samples, n_features = X.shape

if sample_weight is not None and self.strategy == "uniform":
raise ValueError(
"`sample_weight` was provided but it cannot be "
"used with strategy='uniform'. Got strategy="
f"{self.strategy!r} instead."
)
if sample_weight is not None:
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)

if self.subsample is not None and n_samples > self.subsample:
# Take a subsample of `X`
# When resampling, it is important to subsample **with replacement** to
# preserve the distribution, in particular in the presence of a few data
# points with large weights. You can check this by setting `replace=False`
# in sklearn.utils.test.test_indexing.test_resample_weighted and check that
# it fails as a justification for this claim.
X = resample(
X,
replace=False,
replace=True,
n_samples=self.subsample,
random_state=self.random_state,
sample_weight=sample_weight,
)
# Since we already used the weights when resampling when provided,
# we set them back to `None` to avoid accounting for the weights twice
# in subsequent operations to compute weight-aware bin edges with
# quantiles or k-means.
sample_weight = None

n_features = X.shape[1]
n_bins = self._validate_n_bins(n_features)

if sample_weight is not None:
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)

bin_edges = np.zeros(n_features, dtype=object)

# TODO(1.9): remove and switch to quantile_method="averaged_inverted_cdf"
# by default.
quantile_method = self.quantile_method
if self.strategy == "quantile" and quantile_method == "warn":
warnings.warn(
"The current default behavior, quantile_method='linear', will be "
"changed to quantile_method='averaged_inverted_cdf' in "
"scikit-learn version 1.9 to naturally support sample weight "
"equivalence properties by default. Pass "
Comment on lines +301 to +302
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is "sample weight equivalence properties" - could we rephrase this (or am I the only one who doesn't know it)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the weight of a data point =n <=> this data point is repeated n times

It's described in the glossary https://scikit-learn.org/dev/glossary.html#term-sample_weight, but you'll notice that this section is not complete and not very precise :)

Copy link
Contributor Author

@snath-xoc snath-xoc Dec 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@betatim thank you for the coments, I can imagine it's a lot of conversation to follow, as @jeremiedbb said, we conducted some tests with the gist here.

@jeremiedbb found that when specifying an odd case of samples (i.e., with one large outlier), the resampling with repacement strategy held best in terms of sample weight equivalence (i.e., equivalent properties when weighting vs. repeating a sample n times, where n=0 means omitting the sample altogether).

We also had to modify the "uniform" strategy with @ogrisel's help and implement an averaged weighted quantile strategy with accompanying tests. Let me know if you need further clarification :) .

Copy link
Member

@ogrisel ogrisel Dec 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jeremiedbb I will open a PR to improve the glossary entry. This is an excellent place to document the expected semantics for sample_weight in scikit-learn.

@betatim Meanwhile, the precise definition of that equivalence is given in check_sample_weight_equivalence_on_dense_data. Note that this check is expected to fail for estimators which have a fit method whose outcome depends on the value of a random_state param (aka estimators with a non-deterministic fit).

The fact that an estimator has a deterministic fit or not, can both depend on the choice of the hyper-parameters (e.g. subsample, max_features, ...) and sometimes dataset size (e.g. subsample=200_000 will enable random subsampling only on training sets with more than 200_000 rows).

For estimators with a stochastic fit, we have to run a more computational intensive statistical test to check if the equivalence property holds. This is currently done in external notebooks from this gist: https://gist.github.com/snath-xoc/fb28feab39403a1e66b00b5b28f1dcbf . We don't have specific plans to refactor those notebooks as a common test at this time because:

  • they can be slow to run;
  • the fact that an estimator pass or fail a statistical test depends on a choice of a pvalue threshold which is quite arbitrary, especially since we have to do multiple comparisons in our case and the individual subtests are dependent in non-trivial manners which makes adjusting for multiple comparison challenging.

So for now, we keep those notebooks as a guide to help us detect stochastic estimators with suspicious sample_weight handling. Bugs are manually confirmed by auditing their source code, starting from the worst offenders by ranking by ascending pvalues.

Copy link
Member

@ogrisel ogrisel Dec 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will open a PR to improve the glossary entry. This is an excellent place to document the expected semantics for sample_weight in scikit-learn.

I opened the PR: #30564.

"quantile_method='averaged_inverted_cdf' explicitly to silence this "
"warning.",
FutureWarning,
)
quantile_method = "linear"

if (
self.strategy == "quantile"
and quantile_method not in ["inverted_cdf", "averaged_inverted_cdf"]
and sample_weight is not None
):
raise ValueError(
"When fitting with strategy='quantile' and sample weights, "
"quantile_method should either be set to 'averaged_inverted_cdf' or "
f"'inverted_cdf', got quantile_method='{quantile_method}' instead."
)

if self.strategy != "quantile" and sample_weight is not None:
# Preprare a mask to filter out zero-weight samples when extracting
# the min and max values of each columns which are needed for the
# "uniform" and "kmeans" strategies.
nnz_weight_mask = sample_weight != 0
else:
# Otherwise, all samples are used. Use a slice to avoid creating a
# new array.
nnz_weight_mask = slice(None)

for jj in range(n_features):
column = X[:, jj]
col_min, col_max = column.min(), column.max()
col_min = column[nnz_weight_mask].min()
col_max = column[nnz_weight_mask].max()

if col_min == col_max:
warnings.warn(
Expand All @@ -270,14 +344,47 @@ def fit(self, X, y=None, sample_weight=None):
bin_edges[jj] = np.linspace(col_min, col_max, n_bins[jj] + 1)

elif self.strategy == "quantile":
quantiles = np.linspace(0, 100, n_bins[jj] + 1)
percentile_levels = np.linspace(0, 100, n_bins[jj] + 1)

# TODO: simplify the following when numpy min version >= 1.22.

# method="linear" is the implicit default for any numpy
# version. So we keep it version independent in that case by
# using an empty param dict.
percentile_kwargs = {}
if quantile_method != "linear" and sample_weight is None:
if np_version < parse_version("1.22"):
if quantile_method in ["averaged_inverted_cdf", "inverted_cdf"]:
# The method parameter is not supported in numpy <
# 1.22 but we can define unit sample weight to use
# our own implementation instead:
sample_weight = np.ones(X.shape[0], dtype=X.dtype)
else:
raise ValueError(
f"quantile_method='{quantile_method}' is not "
"supported with numpy < 1.22"
)
else:
percentile_kwargs["method"] = quantile_method

if sample_weight is None:
bin_edges[jj] = np.asarray(np.percentile(column, quantiles))
bin_edges[jj] = np.asarray(
np.percentile(column, percentile_levels, **percentile_kwargs),
dtype=np.float64,
)
else:
# TODO: make _weighted_percentile and
# _averaged_weighted_percentile accept an array of
# quantiles instead of calling it multiple times and
# sorting the column multiple times as a result.
percentile_func = {
"inverted_cdf": _weighted_percentile,
"averaged_inverted_cdf": _averaged_weighted_percentile,
}[quantile_method]
bin_edges[jj] = np.asarray(
[
_weighted_percentile(column, sample_weight, q)
for q in quantiles
percentile_func(column, sample_weight, percentile=p)
for p in percentile_levels
],
dtype=np.float64,
)
Expand Down
Loading
Loading