-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
Fix sample weight passing in KBinsDiscretizer
#29907
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8bc7bfd
09d82b2
f25e535
e1c3eae
81c20aa
2a27aaa
12f3ac7
eb7248c
33b3b39
d8a1dc2
e494a53
71ba4c0
2e6daba
092119f
03a5275
249e80d
6658562
1883cef
214be7d
3d95892
25804c4
8803914
0f8eb3c
9417cfe
3c09a50
4f1e101
fc59d67
d2f0930
db00df1
ae76737
c42b972
d7ad9ec
288bd12
1ca1bc1
cc6aebf
447ee16
3379d8c
6babc16
7aa8052
46953f8
eb1eff6
f9316e4
79f862a
ce95a4f
5268419
0fc5fb2
6dc0eae
0eea4f8
27dbbbe
6fe58f7
f1b313f
d31dd88
6446ffc
0fbf8d8
621f445
7fc7280
5dadf43
3ed3233
d941c85
a2ab009
2bbd6e0
903dd42
a2bcd8e
a34756a
7f41bbd
9e5c0a1
71e182f
bf4479a
4e7478b
6eb06ce
a3e47d4
8746a25
6d8de61
94aa394
d74456d
d6b548f
676abb4
5e9bb4b
e6e79fd
43365b5
1b40d29
5e837ed
5bb2886
cf3c720
03f4c95
8e567de
2d3822c
4027771
7aaed10
626f5e3
c6fcad3
16c2dbd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
- :class:`preprocessing.KBinsDiscretizer` with `strategy="uniform"` now | ||
accepts `sample_weight`. Additionally with `strategy="quantile"` the | ||
`quantile_method` can now be specified (in the future | ||
`quantile_method="averaged_inverted_cdf"` will become the default) | ||
:pr:`29907` by :user:`Shruti Nath <snath-xoc>` and :user:`Olivier Grisel | ||
<ogrisel>` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
- :class:`preprocessing.KBinsDiscretizer` now uses weighted resampling when | ||
sample weights are given and subsampling is used. This may change results | ||
even when not using sample weights, although in absolute and not in terms | ||
of statistical properties. | ||
:pr:`29907` by :user:`Shruti Nath <snath-xoc>` and :user:`Jérémie du Boisberranger | ||
<jeremiedbb>` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
|
||
- :func: `resample` now handles sample weights which allows | ||
weighted resampling. | ||
- :func: `_averaged_weighted_percentile` now added which implements | ||
an averaged inverted cdf calculation of percentiles. | ||
:pr:`29907` by :user:`Shruti Nath <snath-xoc>` and :user:`Olivier Grisel | ||
<ogrisel>` |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,8 @@ | |
from ..utils import resample | ||
from ..utils._param_validation import Interval, Options, StrOptions | ||
from ..utils.deprecation import _deprecate_Xt_in_inverse_transform | ||
from ..utils.stats import _weighted_percentile | ||
from ..utils.fixes import np_version, parse_version | ||
from ..utils.stats import _averaged_weighted_percentile, _weighted_percentile | ||
from ..utils.validation import ( | ||
_check_feature_names_in, | ||
_check_sample_weight, | ||
|
@@ -57,6 +58,17 @@ class KBinsDiscretizer(TransformerMixin, BaseEstimator): | |
For an example of the different strategies see: | ||
:ref:`sphx_glr_auto_examples_preprocessing_plot_discretization_strategies.py`. | ||
|
||
quantile_method : {"inverted_cdf", "averaged_inverted_cdf", | ||
"closest_observation", "interpolated_inverted_cdf", "hazen", | ||
"weibull", "linear", "median_unbiased", "normal_unbiased"}, | ||
default="linear" | ||
Method to pass on to np.percentile calculation when using | ||
strategy="quantile". Only `averaged_inverted_cdf` and `inverted_cdf` | ||
support the use of `sample_weight != None` when subsampling is not | ||
active. | ||
|
||
.. versionadded:: 1.7 | ||
|
||
dtype : {np.float32, np.float64}, default=None | ||
The desired data-type for the output. If None, output dtype is | ||
consistent with input dtype. Only np.float32 and np.float64 are | ||
|
@@ -175,6 +187,22 @@ class KBinsDiscretizer(TransformerMixin, BaseEstimator): | |
"n_bins": [Interval(Integral, 2, None, closed="left"), "array-like"], | ||
"encode": [StrOptions({"onehot", "onehot-dense", "ordinal"})], | ||
"strategy": [StrOptions({"uniform", "quantile", "kmeans"})], | ||
"quantile_method": [ | ||
StrOptions( | ||
{ | ||
"warn", | ||
"inverted_cdf", | ||
"averaged_inverted_cdf", | ||
"closest_observation", | ||
"interpolated_inverted_cdf", | ||
"hazen", | ||
"weibull", | ||
"linear", | ||
"median_unbiased", | ||
"normal_unbiased", | ||
} | ||
) | ||
], | ||
"dtype": [Options(type, {np.float64, np.float32}), None], | ||
"subsample": [Interval(Integral, 1, None, closed="left"), None], | ||
"random_state": ["random_state"], | ||
|
@@ -186,13 +214,15 @@ def __init__( | |
*, | ||
encode="onehot", | ||
strategy="quantile", | ||
quantile_method="warn", | ||
dtype=None, | ||
subsample=200_000, | ||
random_state=None, | ||
): | ||
self.n_bins = n_bins | ||
self.encode = encode | ||
self.strategy = strategy | ||
self.quantile_method = quantile_method | ||
self.dtype = dtype | ||
self.subsample = subsample | ||
self.random_state = random_state | ||
|
@@ -213,10 +243,12 @@ def fit(self, X, y=None, sample_weight=None): | |
|
||
sample_weight : ndarray of shape (n_samples,) | ||
Contains weight values to be associated with each sample. | ||
Cannot be used when `strategy` is set to `"uniform"`. | ||
|
||
.. versionadded:: 1.3 | ||
|
||
.. versionchanged:: 1.7 | ||
Added support for strategy="uniform". | ||
|
||
Returns | ||
------- | ||
self : object | ||
|
@@ -231,32 +263,74 @@ def fit(self, X, y=None, sample_weight=None): | |
|
||
n_samples, n_features = X.shape | ||
|
||
if sample_weight is not None and self.strategy == "uniform": | ||
raise ValueError( | ||
"`sample_weight` was provided but it cannot be " | ||
"used with strategy='uniform'. Got strategy=" | ||
f"{self.strategy!r} instead." | ||
) | ||
if sample_weight is not None: | ||
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype) | ||
|
||
if self.subsample is not None and n_samples > self.subsample: | ||
# Take a subsample of `X` | ||
# When resampling, it is important to subsample **with replacement** to | ||
# preserve the distribution, in particular in the presence of a few data | ||
# points with large weights. You can check this by setting `replace=False` | ||
# in sklearn.utils.test.test_indexing.test_resample_weighted and check that | ||
# it fails as a justification for this claim. | ||
X = resample( | ||
snath-xoc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
X, | ||
replace=False, | ||
replace=True, | ||
n_samples=self.subsample, | ||
random_state=self.random_state, | ||
sample_weight=sample_weight, | ||
) | ||
snath-xoc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Since we already used the weights when resampling when provided, | ||
# we set them back to `None` to avoid accounting for the weights twice | ||
# in subsequent operations to compute weight-aware bin edges with | ||
# quantiles or k-means. | ||
sample_weight = None | ||
|
||
n_features = X.shape[1] | ||
n_bins = self._validate_n_bins(n_features) | ||
|
||
if sample_weight is not None: | ||
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype) | ||
|
||
bin_edges = np.zeros(n_features, dtype=object) | ||
|
||
# TODO(1.9): remove and switch to quantile_method="averaged_inverted_cdf" | ||
# by default. | ||
quantile_method = self.quantile_method | ||
if self.strategy == "quantile" and quantile_method == "warn": | ||
snath-xoc marked this conversation as resolved.
Show resolved
Hide resolved
snath-xoc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
warnings.warn( | ||
"The current default behavior, quantile_method='linear', will be " | ||
"changed to quantile_method='averaged_inverted_cdf' in " | ||
"scikit-learn version 1.9 to naturally support sample weight " | ||
"equivalence properties by default. Pass " | ||
Comment on lines
+301
to
+302
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is "sample weight equivalence properties" - could we rephrase this (or am I the only one who doesn't know it)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the weight of a data point =n <=> this data point is repeated n times It's described in the glossary https://scikit-learn.org/dev/glossary.html#term-sample_weight, but you'll notice that this section is not complete and not very precise :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @betatim thank you for the coments, I can imagine it's a lot of conversation to follow, as @jeremiedbb said, we conducted some tests with the gist here. @jeremiedbb found that when specifying an odd case of samples (i.e., with one large outlier), the resampling with repacement strategy held best in terms of sample weight equivalence (i.e., equivalent properties when weighting vs. repeating a sample n times, where n=0 means omitting the sample altogether). We also had to modify the "uniform" strategy with @ogrisel's help and implement an averaged weighted quantile strategy with accompanying tests. Let me know if you need further clarification :) . There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jeremiedbb I will open a PR to improve the glossary entry. This is an excellent place to document the expected semantics for @betatim Meanwhile, the precise definition of that equivalence is given in The fact that an estimator has a deterministic fit or not, can both depend on the choice of the hyper-parameters (e.g. For estimators with a stochastic fit, we have to run a more computational intensive statistical test to check if the equivalence property holds. This is currently done in external notebooks from this gist: https://gist.github.com/snath-xoc/fb28feab39403a1e66b00b5b28f1dcbf . We don't have specific plans to refactor those notebooks as a common test at this time because:
So for now, we keep those notebooks as a guide to help us detect stochastic estimators with suspicious There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I opened the PR: #30564. |
||
"quantile_method='averaged_inverted_cdf' explicitly to silence this " | ||
"warning.", | ||
FutureWarning, | ||
) | ||
quantile_method = "linear" | ||
|
||
if ( | ||
self.strategy == "quantile" | ||
and quantile_method not in ["inverted_cdf", "averaged_inverted_cdf"] | ||
and sample_weight is not None | ||
): | ||
raise ValueError( | ||
"When fitting with strategy='quantile' and sample weights, " | ||
"quantile_method should either be set to 'averaged_inverted_cdf' or " | ||
f"'inverted_cdf', got quantile_method='{quantile_method}' instead." | ||
) | ||
|
||
if self.strategy != "quantile" and sample_weight is not None: | ||
# Preprare a mask to filter out zero-weight samples when extracting | ||
# the min and max values of each columns which are needed for the | ||
# "uniform" and "kmeans" strategies. | ||
nnz_weight_mask = sample_weight != 0 | ||
else: | ||
# Otherwise, all samples are used. Use a slice to avoid creating a | ||
# new array. | ||
nnz_weight_mask = slice(None) | ||
|
||
for jj in range(n_features): | ||
column = X[:, jj] | ||
col_min, col_max = column.min(), column.max() | ||
col_min = column[nnz_weight_mask].min() | ||
col_max = column[nnz_weight_mask].max() | ||
|
||
if col_min == col_max: | ||
warnings.warn( | ||
|
@@ -270,14 +344,47 @@ def fit(self, X, y=None, sample_weight=None): | |
bin_edges[jj] = np.linspace(col_min, col_max, n_bins[jj] + 1) | ||
|
||
elif self.strategy == "quantile": | ||
quantiles = np.linspace(0, 100, n_bins[jj] + 1) | ||
percentile_levels = np.linspace(0, 100, n_bins[jj] + 1) | ||
|
||
# TODO: simplify the following when numpy min version >= 1.22. | ||
|
||
# method="linear" is the implicit default for any numpy | ||
# version. So we keep it version independent in that case by | ||
# using an empty param dict. | ||
percentile_kwargs = {} | ||
if quantile_method != "linear" and sample_weight is None: | ||
if np_version < parse_version("1.22"): | ||
if quantile_method in ["averaged_inverted_cdf", "inverted_cdf"]: | ||
# The method parameter is not supported in numpy < | ||
# 1.22 but we can define unit sample weight to use | ||
# our own implementation instead: | ||
sample_weight = np.ones(X.shape[0], dtype=X.dtype) | ||
else: | ||
raise ValueError( | ||
f"quantile_method='{quantile_method}' is not " | ||
"supported with numpy < 1.22" | ||
) | ||
else: | ||
percentile_kwargs["method"] = quantile_method | ||
|
||
if sample_weight is None: | ||
bin_edges[jj] = np.asarray(np.percentile(column, quantiles)) | ||
bin_edges[jj] = np.asarray( | ||
np.percentile(column, percentile_levels, **percentile_kwargs), | ||
dtype=np.float64, | ||
) | ||
else: | ||
# TODO: make _weighted_percentile and | ||
# _averaged_weighted_percentile accept an array of | ||
# quantiles instead of calling it multiple times and | ||
# sorting the column multiple times as a result. | ||
percentile_func = { | ||
"inverted_cdf": _weighted_percentile, | ||
"averaged_inverted_cdf": _averaged_weighted_percentile, | ||
}[quantile_method] | ||
bin_edges[jj] = np.asarray( | ||
[ | ||
ogrisel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
_weighted_percentile(column, sample_weight, q) | ||
for q in quantiles | ||
percentile_func(column, sample_weight, percentile=p) | ||
for p in percentile_levels | ||
], | ||
dtype=np.float64, | ||
) | ||
ogrisel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@snath-xoc I missed that in the review, but we should not mention private API related names in the changelog. It's confusing: the primary users of the changelog are library users, and they should not be incentivized to use private API in their code otherwise, since it can break without deprecation cycles when upgrading scikit-learn.
Can you please open a new PR to delete those lines?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
O.K. will do, thanks for letting me know!