diff --git a/doc/whats_new/v1.2.rst b/doc/whats_new/v1.2.rst index 2e174eb5044b2..72155587ea853 100644 --- a/doc/whats_new/v1.2.rst +++ b/doc/whats_new/v1.2.rst @@ -220,6 +220,11 @@ Changelog matrices in a variety of estimators and avoid an `EfficiencyWarning`. :pr:`23139` by `Tom Dupre la Tour`_. +- |Enhancement| :func:`neighbors.KNeighborsClassifier.predict` is up to + three times faster by leveraging `scipy.sparse.csr_matrix` format + for mode calculation via `csr_matrix.argmax`. + :pr:`23721` by :user:`Meekail Zain ` + :mod:`sklearn.svm` .................. diff --git a/sklearn/neighbors/_classification.py b/sklearn/neighbors/_classification.py index 68f38a5c3cfbc..faf82d2d69a88 100644 --- a/sklearn/neighbors/_classification.py +++ b/sklearn/neighbors/_classification.py @@ -9,7 +9,7 @@ # License: BSD 3 clause (C) INRIA, University of Amsterdam import numpy as np -from scipy import stats +from scipy import sparse from ..utils.extmath import weighted_mode from ..utils.validation import _is_arraylike, _num_samples @@ -20,6 +20,25 @@ from ..utils._param_validation import StrOptions +def _sparse_class_counts(neighbors_class_indices): + """Sparse class count encoding of neighbors classes + + Convert a dense numpy array of class integer indices for the results of + neighbors queries into a sparse CSR matrix with class counts. + + The sparse.csr_matrix constructor automatically sums repeated count + values in case a given query has several neighbors of the same class. + """ + n_queries, n_neighbors = neighbors_class_indices.shape + data = np.ones(shape=n_queries * n_neighbors, dtype=np.uint32) + indices = neighbors_class_indices.ravel() + indptr = np.arange(n_queries + 1) * n_neighbors + return sparse.csr_matrix( + (data, indices, indptr), + shape=(n_queries, neighbors_class_indices.max() + 1), + ) + + class KNeighborsClassifier(KNeighborsMixin, ClassifierMixin, NeighborsBase): """Classifier implementing the k-nearest neighbors vote. @@ -241,7 +260,13 @@ def predict(self, X): y_pred = np.empty((n_queries, n_outputs), dtype=classes_[0].dtype) for k, classes_k in enumerate(classes_): if weights is None: - mode, _ = stats.mode(_y[neigh_ind, k], axis=1) + # Compute stats.mode(_y[neigh_ind, k], axis=1) more efficiently + # by using the argmax of a sparse (CSR) count representation. + # _y[neigh_ind, k] has shape (n_queries, n_neighbors) with + # integer values representing class indices. The sparse count + # representation has shape (n_queries, n_classes_k) with integer + # count values. + mode = _sparse_class_counts(_y[neigh_ind, k]).argmax(axis=1) else: mode, _ = weighted_mode(_y[neigh_ind, k], weights, axis=1)