Skip to content

ENH: refactored utils/validation._check_sample_weights() and added stronger sample_weights checks for all estimators #14653

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions sklearn/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .preprocessing import label_binarize, LabelBinarizer
from .utils import check_X_y, check_array, indexable, column_or_1d
from .utils.validation import check_is_fitted, check_consistent_length
from .utils.validation import _check_sample_weight
from .isotonic import IsotonicRegression
from .svm import LinearSVC
from .model_selection import check_cv
Expand Down Expand Up @@ -155,6 +156,9 @@ def fit(self, X, y, sample_weight=None):
else:
base_estimator = self.base_estimator

if sample_weight is not None:
sample_weight = _check_sample_weight(sample_weight, X)

if self.cv == "prefit":
calibrated_classifier = _CalibratedClassifier(
base_estimator, method=self.method)
Expand All @@ -172,12 +176,8 @@ def fit(self, X, y, sample_weight=None):
warnings.warn("%s does not support sample_weight. Samples"
" weights are only used for the calibration"
" itself." % estimator_name)
sample_weight = check_array(sample_weight, ensure_2d=False)
base_estimator_sample_weight = None
else:
if sample_weight is not None:
sample_weight = check_array(sample_weight, ensure_2d=False)
check_consistent_length(y, sample_weight)
base_estimator_sample_weight = sample_weight
for train, test in cv.split(X, y):
this_estimator = clone(base_estimator)
Expand Down
4 changes: 2 additions & 2 deletions sklearn/ensemble/forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class calls the ``fit`` method of each sub-estimator on random samples
from .base import BaseEnsemble, _partition_estimators
from ..utils.fixes import parallel_helper, _joblib_parallel_args
from ..utils.multiclass import check_classification_targets
from ..utils.validation import check_is_fitted
from ..utils.validation import check_is_fitted, _check_sample_weight


__all__ = ["RandomForestClassifier",
Expand Down Expand Up @@ -243,7 +243,7 @@ def fit(self, X, y, sample_weight=None):
X = check_array(X, accept_sparse="csc", dtype=DTYPE)
y = check_array(y, accept_sparse='csc', ensure_2d=False, dtype=None)
if sample_weight is not None:
sample_weight = check_array(sample_weight, ensure_2d=False)
sample_weight = _check_sample_weight(sample_weight, X)
if issparse(X):
# Pre-sort indices to avoid that each individual tree of the
# ensemble sorts the indices.
Expand Down
11 changes: 5 additions & 6 deletions sklearn/linear_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# License: BSD 3 clause

from abc import ABCMeta, abstractmethod
import numbers
import warnings

import numpy as np
Expand All @@ -34,7 +33,7 @@
from ..utils.fixes import sparse_lsqr
from ..utils.seq_dataset import ArrayDataset32, CSRDataset32
from ..utils.seq_dataset import ArrayDataset64, CSRDataset64
from ..utils.validation import check_is_fitted
from ..utils.validation import check_is_fitted, _check_sample_weight
from ..preprocessing.data import normalize as f_normalize

# TODO: bayesian_ridge_regression and bayesian_regression_ard
Expand Down Expand Up @@ -118,8 +117,8 @@ def _preprocess_data(X, y, fit_intercept, normalize=False, copy=True,
centered. This function also systematically makes y consistent with X.dtype
"""

if isinstance(sample_weight, numbers.Number):
sample_weight = None
if sample_weight is not None:
sample_weight = _check_sample_weight(sample_weight, X)

if check_input:
X = check_array(X, copy=copy, accept_sparse=['csr', 'csc'],
Expand Down Expand Up @@ -467,8 +466,8 @@ def fit(self, X, y, sample_weight=None):
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc', 'coo'],
y_numeric=True, multi_output=True)

if sample_weight is not None and np.atleast_1d(sample_weight).ndim > 1:
raise ValueError("Sample weights must be 1D array or scalar")
if sample_weight is not None:
sample_weight = _check_sample_weight(sample_weight, X, X.dtype)

X, y, X_offset, y_offset, X_scale = self._preprocess_data(
X, y, fit_intercept=self.fit_intercept, normalize=self.normalize,
Expand Down
6 changes: 6 additions & 0 deletions sklearn/utils/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,12 @@ def test_check_sample_weight():
sample_weight = _check_sample_weight(None, X, dtype=X.dtype)
assert sample_weight.dtype == np.float64

# wrongly formated sample_weight
sample_weight = np.array(["1", "pi", "e"])
err_msg = "could not convert string to float: 'pi'"
with pytest.raises(ValueError, match=err_msg):
_check_sample_weight(sample_weight, X)


@pytest.mark.parametrize("toarray", [
np.array, sp.csr_matrix, sp.csc_matrix])
Expand Down
46 changes: 31 additions & 15 deletions sklearn/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,28 +1025,44 @@ def _check_sample_weight(sample_weight, X, dtype=None):
"""
n_samples = _num_samples(X)

if dtype is not None and dtype not in [np.float32, np.float64]:
# this check is needed to ensure that we don't change the dtype of
# of sample_weight if it's already np.float32.
# since sample_weight can be a list or an array, we first
# need to verify that it has a dtype attribute before the check.
# if dtype is None or any other type besides np.float32, np.float64
# is given.

if hasattr(sample_weight, "dtype"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is useless I think. When this is an array we will return it directly

Copy link
Contributor Author

@maxwell-aladago maxwell-aladago Aug 23, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could be an array of strings, returning it immediately can lead to problems later.

Edit: or it could have the wrong number of elements or dimensions.

dtype = sample_weight.dtype

if dtype not in [np.float32, np.float64]:
dtype = np.float64

if sample_weight is None or isinstance(sample_weight, numbers.Number):
if sample_weight is None:
sample_weight = np.ones(n_samples, dtype=dtype)
else:
elif isinstance(sample_weight, numbers.Number):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the else statement was fine, wasn't it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would even write

if sample_weigt is None:
    sample_weight = np.ones(...)
elif isinstance(..., Number):
    ...
else:
    return sample_weight

Basically remove the first

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sample_weight may have the wrong number of dimensions or elements if it's already an array. Thus, the further checks are necessary. We can only ignore the checks below if sample_weight is created within the function (i.e, when sample_weight is None or it's an integer

sample_weight = np.full(n_samples, sample_weight,
dtype=dtype)
else:
if dtype is None:
dtype = [np.float64, np.float32]
sample_weight = check_array(
sample_weight, accept_sparse=False,
ensure_2d=False, dtype=dtype, order="C"
)
if sample_weight.ndim != 1:
raise ValueError("Sample weights must be 1D array or scalar")

if sample_weight.shape != (n_samples,):
raise ValueError("sample_weight.shape == {}, expected {}!"
.format(sample_weight.shape, (n_samples,)))
return sample_weight

# at this point, sample_weight is either a list or
# an array. These checks will validate that the dtype
# of the returned sample_weight is either np.float32 or
# np.float64. If sample weight contained elements which
# cannot be passed safely to the above types, the
# following line will raise a ValueError
sample_weight = np.array(sample_weight, dtype=dtype)

# sample_weights must be 1-D arrays
if sample_weight.ndim != 1:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already return sample_weight if it was an array. Shall make the check as well in this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't return sample weight if it's an array. We need to check that it's dtype is one of np.float32 or np.float64.

raise ValueError("Sample weights must be 1D array or scalar")

# and must have the same number of elements
# as X
if sample_weight.shape[0] != n_samples:
raise ValueError("sample_weight.shape == {}, expected {}!"
.format(sample_weight.shape, (n_samples, )))
return sample_weight


Expand Down