Skip to content

Commit c5440bf

Browse files
committed
Made the following changes:
1. Make clear that partial_fit(None) enables global clustering. 2. Add compute_labels argument.
1 parent b01e21e commit c5440bf

File tree

4 files changed

+101
-57
lines changed

4 files changed

+101
-57
lines changed

doc/modules/clustering.rst

+17-2
Original file line numberDiff line numberDiff line change
@@ -824,8 +824,23 @@ samples are mapped to the global label of the nearest subcluster.
824824
then this node is again split into two and the process is continued
825825
recursively, till it reaches the root.
826826

827-
Birch is generally slightly faster than MiniBatchKMeans on large clusters and
828-
is slightly slower than MiniBatchKMeans on large features.
827+
**Birch or MiniBatchKMeans?**
828+
829+
- Birch does not scale very well to high dimensionsal data. If ``n_features``
830+
is greater than twenty, it is generally better to use MiniBatchKMeans.
831+
- If the number of instances of data needs to be reduced, or if ``n_clusters``
832+
is really large, it is generally better to use Birch.
833+
834+
**How to use partial_fit?**
835+
836+
To avoid the computation of global clustering, for every call of ``partial_fit``
837+
the user is advised
838+
1. To set ``n_clusters=None`` initially
839+
2. Train all data by multiple calls to partial_fit.
840+
3. Set ``n_clusters`` to a required value using
841+
``brc.set_params(n_clusters=n_clusters)``
842+
4. Call ``partial_fit`` finally with no arguments, i.e ``brc.partial_fit()``
843+
which performs the global clustering.
829844

830845
.. image:: ../auto_examples/cluster/images/plot_birch_vs_minibatchkmeans_001.png
831846
:target: ../auto_examples/cluster/plot_birch_vs_minibatchkmeans.html

doc/whats_new.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ New features
5959
By `Gael Varoquaux`_ and `Florian Wilhelm`_.
6060

6161
- Add :class:`cluster.Birch`, an online clustering algorithm. By
62-
`Manoj Kumar`_ and `Alexandre Gramfort`_.
62+
`Manoj Kumar`_, `Alexandre Gramfort`_ and `Joel Nothman`_.
6363

6464
Enhancements
6565
............

sklearn/cluster/birch.py

+73-51
Original file line numberDiff line numberDiff line change
@@ -326,20 +326,21 @@ class Birch(BaseEstimator, TransformerMixin, ClusterMixin):
326326
subcluster is started.
327327
328328
branching_factor : int, default 50
329-
Maximun number of CF subclusters in each node. If a new samples enters
329+
Maximum number of CF subclusters in each node. If a new samples enters
330330
such that the number of subclusters exceed the branching_factor then
331331
the node has to be split. The corresponding parent also has to be
332332
split and if the number of subclusters in the parent is greater than
333333
the branching factor, then it has to be split recursively.
334334
335-
n_clusters : int, instance of sklearn.cluster model or None, default 3
336-
Number of clusters after the final clustring step, which treats the
335+
n_clusters : int, instance of sklearn.cluster model, default None
336+
Number of clusters after the final clustering step, which treats the
337337
subclusters from the leaves as new samples. By default the global
338338
clustering step is AgglomerativeClustering with n_clusters set to 3.
339-
If set to None, this final clustering step is not performed and the
339+
By default, this final clustering step is not performed and the
340340
subclusters are returned as they are.
341-
It is advised to set n_clusters=None if ``partial_fit`` is used, to
342-
avoid this final clustering step for every call to ``partial_fit``.
341+
342+
compute_labels : bool, default True
343+
Whether or not to compute labels for each fit.
343344
344345
Attributes
345346
----------
@@ -357,16 +358,19 @@ class Birch(BaseEstimator, TransformerMixin, ClusterMixin):
357358
they are clustered globally.
358359
359360
labels_ : ndarray, shape (n_samples,)
360-
Array of labels assigned to the input data or the last
361-
batch of input data if partial_fit was used.
361+
Array of labels assigned to the input data.
362+
if partial_fit is used instead of fit, they are assigned to the
363+
last batch of data.
362364
363365
Examples
364366
--------
365367
>>> from sklearn.cluster import Birch
366368
>>> X = [[0, 1], [0.3, 1], [-0.3, 1], [0, -1], [0.3, -1], [-0.3, -1]]
367-
>>> brc = Birch(threshold=0.5, n_clusters=None)
369+
>>> brc = Birch(branching_factor=50, n_clusters=None, threshold=0.5,
370+
... compute_labels=True)
368371
>>> brc.fit(X)
369-
Birch(branching_factor=50, n_clusters=None, threshold=0.5)
372+
Birch(branching_factor=50, compute_labels=True, n_clusters=None,
373+
threshold=0.5)
370374
>>> brc.predict(X)
371375
array([0, 0, 0, 1, 1, 1])
372376
@@ -381,10 +385,12 @@ class Birch(BaseEstimator, TransformerMixin, ClusterMixin):
381385
https://code.google.com/p/jbirch/
382386
"""
383387

384-
def __init__(self, threshold=0.5, branching_factor=50, n_clusters=3):
388+
def __init__(self, threshold=0.5, branching_factor=50, n_clusters=3,
389+
compute_labels=True):
385390
self.threshold = threshold
386391
self.branching_factor = branching_factor
387392
self.n_clusters = n_clusters
393+
self.compute_labels = compute_labels
388394
self.partial_fit_ = False
389395
self.root_ = None
390396

@@ -474,43 +480,9 @@ def fit(self, X, y=None):
474480

475481
centroids = np.concatenate([
476482
leaf.centroids_ for leaf in self.get_leaves()])
477-
478-
# Preprocessing for the global clustering.
479-
not_enough_centroids = False
480-
if hasattr(self.n_clusters, 'fit_predict'):
481-
global_cluster = self.n_clusters
482-
elif isinstance(self.n_clusters, int):
483-
global_cluster = AgglomerativeClustering(
484-
n_clusters=self.n_clusters)
485-
# There is no need to perform the global clustering step.
486-
if len(centroids) < self.n_clusters:
487-
not_enough_centroids = True
488-
elif self.n_clusters is not None:
489-
raise ValueError("n_clusters should be an instance of "
490-
"ClusterMixin or an int")
491-
492-
# To use in predict to avoid recalculation.
493483
self.subcluster_centers_ = centroids
494-
self._subcluster_norms = row_norms(
495-
self.subcluster_centers_, squared=True)
496-
497-
if self.n_clusters is None or not_enough_centroids:
498-
self.subcluster_labels_ = np.arange(len(centroids))
499-
self.labels_ = self.predict(X)
500-
if not_enough_centroids:
501-
warnings.warn(
502-
"Number of subclusters found (%d) by Birch is less "
503-
"than (%d). Decrease the threshold."
504-
% (len(centroids), self.n_clusters))
505-
return self
506-
507-
# The global clustering step that clusters the subclusters of
508-
# the leaves. It assumes the centroids of the subclusters as
509-
# samples and finds the final centroids.
510-
self.subcluster_labels_ = global_cluster.fit_predict(
511-
self.subcluster_centers_)
512-
self.labels_ = self.predict(X)
513484

485+
self._global_clustering(X)
514486
return self
515487

516488
def get_leaves(self):
@@ -535,7 +507,7 @@ def _check_fit(self, X):
535507
if X.shape[1] != self.subcluster_centers_.shape[1]:
536508
raise ValueError(
537509
"Training data and predicted data do "
538-
"not have same features.")
510+
"not have same no. of features.")
539511

540512
def predict(self, X):
541513
"""
@@ -560,17 +532,23 @@ def predict(self, X):
560532
reduced_distance += self._subcluster_norms
561533
return self.subcluster_labels_[np.argmin(reduced_distance, axis=1)]
562534

563-
def partial_fit(self, X, y=None):
535+
def partial_fit(self, X=None, y=None):
564536
"""
565537
Online learning. Prevents rebuilding of CFTree from scratch.
566538
567539
Parameters
568540
----------
569-
X : {array-like, sparse matrix}, shape (n_samples, n_features)
570-
Input data.
541+
X : {array-like, sparse matrix}, shape (n_samples, n_features), None
542+
Input data. If X is not provided, only the global clustering
543+
step is done.
571544
"""
572545
self.partial_fit_ = True
573-
return self.fit(X)
546+
if X is None:
547+
# Perform just the final global clustering step.
548+
self._global_clustering()
549+
return self
550+
else:
551+
return self.fit(X)
574552

575553
def transform(self, X, y=None):
576554
"""
@@ -591,3 +569,47 @@ def transform(self, X, y=None):
591569
if not hasattr(self, 'subcluster_centers_'):
592570
raise ValueError("Fit training data before predicting")
593571
return euclidean_distances(X, self.subcluster_centers_)
572+
573+
def _global_clustering(self, X=None):
574+
"""
575+
Global clustering for the subclusters obtained after fitting
576+
"""
577+
clusters = self.n_clusters
578+
centroids = self.subcluster_centers_
579+
compute_labels = (X is not None) and self.compute_labels
580+
581+
# Preprocessing for the global clustering.
582+
not_enough_centroids = False
583+
if hasattr(clusters, 'fit_predict'):
584+
global_cluster = clusters
585+
elif isinstance(clusters, int):
586+
global_cluster = AgglomerativeClustering(
587+
n_clusters=clusters)
588+
# There is no need to perform the global clustering step.
589+
if len(centroids) < clusters:
590+
not_enough_centroids = True
591+
elif clusters is not None:
592+
raise ValueError("n_clusters should be an instance of "
593+
"ClusterMixin or an int")
594+
595+
# To use in predict to avoid recalculation.
596+
if compute_labels:
597+
self._subcluster_norms = row_norms(
598+
self.subcluster_centers_, squared=True)
599+
600+
if self.n_clusters is None or not_enough_centroids:
601+
self.subcluster_labels_ = np.arange(len(centroids))
602+
if not_enough_centroids:
603+
warnings.warn(
604+
"Number of subclusters found (%d) by Birch is less "
605+
"than (%d). Decrease the threshold."
606+
% (len(centroids), self.n_clusters))
607+
else:
608+
# The global clustering step that clusters the subclusters of
609+
# the leaves. It assumes the centroids of the subclusters as
610+
# samples and finds the final centroids.
611+
self.subcluster_labels_ = global_cluster.fit_predict(
612+
self.subcluster_centers_)
613+
614+
if compute_labels:
615+
self.labels_ = self.predict(X)

sklearn/cluster/tests/test_birch.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,22 @@ def test_n_samples_leaves_roots():
3333

3434

3535
def test_partial_fit():
36-
"""Test that fit is equivalent to calling partial_fit multiple times"""
36+
"""Test that fit is equivalent to calling partial_fit multiple times"""
37+
# Test that same subcluster centres are obtained after calling partial
38+
# fit twice
3739
X, y = make_blobs(n_samples=100)
3840
brc = Birch(n_clusters=3)
3941
brc.fit(X)
40-
brc_partial = Birch()
42+
brc_partial = Birch(n_clusters=None)
4143
brc_partial.partial_fit(X[:50])
4244
brc_partial.partial_fit(X[50:])
4345
assert_array_equal(brc_partial.subcluster_centers_, brc.subcluster_centers_)
44-
assert_equal(len(np.unique(brc.labels_)), 3)
46+
47+
# Test that same global labels are obtained after calling partial_fit
48+
# with None
49+
brc_partial.set_params(n_clusters=3)
50+
brc_partial.partial_fit(None)
51+
assert_array_equal(brc_partial.subcluster_labels_, brc.subcluster_labels_)
4552

4653

4754
def test_birch_predict():

0 commit comments

Comments
 (0)