Skip to content

Commit aa1e69a

Browse files
authored
API Removes the use of fit_ and partial_fit_ in Birch (scikit-learn#19297)
* API Removes the use of fit_ and partial_fit_ in Birch * DOC Adds whats new * ENH Adjust names * CLN Uses a verbose name
1 parent b943324 commit aa1e69a

File tree

3 files changed

+42
-6
lines changed

3 files changed

+42
-6
lines changed

doc/whats_new/v1.0.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ Changelog
5656
in multicore settings. :pr:`19052` by
5757
:user:`Yusuke Nagasaka <YusukeNagasaka>`.
5858

59+
- |API| :class:`cluster.Birch` attributes, `fit_` and `partial_fit_`, are
60+
deprecated and will be removed in 1.2. :pr:`19297` by `Thomas Fan`_.
61+
5962
- |Fix| Fixes incorrect multiple data-conversion warnings when clustering
6063
boolean data. :pr:`19046` by :user:`Surya Prakash <jdsurya>`.
6164

sklearn/cluster/_birch.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from ..metrics.pairwise import euclidean_distances
1414
from ..base import TransformerMixin, ClusterMixin, BaseEstimator
1515
from ..utils.extmath import row_norms
16+
from ..utils import deprecated
1617
from ..utils.validation import check_is_fitted, _deprecate_positional_args
1718
from ..exceptions import ConvergenceWarning
1819
from . import AgglomerativeClustering
@@ -440,6 +441,24 @@ def __init__(self, *, threshold=0.5, branching_factor=50, n_clusters=3,
440441
self.compute_labels = compute_labels
441442
self.copy = copy
442443

444+
# TODO: Remove in 1.2
445+
# mypy error: Decorated property not supported
446+
@deprecated( # type: ignore
447+
"fit_ is deprecated in 1.0 and will be removed in 1.2"
448+
)
449+
@property
450+
def fit_(self):
451+
return self._deprecated_fit
452+
453+
# TODO: Remove in 1.2
454+
# mypy error: Decorated property not supported
455+
@deprecated( # type: ignore
456+
"partial_fit_ is deprecated in 1.0 and will be removed in 1.2"
457+
)
458+
@property
459+
def partial_fit_(self):
460+
return self._deprecated_partial_fit
461+
443462
def fit(self, X, y=None):
444463
"""
445464
Build a CF Tree for the input data.
@@ -457,12 +476,13 @@ def fit(self, X, y=None):
457476
self
458477
Fitted estimator.
459478
"""
460-
self.fit_, self.partial_fit_ = True, False
461-
return self._fit(X)
479+
# TODO: Remove deprected flags in 1.2
480+
self._deprecated_fit, self._deprecated_partial_fit = True, False
481+
return self._fit(X, partial=False)
462482

463-
def _fit(self, X):
483+
def _fit(self, X, partial):
464484
has_root = getattr(self, 'root_', None)
465-
first_call = self.fit_ or (self.partial_fit_ and not has_root)
485+
first_call = not (partial and has_root)
466486

467487
X = self._validate_data(X, accept_sparse='csr', copy=self.copy,
468488
reset=first_call)
@@ -552,13 +572,14 @@ def partial_fit(self, X=None, y=None):
552572
self
553573
Fitted estimator.
554574
"""
555-
self.partial_fit_, self.fit_ = True, False
575+
# TODO: Remove deprected flags in 1.2
576+
self._deprecated_partial_fit, self._deprecated_fit = True, False
556577
if X is None:
557578
# Perform just the final global clustering step.
558579
self._global_clustering()
559580
return self
560581
else:
561-
return self._fit(X)
582+
return self._fit(X, partial=True)
562583

563584
def _check_fit(self, X):
564585
check_is_fitted(self)

sklearn/cluster/tests/test_birch.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,3 +179,15 @@ def test_birch_n_clusters_long_int():
179179
X, _ = make_blobs(random_state=0)
180180
n_clusters = np.int64(5)
181181
Birch(n_clusters=n_clusters).fit(X)
182+
183+
184+
# TODO: Remove in 1.2
185+
@pytest.mark.parametrize("attribute", ["fit_", "partial_fit_"])
186+
def test_birch_fit_attributes_deprecated(attribute):
187+
"""Test that fit_ and partial_fit_ attributes are deprecated."""
188+
msg = f"{attribute} is deprecated in 1.0 and will be removed in 1.2"
189+
X, y = make_blobs(n_samples=10)
190+
brc = Birch().fit(X, y)
191+
192+
with pytest.warns(FutureWarning, match=msg):
193+
getattr(brc, attribute)

0 commit comments

Comments
 (0)