@@ -326,20 +326,21 @@ class Birch(BaseEstimator, TransformerMixin, ClusterMixin):
326
326
subcluster is started.
327
327
328
328
branching_factor : int, default 50
329
- Maximun number of CF subclusters in each node. If a new samples enters
329
+ Maximum number of CF subclusters in each node. If a new samples enters
330
330
such that the number of subclusters exceed the branching_factor then
331
331
the node has to be split. The corresponding parent also has to be
332
332
split and if the number of subclusters in the parent is greater than
333
333
the branching factor, then it has to be split recursively.
334
334
335
- n_clusters : int, instance of sklearn.cluster model or None , default 3
336
- Number of clusters after the final clustring step, which treats the
335
+ n_clusters : int, instance of sklearn.cluster model, default None
336
+ Number of clusters after the final clustering step, which treats the
337
337
subclusters from the leaves as new samples. By default the global
338
338
clustering step is AgglomerativeClustering with n_clusters set to 3.
339
- If set to None , this final clustering step is not performed and the
339
+ By default , this final clustering step is not performed and the
340
340
subclusters are returned as they are.
341
- It is advised to set n_clusters=None if ``partial_fit`` is used, to
342
- avoid this final clustering step for every call to ``partial_fit``.
341
+
342
+ compute_labels : bool, default True
343
+ Whether or not to compute labels for each fit.
343
344
344
345
Attributes
345
346
----------
@@ -357,16 +358,19 @@ class Birch(BaseEstimator, TransformerMixin, ClusterMixin):
357
358
they are clustered globally.
358
359
359
360
labels_ : ndarray, shape (n_samples,)
360
- Array of labels assigned to the input data or the last
361
- batch of input data if partial_fit was used.
361
+ Array of labels assigned to the input data.
362
+ if partial_fit is used instead of fit, they are assigned to the
363
+ last batch of data.
362
364
363
365
Examples
364
366
--------
365
367
>>> from sklearn.cluster import Birch
366
368
>>> X = [[0, 1], [0.3, 1], [-0.3, 1], [0, -1], [0.3, -1], [-0.3, -1]]
367
- >>> brc = Birch(threshold=0.5, n_clusters=None)
369
+ >>> brc = Birch(branching_factor=50, n_clusters=None, threshold=0.5,
370
+ ... compute_labels=True)
368
371
>>> brc.fit(X)
369
- Birch(branching_factor=50, n_clusters=None, threshold=0.5)
372
+ Birch(branching_factor=50, compute_labels=True, n_clusters=None,
373
+ threshold=0.5)
370
374
>>> brc.predict(X)
371
375
array([0, 0, 0, 1, 1, 1])
372
376
@@ -381,10 +385,12 @@ class Birch(BaseEstimator, TransformerMixin, ClusterMixin):
381
385
https://code.google.com/p/jbirch/
382
386
"""
383
387
384
- def __init__ (self , threshold = 0.5 , branching_factor = 50 , n_clusters = 3 ):
388
+ def __init__ (self , threshold = 0.5 , branching_factor = 50 , n_clusters = 3 ,
389
+ compute_labels = True ):
385
390
self .threshold = threshold
386
391
self .branching_factor = branching_factor
387
392
self .n_clusters = n_clusters
393
+ self .compute_labels = compute_labels
388
394
self .partial_fit_ = False
389
395
self .root_ = None
390
396
@@ -474,43 +480,9 @@ def fit(self, X, y=None):
474
480
475
481
centroids = np .concatenate ([
476
482
leaf .centroids_ for leaf in self .get_leaves ()])
477
-
478
- # Preprocessing for the global clustering.
479
- not_enough_centroids = False
480
- if hasattr (self .n_clusters , 'fit_predict' ):
481
- global_cluster = self .n_clusters
482
- elif isinstance (self .n_clusters , int ):
483
- global_cluster = AgglomerativeClustering (
484
- n_clusters = self .n_clusters )
485
- # There is no need to perform the global clustering step.
486
- if len (centroids ) < self .n_clusters :
487
- not_enough_centroids = True
488
- elif self .n_clusters is not None :
489
- raise ValueError ("n_clusters should be an instance of "
490
- "ClusterMixin or an int" )
491
-
492
- # To use in predict to avoid recalculation.
493
483
self .subcluster_centers_ = centroids
494
- self ._subcluster_norms = row_norms (
495
- self .subcluster_centers_ , squared = True )
496
-
497
- if self .n_clusters is None or not_enough_centroids :
498
- self .subcluster_labels_ = np .arange (len (centroids ))
499
- self .labels_ = self .predict (X )
500
- if not_enough_centroids :
501
- warnings .warn (
502
- "Number of subclusters found (%d) by Birch is less "
503
- "than (%d). Decrease the threshold."
504
- % (len (centroids ), self .n_clusters ))
505
- return self
506
-
507
- # The global clustering step that clusters the subclusters of
508
- # the leaves. It assumes the centroids of the subclusters as
509
- # samples and finds the final centroids.
510
- self .subcluster_labels_ = global_cluster .fit_predict (
511
- self .subcluster_centers_ )
512
- self .labels_ = self .predict (X )
513
484
485
+ self ._global_clustering (X )
514
486
return self
515
487
516
488
def get_leaves (self ):
@@ -535,7 +507,7 @@ def _check_fit(self, X):
535
507
if X .shape [1 ] != self .subcluster_centers_ .shape [1 ]:
536
508
raise ValueError (
537
509
"Training data and predicted data do "
538
- "not have same features." )
510
+ "not have same no. of features." )
539
511
540
512
def predict (self , X ):
541
513
"""
@@ -560,17 +532,23 @@ def predict(self, X):
560
532
reduced_distance += self ._subcluster_norms
561
533
return self .subcluster_labels_ [np .argmin (reduced_distance , axis = 1 )]
562
534
563
- def partial_fit (self , X , y = None ):
535
+ def partial_fit (self , X = None , y = None ):
564
536
"""
565
537
Online learning. Prevents rebuilding of CFTree from scratch.
566
538
567
539
Parameters
568
540
----------
569
- X : {array-like, sparse matrix}, shape (n_samples, n_features)
570
- Input data.
541
+ X : {array-like, sparse matrix}, shape (n_samples, n_features), None
542
+ Input data. If X is not provided, only the global clustering
543
+ step is done.
571
544
"""
572
545
self .partial_fit_ = True
573
- return self .fit (X )
546
+ if X is None :
547
+ # Perform just the final global clustering step.
548
+ self ._global_clustering ()
549
+ return self
550
+ else :
551
+ return self .fit (X )
574
552
575
553
def transform (self , X , y = None ):
576
554
"""
@@ -591,3 +569,47 @@ def transform(self, X, y=None):
591
569
if not hasattr (self , 'subcluster_centers_' ):
592
570
raise ValueError ("Fit training data before predicting" )
593
571
return euclidean_distances (X , self .subcluster_centers_ )
572
+
573
+ def _global_clustering (self , X = None ):
574
+ """
575
+ Global clustering for the subclusters obtained after fitting
576
+ """
577
+ clusters = self .n_clusters
578
+ centroids = self .subcluster_centers_
579
+ compute_labels = (X is not None ) and self .compute_labels
580
+
581
+ # Preprocessing for the global clustering.
582
+ not_enough_centroids = False
583
+ if hasattr (clusters , 'fit_predict' ):
584
+ global_cluster = clusters
585
+ elif isinstance (clusters , int ):
586
+ global_cluster = AgglomerativeClustering (
587
+ n_clusters = clusters )
588
+ # There is no need to perform the global clustering step.
589
+ if len (centroids ) < clusters :
590
+ not_enough_centroids = True
591
+ elif clusters is not None :
592
+ raise ValueError ("n_clusters should be an instance of "
593
+ "ClusterMixin or an int" )
594
+
595
+ # To use in predict to avoid recalculation.
596
+ if compute_labels :
597
+ self ._subcluster_norms = row_norms (
598
+ self .subcluster_centers_ , squared = True )
599
+
600
+ if self .n_clusters is None or not_enough_centroids :
601
+ self .subcluster_labels_ = np .arange (len (centroids ))
602
+ if not_enough_centroids :
603
+ warnings .warn (
604
+ "Number of subclusters found (%d) by Birch is less "
605
+ "than (%d). Decrease the threshold."
606
+ % (len (centroids ), self .n_clusters ))
607
+ else :
608
+ # The global clustering step that clusters the subclusters of
609
+ # the leaves. It assumes the centroids of the subclusters as
610
+ # samples and finds the final centroids.
611
+ self .subcluster_labels_ = global_cluster .fit_predict (
612
+ self .subcluster_centers_ )
613
+
614
+ if compute_labels :
615
+ self .labels_ = self .predict (X )
0 commit comments