From 62695fd0abad7f6bc52ea6b0992a3f588a864726 Mon Sep 17 00:00:00 2001 From: Roman Yurchak Date: Thu, 1 Aug 2019 17:18:20 +0200 Subject: [PATCH 01/14] Initial version of the faster implementation of scipy.stats.mode --- sklearn/neighbors/classification.py | 4 +-- sklearn/utils/extmath.py | 39 +++++++++++++++++++++++++++++ sklearn/utils/tests/test_extmath.py | 30 ++++++++++++++++++++++ 3 files changed, 71 insertions(+), 2 deletions(-) diff --git a/sklearn/neighbors/classification.py b/sklearn/neighbors/classification.py index f0fd0b084365a..53a472086198b 100644 --- a/sklearn/neighbors/classification.py +++ b/sklearn/neighbors/classification.py @@ -10,7 +10,7 @@ import numpy as np from scipy import stats -from ..utils.extmath import weighted_mode +from ..utils.extmath import weighted_mode, _fast_mode from .base import \ _check_weights, _get_weights, \ @@ -180,7 +180,7 @@ def predict(self, X): y_pred = np.empty((n_samples, n_outputs), dtype=classes_[0].dtype) for k, classes_k in enumerate(classes_): if weights is None: - mode, _ = stats.mode(_y[neigh_ind, k], axis=1) + mode = _fast_mode(_y[neigh_ind, k], axis=1) else: mode, _ = weighted_mode(_y[neigh_ind, k], weights, axis=1) diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index fcb03b0cecddd..13b467aef2a8c 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -357,6 +357,45 @@ def randomized_svd(M, n_components, n_oversamples=10, n_iter='auto', return U[:, :n_components], s[:n_components], V[:n_components, :] +def _fast_mode(x, axis=1): + """Returns a faster equivalent for scipy.mode + + This is only implemented for positive integer data. + + Parameters + ---------- + x : array_like, shape (n_samples, n_components) + n-dimensional array of which to find mode(s). + axis : int, optional + Axis along which to operate. Default is 1. + Only axis=1 is supported. + + Returns + ------- + mode: ndarray, shape=(n_samples) + index of the mode + + Examples + -------- + >>> x = np.array([[0, 1, 1], [2, 0, 2]]) + >>> _fast_mode(x, axis=1) + array([1, 2]) + """ + if not hasattr(x, "__array__") or x.dtype.kind != 'i' or x.ndim != 2: + raise ValueError('_fast_mode is only implemented for 2D integer ' + 'arrays!') + data = np.ones(x.shape, dtype=np.int).ravel() + indices = x.ravel() + indptr = np.arange(x.shape[0]+1)*x.shape[1] + # we use the fact that data for repeated indices is summed when + # creating sparse arrays. The index with highest value is then the mode + if axis != 1: + raise ValueError('Only axis=1 is supported.') + z = sparse.csr_matrix((data, indices, indptr), + shape=(x.shape[0], x.max() + 1)) + return np.asarray(np.argmax(z, axis=1)) + + def weighted_mode(a, w, axis=0): """Returns an array of the weighted modal (most common) value in a diff --git a/sklearn/utils/tests/test_extmath.py b/sklearn/utils/tests/test_extmath.py index 2da6e5f5e9943..3ebe02c9b8bc4 100644 --- a/sklearn/utils/tests/test_extmath.py +++ b/sklearn/utils/tests/test_extmath.py @@ -31,6 +31,7 @@ from sklearn.utils.extmath import _deterministic_vector_sign_flip from sklearn.utils.extmath import softmax from sklearn.utils.extmath import stable_cumsum +from sklearn.utils.extmath import _fast_mode from sklearn.datasets.samples_generator import make_low_rank_matrix @@ -640,3 +641,32 @@ def test_stable_cumsum(): assert_array_equal(stable_cumsum(A, axis=0), np.cumsum(A, axis=0)) assert_array_equal(stable_cumsum(A, axis=1), np.cumsum(A, axis=1)) assert_array_equal(stable_cumsum(A, axis=2), np.cumsum(A, axis=2)) + + +class TestFastMode(): + + def test_scipy_stats_axis_1(self): + rng = np.random.RandomState(0) + + X = rng.randint(10, size=(100, 20)) + mode_ref, _ = stats.mode(X, axis=1) + mode = _fast_mode(X, axis=1) + assert_array_equal(mode, mode_ref) + + @pytest.mark.parametrize( + 'x', + [np.ones((10, 10), dtype=np.float), 1, np.ones(5, dtype=np.int)], + ids=['array_float64', 'int', '1D-array']) + def test_input_validation(self, x): + with pytest.raises(ValueError, + match='only implemented for 2D integer arrays'): + _fast_mode(x) + + def test_ties(self): + # Check that ties are resolved in the same way as in stats.mode + X = np.ones((6, 9), dtype=np.int) + X[:, 3:] = 2 + X[:, 6:] = 3 + mode_ref, _ = stats.mode(X, axis=1) + mode = _fast_mode(X, axis=1) + assert_array_equal(mode, mode_ref) From 6bd2a08568d01f6853374f696365fd08ae7d700c Mon Sep 17 00:00:00 2001 From: Roman Yurchak Date: Thu, 1 Aug 2019 17:59:32 +0200 Subject: [PATCH 02/14] Also support weights --- sklearn/neighbors/classification.py | 6 +---- sklearn/utils/extmath.py | 37 +++++++++++++++++++++++++---- sklearn/utils/tests/test_extmath.py | 34 ++++++++++++++++++++++++++ 3 files changed, 68 insertions(+), 9 deletions(-) diff --git a/sklearn/neighbors/classification.py b/sklearn/neighbors/classification.py index 53a472086198b..990b08e08bd1d 100644 --- a/sklearn/neighbors/classification.py +++ b/sklearn/neighbors/classification.py @@ -179,11 +179,7 @@ def predict(self, X): y_pred = np.empty((n_samples, n_outputs), dtype=classes_[0].dtype) for k, classes_k in enumerate(classes_): - if weights is None: - mode = _fast_mode(_y[neigh_ind, k], axis=1) - else: - mode, _ = weighted_mode(_y[neigh_ind, k], weights, axis=1) - + mode = _fast_mode(_y[neigh_ind, k], weights=weights, axis=1) mode = np.asarray(mode.ravel(), dtype=np.intp) y_pred[:, k] = classes_k.take(mode) diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index 13b467aef2a8c..184454c44382a 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -357,7 +357,7 @@ def randomized_svd(M, n_components, n_oversamples=10, n_iter='auto', return U[:, :n_components], s[:n_components], V[:n_components, :] -def _fast_mode(x, axis=1): +def _fast_mode(x, weights=None, axis=1): """Returns a faster equivalent for scipy.mode This is only implemented for positive integer data. @@ -366,6 +366,8 @@ def _fast_mode(x, axis=1): ---------- x : array_like, shape (n_samples, n_components) n-dimensional array of which to find mode(s). + w : array_like, shape (n_samples, n_components) + n-dimensional array of weights for each value axis : int, optional Axis along which to operate. Default is 1. Only axis=1 is supported. @@ -379,13 +381,40 @@ def _fast_mode(x, axis=1): -------- >>> x = np.array([[0, 1, 1], [2, 0, 2]]) >>> _fast_mode(x, axis=1) - array([1, 2]) + array([[1], [2]]) + + Next we illustrate weighted mode calculations + + >>> x = np.array([[4, 1, 4, 2, 4, 2]]) + >>> weights = np.array([[1, 1, 1, 1, 1, 1]]) + >>> _fast_mode(x, weights) + array([[4]]) + + The value 4 appears three times: with uniform weights, the result is + simply the mode of the distribution. + + >>> weights = np.array([[1, 3, 0.5, 1.5, 1, 2]]) # deweight the 4's + >>> _fast_mode(x, weights) + array([[2]]) + + The value 2 has the highest score: it appears twice with weights of + 1.5 and 2: the sum of these is 3.5. + """ if not hasattr(x, "__array__") or x.dtype.kind != 'i' or x.ndim != 2: raise ValueError('_fast_mode is only implemented for 2D integer ' 'arrays!') - data = np.ones(x.shape, dtype=np.int).ravel() - indices = x.ravel() + if x.min() < 0: + raise ValueError('only positive data is supported.') + + if weights is None: + data = np.ones(x.shape, dtype=np.int).ravel() + else: + if x.shape != weights.shape: + raise ValueError("x.shape {} != weights.shape {}" + .format(x.shape, weights.shape)) + data = np.ascontiguousarray(weights).ravel() + indices = np.ascontiguousarray(x).ravel() indptr = np.arange(x.shape[0]+1)*x.shape[1] # we use the fact that data for repeated indices is summed when # creating sparse arrays. The index with highest value is then the mode diff --git a/sklearn/utils/tests/test_extmath.py b/sklearn/utils/tests/test_extmath.py index 3ebe02c9b8bc4..90de0933ddcaf 100644 --- a/sklearn/utils/tests/test_extmath.py +++ b/sklearn/utils/tests/test_extmath.py @@ -662,6 +662,12 @@ def test_input_validation(self, x): match='only implemented for 2D integer arrays'): _fast_mode(x) + def test_negative_values(self): + x = - np.ones((10, 10), dtype=np.int) + with pytest.raises(ValueError, + match="only positive data is supported"): + _fast_mode(x) + def test_ties(self): # Check that ties are resolved in the same way as in stats.mode X = np.ones((6, 9), dtype=np.int) @@ -670,3 +676,31 @@ def test_ties(self): mode_ref, _ = stats.mode(X, axis=1) mode = _fast_mode(X, axis=1) assert_array_equal(mode, mode_ref) + + def test_uniform_weights(self): + # with uniform weights, results should be identical to + # stats.mode + rng = np.random.RandomState(0) + x = rng.randint(10, size=(10, 5)) + weights = np.ones(x.shape) + + mode, _ = stats.mode(x, axis=1) + mode2 = _fast_mode(x, weights, axis=1) + + assert_array_equal(mode, mode2) + + def test_random_weights(self): + # set this up so that each row should have a weighted mode of 6, + # with a score that is easily reproduced + mode_result = 6 + + rng = np.random.RandomState(0) + x = rng.randint(mode_result, size=(100, 10)) + w = rng.random_sample(x.shape) + + x[:, :5] = mode_result + w[:, :5] += 1 + + mode = _fast_mode(x, w, axis=1) + + assert_array_equal(mode, mode_result) From c7d3286f8dc6a1b3234bed2d972458e9e2d4a99a Mon Sep 17 00:00:00 2001 From: Roman Yurchak Date: Thu, 1 Aug 2019 18:17:05 +0200 Subject: [PATCH 03/14] Fixes to RadiusNeigboursClassifier --- sklearn/neighbors/classification.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sklearn/neighbors/classification.py b/sklearn/neighbors/classification.py index 990b08e08bd1d..8f66ae57cba9a 100644 --- a/sklearn/neighbors/classification.py +++ b/sklearn/neighbors/classification.py @@ -9,8 +9,7 @@ # License: BSD 3 clause (C) INRIA, University of Amsterdam import numpy as np -from scipy import stats -from ..utils.extmath import weighted_mode, _fast_mode +from ..utils.extmath import _fast_mode from .base import \ _check_weights, _get_weights, \ @@ -416,11 +415,12 @@ def predict(self, X): pred_labels = np.zeros(len(neigh_ind), dtype=object) pred_labels[:] = [_y[ind, k] for ind in neigh_ind] if weights is None: - mode = np.array([stats.mode(pl)[0] + mode = np.array([_fast_mode(np.atleast_2d(pl)).ravel() for pl in pred_labels[inliers]], dtype=np.int) else: mode = np.array( - [weighted_mode(pl, w)[0] + [_fast_mode(np.atleast_2d(pl), np.atleast_2d(w), + axis=1).ravel() for (pl, w) in zip(pred_labels[inliers], weights[inliers]) ], dtype=np.int) From c391beed0a5dc868b8eb7346bbf776c01297eded Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 21 Jun 2022 16:54:16 -0400 Subject: [PATCH 04/14] Simplified implementation and fixed critical bugs --- sklearn/neighbors/_classification.py | 18 ++++++-- sklearn/utils/extmath.py | 67 ---------------------------- sklearn/utils/tests/test_extmath.py | 67 +--------------------------- 3 files changed, 16 insertions(+), 136 deletions(-) diff --git a/sklearn/neighbors/_classification.py b/sklearn/neighbors/_classification.py index 7a86199fabd47..e7d4d884b6f5a 100644 --- a/sklearn/neighbors/_classification.py +++ b/sklearn/neighbors/_classification.py @@ -9,8 +9,8 @@ # License: BSD 3 clause (C) INRIA, University of Amsterdam import numpy as np -from ..utils.extmath import _fast_mode from ..utils.validation import _is_arraylike, _num_samples +from scipy import sparse import warnings from ._base import _check_weights, _get_weights @@ -198,6 +198,15 @@ def fit(self, X, y): return self._fit(X, y) + def _build_sparse_matrix(self, neighbor_labels, weights): + data = weights.ravel() + indices = neighbor_labels.ravel() + indptr = np.arange(neighbor_labels.shape[0] + 1) * self.n_neighbors + return sparse.csr_matrix( + (data, indices, indptr), + shape=(neighbor_labels.shape[0], neighbor_labels.max() + 1), + ) + def predict(self, X): """Predict the class labels for the provided data. @@ -228,11 +237,14 @@ def predict(self, X): n_outputs = len(classes_) n_queries = _num_samples(X) - weights = _get_weights(neigh_dist, self.weights) y_pred = np.empty((n_queries, n_outputs), dtype=classes_[0].dtype) for k, classes_k in enumerate(classes_): - mode = _fast_mode(_y[neigh_ind, k], weights=weights, axis=1) + + weights = _get_weights(neigh_dist, self.weights) + if weights is None: + weights = np.ones(neigh_ind.shape, dtype=np.int) + mode = self._build_sparse_matrix(_y[neigh_ind, k], weights).argmax(axis=1) mode = np.asarray(mode.ravel(), dtype=np.intp) y_pred[:, k] = classes_k.take(mode) diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index 2c5a4e18081ec..e4513a62bf07e 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -434,73 +434,6 @@ def randomized_svd( return U[:, :n_components], s[:n_components], Vt[:n_components, :] -def _fast_mode(x, weights=None, *, axis=1): - """Returns a faster equivalent for scipy.mode - - This is only implemented for positive integer data. - - Parameters - ---------- - x : array_like, shape (n_samples, n_components) - n-dimensional array of which to find mode(s). - w : array_like, shape (n_samples, n_components) - n-dimensional array of weights for each value - axis : int, optional - Axis along which to operate. Default is 1. - Only axis=1 is supported. - - Returns - ------- - mode: ndarray, shape=(n_samples) - index of the mode - - Examples - -------- - >>> x = np.array([[0, 1, 1], [2, 0, 2]]) - >>> _fast_mode(x, axis=1) - array([[1], [2]]) - - Next we illustrate weighted mode calculations - - >>> x = np.array([[4, 1, 4, 2, 4, 2]]) - >>> weights = np.array([[1, 1, 1, 1, 1, 1]]) - >>> _fast_mode(x, weights) - array([[4]]) - - The value 4 appears three times: with uniform weights, the result is - simply the mode of the distribution. - - >>> weights = np.array([[1, 3, 0.5, 1.5, 1, 2]]) # deweight the 4's - >>> _fast_mode(x, weights) - array([[2]]) - - The value 2 has the highest score: it appears twice with weights of - 1.5 and 2: the sum of these is 3.5. - - """ - if not hasattr(x, "__array__") or x.dtype.kind != "i" or x.ndim != 2: - raise ValueError("_fast_mode is only implemented for 2D integer arrays!") - if x.min() < 0: - raise ValueError("only positive data is supported.") - - if weights is None: - data = np.ones(x.shape, dtype=np.int).ravel() - else: - if x.shape != weights.shape: - raise ValueError( - "x.shape {} != weights.shape {}".format(x.shape, weights.shape) - ) - data = np.ascontiguousarray(weights).ravel() - indices = np.ascontiguousarray(x).ravel() - indptr = np.arange(x.shape[0] + 1) * x.shape[1] - # we use the fact that data for repeated indices is summed when - # creating sparse arrays. The index with highest value is then the mode - if axis != 1: - raise ValueError("Only axis=1 is supported.") - z = sparse.csr_matrix((data, indices, indptr), shape=(x.shape[0], x.max() + 1)) - return np.asarray(np.argmax(z, axis=1)) - - def _randomized_eigsh( M, n_components, diff --git a/sklearn/utils/tests/test_extmath.py b/sklearn/utils/tests/test_extmath.py index b4e3f983c02b3..07a553c8cf09d 100644 --- a/sklearn/utils/tests/test_extmath.py +++ b/sklearn/utils/tests/test_extmath.py @@ -31,7 +31,7 @@ from sklearn.utils.extmath import _deterministic_vector_sign_flip from sklearn.utils.extmath import softmax from sklearn.utils.extmath import stable_cumsum -from sklearn.utils.extmath import safe_sparse_dot, _fast_mode +from sklearn.utils.extmath import safe_sparse_dot from sklearn.datasets import make_low_rank_matrix, make_sparse_spd_matrix @@ -927,71 +927,6 @@ def test_stable_cumsum(): assert_array_equal(stable_cumsum(A, axis=2), np.cumsum(A, axis=2)) -def test_fast_mode_scipy_stats_axis_1(): - rng = np.random.RandomState(0) - - X = rng.randint(10, size=(100, 20)) - mode_ref, _ = stats.mode(X, axis=1) - mode = _fast_mode(X, axis=1) - assert_array_equal(mode, mode_ref) - - -@pytest.mark.parametrize( - "x", - [np.ones((10, 10), dtype=np.float), 1, np.ones(5, dtype=np.int)], - ids=["array_float64", "int", "1D-array"], -) -def test_fast_mode_input_validation(x): - with pytest.raises(ValueError, match="only implemented for 2D integer arrays"): - _fast_mode(x) - - -def test_fast_mode_negative_values(): - x = -np.ones((10, 10), dtype=np.int) - with pytest.raises(ValueError, match="only positive data is supported"): - _fast_mode(x) - - -def test_fast_mode_ties(): - # Check that ties are resolved in the same way as in stats.mode - X = np.ones((6, 9), dtype=np.int) - X[:, 3:] = 2 - X[:, 6:] = 3 - mode_ref, _ = stats.mode(X, axis=1) - mode = _fast_mode(X, axis=1) - assert_array_equal(mode, mode_ref) - - -def test_fast_mode_uniform_weights(): - # with uniform weights, results should be identical to - # stats.mode - rng = np.random.RandomState(0) - x = rng.randint(10, size=(10, 5)) - weights = np.ones(x.shape) - - mode, _ = stats.mode(x, axis=1) - mode2 = _fast_mode(x, weights, axis=1) - - assert_array_equal(mode, mode2) - - -def test_fast_mode_random_weights(): - # set this up so that each row should have a weighted mode of 6, - # with a score that is easily reproduced - mode_result = 6 - - rng = np.random.RandomState(0) - x = rng.randint(mode_result, size=(100, 10)) - w = rng.random_sample(x.shape) - - x[:, :5] = mode_result - w[:, :5] += 1 - - mode = _fast_mode(x, w, axis=1) - - assert_array_equal(mode, mode_result) - - @pytest.mark.parametrize( "A_array_constr", [np.array, sparse.csr_matrix], ids=["dense", "sparse"] ) From 77d6346f6c0a45081972adf8732fb1117165badf Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 21 Jun 2022 18:46:30 -0400 Subject: [PATCH 05/14] Cleaned up implementation --- sklearn/neighbors/_classification.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/sklearn/neighbors/_classification.py b/sklearn/neighbors/_classification.py index e7d4d884b6f5a..51b81bcba8745 100644 --- a/sklearn/neighbors/_classification.py +++ b/sklearn/neighbors/_classification.py @@ -9,6 +9,7 @@ # License: BSD 3 clause (C) INRIA, University of Amsterdam import numpy as np +from ..utils.extmath import weighted_mode from ..utils.validation import _is_arraylike, _num_samples from scipy import sparse @@ -239,12 +240,15 @@ def predict(self, X): n_queries = _num_samples(X) y_pred = np.empty((n_queries, n_outputs), dtype=classes_[0].dtype) + weights = _get_weights(neigh_dist, self.weights) for k, classes_k in enumerate(classes_): - - weights = _get_weights(neigh_dist, self.weights) if weights is None: - weights = np.ones(neigh_ind.shape, dtype=np.int) - mode = self._build_sparse_matrix(_y[neigh_ind, k], weights).argmax(axis=1) + _weights = np.ones(neigh_ind.shape, dtype=np.int) + mode = self._build_sparse_matrix(_y[neigh_ind, k], _weights).argmax( + axis=1 + ) + else: + mode, _ = weighted_mode(_y[neigh_ind, k], weights, axis=1) mode = np.asarray(mode.ravel(), dtype=np.intp) y_pred[:, k] = classes_k.take(mode) From 9c7abb4a1d31a20a2cf813e79d5a647e04522782 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 21 Jun 2022 18:53:40 -0400 Subject: [PATCH 06/14] Formatting --- sklearn/neighbors/_classification.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/neighbors/_classification.py b/sklearn/neighbors/_classification.py index 51b81bcba8745..bdbdf61df7784 100644 --- a/sklearn/neighbors/_classification.py +++ b/sklearn/neighbors/_classification.py @@ -249,6 +249,7 @@ def predict(self, X): ) else: mode, _ = weighted_mode(_y[neigh_ind, k], weights, axis=1) + mode = np.asarray(mode.ravel(), dtype=np.intp) y_pred[:, k] = classes_k.take(mode) From f4ec893824c72b6bfcc0eb32b0e5c3bb6e8c43a2 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 21 Jun 2022 18:54:16 -0400 Subject: [PATCH 07/14] Formatting --- sklearn/neighbors/_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/neighbors/_classification.py b/sklearn/neighbors/_classification.py index bdbdf61df7784..f69cdea378539 100644 --- a/sklearn/neighbors/_classification.py +++ b/sklearn/neighbors/_classification.py @@ -238,9 +238,9 @@ def predict(self, X): n_outputs = len(classes_) n_queries = _num_samples(X) + weights = _get_weights(neigh_dist, self.weights) y_pred = np.empty((n_queries, n_outputs), dtype=classes_[0].dtype) - weights = _get_weights(neigh_dist, self.weights) for k, classes_k in enumerate(classes_): if weights is None: _weights = np.ones(neigh_ind.shape, dtype=np.int) From 4106805b67195601e49ed6dfe8a4efa400cb6edb Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 21 Jun 2022 19:47:52 -0400 Subject: [PATCH 08/14] Updated dtype --- sklearn/neighbors/_classification.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/neighbors/_classification.py b/sklearn/neighbors/_classification.py index f69cdea378539..264d4ed34d9e3 100644 --- a/sklearn/neighbors/_classification.py +++ b/sklearn/neighbors/_classification.py @@ -243,14 +243,14 @@ def predict(self, X): y_pred = np.empty((n_queries, n_outputs), dtype=classes_[0].dtype) for k, classes_k in enumerate(classes_): if weights is None: - _weights = np.ones(neigh_ind.shape, dtype=np.int) + _weights = np.ones(neigh_ind.shape, dtype=int) mode = self._build_sparse_matrix(_y[neigh_ind, k], _weights).argmax( axis=1 ) else: mode, _ = weighted_mode(_y[neigh_ind, k], weights, axis=1) - mode = np.asarray(mode.ravel(), dtype=np.intp) + mode = np.asarray(mode.ravel(), dtype=int) y_pred[:, k] = classes_k.take(mode) if not self.outputs_2d_: From ce3ad2c10036e08e0100cc0196f9bbc7e0a896eb Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 21 Jun 2022 19:55:06 -0400 Subject: [PATCH 09/14] Added changelog entry --- doc/whats_new/v1.2.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/doc/whats_new/v1.2.rst b/doc/whats_new/v1.2.rst index efda80e75e04a..3a5c6ab508abd 100644 --- a/doc/whats_new/v1.2.rst +++ b/doc/whats_new/v1.2.rst @@ -171,6 +171,10 @@ Changelog matrices in a variety of estimators and avoid an `EfficiencyWarning`. :pr:`23139` by `Tom Dupre la Tour`_. +- |Enhancement| :func:`neighbors.KNeighborsClassifier.predict` is now much faster + by leveraging `scipy.sparse.csr_matrix` format for mode calculation via + `csr_matrix.argmax`. :pr:`23721` by :user:`Meekail Zain ` + :mod:`sklearn.svm` .................. From 46c262304830b5945a2865bf8f58a3d746c009ee Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 22 Jun 2022 13:01:59 -0400 Subject: [PATCH 10/14] Updated dtype to intp --- sklearn/neighbors/_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/neighbors/_classification.py b/sklearn/neighbors/_classification.py index 264d4ed34d9e3..c74eefa4d1444 100644 --- a/sklearn/neighbors/_classification.py +++ b/sklearn/neighbors/_classification.py @@ -250,7 +250,7 @@ def predict(self, X): else: mode, _ = weighted_mode(_y[neigh_ind, k], weights, axis=1) - mode = np.asarray(mode.ravel(), dtype=int) + mode = np.asarray(mode.ravel(), dtype=np.intp) y_pred[:, k] = classes_k.take(mode) if not self.outputs_2d_: From e79136266a3439f25e3ddaf8e8e9ced7df2596d0 Mon Sep 17 00:00:00 2001 From: Meekail Zain <34613774+Micky774@users.noreply.github.com> Date: Thu, 30 Jun 2022 12:25:37 -0400 Subject: [PATCH 11/14] Apply suggestions from code review Co-authored-by: Olivier Grisel --- sklearn/neighbors/_classification.py | 30 +++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/sklearn/neighbors/_classification.py b/sklearn/neighbors/_classification.py index 62601321aeee0..5c8eef0bb9062 100644 --- a/sklearn/neighbors/_classification.py +++ b/sklearn/neighbors/_classification.py @@ -206,13 +206,22 @@ def fit(self, X, y): return self._fit(X, y) - def _build_sparse_matrix(self, neighbor_labels, weights): - data = weights.ravel() - indices = neighbor_labels.ravel() - indptr = np.arange(neighbor_labels.shape[0] + 1) * self.n_neighbors + def _sparse_class_counts(self, neighbors_class_indices): + """Sparse class count encoding of neighbors classes + + Convert a dense numpy array of class integer indices for the results of + neighbors queries into a sparse CSR matrix with class counts. + + The sparse.csr_matrix constructor automatically sums repeated count + values in case a given query has several neighbors of the same class. + """ + n_queries, n_neighbors = neighbors_class_indices.shape + data = np.ones(shape=n_queries * n_neighbors, dtype=np.uint32) + indices = neighbors_class_indices.ravel() + indptr = np.arange(n_queries + 1) * n_neighbors return sparse.csr_matrix( (data, indices, indptr), - shape=(neighbor_labels.shape[0], neighbor_labels.max() + 1), + shape=(n_queries, neighbors_class_indices.max() + 1), ) def predict(self, X): @@ -250,10 +259,13 @@ def predict(self, X): y_pred = np.empty((n_queries, n_outputs), dtype=classes_[0].dtype) for k, classes_k in enumerate(classes_): if weights is None: - _weights = np.ones(neigh_ind.shape, dtype=int) - mode = self._build_sparse_matrix(_y[neigh_ind, k], _weights).argmax( - axis=1 - ) + # compute stats.mode(_y[neigh_ind, k], axis=1) more efficiently + # by using the argmax of a sparse (CSR) count representation. + # _y[neigh_ind, k] has shape (n_queries, n_neighbors) with + # integer values representing class indices. The sparse count + # representation has shape (n_queries, n_classes_k) with integer + # count values. + mode = self._sparse_class_counts(_y[neigh_ind, k]).argmax(axis=1) else: mode, _ = weighted_mode(_y[neigh_ind, k], weights, axis=1) From 12abe8f3bc6b7fa6f37cba51631d30f9877b9958 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 30 Jun 2022 12:32:01 -0400 Subject: [PATCH 12/14] Moved method-->function --- sklearn/neighbors/_classification.py | 41 ++++++++++++++-------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/sklearn/neighbors/_classification.py b/sklearn/neighbors/_classification.py index 5c8eef0bb9062..af4871e986582 100644 --- a/sklearn/neighbors/_classification.py +++ b/sklearn/neighbors/_classification.py @@ -20,6 +20,25 @@ from ..utils._param_validation import StrOptions +def _sparse_class_counts(neighbors_class_indices): + """Sparse class count encoding of neighbors classes + + Convert a dense numpy array of class integer indices for the results of + neighbors queries into a sparse CSR matrix with class counts. + + The sparse.csr_matrix constructor automatically sums repeated count + values in case a given query has several neighbors of the same class. + """ + n_queries, n_neighbors = neighbors_class_indices.shape + data = np.ones(shape=n_queries * n_neighbors, dtype=np.uint32) + indices = neighbors_class_indices.ravel() + indptr = np.arange(n_queries + 1) * n_neighbors + return sparse.csr_matrix( + (data, indices, indptr), + shape=(n_queries, neighbors_class_indices.max() + 1), + ) + + class KNeighborsClassifier(KNeighborsMixin, ClassifierMixin, NeighborsBase): """Classifier implementing the k-nearest neighbors vote. @@ -206,24 +225,6 @@ def fit(self, X, y): return self._fit(X, y) - def _sparse_class_counts(self, neighbors_class_indices): - """Sparse class count encoding of neighbors classes - - Convert a dense numpy array of class integer indices for the results of - neighbors queries into a sparse CSR matrix with class counts. - - The sparse.csr_matrix constructor automatically sums repeated count - values in case a given query has several neighbors of the same class. - """ - n_queries, n_neighbors = neighbors_class_indices.shape - data = np.ones(shape=n_queries * n_neighbors, dtype=np.uint32) - indices = neighbors_class_indices.ravel() - indptr = np.arange(n_queries + 1) * n_neighbors - return sparse.csr_matrix( - (data, indices, indptr), - shape=(n_queries, neighbors_class_indices.max() + 1), - ) - def predict(self, X): """Predict the class labels for the provided data. @@ -259,13 +260,13 @@ def predict(self, X): y_pred = np.empty((n_queries, n_outputs), dtype=classes_[0].dtype) for k, classes_k in enumerate(classes_): if weights is None: - # compute stats.mode(_y[neigh_ind, k], axis=1) more efficiently + # Compute stats.mode(_y[neigh_ind, k], axis=1) more efficiently # by using the argmax of a sparse (CSR) count representation. # _y[neigh_ind, k] has shape (n_queries, n_neighbors) with # integer values representing class indices. The sparse count # representation has shape (n_queries, n_classes_k) with integer # count values. - mode = self._sparse_class_counts(_y[neigh_ind, k]).argmax(axis=1) + mode = _sparse_class_counts(_y[neigh_ind, k]).argmax(axis=1) else: mode, _ = weighted_mode(_y[neigh_ind, k], weights, axis=1) From e506477d050aeb332f9ec8c0637c9d5d8f4c2897 Mon Sep 17 00:00:00 2001 From: Meekail Zain <34613774+Micky774@users.noreply.github.com> Date: Thu, 7 Jul 2022 12:07:45 -0400 Subject: [PATCH 13/14] Update doc/whats_new/v1.2.rst Co-authored-by: Julien Jerphanion --- doc/whats_new/v1.2.rst | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/doc/whats_new/v1.2.rst b/doc/whats_new/v1.2.rst index b16fb8631d70f..f7843c3ec70f9 100644 --- a/doc/whats_new/v1.2.rst +++ b/doc/whats_new/v1.2.rst @@ -214,9 +214,10 @@ Changelog matrices in a variety of estimators and avoid an `EfficiencyWarning`. :pr:`23139` by `Tom Dupre la Tour`_. -- |Enhancement| :func:`neighbors.KNeighborsClassifier.predict` is now much faster - by leveraging `scipy.sparse.csr_matrix` format for mode calculation via - `csr_matrix.argmax`. :pr:`23721` by :user:`Meekail Zain ` +- |Enhancement| :func:`neighbors.KNeighborsClassifier.predict` is up to + three times faster by leveraging `scipy.sparse.csr_matrix` format + for mode calculation via `csr_matrix.argmax`. + :pr:`23721` by :user:`Meekail Zain ` :mod:`sklearn.svm` .................. From ba80904343c09796017a87ea6a745a6f8f45143c Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 7 Jul 2022 12:14:53 -0400 Subject: [PATCH 14/14] Moved import statement --- sklearn/neighbors/_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/neighbors/_classification.py b/sklearn/neighbors/_classification.py index af4871e986582..faf82d2d69a88 100644 --- a/sklearn/neighbors/_classification.py +++ b/sklearn/neighbors/_classification.py @@ -9,9 +9,9 @@ # License: BSD 3 clause (C) INRIA, University of Amsterdam import numpy as np +from scipy import sparse from ..utils.extmath import weighted_mode from ..utils.validation import _is_arraylike, _num_samples -from scipy import sparse import warnings from ._base import _check_weights, _get_weights