-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
ENH: refactored utils/validation._check_sample_weights() and added stronger sample_weights checks for all estimators #14653
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ac19cd6
4c25185
7b80778
e679df8
bdeb2c2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1025,28 +1025,44 @@ def _check_sample_weight(sample_weight, X, dtype=None): | |
""" | ||
n_samples = _num_samples(X) | ||
|
||
if dtype is not None and dtype not in [np.float32, np.float64]: | ||
# this check is needed to ensure that we don't change the dtype of | ||
# of sample_weight if it's already np.float32. | ||
# since sample_weight can be a list or an array, we first | ||
# need to verify that it has a dtype attribute before the check. | ||
# if dtype is None or any other type besides np.float32, np.float64 | ||
# is given. | ||
|
||
if hasattr(sample_weight, "dtype"): | ||
dtype = sample_weight.dtype | ||
|
||
if dtype not in [np.float32, np.float64]: | ||
dtype = np.float64 | ||
|
||
if sample_weight is None or isinstance(sample_weight, numbers.Number): | ||
if sample_weight is None: | ||
sample_weight = np.ones(n_samples, dtype=dtype) | ||
else: | ||
elif isinstance(sample_weight, numbers.Number): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the else statement was fine, wasn't it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would even write if sample_weigt is None:
sample_weight = np.ones(...)
elif isinstance(..., Number):
...
else:
return sample_weight Basically remove the first There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
sample_weight = np.full(n_samples, sample_weight, | ||
dtype=dtype) | ||
else: | ||
if dtype is None: | ||
dtype = [np.float64, np.float32] | ||
sample_weight = check_array( | ||
sample_weight, accept_sparse=False, | ||
ensure_2d=False, dtype=dtype, order="C" | ||
) | ||
if sample_weight.ndim != 1: | ||
raise ValueError("Sample weights must be 1D array or scalar") | ||
|
||
if sample_weight.shape != (n_samples,): | ||
raise ValueError("sample_weight.shape == {}, expected {}!" | ||
.format(sample_weight.shape, (n_samples,))) | ||
return sample_weight | ||
|
||
# at this point, sample_weight is either a list or | ||
# an array. These checks will validate that the dtype | ||
# of the returned sample_weight is either np.float32 or | ||
# np.float64. If sample weight contained elements which | ||
# cannot be passed safely to the above types, the | ||
# following line will raise a ValueError | ||
sample_weight = np.array(sample_weight, dtype=dtype) | ||
|
||
# sample_weights must be 1-D arrays | ||
if sample_weight.ndim != 1: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We already return sample_weight if it was an array. Shall make the check as well in this case? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't return sample weight if it's an array. We need to check that it's |
||
raise ValueError("Sample weights must be 1D array or scalar") | ||
|
||
# and must have the same number of elements | ||
# as X | ||
if sample_weight.shape[0] != n_samples: | ||
raise ValueError("sample_weight.shape == {}, expected {}!" | ||
.format(sample_weight.shape, (n_samples, ))) | ||
return sample_weight | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is useless I think. When this is an array we will return it directly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It could be an array of strings, returning it immediately can lead to problems later.
Edit: or it could have the wrong number of elements or dimensions.