|
27 | 27 | from ..utils import check_array
|
28 | 28 | from ..utils import gen_batches
|
29 | 29 | from ..utils import check_random_state
|
30 |
| -from ..utils.validation import check_is_fitted |
| 30 | +from ..utils.validation import check_is_fitted, _check_sample_weight |
31 | 31 | from ..utils.validation import FLOAT_DTYPES
|
32 | 32 | from ..exceptions import ConvergenceWarning
|
33 | 33 | from . import _k_means
|
@@ -167,19 +167,19 @@ def _tolerance(X, tol):
|
167 | 167 | return np.mean(variances) * tol
|
168 | 168 |
|
169 | 169 |
|
170 |
| -def _check_sample_weight(X, sample_weight): |
| 170 | +def _check_normalize_sample_weight(sample_weight, X): |
171 | 171 | """Set sample_weight if None, and check for correct dtype"""
|
172 |
| - n_samples = X.shape[0] |
173 |
| - if sample_weight is None: |
174 |
| - return np.ones(n_samples, dtype=X.dtype) |
175 |
| - else: |
176 |
| - sample_weight = np.asarray(sample_weight) |
177 |
| - if n_samples != len(sample_weight): |
178 |
| - raise ValueError("n_samples=%d should be == len(sample_weight)=%d" |
179 |
| - % (n_samples, len(sample_weight))) |
| 172 | + |
| 173 | + sample_weight_was_none = sample_weight is None |
| 174 | + |
| 175 | + sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype) |
| 176 | + if not sample_weight_was_none: |
180 | 177 | # normalize the weights to sum up to n_samples
|
| 178 | + # an array of 1 (i.e. samples_weight is None) is already normalized |
| 179 | + n_samples = len(sample_weight) |
181 | 180 | scale = n_samples / sample_weight.sum()
|
182 |
| - return (sample_weight * scale).astype(X.dtype, copy=False) |
| 181 | + sample_weight *= scale |
| 182 | + return sample_weight |
183 | 183 |
|
184 | 184 |
|
185 | 185 | def k_means(X, n_clusters, sample_weight=None, init='k-means++',
|
@@ -437,7 +437,7 @@ def _kmeans_single_elkan(X, sample_weight, n_clusters, max_iter=300,
|
437 | 437 | if verbose:
|
438 | 438 | print('Initialization complete')
|
439 | 439 |
|
440 |
| - checked_sample_weight = _check_sample_weight(X, sample_weight) |
| 440 | + checked_sample_weight = _check_normalize_sample_weight(sample_weight, X) |
441 | 441 | centers, labels, n_iter = k_means_elkan(X, checked_sample_weight,
|
442 | 442 | n_clusters, centers, tol=tol,
|
443 | 443 | max_iter=max_iter, verbose=verbose)
|
@@ -522,7 +522,7 @@ def _kmeans_single_lloyd(X, sample_weight, n_clusters, max_iter=300,
|
522 | 522 | """
|
523 | 523 | random_state = check_random_state(random_state)
|
524 | 524 |
|
525 |
| - sample_weight = _check_sample_weight(X, sample_weight) |
| 525 | + sample_weight = _check_normalize_sample_weight(sample_weight, X) |
526 | 526 |
|
527 | 527 | best_labels, best_inertia, best_centers = None, None, None
|
528 | 528 | # init
|
@@ -665,7 +665,7 @@ def _labels_inertia(X, sample_weight, x_squared_norms, centers,
|
665 | 665 | Sum of squared distances of samples to their closest cluster center.
|
666 | 666 | """
|
667 | 667 | n_samples = X.shape[0]
|
668 |
| - sample_weight = _check_sample_weight(X, sample_weight) |
| 668 | + sample_weight = _check_normalize_sample_weight(sample_weight, X) |
669 | 669 | # set the default value of centers to -1 to be able to detect any anomaly
|
670 | 670 | # easily
|
671 | 671 | labels = np.full(n_samples, -1, np.int32)
|
@@ -1495,7 +1495,7 @@ def fit(self, X, y=None, sample_weight=None):
|
1495 | 1495 | raise ValueError("n_samples=%d should be >= n_clusters=%d"
|
1496 | 1496 | % (n_samples, self.n_clusters))
|
1497 | 1497 |
|
1498 |
| - sample_weight = _check_sample_weight(X, sample_weight) |
| 1498 | + sample_weight = _check_normalize_sample_weight(sample_weight, X) |
1499 | 1499 |
|
1500 | 1500 | n_init = self.n_init
|
1501 | 1501 | if hasattr(self.init, '__array__'):
|
@@ -1644,7 +1644,7 @@ def _labels_inertia_minibatch(self, X, sample_weight):
|
1644 | 1644 | """
|
1645 | 1645 | if self.verbose:
|
1646 | 1646 | print('Computing label assignment and total inertia')
|
1647 |
| - sample_weight = _check_sample_weight(X, sample_weight) |
| 1647 | + sample_weight = _check_normalize_sample_weight(sample_weight, X) |
1648 | 1648 | x_squared_norms = row_norms(X, squared=True)
|
1649 | 1649 | slices = gen_batches(X.shape[0], self.batch_size)
|
1650 | 1650 | results = [_labels_inertia(X[s], sample_weight[s], x_squared_norms[s],
|
@@ -1679,7 +1679,7 @@ def partial_fit(self, X, y=None, sample_weight=None):
|
1679 | 1679 | if n_samples == 0:
|
1680 | 1680 | return self
|
1681 | 1681 |
|
1682 |
| - sample_weight = _check_sample_weight(X, sample_weight) |
| 1682 | + sample_weight = _check_normalize_sample_weight(sample_weight, X) |
1683 | 1683 |
|
1684 | 1684 | x_squared_norms = row_norms(X, squared=True)
|
1685 | 1685 | self.random_state_ = getattr(self, "random_state_",
|
|
0 commit comments