Skip to content
Merged
6 changes: 5 additions & 1 deletion doc/whats_new/v0.24.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ Changelog
:user:`Lucy Liu <lucyleeow>`.

:mod:`sklearn.cluster`
.......................
......................

- |Fix| Fixed a bug in :class:`cluster.MeanShift` with `bin_seeding=True`. When
the estimated bandwidth is 0, the behavior is equivalent to
Expand All @@ -66,6 +66,10 @@ Changelog
weighted by the sample weights. :pr:`17848` by
:user:`Jérémie du Boisberranger <jeremiedbb>`.

- |API| :class:`cluster.MiniBatchKMeans` attributes, `counts_` and
`init_size_`, are deprecated and will be removed in 0.26. :pr:`17864` by
:user:`Jérémie du Boisberranger <jeremiedbb>`.

:mod:`sklearn.covariance`
.........................

Expand Down
56 changes: 44 additions & 12 deletions sklearn/cluster/_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ..utils import check_array
from ..utils import gen_batches
from ..utils import check_random_state
from ..utils import deprecated
from ..utils.validation import check_is_fitted, _check_sample_weight
from ..utils._openmp_helpers import _openmp_effective_n_threads
from ..exceptions import ConvergenceWarning
Expand Down Expand Up @@ -1531,6 +1532,21 @@ class MiniBatchKMeans(KMeans):
defined as the sum of square distances of samples to their nearest
neighbor.

n_iter_ : int
Number of batches processed.

counts_ : ndarray of shape (n_clusters,)
Weigth sum of each cluster.

.. deprecated:: 0.24
This attribute is deprecated in 0.24 and will be removed in 0.26.

init_size_ : int
The effective number of samples used for the initialization.

.. deprecated:: 0.24
This attribute is deprecated in 0.24 and will be removed in 0.26.

See Also
--------
KMeans
Expand Down Expand Up @@ -1588,6 +1604,24 @@ def __init__(self, n_clusters=8, *, init='k-means++', max_iter=100,
self.init_size = init_size
self.reassignment_ratio = reassignment_ratio

@deprecated("The attribute 'counts_' is deprecated in 0.24" # type: ignore
" and will be removed in 0.26.")
@property
def counts_(self):
return self._counts

@deprecated("The attribute 'init_size_' is deprecated in " # type: ignore
"0.24 and will be removed in 0.26.")
@property
def init_size_(self):
return self._init_size

@deprecated("The attribute 'random_state_' is deprecated " # type: ignore
"in 0.24 and will be removed in 0.26.")
@property
def random_state_(self):
return getattr(self, "_random_state", None)

def _check_params(self, X):
super()._check_params(X)

Expand Down Expand Up @@ -1619,8 +1653,6 @@ def _check_params(self, X):
RuntimeWarning, stacklevel=2)
self._init_size = 3 * self.n_clusters
self._init_size = min(self._init_size, X.shape[0])
# FIXME: init_size_ will be deprecated and this line will be removed
self.init_size_ = self._init_size

# reassignment_ratio
if self.reassignment_ratio < 0:
Expand Down Expand Up @@ -1727,7 +1759,7 @@ def fit(self, X, y=None, sample_weight=None):
% (init_idx + 1, self._n_init, inertia))
if best_inertia is None or inertia < best_inertia:
self.cluster_centers_ = cluster_centers
self.counts_ = weight_sums
self._counts = weight_sums
best_inertia = inertia

# Empty context to be used inplace by the convergence check routine
Expand All @@ -1744,15 +1776,15 @@ def fit(self, X, y=None, sample_weight=None):
batch_inertia, centers_squared_diff = _mini_batch_step(
X[minibatch_indices], sample_weight[minibatch_indices],
x_squared_norms[minibatch_indices],
self.cluster_centers_, self.counts_,
self.cluster_centers_, self._counts,
old_center_buffer, tol > 0.0, distances=distances,
# Here we randomly choose whether to perform
# random reassignment: the choice is done as a function
# of the iteration index, and the minimum number of
# counts, in order to force this reassignment to happen
# every once in a while
random_reassign=((iteration_idx + 1)
% (10 + int(self.counts_.min())) == 0),
% (10 + int(self._counts.min())) == 0),
random_state=random_state,
reassignment_ratio=self.reassignment_ratio,
verbose=self.verbose)
Expand Down Expand Up @@ -1831,7 +1863,7 @@ def partial_fit(self, X, y=None, sample_weight=None):
order='C', accept_large_sparse=False,
reset=is_first_call_to_partial_fit)

self.random_state_ = getattr(self, "random_state_",
self._random_state = getattr(self, "_random_state",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also need to deprecate this one, or was it a typo / random_state_ never existed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure what to do about this one. In the PR descr I asked

random_state_ is only created and used in partial_fit. I think it's safe to directly privatize it. What do you think ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missed that, sorry. If it's only a matter of having a @Property it wouldn't hurt to deprecate (no need to document it though). If it's a pain, I agree there's little risk in just ignoring this. Up to you, feel free to self merge

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's deprecate it as well

check_random_state(self.random_state))
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)

Expand All @@ -1850,26 +1882,26 @@ def partial_fit(self, X, y=None, sample_weight=None):
# initialize the cluster centers
self.cluster_centers_ = _init_centroids(
X, self.n_clusters, init,
random_state=self.random_state_,
random_state=self._random_state,
x_squared_norms=x_squared_norms, init_size=self.init_size)

self.counts_ = np.zeros(self.n_clusters,
self._counts = np.zeros(self.n_clusters,
dtype=sample_weight.dtype)
random_reassign = False
distances = None
else:
# The lower the minimum count is, the more we do random
# reassignment, however, we don't want to do random
# reassignment too often, to allow for building up counts
random_reassign = self.random_state_.randint(
10 * (1 + self.counts_.min())) == 0
random_reassign = self._random_state.randint(
10 * (1 + self._counts.min())) == 0
distances = np.zeros(X.shape[0], dtype=X.dtype)

_mini_batch_step(X, sample_weight, x_squared_norms,
self.cluster_centers_, self.counts_,
self.cluster_centers_, self._counts,
np.zeros(0, dtype=X.dtype), 0,
random_reassign=random_reassign, distances=distances,
random_state=self.random_state_,
random_state=self._random_state,
reassignment_ratio=self.reassignment_ratio,
verbose=self.verbose)

Expand Down
19 changes: 16 additions & 3 deletions sklearn/cluster/tests/test_k_means.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def test_minibatch_reassign():
# Turn on verbosity to smoke test the display code
_mini_batch_step(this_X, sample_weight, (X ** 2).sum(axis=1),
mb_k_means.cluster_centers_,
mb_k_means.counts_,
mb_k_means._counts,
np.zeros(X.shape[1], np.double),
False, distances=np.zeros(X.shape[0]),
random_reassign=True, random_state=42,
Expand All @@ -454,7 +454,7 @@ def test_minibatch_reassign():
# Turn on verbosity to smoke test the display code
_mini_batch_step(this_X, sample_weight, (X ** 2).sum(axis=1),
mb_k_means.cluster_centers_,
mb_k_means.counts_,
mb_k_means._counts,
np.zeros(X.shape[1], np.double),
False, distances=np.zeros(X.shape[0]),
random_reassign=True, random_state=42,
Expand Down Expand Up @@ -529,7 +529,7 @@ def test_minibatch_set_init_size():
init_size=666, random_state=42,
n_init=1).fit(X)
assert mb_k_means.init_size == 666
assert mb_k_means.init_size_ == n_samples
assert mb_k_means._init_size == n_samples
_check_fitted_model(mb_k_means)


Expand Down Expand Up @@ -933,6 +933,19 @@ def test_n_jobs_deprecated(n_jobs):
kmeans.fit(X)


@pytest.mark.parametrize("attr", ["counts_", "init_size_", "random_state_"])
def test_minibatch_kmeans_deprecated_attributes(attr):
# check that we raise a deprecation warning when accessing `init_size_`
# FIXME: remove in 0.26
depr_msg = (f"The attribute '{attr}' is deprecated in 0.24 and will be "
f"removed in 0.26.")
km = MiniBatchKMeans(n_clusters=2, n_init=1, init='random', random_state=0)
km.fit(X)

with pytest.warns(FutureWarning, match=depr_msg):
getattr(km, attr)


def test_warning_elkan_1_cluster():
X, _ = make_blobs(n_samples=10, n_features=2, centers=1, random_state=0)
kmeans = KMeans(n_clusters=1, n_init=1, init='random', random_state=0,
Expand Down
7 changes: 2 additions & 5 deletions sklearn/tests/test_docstring_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,13 +232,10 @@ def test_fit_docstring_attributes(name, Estimator):
with ignore_warnings(category=FutureWarning):
assert hasattr(est, attr.name)

IGNORED = {'BayesianRidge', 'Birch', 'CCA', 'CategoricalNB',
'KernelCenterer',
IGNORED = {'BayesianRidge', 'Birch', 'CCA',
'LarsCV', 'Lasso', 'LassoLarsIC',
'MiniBatchKMeans',
'OrthogonalMatchingPursuit',
'PLSCanonical', 'PLSSVD',
'PassiveAggressiveClassifier'}
'PLSCanonical', 'PLSSVD'}

if Estimator.__name__ in IGNORED:
pytest.xfail(
Expand Down