Skip to content

Commit 13ab4d5

Browse files
committed
Explicit error when n_dim changes during partial_fit
1 parent 815844d commit 13ab4d5

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

sklearn/cluster/birch.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ def fit(self, X, y=None):
470470
leaf.centroids_ for leaf in self._get_leaves()])
471471
self.subcluster_centers_ = centroids
472472

473-
self.global_clustering(X)
473+
self._global_clustering(X)
474474
return self
475475

476476
def _get_leaves(self):
@@ -490,9 +490,14 @@ def _get_leaves(self):
490490
return leaves
491491

492492
def _check_fit(self, X):
493-
if not hasattr(self, 'subcluster_centers_'):
493+
is_fitted = hasattr(self, 'subcluster_centers_')
494+
has_partial_fit = hasattr(self, 'partial_fit')
495+
496+
# Should raise an error if one does not fit before predicting.
497+
if not has_partial_fit and not is_fitted:
494498
raise ValueError("Fit training data before predicting")
495-
if X.shape[1] != self.subcluster_centers_.shape[1]:
499+
500+
if is_fitted and X.shape[1] != self.subcluster_centers_.shape[1]:
496501
raise ValueError(
497502
"Training data and predicted data do "
498503
"not have same no. of features.")
@@ -532,9 +537,10 @@ def partial_fit(self, X=None, y=None):
532537
"""
533538
if X is None:
534539
# Perform just the final global clustering step.
535-
self.global_clustering()
540+
self._global_clustering()
536541
return self
537542
else:
543+
self._check_fit(X)
538544
return self.fit(X)
539545

540546
def transform(self, X, y=None):
@@ -557,7 +563,7 @@ def transform(self, X, y=None):
557563
raise ValueError("Fit training data before predicting")
558564
return euclidean_distances(X, self.subcluster_centers_)
559565

560-
def global_clustering(self, X=None):
566+
def _global_clustering(self, X=None):
561567
"""
562568
Global clustering for the subclusters obtained after fitting
563569
"""

0 commit comments

Comments
 (0)