@@ -470,7 +470,7 @@ def fit(self, X, y=None):
470
470
leaf .centroids_ for leaf in self ._get_leaves ()])
471
471
self .subcluster_centers_ = centroids
472
472
473
- self .global_clustering (X )
473
+ self ._global_clustering (X )
474
474
return self
475
475
476
476
def _get_leaves (self ):
@@ -490,9 +490,14 @@ def _get_leaves(self):
490
490
return leaves
491
491
492
492
def _check_fit (self , X ):
493
- if not hasattr (self , 'subcluster_centers_' ):
493
+ is_fitted = hasattr (self , 'subcluster_centers_' )
494
+ has_partial_fit = hasattr (self , 'partial_fit' )
495
+
496
+ # Should raise an error if one does not fit before predicting.
497
+ if not has_partial_fit and not is_fitted :
494
498
raise ValueError ("Fit training data before predicting" )
495
- if X .shape [1 ] != self .subcluster_centers_ .shape [1 ]:
499
+
500
+ if is_fitted and X .shape [1 ] != self .subcluster_centers_ .shape [1 ]:
496
501
raise ValueError (
497
502
"Training data and predicted data do "
498
503
"not have same no. of features." )
@@ -532,9 +537,10 @@ def partial_fit(self, X=None, y=None):
532
537
"""
533
538
if X is None :
534
539
# Perform just the final global clustering step.
535
- self .global_clustering ()
540
+ self ._global_clustering ()
536
541
return self
537
542
else :
543
+ self ._check_fit (X )
538
544
return self .fit (X )
539
545
540
546
def transform (self , X , y = None ):
@@ -557,7 +563,7 @@ def transform(self, X, y=None):
557
563
raise ValueError ("Fit training data before predicting" )
558
564
return euclidean_distances (X , self .subcluster_centers_ )
559
565
560
- def global_clustering (self , X = None ):
566
+ def _global_clustering (self , X = None ):
561
567
"""
562
568
Global clustering for the subclusters obtained after fitting
563
569
"""
0 commit comments