Skip to content

MNT: Remove duplicated data validation done in internally used BinaryTrees #19418

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

jjerphan
Copy link
Member

@jjerphan jjerphan commented Feb 9, 2021

Reference Issues/PRs

Fixes #18749.

What does this implement/fix? Explain your changes.

BinaryTrees validate the data in some of their methods.
Those interfaces are used internally in some algorithms which already validate the data.
Hence, this PR proposes to remove the extra and duplicated data validation.

Any other comments?

This uses the global context manager for configuration, as suggested by @cmarmo in #18749.

I am not entirely sure if I've used it in the intended way, especially as this context manager gets
called in joblib.delayed method which get themselves serialized.

@jjerphan
Copy link
Member Author

jjerphan commented Apr 9, 2021

I get a strange behaviour when running the tests.

From this PR's base comit and after having applying some of the first commit's patches (1828e00), if I run all the tests at once using:

pytest

some tests, which aren't related to those changes, fail.

Those tests aren't failing if I run the suite module wise, for instance:

find sklearn -maxdepth 1 -type d -wholename "*/*" -exec pytest {} \;

Still, the failing tests are related to inputs checks.
I suspect something unexpected is happening with the config_context.

Might it be linked to its thread-unsafety?

@jnothman: do you have any idea? 🙂

Edit: tests are passing using a module-wise run as of 32a5cff.

@thomasjpfan
Copy link
Member

Can you see what happens when you include #18736 ? That PR should make the configuration threadsafe.

@jjerphan
Copy link
Member Author

jjerphan commented Apr 9, 2021

Can you see what happens when you include #18736 ? That PR should make the configuration threadsafe.

It does, thanks!

I'll wait for #18736 to be merged to merge main back in this PR.

@jjerphan jjerphan marked this pull request as ready for review May 5, 2021 06:43
@jjerphan jjerphan changed the title [WIP] MNT: Remove duplicated data validation done in internally used BinaryTrees MNT: Remove duplicated data validation done in internally used BinaryTrees May 5, 2021
Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At a glance, this should speed up the computation. Do we have benchmarks that demonstrates the improvement?

jjerphan added 2 commits May 9, 2021 21:54
We remove the validation of c done at the beginning of
KDTree.__init__ and KDTree.query_radius as c already got validated at
the beginning of feature_selection._estimate_mi.
@jjerphan
Copy link
Member Author

jjerphan commented May 10, 2021

Do we have benchmarks that demonstrates the improvement?

I am currently performing some benchmarks on the public interfaces impacted by those changes.

A CI test is failing, but this seems unrelated to those changes here: is it the case? 🤔

@thomasjpfan
Copy link
Member

It is unrelated, I opened #20071 to adjust the atol based on 32bitness

Copy link
Member

@rth rth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am currently performing some on the public interfaces impacted by those changes.

Could you post a summary here? I wasn't sure which commits are compared.

I was under the impression that checking for finite values is expensive, but I'm not longer so sure,

In [1]: from sklearn.utils import check_array

In [2]: import numpy as np

In [3]: x = np.random.RandomState(0).rand(10000, 500)

In [4]: %timeit x + 1
6.98 ms ± 22 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [5]: %timeit x.copy()
6.33 ms ± 8.24 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [6]: %timeit check_array(x)
2.16 ms ± 29.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

@jjerphan
Copy link
Member Author

Here is an extract of the benchmarks' report. My machines are currently busy running others, but I'd like to get back to it soon with a better benchmark plan on internal APIs.

This PR's changes potentially give no improvements.

· Creating environments
· Discovering benchmarks
· Running 12 total benchmarks (2 commits * 1 environments * 6 benchmarks)
[  0.00%] · For scikit-learn commit ff32212c <remove_validation_internal_trees> (round 1/1):
[  0.00%] ·· Benchmarking conda-py3.8-cython-joblib-numpy-scipy-threadpoolctl
[  8.33%] ··· knn.KernelDensityRemovedCheck.time_query                        ok
[  8.33%] ··· ======= =========== ===========
              --                 d           
              ------- -----------------------
                 n         10         100    
              ======= =========== ===========
                1000    130±6ms     349±20ms 
               10000   10.5±0.5s   38.4±0.6s 
              ======= =========== ===========

[ 16.67%] ··· knn.MutualInfoCCRemovedCheck.time_mi_cc                         ok
[ 16.67%] ··· ======= =========== ===========
              --                 d           
              ------- -----------------------
                 n         10         100    
              ======= =========== ===========
                1000   106±0.9ms    593±20ms 
               10000    712±10ms   7.74±0.2s 
              ======= =========== ===========

[ 25.00%] ··· knn.MutualInfoCDRemovedCheck.time_mi_cd                         ok
[ 25.00%] ··· ======= ========== ============
              --                 d           
              ------- -----------------------
                 n        10         100     
              ======= ========== ============
                1000   47.4±2ms    465±20ms  
               10000   451±20ms   4.35±0.09s 
              ======= ========== ============

[ 50.00%] · For scikit-learn commit de1262c3 <main> (round 1/1):
[ 50.00%] ·· Building for conda-py3.8-cython-joblib-numpy-scipy-threadpoolctl
[ 50.00%] ·· Benchmarking conda-py3.8-cython-joblib-numpy-scipy-threadpoolctl
[ 58.33%] ··· knn.KernelDensityRemovedCheck.time_query                        ok
[ 58.33%] ··· ======= =========== ===========
              --                 d           
              ------- -----------------------
                 n         10         100    
              ======= =========== ===========
                1000    133±3ms     352±20ms 
               10000   9.44±0.2s   37.1±0.7s 
              ======= =========== ===========

[ 66.67%] ··· knn.MutualInfoCCRemovedCheck.time_mi_cc                         ok
[ 66.67%] ··· ======= ========= ============
              --                d           
              ------- ----------------------
                 n        10        100     
              ======= ========= ============
                1000   106±1ms    583±4ms   
               10000   664±5ms   6.07±0.01s 
              ======= ========= ============

[ 75.00%] ··· knn.MutualInfoCDRemovedCheck.time_mi_cd                         ok
[ 75.00%] ··· ======= ============ =========
              --                d           
              ------- ----------------------
                 n         10         100   
              ======= ============ =========
                1000   42.1±0.6ms   424±2ms 
               10000    408±3ms     4.03±0s 
              ======= ============ =========

       before           after         ratio
     [de1262c3]       [ff32212c]
     <main>           <remove_validation_internal_trees>
+      42.1±0.6ms         47.4±2ms     1.12  knn.MutualInfoCDRemovedCheck.time_mi_cd(1000, 10)
+         408±3ms         451±20ms     1.11  knn.MutualInfoCDRemovedCheck.time_mi_cd(10000, 10)

SOME BENCHMARKS HAVE CHANGED SIGNIFICANTLY.
PERFORMANCE DECREASED.

@thomasjpfan
Copy link
Member

The overhead of the finite check appears with slightly bigger datasets:

from sklearn.utils import check_array
import numpy as np

x = np.random.RandomState(0).rand(100_000, 500)
%timeit check_array(x, force_all_finite=False)
# 14.3 µs ± 85 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

%timeit check_array(x, force_all_finite=True)
# 31.5 ms ± 427 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

@jjerphan
Copy link
Member Author

I reran those benchmarks. It looks like that performances decreased generally.

       before           after         ratio
     [c6751835]       [903125f2]
     <main>           <remove_validation_internal_trees>
+        667±30μs      1.41±0.07ms     2.12  knn.NeighborsBaseQueryRemovedCheck.time_kneighbors(100, 10)
+        571±60ms        863±100ms     1.51  knn.MutualInfoCDRemovedCheck.time_mi_cd(1000, 100)
+      4.81±0.03s        6.93±0.5s     1.44  knn.MutualInfoCDRemovedCheck.time_mi_cd(10000, 100)
+         187±3ms         270±20ms     1.44  knn.MutualInfoCDRemovedCheck.time_mi_cd(100, 100)
+      18.7±0.6ms         26.4±1ms     1.41  knn.MutualInfoCDRemovedCheck.time_mi_cd(100, 10)
+      47.2±0.8ms         65.3±8ms     1.38  knn.MutualInfoCDRemovedCheck.time_mi_cd(1000, 10)
+        744±20ms        1.02±0.1s     1.37  knn.MutualInfoCCRemovedCheck.time_mi_cc(10000, 10)
+        512±60ms         594±40ms     1.16  knn.MutualInfoCDRemovedCheck.time_mi_cd(10000, 10)
+      57.2±0.9ms         65.0±5ms     1.14  knn.NeighborsBaseQueryRemovedCheck.time_kneighbors(1000, 10)
+        51.4±3ms         56.7±2ms     1.10  knn.NeighborsBaseQueryRemovedCheck.time_radius_neighbors(1000, 10)
-        28.6±5ms       25.4±0.7ms     0.89  knn.NeighborsBaseCreationRemovedCheck.time_creation(10000, 10)
-      14.5±0.5ms       12.7±0.4ms     0.88  knn.MutualInfoCCRemovedCheck.time_mi_cc(100, 10)
-      6.60±0.4ms       5.35±0.2ms     0.81  knn.KernelDensityRemovedCheck.time_query(100, 100)

SOME BENCHMARKS HAVE CHANGED SIGNIFICANTLY.
PERFORMANCE DECREASED.
Full report
· Creating environments
· Discovering benchmarks
· Running 12 total benchmarks (2 commits * 1 environments * 6 benchmarks)
[  0.00%] · For scikit-learn commit 903125f2  (round 1/1):
[  0.00%] ·· Benchmarking conda-py3.9-cython-joblib-numpy-scipy-threadpoolctl
[  8.33%] ··· knn.KernelDensityRemovedCheck.time_query                        ok
[  8.33%] ··· ======= ============ ============
              --                  d            
              ------- -------------------------
                 n         10          100     
              ======= ============ ============
                100    1.05±0.2ms   5.35±0.2ms 
                1000    146±10ms     344±2ms   
               10000   10.9±0.6s     48.7±2s   
              ======= ============ ============

[ 16.67%] ··· knn.MutualInfoCCRemovedCheck.time_mi_cc ok
[ 16.67%] ··· ======= ============ ===========
-- d
------- ------------------------
n 10 100
======= ============ ===========
100 12.7±0.4ms 174±2ms
1000 110±5ms 683±30ms
10000 1.02±0.1s 8.92±0.3s
======= ============ ===========

[ 25.00%] ··· knn.MutualInfoCDRemovedCheck.time_mi_cd ok
[ 25.00%] ··· ======= ========== ===========
-- d
------- ----------------------
n 10 100
======= ========== ===========
100 26.4±1ms 270±20ms
1000 65.3±8ms 863±100ms
10000 594±40ms 6.93±0.5s
======= ========== ===========

[ 33.33%] ··· ...ghborsBaseCreationRemovedCheck.time_creation ok
[ 33.33%] ··· ======= ============ ==========
-- d
------- -----------------------
n 10 100
======= ============ ==========
100 179±10μs 735±30μs
1000 1.98±0.1ms 15.6±3ms
10000 25.4±0.7ms 180±30ms
======= ============ ==========

[ 41.67%] ··· ...ighborsBaseQueryRemovedCheck.time_kneighbors ok
[ 41.67%] ··· ======= ============= ============
-- d
------- --------------------------
n 10 100
======= ============= ============
100 1.41±0.07ms 4.21±0.3ms
1000 65.0±5ms 269±5ms
10000 1.79±0.06s 35.3±0.2s
======= ============= ============

[ 50.00%] ··· ...sBaseQueryRemovedCheck.time_radius_neighbors ok
[ 50.00%] ··· ======= ============ ============
-- d
------- -------------------------
n 10 100
======= ============ ============
100 772±40μs 5.84±0.4ms
1000 56.7±2ms 184±5ms
10000 1.16±0.03s 6.00±0.1s
======= ============ ============

[ 50.00%] · For scikit-learn commit c6751835

(round 1/1):
[ 50.00%] ·· Building for conda-py3.9-cython-joblib-numpy-scipy-threadpoolctl
[ 50.00%] ·· Benchmarking conda-py3.9-cython-joblib-numpy-scipy-threadpoolctl
[ 58.33%] ··· knn.KernelDensityRemovedCheck.time_query ok
[ 58.33%] ··· ======= ============= ============
-- d
------- --------------------------
n 10 100
======= ============= ============
100 1.10±0.07ms 6.60±0.4ms
1000 151±9ms 376±10ms
10000 11.5±0.6s 47.9±0.9s
======= ============= ============

[ 66.67%] ··· knn.MutualInfoCCRemovedCheck.time_mi_cc ok
[ 66.67%] ··· ======= ============ ===========
-- d
------- ------------------------
n 10 100
======= ============ ===========
100 14.5±0.5ms 187±2ms
1000 110±1ms 638±8ms
10000 744±20ms 7.05±0.1s
======= ============ ===========

[ 75.00%] ··· knn.MutualInfoCDRemovedCheck.time_mi_cd ok
[ 75.00%] ··· ======= ============ ============
-- d
------- -------------------------
n 10 100
======= ============ ============
100 18.7±0.6ms 187±3ms
1000 47.2±0.8ms 571±60ms
10000 512±60ms 4.81±0.03s
======= ============ ============

[ 83.33%] ··· ...ghborsBaseCreationRemovedCheck.time_creation ok
[ 83.33%] ··· ======= ============ ============
-- d
------- -------------------------
n 10 100
======= ============ ============
100 172±7μs 724±70μs
1000 2.07±0.1ms 14.9±0.6ms
10000 28.6±5ms 166±4ms
======= ============ ============

[ 91.67%] ··· ...ighborsBaseQueryRemovedCheck.time_kneighbors ok
[ 91.67%] ··· ======= ============ ============
-- d
------- -------------------------
n 10 100
======= ============ ============
100 667±30μs 4.21±0.2ms
1000 57.2±0.9ms 264±2ms
10000 1.67±0.01s 34.2±0.2s
======= ============ ============

[100.00%] ··· ...sBaseQueryRemovedCheck.time_radius_neighbors ok
[100.00%] ··· ======= ============ ============
-- d
------- -------------------------
n 10 100
======= ============ ============
100 771±60μs 5.33±0.2ms
1000 51.4±3ms 181±2ms
10000 1.11±0.01s 5.85±0.04s
======= ============ ============

   before           after         ratio
 [c6751835]       [903125f2]
 <main>           <remove_validation_internal_trees>
  •    667±30μs      1.41±0.07ms     2.12  knn.NeighborsBaseQueryRemovedCheck.time_kneighbors(100, 10)
    
  •    571±60ms        863±100ms     1.51  knn.MutualInfoCDRemovedCheck.time_mi_cd(1000, 100)
    
  •  4.81±0.03s        6.93±0.5s     1.44  knn.MutualInfoCDRemovedCheck.time_mi_cd(10000, 100)
    
  •     187±3ms         270±20ms     1.44  knn.MutualInfoCDRemovedCheck.time_mi_cd(100, 100)
    
  •  18.7±0.6ms         26.4±1ms     1.41  knn.MutualInfoCDRemovedCheck.time_mi_cd(100, 10)
    
  •  47.2±0.8ms         65.3±8ms     1.38  knn.MutualInfoCDRemovedCheck.time_mi_cd(1000, 10)
    
  •    744±20ms        1.02±0.1s     1.37  knn.MutualInfoCCRemovedCheck.time_mi_cc(10000, 10)
    
  •    512±60ms         594±40ms     1.16  knn.MutualInfoCDRemovedCheck.time_mi_cd(10000, 10)
    
  •  57.2±0.9ms         65.0±5ms     1.14  knn.NeighborsBaseQueryRemovedCheck.time_kneighbors(1000, 10)
    
  •    51.4±3ms         56.7±2ms     1.10  knn.NeighborsBaseQueryRemovedCheck.time_radius_neighbors(1000, 10)
    
  •    28.6±5ms       25.4±0.7ms     0.89  knn.NeighborsBaseCreationRemovedCheck.time_creation(10000, 10)
    
  •  14.5±0.5ms       12.7±0.4ms     0.88  knn.MutualInfoCCRemovedCheck.time_mi_cc(100, 10)
    
  •  6.60±0.4ms       5.35±0.2ms     0.81  knn.KernelDensityRemovedCheck.time_query(100, 100)
    

SOME BENCHMARKS HAVE CHANGED SIGNIFICANTLY.
PERFORMANCE DECREASED.

@jjerphan
Copy link
Member Author

What do you think @thomasjpfan and @rth? Should we explore it differently?

@jjerphan
Copy link
Member Author

I reran them on a machine with much more cores, and got a different summary

       before           after         ratio
     [1cd282d6]       [903125f2]
     <main>           <remove_validation_internal_trees>
-       134±0.2ms          121±2ms     0.90  19418.MutualInfoCCRemovedCheck.time_mi_cc(100, 100)
-     9.88±0.01ms         8.02±0ms     0.81  19418.NeighborsBaseCreationRemovedCheck.time_creation(1000, 100)

SOME BENCHMARKS HAVE CHANGED SIGNIFICANTLY.
PERFORMANCE INCREASED.
Full report
 → asv continuous -b RemovedCheck main remove_validation_internal_trees
· Creating environments
· Discovering benchmarks.
·· Uninstalling from conda-py3.9-cython-joblib-numpy-scipy-threadpoolctl...
·· Installing 903125f2  into conda-py3.9-cython-joblib-numpy-scipy-threadpoolctl..
· Running 12 total benchmarks (2 commits * 1 environments * 6 benchmarks)
[  0.00%] · For scikit-learn commit 903125f2  (round 1/1):
[  0.00%] ·· Benchmarking conda-py3.9-cython-joblib-numpy-scipy-threadpoolctl
[  8.33%] ··· 19418.KernelDensityRemovedCheck.time_query                                                                                                                                                                                                                                                                    ok
[  8.33%] ··· ======= ============= ============
              --                  d
              ------- --------------------------
                 n          10          100
              ======= ============= ============
                100    1.10±0.01ms   2.58±0.1ms
                1000    99.8±0.6ms    269±2ms
               10000    10.6±0.04s   22.3±0.01s
              ======= ============= ============

[ 16.67%] ··· 19418.MutualInfoCCRemovedCheck.time_mi_cc ok
[ 16.67%] ··· ======= ============ ============
-- d
------- -------------------------
n 10 100
======= ============ ============
100 13.3±0.2ms 121±2ms
1000 47.5±0.1ms 498±5ms
10000 635±60ms 5.23±0.01s
======= ============ ============

[ 25.00%] ··· 19418.MutualInfoCDRemovedCheck.time_mi_cd ok
[ 25.00%] ··· ======= ============ ============
-- d
------- -------------------------
n 10 100
======= ============ ============
100 18.5±0.3ms 178±1ms
1000 43.4±0.6ms 430±2ms
10000 436±5ms 4.24±0.09s
======= ============ ============

[ 33.33%] ··· 19418.NeighborsBaseCreationRemovedCheck.time_creation ok
[ 33.33%] ··· ======= ============= ============
-- d
------- --------------------------
n 10 100
======= ============= ============
100 195±0.6μs 390±6μs
1000 934±10μs 8.02±0ms
10000 13.6±0.03ms 120±0.07ms
======= ============= ============

[ 41.67%] ··· 19418.NeighborsBaseQueryRemovedCheck.time_kneighbors ok
[ 41.67%] ··· ======= ============ ============
-- d
------- -------------------------
n 10 100
======= ============ ============
100 648±20μs 1.70±0ms
1000 22.3±0.4ms 206±2ms
10000 1.44±0.01s 16.5±0.01s
======= ============ ============

[ 50.00%] ··· 19418.NeighborsBaseQueryRemovedCheck.time_radius_neighbors ok
[ 50.00%] ··· ======= ============ =============
-- d
------- --------------------------
n 10 100
======= ============ =============
100 664±20μs 1.78±0.01ms
1000 17.1±0.2ms 81.6±0.8ms
10000 781±4ms 2.89±0.01s
======= ============ =============

[ 50.00%] · For scikit-learn commit 1cd282d6

(round 1/1):
[ 50.00%] ·· Building for conda-py3.9-cython-joblib-numpy-scipy-threadpoolctl......
[ 50.00%] ·· Benchmarking conda-py3.9-cython-joblib-numpy-scipy-threadpoolctl
[ 58.33%] ··· 19418.KernelDensityRemovedCheck.time_query ok
[ 58.33%] ··· ======= ============= =============
-- d
------- ---------------------------
n 10 100
======= ============= =============
100 1.12±0.01ms 2.39±0.01ms
1000 100.0±0.2ms 265±0.4ms
10000 10.5±0.01s 22.4±0s
======= ============= =============

[ 66.67%] ··· 19418.MutualInfoCCRemovedCheck.time_mi_cc ok
[ 66.67%] ··· ======= ============ ============
-- d
------- -------------------------
n 10 100
======= ============ ============
100 14.4±0.2ms 134±0.2ms
1000 49.2±0.2ms 514±10ms
10000 606±30ms 5.24±0.01s
======= ============ ============

[ 75.00%] ··· 19418.MutualInfoCDRemovedCheck.time_mi_cd ok
[ 75.00%] ··· ======= ============= ============
-- d
------- --------------------------
n 10 100
======= ============= ============
100 19.7±0.07ms 189±0.2ms
1000 44.4±0.09ms 443±1ms
10000 438±6ms 4.08±0.01s
======= ============= ============

[ 83.33%] ··· 19418.NeighborsBaseCreationRemovedCheck.time_creation ok
[ 83.33%] ··· ======= ============= =============
-- d
------- ---------------------------
n 10 100
======= ============= =============
100 203±3μs 372±0.9μs
1000 975±4μs 9.88±0.01ms
10000 13.7±0.04ms 121±0.1ms
======= ============= =============

[ 91.67%] ··· 19418.NeighborsBaseQueryRemovedCheck.time_kneighbors ok
[ 91.67%] ··· ======= ============ ===========
-- d
------- ------------------------
n 10 100
======= ============ ===========
100 617±2μs 1.74±0ms
1000 22.2±0.1ms 206±0.3ms
10000 1.45±0.02s 16.9±0.4s
======= ============ ===========

[100.00%] ··· 19418.NeighborsBaseQueryRemovedCheck.time_radius_neighbors ok
[100.00%] ··· ======= ============ =============
-- d
------- --------------------------
n 10 100
======= ============ =============
100 674±2μs 1.77±0.02ms
1000 17.6±0.2ms 82.0±0.8ms
10000 830±20ms 2.88±0.01s
======= ============ =============

   before           after         ratio
 [1cd282d6]       [903125f2]
 <main>           <remove_validation_internal_trees>
  •   134±0.2ms          121±2ms     0.90  19418.MutualInfoCCRemovedCheck.time_mi_cc(100, 100)
    
  • 9.88±0.01ms         8.02±0ms     0.81  19418.NeighborsBaseCreationRemovedCheck.time_creation(1000, 100)
    

SOME BENCHMARKS HAVE CHANGED SIGNIFICANTLY.
PERFORMANCE INCREASED.

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for providing the benchmarking updates!

@jjerphan
Copy link
Member Author

I reran this on my machine for c882688 and I got that benchmarks haven't significantly changed.

Full report
 → asv continuous -b RemovedCheck main remove_validation_internal_trees
· Creating environments
· Discovering benchmarks
·· Uninstalling from conda-py3.9-cython-joblib-numpy-scipy-threadpoolctl
·· Installing c8826882 <remove_validation_internal_trees> into conda-py3.9-cython-joblib-numpy-scipy-threadpoolctl.
· Running 12 total benchmarks (2 commits * 1 environments * 6 benchmarks)
[  0.00%] · For scikit-learn commit c8826882 <remove_validation_internal_trees> (round 1/1):
[  0.00%] ·· Benchmarking conda-py3.9-cython-joblib-numpy-scipy-threadpoolctl
[  8.33%] ··· bench.KernelDensityRemovedCheck.time_query                                                                                                                                                                                                    ok
[  8.33%] ··· ======= ============ =============
              --                  d             
              ------- --------------------------
                 n         10           100     
              ======= ============ =============
                100     857±3μs     2.73±0.02ms 
                1000   77.3±0.3ms     273±1ms   
               10000   8.23±0.05s    32.0±0.1s  
              ======= ============ =============

[ 16.67%] ··· bench.MutualInfoCCRemovedCheck.time_mi_cc                                                                                                                                                                                                     ok
[ 16.67%] ··· ======= ============= ============
              --                  d             
              ------- --------------------------
                 n          10          100     
              ======= ============= ============
                100    11.6±0.08ms   110±0.6ms  
                1000    47.6±0.3ms    468±2ms   
               10000     550±4ms     5.42±0.07s 
              ======= ============= ============

[ 25.00%] ··· bench.MutualInfoCDRemovedCheck.time_mi_cd                                                                                                                                                                                                     ok
[ 25.00%] ··· ======= ============= ============
              --                  d             
              ------- --------------------------
                 n          10          100     
              ======= ============= ============
                100    16.7±0.08ms   166±0.6ms  
                1000    40.6±0.1ms    408±1ms   
               10000     387±1ms     3.87±0.01s 
              ======= ============= ============

[ 33.33%] ··· bench.NeighborsBaseCreationRemovedCheck.time_creation                                                                                                                                                                                         ok
[ 33.33%] ··· ======= ============= =============
              --                   d             
              ------- ---------------------------
                 n          10           100     
              ======= ============= =============
                100     149±0.6μs      301±1μs   
                1000     842±4μs     6.13±0.02ms 
               10000   10.7±0.04ms    95.1±0.4ms 
              ======= ============= =============

[ 41.67%] ··· bench.NeighborsBaseQueryRemovedCheck.time_kneighbors                                                                                                                                                                                          ok
[ 41.67%] ··· ======= ============ =============
              --                  d             
              ------- --------------------------
                 n         10           100     
              ======= ============ =============
                100     557±3μs     1.71±0.01ms 
                1000   22.8±0.1ms    184±0.9ms  
               10000   1.33±0.01s    24.0±0.03s 
              ======= ============ =============

[ 50.00%] ··· bench.NeighborsBaseQueryRemovedCheck.time_radius_neighbors                                                                                                                                                                                    ok
[ 50.00%] ··· ======= ============ =============
              --                  d             
              ------- --------------------------
                 n         10           100     
              ======= ============ =============
                100     629±4μs     2.16±0.02ms 
                1000   20.8±0.2ms     111±1ms   
               10000    875±6ms      4.64±0.04s 
              ======= ============ =============

[ 50.00%] · For scikit-learn commit b5e5db4a <main> (round 1/1):
[ 50.00%] ·· Building for conda-py3.9-cython-joblib-numpy-scipy-threadpoolctl..
[ 50.00%] ·· Benchmarking conda-py3.9-cython-joblib-numpy-scipy-threadpoolctl
[ 58.33%] ··· bench.KernelDensityRemovedCheck.time_query                                                                                                                                                                                                    ok
[ 58.33%] ··· ======= ============ =============
              --                  d             
              ------- --------------------------
                 n         10           100     
              ======= ============ =============
                100     861±3μs     2.72±0.01ms 
                1000   77.3±0.2ms    272±0.5ms  
               10000   8.19±0.02s    32.1±0.01s 
              ======= ============ =============

[ 66.67%] ··· bench.MutualInfoCCRemovedCheck.time_mi_cc                                                                                                                                                                                                     ok
[ 66.67%] ··· ======= ============= ============
              --                  d             
              ------- --------------------------
                 n          10          100     
              ======= ============= ============
                100    12.6±0.08ms   120±0.4ms  
                1000    48.8±0.2ms    480±2ms   
               10000     548±4ms     5.40±0.01s 
              ======= ============= ============

[ 75.00%] ··· bench.MutualInfoCDRemovedCheck.time_mi_cd                                                                                                                                                                                                     ok
[ 75.00%] ··· ======= ============ ============
              --                  d            
              ------- -------------------------
                 n         10          100     
              ======= ============ ============
                100    17.5±0.1ms   173±0.9ms  
                1000   41.0±0.2ms    409±2ms   
               10000    378±2ms     3.80±0.01s 
              ======= ============ ============

[ 83.33%] ··· bench.NeighborsBaseCreationRemovedCheck.time_creation                                                                                                                                                                                         ok
[ 83.33%] ··· ======= ============= =============
              --                   d             
              ------- ---------------------------
                 n          10           100     
              ======= ============= =============
                100      159±1μs       313±1μs   
                1000     861±3μs     6.21±0.01ms 
               10000   10.8±0.04ms    95.5±0.4ms 
              ======= ============= =============

[ 91.67%] ··· bench.NeighborsBaseQueryRemovedCheck.time_kneighbors                                                                                                                                                                                          ok
[ 91.67%] ··· ======= ============ =============
              --                  d             
              ------- --------------------------
                 n         10           100     
              ======= ============ =============
                100     574±4μs     1.74±0.01ms 
                1000   22.8±0.1ms     185±1ms   
               10000    1.33±0s      24.1±0.07s 
              ======= ============ =============

[100.00%] ··· bench.NeighborsBaseQueryRemovedCheck.time_radius_neighbors                                                                                                                                                                                    ok
[100.00%] ··· ======= ============ =============
              --                  d             
              ------- --------------------------
                 n         10           100     
              ======= ============ =============
                100     645±4μs     2.17±0.02ms 
                1000   20.4±0.2ms     109±1ms   
               10000    862±5ms      4.56±0.02s 
              ======= ============ =============


BENCHMARKS NOT SIGNIFICANTLY CHANGED.

@rth
Copy link
Member

rth commented Aug 6, 2021

If there is no measurable impact I would stay we keep the files unchanged and close this PR?

@jjerphan
Copy link
Member Author

jjerphan commented Aug 6, 2021

I profiled it using viztracer using the setup this blog post.

It looks like the overhead is negligible compared to the rest of the runtime.

Script for the other, non annotated checks
from sklearn.datasets import make_regression, make_classification
from sklearn.feature_selection import mutual_info_regression, mutual_info_classif
from sklearn.neighbors import NearestNeighbors, KernelDensity


def main(args=None):
    X_train, y_train = make_classification(n_samples=10_000, n_features=10)
    X_test, _ = make_classification(n_samples=10_000, n_features=10)

    mutual_info_regression(X_train, y_train, discrete_features=False)

    mutual_info_classif(X_train, y_train, discrete_features=False)

    kde = KernelDensity().fit(X_train, y_train)

    kde.score_samples(X_train)
        

if __name__ == "__main__":
    main()

Let's close.

@jjerphan jjerphan closed this Aug 6, 2021
Copy link
Member Author

@jjerphan jjerphan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I profiled it using viztracer using the setup this blog post.

It looks like the overhead is negligible compared to the rest of the runtime.

Script for the other, non annotated checks
from sklearn.datasets import make_regression, make_classification
from sklearn.feature_selection import mutual_info_regression, mutual_info_classif
from sklearn.neighbors import NearestNeighbors, KernelDensity


def main(args=None):
    X_train, y_train = make_classification(n_samples=10_000, n_features=10)
    X_test, _ = make_classification(n_samples=10_000, n_features=10)

    mutual_info_regression(X_train, y_train, discrete_features=False)

    mutual_info_classif(X_train, y_train, discrete_features=False)

    kde = KernelDensity().fit(X_train, y_train)

    kde.score_samples(X_train)
        

if __name__ == "__main__":
    main()

Let's close.

Comment on lines +779 to +783
with config_context(assume_finite=True):
# We remove the validation of the query points
# (in *parallel_kwargs) done at the beginning of
# BinaryTree.query as those points already got
# validated in the caller.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checks running time (represented by the small spike at the beginning) here is negligible.
Screenshot 2021-08-06 at 18-05-07 Screenshot

Script
# profile.py
from sklearn.datasets import make_regression
from sklearn.neighbors import NearestNeighbors


def main(args=None):
    X_train, _ = make_regression(n_samples=100_000, n_features=10)
    X_test, _ = make_regression(n_samples=100_000, n_features=10)

    nn = NearestNeighbors(algorithm='kd_tree').fit(X_train)
    nn.kneighbors(X_test, n_neighbors=2)


if __name__ == "__main__":
    main()
giltracer --state-detect profile.py

Comment on lines +1122 to +1126
with config_context(assume_finite=True):
# We remove the validation of the query points
# (in *parallel_kwargs) done at the beginning of
# BinaryTree.query_radius as those points already
# got validated in the caller.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checks running time here is negligible here similarly.

Script
# profile.py
from sklearn.datasets import make_regression
from sklearn.neighbors import NearestNeighbors


def main(args=None):
    X_train, _ = make_regression(n_samples=100_000, n_features=10)
    X_test, _ = make_regression(n_samples=100_000, n_features=10)

    nn = NearestNeighbors(algorithm='kd_tree').fit(X_train)
    nn.kneighbors(X_test, n_neighbors=2)


if __name__ == "__main__":
    main()
giltracer --state-detect profile.py

Comment on lines +546 to +549
with config_context(assume_finite=True):
# In the following cases, we remove the validation of X done at
# the beginning of the BinaryTree's constructors as X already got
# validated when calling this method, NeighborsBase._fit.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checks running time (represented by the small spike at the beginning) is negligible here.
Screenshot 2021-08-06 at 18-04-18 Screenshot

Script
# profile.py
from sklearn.datasets import make_regression
from sklearn.neighbors import NearestNeighbors


def main(args=None):
    X_train, _ = make_regression(n_samples=100_000, n_features=10)
    X_test, _ = make_regression(n_samples=100_000, n_features=10)

    nn = NearestNeighbors(algorithm='kd_tree').fit(X_train)
    nn.kneighbors(X_test, n_neighbors=2)


if __name__ == "__main__":
    main()
giltracer --state-detect profile.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Investigate and fix performances in *_tree dependent algorithms due to multiple input validation.
3 participants