Skip to content
Merged
8 changes: 5 additions & 3 deletions sklearn/cluster/_affinity_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@

from ..exceptions import ConvergenceWarning
from ..base import BaseEstimator, ClusterMixin
from ..utils import as_float_array, check_array, check_random_state
from ..utils import as_float_array, check_random_state
from ..utils.deprecation import deprecated
from ..utils.validation import check_is_fitted, _deprecate_positional_args
from ..metrics import euclidean_distances
from ..metrics import pairwise_distances_argmin
from .._config import config_context


def _equal_similarities_and_preferences(S, preference):
Expand Down Expand Up @@ -446,13 +447,14 @@ def predict(self, X):
Cluster labels.
"""
check_is_fitted(self)
X = check_array(X)
X = self._validate_data(X, reset=False)
if not hasattr(self, "cluster_centers_"):
raise ValueError("Predict method is not supported when "
"affinity='precomputed'.")

if self.cluster_centers_.shape[0] > 0:
return pairwise_distances_argmin(X, self.cluster_centers_)
with config_context(assume_finite=True):
return pairwise_distances_argmin(X, self.cluster_centers_)
Comment on lines +456 to +457
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a quick benchmark:

from sklearn.cluster import AffinityPropagation
from sklearn.datasets import make_classification

X, _ = make_classification(n_features=10_000, n_samples=5_000, random_state=42)
aff_prop = AffinityPropagation(random_state=42)
aff_prop.fit(X)

# this PR
%timeit aff_prop.predict(X)
# 182 ms ± 2.23 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# master
%timeit aff_prop.predict(X)
# 254 ms ± 5.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

else:
warnings.warn("This model does not have any cluster centers "
"because affinity propagation did not converge. "
Expand Down
20 changes: 11 additions & 9 deletions sklearn/cluster/_birch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
from ..metrics import pairwise_distances_argmin
from ..metrics.pairwise import euclidean_distances
from ..base import TransformerMixin, ClusterMixin, BaseEstimator
from ..utils import check_array
from ..utils.extmath import row_norms
from ..utils.validation import check_is_fitted, _deprecate_positional_args
from ..exceptions import ConvergenceWarning
from . import AgglomerativeClustering
from .._config import config_context


def _iterate_sparse_X(X):
Expand Down Expand Up @@ -585,14 +585,14 @@ def predict(self, X):
labels : ndarray of shape(n_samples,)
Labelled data.
"""
X = check_array(X, accept_sparse='csr')
self._check_fit(X)
check_is_fitted(self)
X = self._validate_data(X, accept_sparse='csr', reset=False)
kwargs = {'Y_norm_squared': self._subcluster_norms}
return self.subcluster_labels_[
pairwise_distances_argmin(X,
self.subcluster_centers_,
metric_kwargs=kwargs)
]

with config_context(assume_finite=True):
argmin = pairwise_distances_argmin(X, self.subcluster_centers_,
metric_kwargs=kwargs)
return self.subcluster_labels_[argmin]

def transform(self, X):
"""
Expand All @@ -612,7 +612,9 @@ def transform(self, X):
Transformed data.
"""
check_is_fitted(self)
return euclidean_distances(X, self.subcluster_centers_)
self._validate_data(X, accept_sparse='csr', reset=False)
with config_context(assume_finite=True):
return euclidean_distances(X, self.subcluster_centers_)

def _global_clustering(self, X=None):
"""
Expand Down
6 changes: 1 addition & 5 deletions sklearn/cluster/_feature_agglomeration.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import numpy as np

from ..base import TransformerMixin
from ..utils import check_array
from ..utils.validation import check_is_fitted
from scipy.sparse import issparse

Expand Down Expand Up @@ -38,10 +37,7 @@ def transform(self, X):
"""
check_is_fitted(self)

X = check_array(X)
if len(self.labels_) != X.shape[1]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're assuming that the invariance of len(self.labels_) == X_train.shape[1] is enforced elsewhere, right? I guess this was only a workaround for not having n_features_in_ anyway.

raise ValueError("X has a different number of features than "
"during fitting.")
X = self._validate_data(X, reset=False)
if self.pooling_func == np.mean and not issparse(X):
size = np.bincount(self.labels_)
n_samples = X.shape[0]
Expand Down
12 changes: 3 additions & 9 deletions sklearn/cluster/_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,15 +854,9 @@ def _validate_center_shape(self, X, centers):
f"match the number of features of the data {X.shape[1]}.")

def _check_test_data(self, X):
X = check_array(X, accept_sparse='csr', dtype=[np.float64, np.float32],
order='C', accept_large_sparse=False)
n_samples, n_features = X.shape
expected_n_features = self.cluster_centers_.shape[1]
if not n_features == expected_n_features:
raise ValueError(
f"Incorrect number of features. Got {n_features} features, "
f"expected {expected_n_features}.")

X = self._validate_data(X, accept_sparse='csr', reset=False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we still need the function now? but fine with me.

dtype=[np.float64, np.float32],
order='C', accept_large_sparse=False)
return X

def _check_mkl_vcomp(self, X, n_samples):
Expand Down
6 changes: 4 additions & 2 deletions sklearn/cluster/_mean_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ..base import BaseEstimator, ClusterMixin
from ..neighbors import NearestNeighbors
from ..metrics.pairwise import pairwise_distances_argmin
from .._config import config_context


@_deprecate_positional_args
Expand Down Expand Up @@ -462,5 +463,6 @@ def predict(self, X):
Index of the cluster each sample belongs to.
"""
check_is_fitted(self)

return pairwise_distances_argmin(X, self.cluster_centers_)
X = self._validate_data(X, reset=False)
with config_context(assume_finite=True):
return pairwise_distances_argmin(X, self.cluster_centers_)
1 change: 0 additions & 1 deletion sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,6 @@ def test_search_cv(estimator, check, request):
# check_classifiers_train would need to be updated with the error message
N_FEATURES_IN_AFTER_FIT_MODULES_TO_IGNORE = {
'calibration',
'cluster',
'compose',
'covariance',
'cross_decomposition',
Expand Down