Skip to content

MAINT Common sample_weight validation #14307

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 35 commits into from
Jul 19, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
9800b4b
Common sample_weight validation in huber
rth Jul 12, 2019
95df187
Use _check_sample_weight in linear_model/stochastic_gradient
rth Jul 12, 2019
c889db3
Use check_sample weight in logistic
rth Jul 12, 2019
bd52cfc
Better order parameter
rth Jul 12, 2019
9e108a4
Sample weight checks in svm
rth Jul 12, 2019
4ff292d
Tests
rth Jul 12, 2019
bb64a9b
Minor fix
rth Jul 12, 2019
29e4ff6
Better handle numeric sample_weight
rth Jul 12, 2019
4d7bb15
Update sklearn/utils/validation.py
rth Jul 12, 2019
1c0f6a7
Update sklearn/utils/validation.py
rth Jul 12, 2019
59abc05
Update sklearn/utils/validation.py
rth Jul 12, 2019
84b0ac0
Update sklearn/utils/validation.py
rth Jul 12, 2019
4ea0694
Update sklearn/utils/validation.py
rth Jul 12, 2019
cfc7a97
Update sklearn/utils/validation.py
rth Jul 12, 2019
908bbfc
Update sklearn/utils/validation.py
rth Jul 12, 2019
c6280b6
Update sklearn/utils/validation.py
rth Jul 12, 2019
2b84f90
Update sklearn/utils/validation.py
rth Jul 12, 2019
d81fec1
Update sklearn/utils/validation.py
rth Jul 12, 2019
b2b1773
Remove kwargs
rth Jul 12, 2019
22f9275
Use _check_sample_weight(sample_weight, X)
rth Jul 12, 2019
c28226a
Update kmeans to use _check_sample_weight
rth Jul 12, 2019
ed2dc69
Non float dtype should not be supported
rth Jul 12, 2019
0ce20bc
Merge remote-tracking branch 'upstream/master' into check-sample-weight
rth Jul 15, 2019
380d9eb
Drop order parameter + address review comments
rth Jul 15, 2019
3fa5f73
Fix typo
rth Jul 15, 2019
08e204f
Fix tests
rth Jul 15, 2019
44d99c1
Address review
rth Jul 15, 2019
3fc9d1a
Attempt to fix 32bit / 64bit comparison test
rth Jul 16, 2019
22e1070
Check that float32 dtype is preserved
rth Jul 16, 2019
561bb6a
Update sklearn/utils/validation.py
rth Jul 18, 2019
71ecf65
Lint
rth Jul 18, 2019
e244ad5
Update sklearn/utils/validation.py
rth Jul 19, 2019
13f9dec
Improve docstring wording
rth Jul 19, 2019
9cccaf6
Fix tests
rth Jul 19, 2019
fb22cfc
Nicolas's comments
rth Jul 19, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions sklearn/cluster/k_means_.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from ..utils import check_array
from ..utils import gen_batches
from ..utils import check_random_state
from ..utils.validation import check_is_fitted
from ..utils.validation import check_is_fitted, _check_sample_weight
from ..utils.validation import FLOAT_DTYPES
from ..exceptions import ConvergenceWarning
from . import _k_means
Expand Down Expand Up @@ -164,19 +164,19 @@ def _tolerance(X, tol):
return np.mean(variances) * tol


def _check_sample_weight(X, sample_weight):
def _check_normalize_sample_weight(sample_weight, X):
"""Set sample_weight if None, and check for correct dtype"""
n_samples = X.shape[0]
if sample_weight is None:
return np.ones(n_samples, dtype=X.dtype)
else:
sample_weight = np.asarray(sample_weight)
if n_samples != len(sample_weight):
raise ValueError("n_samples=%d should be == len(sample_weight)=%d"
% (n_samples, len(sample_weight)))

sample_weight_was_none = sample_weight is None

sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)
if not sample_weight_was_none:
# normalize the weights to sum up to n_samples
# an array of 1 (i.e. samples_weight is None) is already normalized
n_samples = len(sample_weight)
scale = n_samples / sample_weight.sum()
return (sample_weight * scale).astype(X.dtype, copy=False)
sample_weight *= scale
return sample_weight


def k_means(X, n_clusters, sample_weight=None, init='k-means++',
Expand Down Expand Up @@ -434,7 +434,7 @@ def _kmeans_single_elkan(X, sample_weight, n_clusters, max_iter=300,
if verbose:
print('Initialization complete')

checked_sample_weight = _check_sample_weight(X, sample_weight)
checked_sample_weight = _check_normalize_sample_weight(sample_weight, X)
centers, labels, n_iter = k_means_elkan(X, checked_sample_weight,
n_clusters, centers, tol=tol,
max_iter=max_iter, verbose=verbose)
Expand Down Expand Up @@ -519,7 +519,7 @@ def _kmeans_single_lloyd(X, sample_weight, n_clusters, max_iter=300,
"""
random_state = check_random_state(random_state)

sample_weight = _check_sample_weight(X, sample_weight)
sample_weight = _check_normalize_sample_weight(sample_weight, X)

best_labels, best_inertia, best_centers = None, None, None
# init
Expand Down Expand Up @@ -662,7 +662,7 @@ def _labels_inertia(X, sample_weight, x_squared_norms, centers,
Sum of squared distances of samples to their closest cluster center.
"""
n_samples = X.shape[0]
sample_weight = _check_sample_weight(X, sample_weight)
sample_weight = _check_normalize_sample_weight(sample_weight, X)
# set the default value of centers to -1 to be able to detect any anomaly
# easily
labels = np.full(n_samples, -1, np.int32)
Expand Down Expand Up @@ -1492,7 +1492,7 @@ def fit(self, X, y=None, sample_weight=None):
raise ValueError("n_samples=%d should be >= n_clusters=%d"
% (n_samples, self.n_clusters))

sample_weight = _check_sample_weight(X, sample_weight)
sample_weight = _check_normalize_sample_weight(sample_weight, X)

n_init = self.n_init
if hasattr(self.init, '__array__'):
Expand Down Expand Up @@ -1641,7 +1641,7 @@ def _labels_inertia_minibatch(self, X, sample_weight):
"""
if self.verbose:
print('Computing label assignment and total inertia')
sample_weight = _check_sample_weight(X, sample_weight)
sample_weight = _check_normalize_sample_weight(sample_weight, X)
x_squared_norms = row_norms(X, squared=True)
slices = gen_batches(X.shape[0], self.batch_size)
results = [_labels_inertia(X[s], sample_weight[s], x_squared_norms[s],
Expand Down Expand Up @@ -1676,7 +1676,7 @@ def partial_fit(self, X, y=None, sample_weight=None):
if n_samples == 0:
return self

sample_weight = _check_sample_weight(X, sample_weight)
sample_weight = _check_normalize_sample_weight(sample_weight, X)

x_squared_norms = row_norms(X, squared=True)
self.random_state_ = getattr(self, "random_state_",
Expand Down
11 changes: 6 additions & 5 deletions sklearn/cluster/tests/test_k_means.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,14 +909,15 @@ def test_sample_weight_length():
# check that an error is raised when passing sample weights
# with an incompatible shape
km = KMeans(n_clusters=n_clusters, random_state=42)
assert_raises_regex(ValueError, r'len\(sample_weight\)', km.fit, X,
sample_weight=np.ones(2))
msg = r'sample_weight.shape == \(2,\), expected \(100,\)'
with pytest.raises(ValueError, match=msg):
km.fit(X, sample_weight=np.ones(2))


def test_check_sample_weight():
from sklearn.cluster.k_means_ import _check_sample_weight
def test_check_normalize_sample_weight():
from sklearn.cluster.k_means_ import _check_normalize_sample_weight
sample_weight = None
checked_sample_weight = _check_sample_weight(X, sample_weight)
checked_sample_weight = _check_normalize_sample_weight(sample_weight, X)
assert _num_samples(X) == _num_samples(checked_sample_weight)
assert_almost_equal(checked_sample_weight.sum(), _num_samples(X))
assert X.dtype == checked_sample_weight.dtype
Expand Down
9 changes: 3 additions & 6 deletions sklearn/linear_model/huber.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from ..base import BaseEstimator, RegressorMixin
from .base import LinearModel
from ..utils import check_X_y
from ..utils import check_consistent_length
from ..utils import axis0_safe_slice
from ..utils.validation import _check_sample_weight
from ..utils.extmath import safe_sparse_dot
from ..utils.optimize import _check_optimize_result

Expand Down Expand Up @@ -255,11 +255,8 @@ def fit(self, X, y, sample_weight=None):
X, y = check_X_y(
X, y, copy=False, accept_sparse=['csr'], y_numeric=True,
dtype=[np.float64, np.float32])
if sample_weight is not None:
sample_weight = np.array(sample_weight)
check_consistent_length(y, sample_weight)
else:
sample_weight = np.ones_like(y)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty sure this didn't produce sample weights with the expected dtype when y was an integer array.


sample_weight = _check_sample_weight(sample_weight, X)

if self.epsilon < 1.0:
raise ValueError(
Expand Down
13 changes: 4 additions & 9 deletions sklearn/linear_model/logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ..utils.fixes import logsumexp
from ..utils.optimize import newton_cg, _check_optimize_result
from ..utils.validation import check_X_y
from ..utils.validation import check_is_fitted
from ..utils.validation import check_is_fitted, _check_sample_weight
from ..utils import deprecated
from ..exceptions import ChangedBehaviorWarning
from ..utils.multiclass import check_classification_targets
Expand Down Expand Up @@ -826,11 +826,8 @@ def _logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
# If sample weights exist, convert them to array (support for lists)
# and check length
# Otherwise set them to 1 for all examples
if sample_weight is not None:
sample_weight = np.array(sample_weight, dtype=X.dtype, order='C')
check_consistent_length(y, sample_weight)
else:
sample_weight = np.ones(X.shape[0], dtype=X.dtype)
sample_weight = _check_sample_weight(sample_weight, X,
dtype=X.dtype)

# If class_weights is a dict (provided by the user), the weights
# are assigned to the original labels. If it is "balanced", then
Expand Down Expand Up @@ -1133,9 +1130,7 @@ def _log_reg_scoring_path(X, y, train, test, pos_class=None, Cs=10,
y_test = y[test]

if sample_weight is not None:
sample_weight = check_array(sample_weight, ensure_2d=False)
check_consistent_length(y, sample_weight)

sample_weight = _check_sample_weight(sample_weight, X)
sample_weight = sample_weight[train]

coefs, Cs, n_iter = _logistic_regression_path(
Expand Down
5 changes: 2 additions & 3 deletions sklearn/linear_model/ransac.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ..base import MultiOutputMixin
from ..utils import check_random_state, check_array, check_consistent_length
from ..utils.random import sample_without_replacement
from ..utils.validation import check_is_fitted
from ..utils.validation import check_is_fitted, _check_sample_weight
from .base import LinearRegression
from ..utils.validation import has_fit_parameter
from ..exceptions import ConvergenceWarning
Expand Down Expand Up @@ -324,8 +324,7 @@ def fit(self, X, y, sample_weight=None):
raise ValueError("%s does not support sample_weight. Samples"
" weights are only used for the calibration"
" itself." % estimator_name)
if sample_weight is not None:
sample_weight = np.asarray(sample_weight)
sample_weight = _check_sample_weight(sample_weight, X)

n_inliers_best = 1
score_best = -np.inf
Expand Down
9 changes: 4 additions & 5 deletions sklearn/linear_model/ridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ..utils import check_consistent_length
from ..utils import compute_sample_weight
from ..utils import column_or_1d
from ..utils.validation import _check_sample_weight
from ..preprocessing import LabelBinarizer
from ..model_selection import GridSearchCV
from ..metrics.scorer import check_scoring
Expand Down Expand Up @@ -428,8 +429,7 @@ def _ridge_regression(X, y, alpha, sample_weight=None, solver='auto',
" %d != %d" % (n_samples, n_samples_))

if has_sw:
if np.atleast_1d(sample_weight).ndim > 1:
raise ValueError("Sample weights must be 1D array or scalar")
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)

if solver not in ['sag', 'saga']:
# SAG supports sample_weight directly. For other solvers,
Expand Down Expand Up @@ -1406,9 +1406,8 @@ def fit(self, X, y, sample_weight=None):
"alphas must be positive. Got {} containing some "
"negative or null value instead.".format(self.alphas))

if sample_weight is not None and not isinstance(sample_weight, float):
sample_weight = check_array(sample_weight, ensure_2d=False,
dtype=X.dtype)
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)

n_samples, n_features = X.shape

X, y, X_offset, y_offset, X_scale = LinearModel._preprocess_data(
Expand Down
4 changes: 2 additions & 2 deletions sklearn/linear_model/sag.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .sag_fast import sag32, sag64
from ..exceptions import ConvergenceWarning
from ..utils import check_array
from ..utils.validation import _check_sample_weight
from ..utils.extmath import row_norms


Expand Down Expand Up @@ -251,8 +252,7 @@ def sag_solver(X, y, sample_weight=None, loss='log', alpha=1., beta=0.,
n_classes = int(y.max()) + 1 if loss == 'multinomial' else 1

# initialization
if sample_weight is None:
sample_weight = np.ones(n_samples, dtype=X.dtype, order='C')
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)

if 'coef' in warm_start_mem.keys():
coef_init = warm_start_mem['coef']
Expand Down
21 changes: 4 additions & 17 deletions sklearn/linear_model/stochastic_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ..utils import check_array, check_random_state, check_X_y
from ..utils.extmath import safe_sparse_dot
from ..utils.multiclass import _check_partial_fit_first_call
from ..utils.validation import check_is_fitted
from ..utils.validation import check_is_fitted, _check_sample_weight
from ..exceptions import ConvergenceWarning
from ..model_selection import StratifiedShuffleSplit, ShuffleSplit

Expand Down Expand Up @@ -169,19 +169,6 @@ def _get_penalty_type(self, penalty):
except KeyError:
raise ValueError("Penalty %s is not supported. " % penalty)

def _validate_sample_weight(self, sample_weight, n_samples):
"""Set the sample weight array."""
if sample_weight is None:
# uniform sample weights
sample_weight = np.ones(n_samples, dtype=np.float64, order='C')
else:
# user-provided array
sample_weight = np.asarray(sample_weight, dtype=np.float64,
order="C")
if sample_weight.shape[0] != n_samples:
raise ValueError("Shapes of X and sample_weight do not match.")
return sample_weight

def _allocate_parameter_mem(self, n_classes, n_features, coef_init=None,
intercept_init=None):
"""Allocate mem for parameters; initialize if provided."""
Expand Down Expand Up @@ -488,7 +475,7 @@ def _partial_fit(self, X, y, alpha, C,
# Allocate datastructures from input arguments
self._expanded_class_weight = compute_class_weight(self.class_weight,
self.classes_, y)
sample_weight = self._validate_sample_weight(sample_weight, n_samples)
sample_weight = _check_sample_weight(sample_weight, X)

if getattr(self, "coef_", None) is None or coef_init is not None:
self._allocate_parameter_mem(n_classes, n_features,
Expand Down Expand Up @@ -1095,9 +1082,9 @@ def _partial_fit(self, X, y, alpha, C, loss, learning_rate,

n_samples, n_features = X.shape

# Allocate datastructures from input arguments
sample_weight = self._validate_sample_weight(sample_weight, n_samples)
sample_weight = _check_sample_weight(sample_weight, X)

# Allocate datastructures from input arguments
if getattr(self, "coef_", None) is None:
self._allocate_parameter_mem(1, n_features, coef_init,
intercept_init)
Expand Down
11 changes: 5 additions & 6 deletions sklearn/svm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
from ..base import BaseEstimator, ClassifierMixin
from ..preprocessing import LabelEncoder
from ..utils.multiclass import _ovr_decision_function
from ..utils import check_array, check_consistent_length, check_random_state
from ..utils import check_array, check_random_state
from ..utils import column_or_1d, check_X_y
from ..utils import compute_class_weight
from ..utils.extmath import safe_sparse_dot
from ..utils.validation import check_is_fitted, _check_large_sparse
from ..utils.validation import _check_sample_weight
from ..utils.multiclass import check_classification_targets
from ..exceptions import ConvergenceWarning
from ..exceptions import NotFittedError
Expand Down Expand Up @@ -906,11 +907,9 @@ def _fit_liblinear(X, y, C, fit_intercept, intercept_scaling, class_weight,
# LibLinear wants targets as doubles, even for classification
y_ind = np.asarray(y_ind, dtype=np.float64).ravel()
y_ind = np.require(y_ind, requirements="W")
if sample_weight is None:
sample_weight = np.ones(X.shape[0])
else:
sample_weight = np.array(sample_weight, dtype=np.float64, order='C')
check_consistent_length(sample_weight, X)

sample_weight = _check_sample_weight(sample_weight, X,
dtype=np.float64)

solver_type = _get_liblinear_solver_type(multi_class, penalty, loss, dual)
raw_coef_, n_iter_ = liblinear.train_wrap(
Expand Down
41 changes: 40 additions & 1 deletion sklearn/utils/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from sklearn.utils.testing import SkipTest
from sklearn.utils.testing import assert_array_equal
from sklearn.utils.testing import assert_allclose_dense_sparse
from sklearn.utils.testing import assert_allclose
from sklearn.utils import as_float_array, check_array, check_symmetric
from sklearn.utils import check_X_y
from sklearn.utils import deprecated
Expand All @@ -39,7 +40,8 @@
check_memory,
check_non_negative,
_num_samples,
check_scalar)
check_scalar,
_check_sample_weight)
import sklearn

from sklearn.exceptions import NotFittedError
Expand Down Expand Up @@ -853,3 +855,40 @@ def test_check_scalar_invalid(x, target_name, target_type, min_val, max_val,
min_val=min_val, max_val=max_val)
assert str(raised_error.value) == str(err_msg)
assert type(raised_error.value) == type(err_msg)


def test_check_sample_weight():
# check array order
sample_weight = np.ones(10)[::2]
assert not sample_weight.flags["C_CONTIGUOUS"]
sample_weight = _check_sample_weight(sample_weight, X=np.ones((5, 1)))
assert sample_weight.flags["C_CONTIGUOUS"]

# check None input
sample_weight = _check_sample_weight(None, X=np.ones((5, 2)))
assert_allclose(sample_weight, np.ones(5))

# check numbers input
sample_weight = _check_sample_weight(2.0, X=np.ones((5, 2)))
assert_allclose(sample_weight, 2 * np.ones(5))

# check wrong number of dimensions
with pytest.raises(ValueError,
match="Sample weights must be 1D array or scalar"):
_check_sample_weight(np.ones((2, 4)), X=np.ones((2, 2)))

# check incorrect n_samples
msg = r"sample_weight.shape == \(4,\), expected \(2,\)!"
with pytest.raises(ValueError, match=msg):
_check_sample_weight(np.ones(4), X=np.ones((2, 2)))

# float32 dtype is preserved
X = np.ones((5, 2))
sample_weight = np.ones(5, dtype=np.float32)
sample_weight = _check_sample_weight(sample_weight, X)
assert sample_weight.dtype == np.float32

# int dtype will be converted to float64 instead
X = np.ones((5, 2), dtype=np.int)
sample_weight = _check_sample_weight(None, X, dtype=X.dtype)
assert sample_weight.dtype == np.float64
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a test for the shape difference error? (it's implicitly tested in kmeans)

Loading