Skip to content

Commit fb169cd

Browse files
rththomasjpfan
authored andcommitted
MAINT Common sample_weight validation (#14307)
1 parent 13b4bef commit fb169cd

File tree

11 files changed

+137
-71
lines changed

11 files changed

+137
-71
lines changed

sklearn/cluster/k_means_.py

+17-17
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from ..utils import check_array
2828
from ..utils import gen_batches
2929
from ..utils import check_random_state
30-
from ..utils.validation import check_is_fitted
30+
from ..utils.validation import check_is_fitted, _check_sample_weight
3131
from ..utils.validation import FLOAT_DTYPES
3232
from ..exceptions import ConvergenceWarning
3333
from . import _k_means
@@ -167,19 +167,19 @@ def _tolerance(X, tol):
167167
return np.mean(variances) * tol
168168

169169

170-
def _check_sample_weight(X, sample_weight):
170+
def _check_normalize_sample_weight(sample_weight, X):
171171
"""Set sample_weight if None, and check for correct dtype"""
172-
n_samples = X.shape[0]
173-
if sample_weight is None:
174-
return np.ones(n_samples, dtype=X.dtype)
175-
else:
176-
sample_weight = np.asarray(sample_weight)
177-
if n_samples != len(sample_weight):
178-
raise ValueError("n_samples=%d should be == len(sample_weight)=%d"
179-
% (n_samples, len(sample_weight)))
172+
173+
sample_weight_was_none = sample_weight is None
174+
175+
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)
176+
if not sample_weight_was_none:
180177
# normalize the weights to sum up to n_samples
178+
# an array of 1 (i.e. samples_weight is None) is already normalized
179+
n_samples = len(sample_weight)
181180
scale = n_samples / sample_weight.sum()
182-
return (sample_weight * scale).astype(X.dtype, copy=False)
181+
sample_weight *= scale
182+
return sample_weight
183183

184184

185185
def k_means(X, n_clusters, sample_weight=None, init='k-means++',
@@ -437,7 +437,7 @@ def _kmeans_single_elkan(X, sample_weight, n_clusters, max_iter=300,
437437
if verbose:
438438
print('Initialization complete')
439439

440-
checked_sample_weight = _check_sample_weight(X, sample_weight)
440+
checked_sample_weight = _check_normalize_sample_weight(sample_weight, X)
441441
centers, labels, n_iter = k_means_elkan(X, checked_sample_weight,
442442
n_clusters, centers, tol=tol,
443443
max_iter=max_iter, verbose=verbose)
@@ -522,7 +522,7 @@ def _kmeans_single_lloyd(X, sample_weight, n_clusters, max_iter=300,
522522
"""
523523
random_state = check_random_state(random_state)
524524

525-
sample_weight = _check_sample_weight(X, sample_weight)
525+
sample_weight = _check_normalize_sample_weight(sample_weight, X)
526526

527527
best_labels, best_inertia, best_centers = None, None, None
528528
# init
@@ -665,7 +665,7 @@ def _labels_inertia(X, sample_weight, x_squared_norms, centers,
665665
Sum of squared distances of samples to their closest cluster center.
666666
"""
667667
n_samples = X.shape[0]
668-
sample_weight = _check_sample_weight(X, sample_weight)
668+
sample_weight = _check_normalize_sample_weight(sample_weight, X)
669669
# set the default value of centers to -1 to be able to detect any anomaly
670670
# easily
671671
labels = np.full(n_samples, -1, np.int32)
@@ -1495,7 +1495,7 @@ def fit(self, X, y=None, sample_weight=None):
14951495
raise ValueError("n_samples=%d should be >= n_clusters=%d"
14961496
% (n_samples, self.n_clusters))
14971497

1498-
sample_weight = _check_sample_weight(X, sample_weight)
1498+
sample_weight = _check_normalize_sample_weight(sample_weight, X)
14991499

15001500
n_init = self.n_init
15011501
if hasattr(self.init, '__array__'):
@@ -1644,7 +1644,7 @@ def _labels_inertia_minibatch(self, X, sample_weight):
16441644
"""
16451645
if self.verbose:
16461646
print('Computing label assignment and total inertia')
1647-
sample_weight = _check_sample_weight(X, sample_weight)
1647+
sample_weight = _check_normalize_sample_weight(sample_weight, X)
16481648
x_squared_norms = row_norms(X, squared=True)
16491649
slices = gen_batches(X.shape[0], self.batch_size)
16501650
results = [_labels_inertia(X[s], sample_weight[s], x_squared_norms[s],
@@ -1679,7 +1679,7 @@ def partial_fit(self, X, y=None, sample_weight=None):
16791679
if n_samples == 0:
16801680
return self
16811681

1682-
sample_weight = _check_sample_weight(X, sample_weight)
1682+
sample_weight = _check_normalize_sample_weight(sample_weight, X)
16831683

16841684
x_squared_norms = row_norms(X, squared=True)
16851685
self.random_state_ = getattr(self, "random_state_",

sklearn/cluster/tests/test_k_means.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -909,14 +909,15 @@ def test_sample_weight_length():
909909
# check that an error is raised when passing sample weights
910910
# with an incompatible shape
911911
km = KMeans(n_clusters=n_clusters, random_state=42)
912-
assert_raises_regex(ValueError, r'len\(sample_weight\)', km.fit, X,
913-
sample_weight=np.ones(2))
912+
msg = r'sample_weight.shape == \(2,\), expected \(100,\)'
913+
with pytest.raises(ValueError, match=msg):
914+
km.fit(X, sample_weight=np.ones(2))
914915

915916

916-
def test_check_sample_weight():
917-
from sklearn.cluster.k_means_ import _check_sample_weight
917+
def test_check_normalize_sample_weight():
918+
from sklearn.cluster.k_means_ import _check_normalize_sample_weight
918919
sample_weight = None
919-
checked_sample_weight = _check_sample_weight(X, sample_weight)
920+
checked_sample_weight = _check_normalize_sample_weight(sample_weight, X)
920921
assert _num_samples(X) == _num_samples(checked_sample_weight)
921922
assert_almost_equal(checked_sample_weight.sum(), _num_samples(X))
922923
assert X.dtype == checked_sample_weight.dtype

sklearn/linear_model/huber.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from ..base import BaseEstimator, RegressorMixin
99
from .base import LinearModel
1010
from ..utils import check_X_y
11-
from ..utils import check_consistent_length
1211
from ..utils import axis0_safe_slice
12+
from ..utils.validation import _check_sample_weight
1313
from ..utils.extmath import safe_sparse_dot
1414
from ..utils.optimize import _check_optimize_result
1515

@@ -255,11 +255,8 @@ def fit(self, X, y, sample_weight=None):
255255
X, y = check_X_y(
256256
X, y, copy=False, accept_sparse=['csr'], y_numeric=True,
257257
dtype=[np.float64, np.float32])
258-
if sample_weight is not None:
259-
sample_weight = np.array(sample_weight)
260-
check_consistent_length(y, sample_weight)
261-
else:
262-
sample_weight = np.ones_like(y)
258+
259+
sample_weight = _check_sample_weight(sample_weight, X)
263260

264261
if self.epsilon < 1.0:
265262
raise ValueError(

sklearn/linear_model/logistic.py

+4-9
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from ..utils.fixes import logsumexp
3131
from ..utils.optimize import newton_cg, _check_optimize_result
3232
from ..utils.validation import check_X_y
33-
from ..utils.validation import check_is_fitted
33+
from ..utils.validation import check_is_fitted, _check_sample_weight
3434
from ..utils import deprecated
3535
from ..exceptions import ChangedBehaviorWarning
3636
from ..utils.multiclass import check_classification_targets
@@ -826,11 +826,8 @@ def _logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
826826
# If sample weights exist, convert them to array (support for lists)
827827
# and check length
828828
# Otherwise set them to 1 for all examples
829-
if sample_weight is not None:
830-
sample_weight = np.array(sample_weight, dtype=X.dtype, order='C')
831-
check_consistent_length(y, sample_weight)
832-
else:
833-
sample_weight = np.ones(X.shape[0], dtype=X.dtype)
829+
sample_weight = _check_sample_weight(sample_weight, X,
830+
dtype=X.dtype)
834831

835832
# If class_weights is a dict (provided by the user), the weights
836833
# are assigned to the original labels. If it is "balanced", then
@@ -1133,9 +1130,7 @@ def _log_reg_scoring_path(X, y, train, test, pos_class=None, Cs=10,
11331130
y_test = y[test]
11341131

11351132
if sample_weight is not None:
1136-
sample_weight = check_array(sample_weight, ensure_2d=False)
1137-
check_consistent_length(y, sample_weight)
1138-
1133+
sample_weight = _check_sample_weight(sample_weight, X)
11391134
sample_weight = sample_weight[train]
11401135

11411136
coefs, Cs, n_iter = _logistic_regression_path(

sklearn/linear_model/ransac.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from ..base import MultiOutputMixin
1212
from ..utils import check_random_state, check_array, check_consistent_length
1313
from ..utils.random import sample_without_replacement
14-
from ..utils.validation import check_is_fitted
14+
from ..utils.validation import check_is_fitted, _check_sample_weight
1515
from .base import LinearRegression
1616
from ..utils.validation import has_fit_parameter
1717
from ..exceptions import ConvergenceWarning
@@ -324,8 +324,7 @@ def fit(self, X, y, sample_weight=None):
324324
raise ValueError("%s does not support sample_weight. Samples"
325325
" weights are only used for the calibration"
326326
" itself." % estimator_name)
327-
if sample_weight is not None:
328-
sample_weight = np.asarray(sample_weight)
327+
sample_weight = _check_sample_weight(sample_weight, X)
329328

330329
n_inliers_best = 1
331330
score_best = -np.inf

sklearn/linear_model/ridge.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from ..utils import check_consistent_length
2828
from ..utils import compute_sample_weight
2929
from ..utils import column_or_1d
30+
from ..utils.validation import _check_sample_weight
3031
from ..preprocessing import LabelBinarizer
3132
from ..model_selection import GridSearchCV
3233
from ..metrics.scorer import check_scoring
@@ -428,8 +429,7 @@ def _ridge_regression(X, y, alpha, sample_weight=None, solver='auto',
428429
" %d != %d" % (n_samples, n_samples_))
429430

430431
if has_sw:
431-
if np.atleast_1d(sample_weight).ndim > 1:
432-
raise ValueError("Sample weights must be 1D array or scalar")
432+
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)
433433

434434
if solver not in ['sag', 'saga']:
435435
# SAG supports sample_weight directly. For other solvers,
@@ -1424,9 +1424,8 @@ def fit(self, X, y, sample_weight=None):
14241424
"alphas must be positive. Got {} containing some "
14251425
"negative or null value instead.".format(self.alphas))
14261426

1427-
if sample_weight is not None and not isinstance(sample_weight, float):
1428-
sample_weight = check_array(sample_weight, ensure_2d=False,
1429-
dtype=X.dtype)
1427+
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)
1428+
14301429
n_samples, n_features = X.shape
14311430

14321431
X, y, X_offset, y_offset, X_scale = LinearModel._preprocess_data(

sklearn/linear_model/sag.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .sag_fast import sag32, sag64
1313
from ..exceptions import ConvergenceWarning
1414
from ..utils import check_array
15+
from ..utils.validation import _check_sample_weight
1516
from ..utils.extmath import row_norms
1617

1718

@@ -251,8 +252,7 @@ def sag_solver(X, y, sample_weight=None, loss='log', alpha=1., beta=0.,
251252
n_classes = int(y.max()) + 1 if loss == 'multinomial' else 1
252253

253254
# initialization
254-
if sample_weight is None:
255-
sample_weight = np.ones(n_samples, dtype=X.dtype, order='C')
255+
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)
256256

257257
if 'coef' in warm_start_mem.keys():
258258
coef_init = warm_start_mem['coef']

sklearn/linear_model/stochastic_gradient.py

+4-17
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from ..utils import check_array, check_random_state, check_X_y
1919
from ..utils.extmath import safe_sparse_dot
2020
from ..utils.multiclass import _check_partial_fit_first_call
21-
from ..utils.validation import check_is_fitted
21+
from ..utils.validation import check_is_fitted, _check_sample_weight
2222
from ..exceptions import ConvergenceWarning
2323
from ..model_selection import StratifiedShuffleSplit, ShuffleSplit
2424

@@ -169,19 +169,6 @@ def _get_penalty_type(self, penalty):
169169
except KeyError:
170170
raise ValueError("Penalty %s is not supported. " % penalty)
171171

172-
def _validate_sample_weight(self, sample_weight, n_samples):
173-
"""Set the sample weight array."""
174-
if sample_weight is None:
175-
# uniform sample weights
176-
sample_weight = np.ones(n_samples, dtype=np.float64, order='C')
177-
else:
178-
# user-provided array
179-
sample_weight = np.asarray(sample_weight, dtype=np.float64,
180-
order="C")
181-
if sample_weight.shape[0] != n_samples:
182-
raise ValueError("Shapes of X and sample_weight do not match.")
183-
return sample_weight
184-
185172
def _allocate_parameter_mem(self, n_classes, n_features, coef_init=None,
186173
intercept_init=None):
187174
"""Allocate mem for parameters; initialize if provided."""
@@ -488,7 +475,7 @@ def _partial_fit(self, X, y, alpha, C,
488475
# Allocate datastructures from input arguments
489476
self._expanded_class_weight = compute_class_weight(self.class_weight,
490477
self.classes_, y)
491-
sample_weight = self._validate_sample_weight(sample_weight, n_samples)
478+
sample_weight = _check_sample_weight(sample_weight, X)
492479

493480
if getattr(self, "coef_", None) is None or coef_init is not None:
494481
self._allocate_parameter_mem(n_classes, n_features,
@@ -1095,9 +1082,9 @@ def _partial_fit(self, X, y, alpha, C, loss, learning_rate,
10951082

10961083
n_samples, n_features = X.shape
10971084

1098-
# Allocate datastructures from input arguments
1099-
sample_weight = self._validate_sample_weight(sample_weight, n_samples)
1085+
sample_weight = _check_sample_weight(sample_weight, X)
11001086

1087+
# Allocate datastructures from input arguments
11011088
if getattr(self, "coef_", None) is None:
11021089
self._allocate_parameter_mem(1, n_features, coef_init,
11031090
intercept_init)

sklearn/svm/base.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88
from ..base import BaseEstimator, ClassifierMixin
99
from ..preprocessing import LabelEncoder
1010
from ..utils.multiclass import _ovr_decision_function
11-
from ..utils import check_array, check_consistent_length, check_random_state
11+
from ..utils import check_array, check_random_state
1212
from ..utils import column_or_1d, check_X_y
1313
from ..utils import compute_class_weight
1414
from ..utils.extmath import safe_sparse_dot
1515
from ..utils.validation import check_is_fitted, _check_large_sparse
16+
from ..utils.validation import _check_sample_weight
1617
from ..utils.multiclass import check_classification_targets
1718
from ..exceptions import ConvergenceWarning
1819
from ..exceptions import NotFittedError
@@ -906,11 +907,9 @@ def _fit_liblinear(X, y, C, fit_intercept, intercept_scaling, class_weight,
906907
# LibLinear wants targets as doubles, even for classification
907908
y_ind = np.asarray(y_ind, dtype=np.float64).ravel()
908909
y_ind = np.require(y_ind, requirements="W")
909-
if sample_weight is None:
910-
sample_weight = np.ones(X.shape[0])
911-
else:
912-
sample_weight = np.array(sample_weight, dtype=np.float64, order='C')
913-
check_consistent_length(sample_weight, X)
910+
911+
sample_weight = _check_sample_weight(sample_weight, X,
912+
dtype=np.float64)
914913

915914
solver_type = _get_liblinear_solver_type(multi_class, penalty, loss, dual)
916915
raw_coef_, n_iter_ = liblinear.train_wrap(

sklearn/utils/tests/test_validation.py

+40-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from sklearn.utils.testing import SkipTest
2121
from sklearn.utils.testing import assert_array_equal
2222
from sklearn.utils.testing import assert_allclose_dense_sparse
23+
from sklearn.utils.testing import assert_allclose
2324
from sklearn.utils import as_float_array, check_array, check_symmetric
2425
from sklearn.utils import check_X_y
2526
from sklearn.utils import deprecated
@@ -39,7 +40,8 @@
3940
check_memory,
4041
check_non_negative,
4142
_num_samples,
42-
check_scalar)
43+
check_scalar,
44+
_check_sample_weight)
4345
import sklearn
4446

4547
from sklearn.exceptions import NotFittedError
@@ -853,3 +855,40 @@ def test_check_scalar_invalid(x, target_name, target_type, min_val, max_val,
853855
min_val=min_val, max_val=max_val)
854856
assert str(raised_error.value) == str(err_msg)
855857
assert type(raised_error.value) == type(err_msg)
858+
859+
860+
def test_check_sample_weight():
861+
# check array order
862+
sample_weight = np.ones(10)[::2]
863+
assert not sample_weight.flags["C_CONTIGUOUS"]
864+
sample_weight = _check_sample_weight(sample_weight, X=np.ones((5, 1)))
865+
assert sample_weight.flags["C_CONTIGUOUS"]
866+
867+
# check None input
868+
sample_weight = _check_sample_weight(None, X=np.ones((5, 2)))
869+
assert_allclose(sample_weight, np.ones(5))
870+
871+
# check numbers input
872+
sample_weight = _check_sample_weight(2.0, X=np.ones((5, 2)))
873+
assert_allclose(sample_weight, 2 * np.ones(5))
874+
875+
# check wrong number of dimensions
876+
with pytest.raises(ValueError,
877+
match="Sample weights must be 1D array or scalar"):
878+
_check_sample_weight(np.ones((2, 4)), X=np.ones((2, 2)))
879+
880+
# check incorrect n_samples
881+
msg = r"sample_weight.shape == \(4,\), expected \(2,\)!"
882+
with pytest.raises(ValueError, match=msg):
883+
_check_sample_weight(np.ones(4), X=np.ones((2, 2)))
884+
885+
# float32 dtype is preserved
886+
X = np.ones((5, 2))
887+
sample_weight = np.ones(5, dtype=np.float32)
888+
sample_weight = _check_sample_weight(sample_weight, X)
889+
assert sample_weight.dtype == np.float32
890+
891+
# int dtype will be converted to float64 instead
892+
X = np.ones((5, 2), dtype=np.int)
893+
sample_weight = _check_sample_weight(None, X, dtype=X.dtype)
894+
assert sample_weight.dtype == np.float64

0 commit comments

Comments
 (0)