Skip to content

ENH Improve performance of KNeighborsClassifier.predict #23721

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 22 commits into from

Conversation

Micky774
Copy link
Contributor

@Micky774 Micky774 commented Jun 21, 2022

Reference Issues/PRs

Fixes #13783
Resolves #14543 (stalled/draft)

What does this implement/fix? Explain your changes.

Leverages csr_matrix to compute fast mode in KNeighborsClassifier.predict (replaces scipy.stats.mode) for uniform weights.

Any other comments?

Theoretically, this is a faster operation even in the weighted case; however, csr_matrix.argmax sums duplicates (which is what we aim to exploit), but this actually changes the underlying data array which is very problematic since it leads to incorrect results in the multi-output loop. We could "fix" this by passing in copies of the weights to create the csr_matrix, but that defeats the whole point.

Hence, currently, this is only a meaningful speedup for weights in {None, "uniform"} since we can easily compute an ndarray of ones each loop iteration to feed to the csr_matrix without worrying about it being mutated.

To Do

  • Memory benchmarks

@Micky774
Copy link
Contributor Author

Initial benchmarks generated w/ this script. Note that the implementation labeled "PR" uses the sparse matrix argmax for all methods, while "hybrid" uses it only for uniform weights. This PR currently implements to so-called "hybrid" option.

Plot

dc950408-a251-480d-b042-96332e851ed4

@Micky774
Copy link
Contributor Author

No significant differences in memory profile either. Tested in Jupyter notebook with

from sklearn.neighbors import KNeighborsClassifier
import numpy as np

rng = np.random.RandomState(0)
X = rng.random(size=(10_000, 200))
y = rng.randint(1, 6, size=(10_000, 3))
neigh = KNeighborsClassifier(n_neighbors=10, weights="uniform")
neigh.fit(X, y)

%load_ext memory_profiler
%memit neigh.predict(X)

Both implementations were ~23-25MiB for weights in {"uniform", "distances"}

@jjerphan
Copy link
Member

Thank you for the follow-up, @Micky774.

Note that we might then introduce dedicated PairwiseDistancesReductions as back-ends for KNeighbors{Classifier,Regressor}.{predict,predict_proba}. For more details, see: #22587.

If you are interested and want to implement some of them, feel free to! 🙂

@ogrisel
Copy link
Member

ogrisel commented Jun 24, 2022

@Micky774 it would be great if you could review #23604 in particular which is the first step to accelerate the k-NN queries on sparse data using the new Cython infrastructure.

@jjerphan
Copy link
Member

Everybody's welcome on the PairwiseDistancesReductions-:boat:.

@Micky774
Copy link
Contributor Author

@ogrisel @jjerphan Even though it'll be replaced w/ the new Cython back-end, since it is simple enough and demonstrably better than what we have right now, may we move forward w/ a review just to provide a performance game in the meantime?

@jjerphan
Copy link
Member

Yes, sure. I do not see those two tasks as mutually exclusive.

@ogrisel
Copy link
Member

ogrisel commented Jun 29, 2022

Indeed, I had not realized that this was happening after the neighbors computation. This code can probably be further Cythonized to use OpenMP / prange parallelism but in the mean time this looks like an quick way to improve the code. Let me do a proper review now.

@Micky774 Micky774 added the Quick Review For PRs that are quick to review label Jul 6, 2022
jjerphan
jjerphan previously approved these changes Jul 7, 2022
Copy link
Member

@jjerphan jjerphan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you @Micky774 and @ogrisel.

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a significant memory overhead when n_samples gets bigger:

from sklearn.neighbors import KNeighborsClassifier
import numpy as np

rng = np.random.RandomState(0)
n_samples = 100_000
X = rng.random(size=(n_samples, 200))
y = rng.randint(1, 6, size=(n_samples, 3))
neigh = KNeighborsClassifier(n_neighbors=10, weights="uniform")
neigh.fit(X, y)

On main, I get:

%memit neigh.predict(X)
# peak memory: 319.08 MiB, increment: 33.17 MiB

and with this PR:

%memit neigh.predict(X)
# peak memory: 546.52 MiB, increment: 37.00 MiB

I think most of the memory overhead is from constructing the sparse matrix itself.

@Micky774
Copy link
Contributor Author

There is a significant memory overhead when n_samples get bigger:
...
I think most of the memory overhead is from constructing the sparse matrix itself.

Yeah that's actually very significant. In that case, it may be better just to go for the PairwiseDistancesReductions back-end instead of using this as an intermediate speed-up. We could potentially just repurpose some of the machinery that scipy.sparse, but I don't think it's worth the effort honestly.

@thomasjpfan
Copy link
Member

With the memory overhead, I am -1 overall on this PR with CSR matrices. We likely need another approach all together to resolve #13783, either through Cython (as suggested in #23721 (comment)) or something more efficient in Python.

@jjerphan jjerphan dismissed their stale review July 21, 2022 13:54

Disapproval due to performance regressions

@Micky774 Micky774 closed this Jul 21, 2022
@Micky774 Micky774 deleted the knn_predict_performance branch July 25, 2022 21:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module:neighbors Performance Quick Review For PRs that are quick to review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

knn predict unreasonably slow b/c of use of scipy.stats.mode
5 participants