Skip to content

ENH Speedup confusion_matrix #9843

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 11, 2021

Conversation

Erotemic
Copy link
Contributor

@Erotemic Erotemic commented Sep 27, 2017

This PR is a speedup to the confusion_matrix function in the case where labels is specified and already in index form. I added check to bypass expensive label to index conversion, which results in a nice 16x speedup.

I've labeled this as WIP because I've only tested the case where len(labels) = 12 and len(y_true) = len(y_pred) = 172800. I need to do is more comprehensive testing to ensure that I didn't slow down other cases for the sake of one particular case. I believe this will be an overall improvement, but I want to make sure.

The one benchmark I've done so far (on the aforementioned data) resulted in a reduction of compute time from 90.0ms to 6.0ms. This is a 15x increase. When computing several hundred confusion matrices, this becomes quite significant.

The running time can be further improved to a 36x increase if we added an extra flag (called enable_checks=True) to allow the user to disable the _check_targets and check_consistent_length call when appropriate. I didn't add this to the PR by default because I thought there might be some pushback on adding a argument to a function signature. However, if a reviewer thinks this is ok, let me know and I'll add it to get some extra speed.

TODO

  • baseline proof-of-concept
  • is adding an extra flag ok for an extra x2 speedup? (lets just keep this PR simple)
  • benchmark: test speed differences on arrays that satisfy the new check condition with different numbers of labels / items (with different data types). Ensure there is now significant slowdown, and find the point where the speedup becomes significant.
  • benchmark: test speed differences on arrays that do not satisfy the new check condition with different numbers of labels / items (with different data types). Ensure there is not a significant slowdown.
  • associated what's new / documentation changes if necessary

@Erotemic Erotemic changed the title [WIP] Added check to bypass expensive label to index conversion [WIP] Speedup confusion_matrix by x16 Sep 27, 2017
@Erotemic Erotemic changed the title [WIP] Speedup confusion_matrix by x16 [WIP] speedup confusion_matrix Sep 27, 2017
Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks worthwhile. Efficient mapping to ints is something pandas does well... Sometimes I wish we could exploit that.

In terms of check_targets, I've wondered about ways to speed up type_of_target to make it linear at worst, and maybe even approximate from a sample when trying to identify number of labels (risky).

A function that did like np.unique but special-cased binary data might already be a big improvement.

@Erotemic
Copy link
Contributor Author

I'm not sure exactly how pandas does its mapping.

Here is a script I used to get timings: https://gist.github.com/Erotemic/f4d6005f0c1a472997426fae3969840f

I tested over the following grid

    basis1 = {
        'n_input': [10, 1000, 10000, 200000],
        'n_classes': [2, 10, 100, 1000],
        'dtype': ['uint8', 'int64'],
    }

    basis2 = {
        'labelkw': [None, 'int-consec', 'int-nonconsec', 'str'],
        'weightkw': (None, 'float'),
    }

the n_input dictates the size of y_pred and y_true, dtype is its data type. n_classes is the number of possible labels.

labelkw indicates a type of labels, None just uses the default labels=None. Next, int-consec provides labels as consecutive integers starting with zero (this is the main use-case this speedup is targeted towards), finally int-nonconsec, and str provide labels that don't satisfy the conditions needed to use avoid label re-mapping.

weightkw is defined similarly for the sample_weights, I didn't report many configurations of this to
avoid combinatoric blowup, but I haven't seen it have other values of it much of an effect in any of my experiments. It has little influence over the modified code anyway, I just included it for completeness.

And here are the results:

We can see > 10x speedups on certain cases when int-consec is satisfied. The greatest speedup is achieved when the number of classes is relatively small and the number of inputs is large. When the number of classes increases, the speedup tends to decrease back towards a 1:1 ratio.

While integer inputs with a specified labels achieves the greatest speedup, the common case where y_pred and y_true are given in normal integer format with specifying labels also achieves a 2-4x speedup.

When the check conditions are not satisfied, or the arrays are small, the slowdown is never worse than ~.9 factor, which a large portion of can be attributed to time-measurement inaccuracies.

-------
(('weightkw', None), ('labelkw', None))
                                                 v2        v1   speedup
dtype=uint8,n_input=10,n_classes=2         0.000356  0.000390  1.094441
dtype=uint8,n_input=10,n_classes=10        0.000415  0.000508  1.224009
dtype=uint8,n_input=10,n_classes=100       0.000402  0.000396  0.985172
dtype=uint8,n_input=10,n_classes=1000      0.000357  0.000337  0.942590
dtype=uint8,n_input=1000,n_classes=2       0.000463  0.000948  2.049485
dtype=uint8,n_input=1000,n_classes=10      0.000478  0.000934  1.952642
dtype=uint8,n_input=1000,n_classes=100     0.000605  0.001745  2.882631
dtype=uint8,n_input=1000,n_classes=1000    0.001876  0.001187  0.632736
dtype=uint8,n_input=10000,n_classes=2      0.001450  0.006236  4.300723
dtype=uint8,n_input=10000,n_classes=10     0.001647  0.006478  3.932127
dtype=uint8,n_input=10000,n_classes=100    0.002211  0.007076  3.200237
dtype=uint8,n_input=10000,n_classes=1000   0.002522  0.008753  3.470694
dtype=uint8,n_input=200000,n_classes=2     0.024391  0.116775  4.787687
dtype=uint8,n_input=200000,n_classes=10    0.026743  0.122020  4.562601
dtype=uint8,n_input=200000,n_classes=100   0.038025  0.133367  3.507361
dtype=uint8,n_input=200000,n_classes=1000  0.041600  0.160080  3.848077
dtype=int64,n_input=10,n_classes=2         0.000346  0.000334  0.964828
dtype=int64,n_input=10,n_classes=10        0.000348  0.000321  0.923288
dtype=int64,n_input=10,n_classes=100       0.000341  0.000327  0.960812
dtype=int64,n_input=10,n_classes=1000      0.000340  0.000328  0.965614
dtype=int64,n_input=1000,n_classes=2       0.000451  0.001059  2.350265
dtype=int64,n_input=1000,n_classes=10      0.000464  0.001065  2.292094
dtype=int64,n_input=1000,n_classes=100     0.000575  0.001192  2.071695
dtype=int64,n_input=1000,n_classes=1000    0.002033  0.001997  0.982178
dtype=int64,n_input=10000,n_classes=2      0.001486  0.007541  5.074603
dtype=int64,n_input=10000,n_classes=10     0.001628  0.007613  4.676479
dtype=int64,n_input=10000,n_classes=100    0.002272  0.008311  3.657329
dtype=int64,n_input=10000,n_classes=1000   0.003875  0.010727  2.768459
dtype=int64,n_input=200000,n_classes=2     0.026398  0.144496  5.473749
dtype=int64,n_input=200000,n_classes=10    0.026786  0.144446  5.392667
dtype=int64,n_input=200000,n_classes=100   0.038138  0.158187  4.147780
dtype=int64,n_input=200000,n_classes=1000  0.051936  0.178297  3.433009

-------
(('weightkw', None), ('labelkw', 'int-consec'))
                                                 v2        v1    speedup
dtype=uint8,n_input=10,n_classes=2         0.000279  0.000265   0.952096
dtype=uint8,n_input=10,n_classes=10        0.000280  0.000266   0.950638
dtype=uint8,n_input=10,n_classes=100       0.000493  0.000495   1.003868
dtype=uint8,n_input=10,n_classes=1000      0.003311  0.003505   1.058392
dtype=uint8,n_input=1000,n_classes=2       0.000353  0.000976   2.762483
dtype=uint8,n_input=1000,n_classes=10      0.000359  0.000969   2.700997
dtype=uint8,n_input=1000,n_classes=100     0.000724  0.001354   1.868664
dtype=uint8,n_input=1000,n_classes=1000    0.006196  0.006965   1.124216
dtype=uint8,n_input=10000,n_classes=2      0.000924  0.006798   7.357677
dtype=uint8,n_input=10000,n_classes=10     0.000843  0.006754   8.014144
dtype=uint8,n_input=10000,n_classes=100    0.001817  0.007756   4.269852
dtype=uint8,n_input=10000,n_classes=1000   0.027496  0.033706   1.225847
dtype=uint8,n_input=200000,n_classes=2     0.014402  0.133362   9.260177
dtype=uint8,n_input=200000,n_classes=10    0.011453  0.132465  11.565603
dtype=uint8,n_input=200000,n_classes=100   0.025776  0.145841   5.658031
dtype=uint8,n_input=200000,n_classes=1000  0.482134  0.605294   1.255448
dtype=int64,n_input=10,n_classes=2         0.000286  0.000257   0.898333
dtype=int64,n_input=10,n_classes=10        0.000274  0.000265   0.968668
dtype=int64,n_input=10,n_classes=100       0.000492  0.000498   1.012112
dtype=int64,n_input=10,n_classes=1000      0.003166  0.003300   1.042401
dtype=int64,n_input=1000,n_classes=2       0.000341  0.000924   2.710987
dtype=int64,n_input=1000,n_classes=10      0.000343  0.000927   2.706333
dtype=int64,n_input=1000,n_classes=100     0.000684  0.001290   1.886641
dtype=int64,n_input=1000,n_classes=1000    0.005632  0.006474   1.149382
dtype=int64,n_input=10000,n_classes=2      0.000958  0.006848   7.144776
dtype=int64,n_input=10000,n_classes=10     0.000858  0.006713   7.820833
dtype=int64,n_input=10000,n_classes=100    0.001809  0.007960   4.399262
dtype=int64,n_input=10000,n_classes=1000   0.012721  0.019166   1.506682
dtype=int64,n_input=200000,n_classes=2     0.014351  0.133434   9.297638
dtype=int64,n_input=200000,n_classes=10    0.011567  0.129226  11.171840
dtype=int64,n_input=200000,n_classes=100   0.025129  0.142739   5.680215
dtype=int64,n_input=200000,n_classes=1000  0.137428  0.261217   1.900759

-------
(('weightkw', None), ('labelkw', 'int-nonconsec'))
                                                 v2        v1   speedup
dtype=uint8,n_input=10,n_classes=2         0.000288  0.000260  0.900826
dtype=uint8,n_input=10,n_classes=10        0.000295  0.000264  0.895631
dtype=uint8,n_input=10,n_classes=100       0.000530  0.000500  0.942870
dtype=uint8,n_input=10,n_classes=1000      0.003524  0.003502  0.993978
dtype=uint8,n_input=1000,n_classes=2       0.000999  0.000989  0.989974
dtype=uint8,n_input=1000,n_classes=10      0.001006  0.000982  0.976766
dtype=uint8,n_input=1000,n_classes=100     0.001353  0.001415  1.045647
dtype=uint8,n_input=1000,n_classes=1000    0.007011  0.006872  0.980173
dtype=uint8,n_input=10000,n_classes=2      0.006899  0.006869  0.995646
dtype=uint8,n_input=10000,n_classes=10     0.006705  0.006732  1.004053
dtype=uint8,n_input=10000,n_classes=100    0.007770  0.007725  0.994231
dtype=uint8,n_input=10000,n_classes=1000   0.034355  0.034284  0.997925
dtype=uint8,n_input=200000,n_classes=2     0.132574  0.133330  1.005708
dtype=uint8,n_input=200000,n_classes=10    0.128076  0.129445  1.010691
dtype=uint8,n_input=200000,n_classes=100   0.144267  0.145128  1.005969
dtype=uint8,n_input=200000,n_classes=1000  0.610469  0.619754  1.015209
dtype=int64,n_input=10,n_classes=2         0.000297  0.000266  0.894779
dtype=int64,n_input=10,n_classes=10        0.000304  0.000300  0.985110
dtype=int64,n_input=10,n_classes=100       0.000560  0.000527  0.940400
dtype=int64,n_input=10,n_classes=1000      0.003448  0.003413  0.989629
dtype=int64,n_input=1000,n_classes=2       0.000974  0.000950  0.976004
dtype=int64,n_input=1000,n_classes=10      0.000950  0.000926  0.974912
dtype=int64,n_input=1000,n_classes=100     0.001317  0.001291  0.980087
dtype=int64,n_input=1000,n_classes=1000    0.006543  0.006444  0.984914
dtype=int64,n_input=10000,n_classes=2      0.006840  0.007099  1.037782
dtype=int64,n_input=10000,n_classes=10     0.006926  0.006887  0.994423
dtype=int64,n_input=10000,n_classes=100    0.007938  0.007927  0.998678
dtype=int64,n_input=10000,n_classes=1000   0.019603  0.019643  1.002055
dtype=int64,n_input=200000,n_classes=2     0.135128  0.133110  0.985068
dtype=int64,n_input=200000,n_classes=10    0.131158  0.133116  1.014933
dtype=int64,n_input=200000,n_classes=100   0.146218  0.146753  1.003659
dtype=int64,n_input=200000,n_classes=1000  0.271275  0.274917  1.013426

-------
(('weightkw', None), ('labelkw', 'str'))
                                                 v2        v1   speedup
dtype=uint8,n_input=10,n_classes=2         0.000287  0.000279  0.970930
dtype=uint8,n_input=10,n_classes=10        0.000281  0.000277  0.983885
dtype=uint8,n_input=10,n_classes=100       0.000489  0.000487  0.995612
dtype=uint8,n_input=10,n_classes=1000      0.003150  0.003116  0.989027
dtype=uint8,n_input=1000,n_classes=2       0.001443  0.001503  1.041467
dtype=uint8,n_input=1000,n_classes=10      0.001517  0.001523  1.003930
dtype=uint8,n_input=1000,n_classes=100     0.002575  0.002537  0.985372
dtype=uint8,n_input=1000,n_classes=1000    0.014328  0.014250  0.994559
dtype=uint8,n_input=10000,n_classes=2      0.011665  0.011521  0.987634
dtype=uint8,n_input=10000,n_classes=10     0.011535  0.011918  1.033194
dtype=uint8,n_input=10000,n_classes=100    0.020836  0.020827  0.999565
dtype=uint8,n_input=10000,n_classes=1000   0.118630  0.120804  1.018325
dtype=uint8,n_input=200000,n_classes=2     0.234595  0.230784  0.983754
dtype=uint8,n_input=200000,n_classes=10    0.225771  0.227541  1.007840
dtype=uint8,n_input=200000,n_classes=100   0.401081  0.408385  1.018211
dtype=uint8,n_input=200000,n_classes=1000  2.316457  2.337287  1.008992
dtype=int64,n_input=10,n_classes=2         0.000294  0.000286  0.974838
dtype=int64,n_input=10,n_classes=10        0.000291  0.000294  1.008183
dtype=int64,n_input=10,n_classes=100       0.000503  0.000493  0.978683
dtype=int64,n_input=10,n_classes=1000      0.003090  0.003484  1.127536
dtype=int64,n_input=1000,n_classes=2       0.001447  0.001503  1.039051
dtype=int64,n_input=1000,n_classes=10      0.001484  0.001492  1.004979
dtype=int64,n_input=1000,n_classes=100     0.002620  0.002817  1.075446
dtype=int64,n_input=1000,n_classes=1000    0.015361  0.014929  0.971830
dtype=int64,n_input=10000,n_classes=2      0.011936  0.011799  0.988494
dtype=int64,n_input=10000,n_classes=10     0.012654  0.012983  1.026058
dtype=int64,n_input=10000,n_classes=100    0.026321  0.020844  0.791924
dtype=int64,n_input=10000,n_classes=1000   0.105764  0.103030  0.974155
dtype=int64,n_input=200000,n_classes=2     0.230603  0.235212  1.019986
dtype=int64,n_input=200000,n_classes=10    0.233359  0.237916  1.019529
dtype=int64,n_input=200000,n_classes=100   0.426613  0.426941  1.000768
dtype=int64,n_input=200000,n_classes=1000  1.987660  2.002712  1.007573

-------
(('weightkw', 'float'), ('labelkw', None))
                                                 v2        v1   speedup
dtype=uint8,n_input=10,n_classes=2         0.000369  0.000329  0.893273
dtype=uint8,n_input=10,n_classes=10        0.000346  0.000311  0.900000
dtype=uint8,n_input=10,n_classes=100       0.000352  0.000319  0.907797
dtype=uint8,n_input=10,n_classes=1000      0.000354  0.000328  0.924630
dtype=uint8,n_input=1000,n_classes=2       0.000464  0.000953  2.055527
dtype=uint8,n_input=1000,n_classes=10      0.000462  0.000929  2.009283
dtype=uint8,n_input=1000,n_classes=100     0.000587  0.001075  1.832588
dtype=uint8,n_input=1000,n_classes=1000    0.000695  0.001348  1.939602
dtype=uint8,n_input=10000,n_classes=2      0.001429  0.006030  4.218849
dtype=uint8,n_input=10000,n_classes=10     0.001570  0.006187  3.940623
dtype=uint8,n_input=10000,n_classes=100    0.002171  0.006845  3.152756
dtype=uint8,n_input=10000,n_classes=1000   0.002447  0.008451  3.454297
dtype=uint8,n_input=200000,n_classes=2     0.023974  0.121703  5.076527
dtype=uint8,n_input=200000,n_classes=10    0.028682  0.126143  4.398005
dtype=uint8,n_input=200000,n_classes=100   0.037272  0.131645  3.531999
dtype=uint8,n_input=200000,n_classes=1000  0.041043  0.161512  3.935235
dtype=int64,n_input=10,n_classes=2         0.000336  0.000314  0.933996
dtype=int64,n_input=10,n_classes=10        0.000335  0.000303  0.903983
dtype=int64,n_input=10,n_classes=100       0.000323  0.000311  0.963072
dtype=int64,n_input=10,n_classes=1000      0.000325  0.000314  0.966226
dtype=int64,n_input=1000,n_classes=2       0.000434  0.001021  2.353491
dtype=int64,n_input=1000,n_classes=10      0.000446  0.001031  2.312133
dtype=int64,n_input=1000,n_classes=100     0.000554  0.001157  2.088640
dtype=int64,n_input=1000,n_classes=1000    0.002076  0.002008  0.967160
dtype=int64,n_input=10000,n_classes=2      0.001611  0.007474  4.639485
dtype=int64,n_input=10000,n_classes=10     0.001580  0.007472  4.728425
dtype=int64,n_input=10000,n_classes=100    0.002220  0.008608  3.877981
dtype=int64,n_input=10000,n_classes=1000   0.003958  0.010562  2.668634
dtype=int64,n_input=200000,n_classes=2     0.025976  0.146988  5.658626
dtype=int64,n_input=200000,n_classes=10    0.026449  0.147190  5.564951
dtype=int64,n_input=200000,n_classes=100   0.037667  0.159755  4.241254
dtype=int64,n_input=200000,n_classes=1000  0.051015  0.183432  3.595628

-------
(('weightkw', 'float'), ('labelkw', 'int-consec'))
                                                 v2        v1    speedup
dtype=uint8,n_input=10,n_classes=2         0.000273  0.000269   0.984279
dtype=uint8,n_input=10,n_classes=10        0.000284  0.000274   0.965575
dtype=uint8,n_input=10,n_classes=100       0.000488  0.000496   1.016113
dtype=uint8,n_input=10,n_classes=1000      0.003274  0.003428   1.046967
dtype=uint8,n_input=1000,n_classes=2       0.000333  0.000940   2.822477
dtype=uint8,n_input=1000,n_classes=10      0.000335  0.000959   2.864672
dtype=uint8,n_input=1000,n_classes=100     0.000694  0.001288   1.857045
dtype=uint8,n_input=1000,n_classes=1000    0.006042  0.006722   1.112497
dtype=uint8,n_input=10000,n_classes=2      0.000894  0.006819   7.631003
dtype=uint8,n_input=10000,n_classes=10     0.000824  0.006673   8.096037
dtype=uint8,n_input=10000,n_classes=100    0.001801  0.007721   4.285866
dtype=uint8,n_input=10000,n_classes=1000   0.027313  0.033189   1.215127
dtype=uint8,n_input=200000,n_classes=2     0.014398  0.134802   9.362771
dtype=uint8,n_input=200000,n_classes=10    0.011958  0.130961  10.951889
dtype=uint8,n_input=200000,n_classes=100   0.025424  0.143794   5.655786
dtype=uint8,n_input=200000,n_classes=1000  0.472963  0.596848   1.261933
dtype=int64,n_input=10,n_classes=2         0.000279  0.000267   0.955594
dtype=int64,n_input=10,n_classes=10        0.000287  0.000262   0.914309
dtype=int64,n_input=10,n_classes=100       0.000486  0.000525   1.079412
dtype=int64,n_input=10,n_classes=1000      0.003190  0.003453   1.082667
dtype=int64,n_input=1000,n_classes=2       0.000388  0.000969   2.494168
dtype=int64,n_input=1000,n_classes=10      0.000356  0.000921   2.583946
dtype=int64,n_input=1000,n_classes=100     0.000763  0.001399   1.832605
dtype=int64,n_input=1000,n_classes=1000    0.005826  0.006512   1.117731
dtype=int64,n_input=10000,n_classes=2      0.000895  0.006837   7.638785
dtype=int64,n_input=10000,n_classes=10     0.000831  0.006705   8.071757
dtype=int64,n_input=10000,n_classes=100    0.001804  0.007694   4.264006
dtype=int64,n_input=10000,n_classes=1000   0.012456  0.019045   1.529009
dtype=int64,n_input=200000,n_classes=2     0.014494  0.134389   9.271906
dtype=int64,n_input=200000,n_classes=10    0.011565  0.130495  11.283903
dtype=int64,n_input=200000,n_classes=100   0.026274  0.151295   5.758298
dtype=int64,n_input=200000,n_classes=1000  0.142133  0.267350   1.880991

-------
(('weightkw', 'float'), ('labelkw', 'int-nonconsec'))
                                                 v2        v1   speedup
dtype=uint8,n_input=10,n_classes=2         0.000291  0.000259  0.890074
dtype=uint8,n_input=10,n_classes=10        0.000288  0.000258  0.896523
dtype=uint8,n_input=10,n_classes=100       0.000521  0.000488  0.936871
dtype=uint8,n_input=10,n_classes=1000      0.003398  0.003444  1.013330
dtype=uint8,n_input=1000,n_classes=2       0.000991  0.000941  0.950181
dtype=uint8,n_input=1000,n_classes=10      0.000952  0.000925  0.972187
dtype=uint8,n_input=1000,n_classes=100     0.001305  0.001281  0.982273
dtype=uint8,n_input=1000,n_classes=1000    0.006898  0.006804  0.986485
dtype=uint8,n_input=10000,n_classes=2      0.006831  0.006817  0.998045
dtype=uint8,n_input=10000,n_classes=10     0.006815  0.006732  0.987826
dtype=uint8,n_input=10000,n_classes=100    0.007706  0.007681  0.996813
dtype=uint8,n_input=10000,n_classes=1000   0.033726  0.033710  0.999512
dtype=uint8,n_input=200000,n_classes=2     0.138235  0.136123  0.984721
dtype=uint8,n_input=200000,n_classes=10    0.132586  0.132749  1.001235
dtype=uint8,n_input=200000,n_classes=100   0.142736  0.143708  1.006805
dtype=uint8,n_input=200000,n_classes=1000  0.602438  0.603199  1.001262
dtype=int64,n_input=10,n_classes=2         0.000282  0.000252  0.894157
dtype=int64,n_input=10,n_classes=10        0.000290  0.000260  0.897119
dtype=int64,n_input=10,n_classes=100       0.000520  0.000487  0.935868
dtype=int64,n_input=10,n_classes=1000      0.003401  0.003377  0.992780
dtype=int64,n_input=1000,n_classes=2       0.000947  0.000918  0.968789
dtype=int64,n_input=1000,n_classes=10      0.000947  0.000969  1.022899
dtype=int64,n_input=1000,n_classes=100     0.001333  0.001282  0.961724
dtype=int64,n_input=1000,n_classes=1000    0.006523  0.006410  0.982711
dtype=int64,n_input=10000,n_classes=2      0.006819  0.006831  1.001713
dtype=int64,n_input=10000,n_classes=10     0.007017  0.006714  0.956847
dtype=int64,n_input=10000,n_classes=100    0.007693  0.007715  1.002882
dtype=int64,n_input=10000,n_classes=1000   0.019556  0.019140  0.978702
dtype=int64,n_input=200000,n_classes=2     0.136135  0.135994  0.998967
dtype=int64,n_input=200000,n_classes=10    0.130219  0.132589  1.018201
dtype=int64,n_input=200000,n_classes=100   0.143348  0.145672  1.016208
dtype=int64,n_input=200000,n_classes=1000  0.272905  0.266272  0.975695

-------
(('weightkw', 'float'), ('labelkw', 'str'))
                                                 v2        v1   speedup
dtype=uint8,n_input=10,n_classes=2         0.000267  0.000264  0.990161
dtype=uint8,n_input=10,n_classes=10        0.000276  0.000262  0.951557
dtype=uint8,n_input=10,n_classes=100       0.000468  0.000464  0.991849
dtype=uint8,n_input=10,n_classes=1000      0.003014  0.003025  1.003718
dtype=uint8,n_input=1000,n_classes=2       0.001404  0.001431  1.019022
dtype=uint8,n_input=1000,n_classes=10      0.001402  0.001392  0.993028
dtype=uint8,n_input=1000,n_classes=100     0.002476  0.002498  1.009150
dtype=uint8,n_input=1000,n_classes=1000    0.014284  0.014313  1.002070
dtype=uint8,n_input=10000,n_classes=2      0.011901  0.011628  0.977101
dtype=uint8,n_input=10000,n_classes=10     0.011538  0.011512  0.997706
dtype=uint8,n_input=10000,n_classes=100    0.020172  0.020309  1.006784
dtype=uint8,n_input=10000,n_classes=1000   0.117401  0.118236  1.007114
dtype=uint8,n_input=200000,n_classes=2     0.226870  0.232280  1.023845
dtype=uint8,n_input=200000,n_classes=10    0.235242  0.239202  1.016834
dtype=uint8,n_input=200000,n_classes=100   0.421127  0.409173  0.971614
dtype=uint8,n_input=200000,n_classes=1000  2.357309  2.323732  0.985756
dtype=int64,n_input=10,n_classes=2         0.000265  0.000258  0.974797
dtype=int64,n_input=10,n_classes=10        0.000268  0.000265  0.989305
dtype=int64,n_input=10,n_classes=100       0.000464  0.000463  0.996917
dtype=int64,n_input=10,n_classes=1000      0.002620  0.002903  1.107634
dtype=int64,n_input=1000,n_classes=2       0.001360  0.001370  1.007365
dtype=int64,n_input=1000,n_classes=10      0.001436  0.001411  0.982896
dtype=int64,n_input=1000,n_classes=100     0.002463  0.002455  0.996805
dtype=int64,n_input=1000,n_classes=1000    0.014024  0.013908  0.991788
dtype=int64,n_input=10000,n_classes=2      0.011310  0.011375  1.005797
dtype=int64,n_input=10000,n_classes=10     0.011390  0.011333  0.994997
dtype=int64,n_input=10000,n_classes=100    0.020109  0.020106  0.999822
dtype=int64,n_input=10000,n_classes=1000   0.101039  0.101266  1.002246
dtype=int64,n_input=200000,n_classes=2     0.230488  0.227511  0.987084
dtype=int64,n_input=200000,n_classes=10    0.225168  0.225930  1.003383
dtype=int64,n_input=200000,n_classes=100   0.399346  0.406895  1.018904
dtype=int64,n_input=200000,n_classes=1000  1.938386  1.936220  0.998882

@jnothman
Copy link
Member

I think a key idea in Pandas is to abstract the mapping behind a polymorphic OO interface, such that if the data is numeric and canonical, it can handle that in one way, and if strings, it can handle that another way. And then to use low-level implementations wherever it helps.

@Erotemic Erotemic changed the title [WIP] speedup confusion_matrix [MRG] speedup confusion_matrix Sep 28, 2017
@Erotemic
Copy link
Contributor Author

It would be interesting to dig into their implementation to see how they do it, but unfortunately I'm not going to be able to allocate any time for that.

However, I think this PR is pretty much complete. The speedup seems to apply to many common cases and the overhead for the other cases is minimal. There may be a small additional speedup to be gained by optionally disabling checks, but this is good enough for now.

I've added a what's new blurb, and I don't think there is any new documentation needed. I noticed that other what's new entries referenced and issue. Should I make an issue for this PR to address (/should I reference the PR number)?

@jnothman
Copy link
Member

I realise it's unclear because of the :issue: macro, but what's new should ideally reference the PR not the issue, or multiple if need be.

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other than wondering if we could/should share this kind of optimisation with other code, this LGTM.

sample_weight = sample_weight[ind]
# If labels are not consecitive integers starting from zero, then
# yt, yp must be converted into index form
need_index_conversion = not (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Surely np.all(labels == np.arange(len(labels)) is just as fast for a reasonable number of classes and much more readable?

Copy link
Contributor Author

@Erotemic Erotemic Sep 30, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is both faster and more readable, when I originally wrote this I didn't consider that the ordering of the labels was significant, so the np.diff was a quick fix after realizing this. Your solution is simpler and better.

In [1]: labels = np.arange(100)

In [2]: %timeit labels.min() == 0 and np.all(np.diff(labels) == 1) 
8.3 µs ± 162 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

In [3]: %timeit np.all(np.arange(len(labels)) == labels)
3.6 µs ± 47.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

@Erotemic
Copy link
Contributor Author

Ah, that is confusing. I added the PR number.

@lesteve
Copy link
Member

lesteve commented Sep 30, 2017

Quickly browsing through your extensive benchmarks (thanks a lot for these!) it seems to me that when this PR improves things, v1 (which I am guessing is master) confusion_matrix takes no more than a few 10-100ms.

Very naively I just think: why do we care? Maybe you can elaborate on when you would want to compute hundreds of confusion_matrix as you mention in your first post? Is confusion_matrix really the bottleneck in your use case (I would think fitting and predicting is more likely to be)?

@Erotemic
Copy link
Contributor Author

So, the speed increase does seem to be linear in terms of the number items being classified. It is true, the speed is overshadowed when the number of classes is large, but in modern deep learning tasks, the number of items greatly outnumbers the number of classes.

I noticed the slowness of this function when running evaluation scripts on a large number of pre-classified (for a semantic segmentation task) images. The learning and prediction was already precomputed. In the larger scheme of things the learning and prediction is the bottleneck, but when simply running the evaluation scripts, the confusion matrix was the bottleneck. Because amount of data is extremely large, making this change caused the running time of my scripts to improve from minutes to seconds. I think this instance justifies this sort of optimization.

@lesteve
Copy link
Member

lesteve commented Oct 2, 2017

Your use case seems completely reasonable, thanks for the details!

@cmarmo
Copy link
Contributor

cmarmo commented Dec 15, 2020

Hi @Erotemic many thanks for your patience! If you are still interested in working on that do you mind synchronizing with upstream? This will help in giving more visibility again to your PR. Thanks!

@Erotemic
Copy link
Contributor Author

Wow, gotta brush the dust off this one. But yeah, I can do that.

@Erotemic Erotemic force-pushed the speedup_confusion_matrix branch from dfc9332 to 2372c75 Compare December 15, 2020 20:09
@Erotemic
Copy link
Contributor Author

Erotemic commented Dec 15, 2020

@cmarmo I rebased this on master.

@cmarmo
Copy link
Contributor

cmarmo commented Dec 15, 2020

@jnothman @lesteve , sorry for bothering, I was under the impression that this PR was worth a review, and you already commented on it... Thanks!

@cmarmo
Copy link
Contributor

cmarmo commented Dec 15, 2020

@Erotemic thanks! I believe that an entry in the what's new file was also planned in the PR description? scikit-learn is heading for 1.0 now! :)

@Erotemic
Copy link
Contributor Author

@cmarmo I re-added it. Let me know if I put it in wrong place / did the formatting wrong.

Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Only a couple of nitpicks.

# convert yt, yp into index
y_pred = np.array([label_to_ind.get(x, n_labels + 1) for x in y_pred])
y_true = np.array([label_to_ind.get(x, n_labels + 1) for x in y_true])
# If labels are not consecutive integers starting from zero, then
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are improving confusion_matrix, could you use sklearn.utils.validation.check_sample_weight in the line above (312-315)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, it looks like check_sample_weight might hurt efficiency in the case where sample weight is None. Using that code will force us to use the same dtype in both (1) the case where sample_weight is None and (2) the case where sample_weight is specified. And if we are forced to specify a dtype, then it would have to be float64. At least IMO I'd prefer the ability to have an integral confusion matrix, especially in the case where there the samples are not weighted.

I haven't benchmarked yet, but adding this might work against the purpose of this PR. (In general my opinion is that sklearn does far too many of these checks in a way that drastically reduces performance and does not offer a nice way to disable them).

Its possible that just forcing float64 isn't a terrible thing to do, but I'll need to benchmark it. As modifying this line is not a clear improvement, I would prefer if we could merge this (its been >3 years) and look at this issue in a separate PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@glemaitre So this change might have merit. I was worried that using float64 would impact how fast it took to build the coo_matrix, but it doesn't seem to:

    import kwplot
    import seaborn as sns
    kwplot.autoplt()
    sns.set()
    ax = sns.lineplot(data=df, x='n', y='time', hue='label')
    ax.set_yscale('log')
    ax.set_xscale('log')

    from sklearn.utils.validation import _check_sample_weight
    import numpy as np
    results = []
    ns = np.logspace(1, 6, 100).astype(np.int)
    for n in ub.ProgIter(ns, desc='time-tradeoff', verbose=3):
        print('n = {!r}'.format(n))

        n_labels = 100
        y_true = np.random.randint(0, n_labels, n).astype(np.int64)
        y_pred = np.random.randint(0, n_labels, n).astype(np.int64)

        sample_weight = np.ones(y_true.shape[0], dtype=np.int64)

        for timer in ti.reset('use-old-uint8-sample-weight-default'):
            with timer:
                if sample_weight.dtype.kind in {'i', 'u', 'b'}:
                    dtype = np.int64
                else:
                    dtype = np.float64
                cm = coo_matrix((sample_weight, (y_true, y_pred)),
                                shape=(n_labels, n_labels), dtype=dtype,
                                ).toarray()
        results.append({
            'n': n,
            'label': ti.label,
            'time': ti.mean(),
        })

        sample_weight = _check_sample_weight(None, y_true, dtype=np.int64)
        for timer in ti.reset('use-new-float64-sample-weight-default'):
            with timer:
                if sample_weight.dtype.kind in {'i', 'u', 'b'}:
                    dtype = np.int64
                else:
                    dtype = np.float64
                cm = coo_matrix((sample_weight, (y_true, y_pred)),
                                shape=(n_labels, n_labels), dtype=dtype,
                                ).toarray()
        results.append({
            'n': n,
            'label': ti.label,
            'time': ti.mean(),
        })
    import pandas as pd
    df = pd.DataFrame(results)

    import kwplot
    import seaborn as sns
    kwplot.autoplt()
    sns.set()
    ax = sns.lineplot(data=df, x='n', y='time', hue='label')
    ax.set_yscale('log')
    ax.set_xscale('log')

image

However, the amount of time it takes to actually call _check_sample_weight is much greater than the way the code currently is:

    import ubelt as ub

    from sklearn.utils.validation import _check_sample_weight
    import numpy as np
    results = []
    ns = np.logspace(1, 6, 100).astype(np.int)
    for n in ub.ProgIter(ns, desc='time-tradeoff', verbose=3):
        print('n = {!r}'.format(n))

        y_true = np.random.randint(0, 100, n).astype(np.int64)
        y_pred = np.random.randint(0, 100, n).astype(np.int64)
        sample_weight = np.random.rand(n)

        import timerit
        ti = timerit.Timerit(9, bestof=3, verbose=2)
        for timer in ti.reset('old-sample-weight-given'):
            with timer:
                np.asarray(sample_weight)
        results.append({
            'n': n,
            'label': ti.label,
            'time': ti.mean(),
        })

        for timer in ti.reset('new-sample-weight-given'):
            with timer:
                _check_sample_weight(sample_weight, y_true, dtype=np.int64)
        results.append({
            'n': n,
            'label': ti.label,
            'time': ti.mean(),
        })

        for timer in ti.reset('old-sample-weight-default'):
            with timer:
                np.ones(y_true.shape[0], dtype=np.int64)
        results.append({
            'n': n,
            'label': ti.label,
            'time': ti.mean(),
        })

        for timer in ti.reset('new-sample-weight-default'):
            with timer:
                _check_sample_weight(None, y_true, dtype=np.int64)
        results.append({
            'n': n,
            'label': ti.label,
            'time': ti.mean(),
        })

    import pandas as pd
    df = pd.DataFrame(results)

    import kwplot
    import seaborn as sns
    kwplot.autoplt()
    sns.set()
    ax = sns.lineplot(data=df, x='n', y='time', hue='label')
    ax.set_yscale('log')
    ax.set_xscale('log')

image

But they do seem to converge towards each other as the number of items grows large. Perhaps it's reasonable to make this replacement.

I just ran all my benchmarks using _check_sample_weight and the speed isn't much different, so it makes sense to make this change for readability and API consistency. I'll patch that in.

Copy link
Contributor Author

@Erotemic Erotemic Mar 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made a mistake, it is noticeably different, but its not that significant. The old way averages out at about 0.1% of the function computation and the new way is 3x slower, taking about 0.3%. However, the computation doesn't grow as fast as other parts of the code, so this will never be a bottleneck so I'm still for replacing the 4 lines with 1 function call.

It would be nice in the future if _check_sample_weight could return a weight array with dtype uint8. While it may not have much speed benefit, it would be more memory efficient.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I was wrong again. It is causing a noticeable slowdown to the point where I'm not ok with it anymore. In my previous test, I ran one variant right after the other, so it never hit all of its cases.

With the original way we have:

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    82       432       1031.0      2.4      0.1      if sample_weight is None:
    83       216       3294.0     15.2      0.2          sample_weight = np.ones(y_true.shape[0], dtype=np.int64)
    84                                               else:
    85       216        843.0      3.9      0.0          sample_weight = np.asarray(sample_weight)

But with the new way we have:

   201       432      40067.0     92.7      3.9      sample_weight = _check_sample_weight(sample_weight, y_true, dtype=np.uint8)

That's going from 0.3% to 3.9%. That's too big of a jump for my taste, I'm going to revert to the previous method.

Copy link
Member

@ogrisel ogrisel Mar 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure this has any meaningful impact in practice?

Looking at the code of _check_sample_weight, it should be efficient enough. Try without the profiler enabled on a large enough on a call to confusion_matrix on a dataset that is large enough for performance to matter (e.g. at least one second).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here some results on my machine:

(dev) ogrisel@mba scikit-learn % git diff
diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py
index b4ab145d8..35c5aab5f 100644
--- a/sklearn/metrics/_classification.py
+++ b/sklearn/metrics/_classification.py
@@ -34,6 +34,7 @@ from ..utils import assert_all_finite
 from ..utils import check_array
 from ..utils import check_consistent_length
 from ..utils import column_or_1d
+from ..utils.validation import _check_sample_weight
 from ..utils.multiclass import unique_labels
 from ..utils.multiclass import type_of_target
 from ..utils.validation import _num_samples
@@ -312,12 +313,9 @@ def confusion_matrix(y_true, y_pred, *, labels=None, sample_weight=None,
         elif len(np.intersect1d(y_true, labels)) == 0:
             raise ValueError("At least one label specified must be in y_true")
 
-    if sample_weight is None:
-        sample_weight = np.ones(y_true.shape[0], dtype=np.int64)
-    else:
-        sample_weight = np.asarray(sample_weight)
-
-    check_consistent_length(y_true, y_pred, sample_weight)
+    check_consistent_length(y_true, y_pred)
+    sample_weight = _check_sample_weight(sample_weight, y_true,
+                                         dtype=np.float64)
 
     if normalize not in ['true', 'pred', 'all', None]:
         raise ValueError("normalize must be one of {'true', 'pred', "
(dev) ogrisel@mba scikit-learn % python  ~/tmp/time_confusion_matrix.py
1.399 +/- 0.005 s
(dev) ogrisel@mba scikit-learn % git stash 
Saved working directory and index state WIP on speedup_confusion_matrix: 4b8e1e864 Use faster intersect1d instead of list comprehension
(dev) ogrisel@mba scikit-learn % python  ~/tmp/time_confusion_matrix.py
1.408 +/- 0.005 s

So no significant difference. It's in the noise.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually thinking about it, it's fine to keep the code the way it is.

If no sample_weight it's passed we would like to have int64 to allow for precise integer outputs.

If the user is passing float sample_weight we do not want to cast them to integer because that would be wrong (and the output will be floating points but that's fine).

Using _check_sample_weight will either break on or the other of the above statements. I therefore think the current code is fine. We could probably add a test to check this behavior but this is out of the scope of this PR.

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given how long this PR has been open, lets see if we can work toward merging it.

y_true = y_true[ind]
# also eliminate weights of eliminated items
sample_weight = sample_weight[ind]
if not np.all(ind):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does including this greatly improve performance?

I think we can revert this change and only have needs_index_conversion so we can get this merged quicker.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember it being significant when I originally wrote this, which was awhile ago. It will depend on the use case, but I believe it does give a noticeable speed up to the fastest case with a minimal impact on the slow cases.

I would have to benchmark again to be sure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I was able to come up with a small benchmark that at least demonstrates the efficacy of the idea. In most cases the true and pred labels are all going to be "valid" -- i.e. between 0 and n_labels. In this case the question is: do we gain a benefit by checking if we should do the __getitem__ operation before we do it? Or is the __getitem__ fast enough such that the check isn't saving us time.

Furthermore, we have the less common case where there are invalid labels. In that case, does doing this extra check (which by definition will fail) cost us that much in overhead?

To test this I came up with this code, which simply times the all and the __getitem__ operation.

    import ubelt as ub
    import numpy as np
    results = []
    ns = np.logspace(1, 7, 100).astype(np.int)
    for n in ub.ProgIter(ns, desc='time-tradeoff', verbose=3):
        print('n = {!r}'.format(n))
        y_true = np.random.randint(0, 100, n).astype(np.int64)
        y_pred = np.random.randint(0, 100, n).astype(np.int64)
        sample_weight = np.random.rand(n)

        isvalid = np.random.rand(n) > 0

        import timerit
        ti = timerit.Timerit(9, bestof=3, verbose=2)
        for timer in ti.reset('all-check'):
            with timer:
                np.all(isvalid)
        results.append({
            'n': n,
            'label': ti.label,
            'time': ti.mean(),
        })

        for timer in ti.reset('all-index'):
            with timer:
                y_true[isvalid]
                y_pred[isvalid]
                sample_weight[isvalid]
        results.append({
            'n': n,
            'label': ti.label,
            'time': ti.mean(),
        })

    df = pd.DataFrame(results)

    import kwplot
    import seaborn as sns
    kwplot.autoplt()
    sns.set()
    ax = sns.lineplot(data=df, x='n', y='time', hue='label')
    ax.set_yscale('log')
    ax.set_xscale('log')

which produces

image

So in the case where there are more than 1000 elements (the case that we start to really care about speed!), we get an order of magnitude difference in how long each of those lines takes.

In the case where we must incur both costs, again the check is an order of magnitude faster than the actual __getitem__ call, so it doesn't add much overhead.

Copy link
Contributor Author

@Erotemic Erotemic Mar 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now the question is, how much does it matter in the grand scheme of this function?

Well, running the benchmark suite we see on average this line takes about 0.8% of the total time (which in this case was 0.735286 s)

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   270                                               # intersect y_pred, y_true with labels, eliminate items not in labels
   271       324       4191.0     12.9      0.6      ind = np.logical_and(y_pred < n_labels, y_true < n_labels)
   272       324       5657.0     17.5      0.8      if not np.all(ind):
   273                                                   y_pred = y_pred[ind]
   274                                                   y_true = y_true[ind]
   275                                                   # also eliminate weights of eliminated items
   276                                                   sample_weight = sample_weight[ind]

Now if we comment that portion of the code and reprofile we see:

   270                                               # intersect y_pred, y_true with labels, eliminate items not in labels
   271       324       4272.0     13.2      0.6      ind = np.logical_and(y_pred < n_labels, y_true < n_labels)
   272                                               # if not np.all(ind):
   273       324       2541.0      7.8      0.4      y_pred = y_pred[ind]
   274       324       2061.0      6.4      0.3      y_true = y_true[ind]
   275                                               # also eliminate weights of eliminated items
   276       324       2182.0      6.7      0.3      sample_weight = sample_weight[ind]

which takes about 1.0% of the time on average. So there is benefit to having this check! But not much. However, that's only part of the story. The benchmark script is running over inputs of size 10, 1000, and 10000, and we know from above that this starts to make a difference when the input is larger. So how much?

Lets restrict inputs sizes to 10000 and 100000.

Now, with the check we see:

   270                                               # intersect y_pred, y_true with labels, eliminate items not in labels
   271       216      17120.0     79.3      0.4      ind = np.logical_and(y_pred < n_labels, y_true < n_labels)
   272       216       5704.0     26.4      0.1      if not np.all(ind):
   273                                                   y_pred = y_pred[ind]
   274                                                   y_true = y_true[ind]
   275                                                   # also eliminate weights of eliminated items
   276                                                   sample_weight = sample_weight[ind]

and without we get

   270                                               # intersect y_pred, y_true with labels, eliminate items not in labels
   271       216      17222.0     79.7      0.4      ind = np.logical_and(y_pred < n_labels, y_true < n_labels)
   272                                               # if not np.all(ind):
   273       216      19397.0     89.8      0.4      y_pred = y_pred[ind]
   274       216      18590.0     86.1      0.4      y_true = y_true[ind]
   275                                               # also eliminate weights of eliminated items
   276       216      22554.0    104.4      0.5      sample_weight = sample_weight[ind]

so that's a 0.1% of the time versus 1.3% of the time. That is a very significant difference, and I believe is enough to justify existence of the line. Let me know what you think. @thomasjpfan

Copy link
Member

@ogrisel ogrisel Mar 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly we don't really care about a 1% performance change of a function call (if that percent stays constant with the dataset size). But has this specific code snippet is simple, I am fine with keeping it.

Base automatically changed from master to main January 22, 2021 10:49
@cmarmo
Copy link
Contributor

cmarmo commented Mar 10, 2021

Hi @Erotemic will you be able to address the last review? Thanks!

@Erotemic
Copy link
Contributor Author

@cmarmo I can do a deep dive and really address concerns on what components cause significant speedup versus what doesn't. It might take me a week or two to get to that though. Let me know if that would be valuable.

@cmarmo
Copy link
Contributor

cmarmo commented Mar 10, 2021

@Erotemic , I believe benchmarks are always very appreciated and surely improve the quality of the contribution, but your PR has already waited for so long, I guess this was also @thomasjpfan concern... If you are motivated to dive in and you think this is doable in one two weeks ...

@Erotemic
Copy link
Contributor Author

While doing the benchmarks I noticed that the ("At least one label specified must be in y_true") check is pretty slow as written:

elif np.all([l not in y_true for l in labels]):

Which took about 430 microseconds on average versus:

elif len(np.intersect1d(y_true, labels)) == 0:

which took 153 microseconds. That is actually reasonably significant as it reduces that line from 10% of the average computation to about 4.7% of the computation. A clear win, and it makes sense from a theoretical perspective, the latter is O(N + D) whereas the previous is O(N * D). It seems like a clear win to upgrade that.

Benchmark with the list comprehension

Pystone time: 0.777921 s
File: /home/joncrall/misc/tests/python/bench_confusion_matrix.py
Function: confusion_matrix_2021_new at line 179

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   179                                           @xdev.profile
   180                                           def confusion_matrix_2021_new(y_true, y_pred, *, labels=None,
   181                                                                         sample_weight=None, normalize=None):
   182       324     143350.0    442.4     18.4      y_type, y_true, y_pred = _check_targets(y_true, y_pred)
   183       324        740.0      2.3      0.1      if y_type not in ("binary", "multiclass"):
   184                                                   raise ValueError("%s is not supported" % y_type)
   185                                           
   186       324        573.0      1.8      0.1      if labels is None:
   187       108      59921.0    554.8      7.7          labels = unique_labels(y_true, y_pred)
   188                                               else:
   189       216        766.0      3.5      0.1          labels = np.asarray(labels)
   190       216        427.0      2.0      0.1          n_labels = labels.size
   191       216        403.0      1.9      0.1          if n_labels == 0:
   192                                                       raise ValueError("'labels' should contains at least one label.")
   193       216        398.0      1.8      0.1          elif y_true.size == 0:
   194                                                       return np.zeros((n_labels, n_labels), dtype=int)
   195                                                   else:
   196                                                       # flag1 = len(np.intersect1d(y_true, labels)) == 0
   197       216      88756.0    410.9     11.4              flag2 = np.all([l not in y_true for l in labels])
   198       216        484.0      2.2      0.1              if flag2:
   199                                                           raise ValueError("At least one label specified must be in y_true")
   200                                           
   201       324        634.0      2.0      0.1      if sample_weight is None:
   202       162       2349.0     14.5      0.3          sample_weight = np.ones(y_true.shape[0], dtype=np.int64)
   203                                               else:
   204       162        598.0      3.7      0.1          sample_weight = np.asarray(sample_weight)
   205                                           
   206       324      19000.0     58.6      2.4      check_consistent_length(y_true, y_pred, sample_weight)
   207                                           
   208       324        772.0      2.4      0.1      if normalize not in ['true', 'pred', 'all', None]:
   209                                                   raise ValueError("normalize must be one of {'true', 'pred', "
   210                                                                    "'all', None}")
   211                                           
   212       324        641.0      2.0      0.1      n_labels = labels.size
   213                                               # If labels are not consecutive integers starting from zero, then
   214                                               # y_true and y_pred must be converted into index form
   215       324        609.0      1.9      0.1      need_index_conversion = not (
   216       840       1849.0      2.2      0.2          labels.dtype.kind in {'i', 'u', 'b'} and
   217       324       7583.0     23.4      1.0          np.all(labels == np.arange(n_labels)) and
   218       384       5423.0     14.1      0.7          y_true.min() >= 0 and y_pred.min() >= 0
   219                                               )
   220       324        546.0      1.7      0.1      if need_index_conversion:
   221       132       1919.0     14.5      0.2          label_to_ind = {y: x for x, y in enumerate(labels)}
   222       132     171946.0   1302.6     22.1          y_pred = np.array([label_to_ind.get(x, n_labels + 1) for x in y_pred])
   223       132     169948.0   1287.5     21.8          y_true = np.array([label_to_ind.get(x, n_labels + 1) for x in y_true])
   224                                           
   225                                               # intersect y_pred, y_true with labels, eliminate items not in labels
   226       324       4107.0     12.7      0.5      ind = np.logical_and(y_pred < n_labels, y_true < n_labels)
   227       324       5427.0     16.8      0.7      if not np.all(ind):
   228                                                   y_pred = y_pred[ind]
   229                                                   y_true = y_true[ind]
   230                                                   # also eliminate weights of eliminated items
   231                                                   sample_weight = sample_weight[ind]
   232                                           
   233                                               # Choose the accumulator dtype to always have high precision
   234       324        864.0      2.7      0.1      if sample_weight.dtype.kind in {'i', 'u', 'b'}:
   235       162        320.0      2.0      0.0          dtype = np.int64
   236                                               else:
   237       162        334.0      2.1      0.0          dtype = np.float64
   238                                           
   239       648      58290.0     90.0      7.5      cm = coo_matrix((sample_weight, (y_true, y_pred)),
   240       324        574.0      1.8      0.1                      shape=(n_labels, n_labels), dtype=dtype,
   241                                                               ).toarray()
   242                                           
   243       324       6118.0     18.9      0.8      with np.errstate(all='ignore'):
   244       324        678.0      2.1      0.1          if normalize == 'true':
   245                                                       cm = cm / cm.sum(axis=1, keepdims=True)
   246       324        585.0      1.8      0.1          elif normalize == 'pred':
   247                                                       cm = cm / cm.sum(axis=0, keepdims=True)
   248       324        557.0      1.7      0.1          elif normalize == 'all':
   249                                                       cm = cm / cm.sum()
   250       324      19772.0     61.0      2.5          cm = np.nan_to_num(cm)
   251                                           
   252       324        660.0      2.0      0.1      return cm

Versus the one with intersect:

Pystone time: 0.735806 s
File: /home/joncrall/misc/tests/python/bench_confusion_matrix.py
Function: confusion_matrix_2021_new at line 179

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   179                                           @xdev.profile
   180                                           def confusion_matrix_2021_new(y_true, y_pred, *, labels=None,
   181                                                                         sample_weight=None, normalize=None):
   182       324     150859.0    465.6     20.5      y_type, y_true, y_pred = _check_targets(y_true, y_pred)
   183       324        773.0      2.4      0.1      if y_type not in ("binary", "multiclass"):
   184                                                   raise ValueError("%s is not supported" % y_type)
   185                                           
   186       324        549.0      1.7      0.1      if labels is None:
   187       108      62184.0    575.8      8.5          labels = unique_labels(y_true, y_pred)
   188                                               else:
   189       216        802.0      3.7      0.1          labels = np.asarray(labels)
   190       216        437.0      2.0      0.1          n_labels = labels.size
   191       216        363.0      1.7      0.0          if n_labels == 0:
   192                                                       raise ValueError("'labels' should contains at least one label.")
   193       216        394.0      1.8      0.1          elif y_true.size == 0:
   194                                                       return np.zeros((n_labels, n_labels), dtype=int)
   195                                                   else:
   196       216      34387.0    159.2      4.7              flag1 = len(np.intersect1d(y_true, labels)) == 0
   197                                                       # flag2 = np.all([l not in y_true for l in labels])
   198       216        465.0      2.2      0.1              if flag1:
   199                                                           raise ValueError("At least one label specified must be in y_true")
   200                                           
   201       324        621.0      1.9      0.1      if sample_weight is None:
   202       162       2462.0     15.2      0.3          sample_weight = np.ones(y_true.shape[0], dtype=np.int64)
   203                                               else:
   204       162        641.0      4.0      0.1          sample_weight = np.asarray(sample_weight)
   205                                           
   206       324      19952.0     61.6      2.7      check_consistent_length(y_true, y_pred, sample_weight)
   207                                           
   208       324        792.0      2.4      0.1      if normalize not in ['true', 'pred', 'all', None]:
   209                                                   raise ValueError("normalize must be one of {'true', 'pred', "
   210                                                                    "'all', None}")
   211                                           
   212       324        633.0      2.0      0.1      n_labels = labels.size
   213                                               # If labels are not consecutive integers starting from zero, then
   214                                               # y_true and y_pred must be converted into index form
   215       324        633.0      2.0      0.1      need_index_conversion = not (
   216       846       1806.0      2.1      0.2          labels.dtype.kind in {'i', 'u', 'b'} and
   217       324       8494.0     26.2      1.2          np.all(labels == np.arange(n_labels)) and
   218       396       6068.0     15.3      0.8          y_true.min() >= 0 and y_pred.min() >= 0
   219                                               )
   220       324        514.0      1.6      0.1      if need_index_conversion:
   221       126       2059.0     16.3      0.3          label_to_ind = {y: x for x, y in enumerate(labels)}
   222       126     168675.0   1338.7     22.9          y_pred = np.array([label_to_ind.get(x, n_labels + 1) for x in y_pred])
   223       126     165491.0   1313.4     22.5          y_true = np.array([label_to_ind.get(x, n_labels + 1) for x in y_true])
   224                                           
   225                                               # intersect y_pred, y_true with labels, eliminate items not in labels
   226       324       4633.0     14.3      0.6      ind = np.logical_and(y_pred < n_labels, y_true < n_labels)
   227       324       6149.0     19.0      0.8      if not np.all(ind):
   228                                                   y_pred = y_pred[ind]
   229                                                   y_true = y_true[ind]
   230                                                   # also eliminate weights of eliminated items
   231                                                   sample_weight = sample_weight[ind]
   232                                           
   233                                               # Choose the accumulator dtype to always have high precision
   234       324        862.0      2.7      0.1      if sample_weight.dtype.kind in {'i', 'u', 'b'}:
   235       162        323.0      2.0      0.0          dtype = np.int64
   236                                               else:
   237       162        338.0      2.1      0.0          dtype = np.float64
   238                                           
   239       648      62685.0     96.7      8.5      cm = coo_matrix((sample_weight, (y_true, y_pred)),
   240       324        555.0      1.7      0.1                      shape=(n_labels, n_labels), dtype=dtype,
   241                                                               ).toarray()
   242                                           
   243       324       6615.0     20.4      0.9      with np.errstate(all='ignore'):
   244       324        667.0      2.1      0.1          if normalize == 'true':
   245                                                       cm = cm / cm.sum(axis=1, keepdims=True)
   246       324        559.0      1.7      0.1          elif normalize == 'pred':
   247                                                       cm = cm / cm.sum(axis=0, keepdims=True)
   248       324        533.0      1.6      0.1          elif normalize == 'all':
   249                                                       cm = cm / cm.sum()
   250       324      21223.0     65.5      2.9          cm = np.nan_to_num(cm)
   251                                           
   252       324        610.0      1.9      0.1      return cm

Hopefully its not a big deal to make minor changes to this PR. Given that I'm putting the time to do the analysis, I figure I might as well make simple improvements when I see them.

@Erotemic
Copy link
Contributor Author

Erotemic commented Mar 11, 2021

I added a patch for @glemaitre's comment to use _check_sample_weight. I do want to note that that function currently does not allow defaulting to a np.uint8 weight array, and this function would be marginally more efficient if that was allowed. I reverted that patch, it was too noticeable a dip in performance.

I'm keeping my benchmark script here: https://github.com/Erotemic/misc/blob/master/tests/python/bench_confusion_matrix.py

Averaging over all results I'm currently seeing these numbers. Let me know if you want detailed benchmarks. Its not much different from the original ones I posted.

I tested the versions of this function from the original 2017 variant I started with, my 2017 patch, the 2021 master branch, and the 2021 variant of this patch.

2017_old                 0.003075
2017_new                 0.001756
2021_old                 0.003113
2021_new                 0.001662
speedup_2017_wrt_2017    2.068655
speedup_2021_wrt_2017    2.155297
speedup_2021_wrt_2021    2.203942

With the new intersect1d check instead of a list comprehension we are even faster than we used to be!

I also rebased on master main.

@Erotemic Erotemic force-pushed the speedup_confusion_matrix branch 2 times, most recently from d6dc056 to 4b8e1e8 Compare March 11, 2021 05:36
@cmarmo cmarmo added this to the 1.0 milestone Mar 11, 2021
Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@ogrisel
Copy link
Member

ogrisel commented Mar 11, 2021

@thomasjpfan I think all your comments have been addressed.

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for working on this PR and your patience through the review process @Erotemic ! The benchmarks you provided were very useful :)

LGTM

@thomasjpfan thomasjpfan changed the title [MRG] speedup confusion_matrix ENH Speedup confusion_matrix Mar 11, 2021
@thomasjpfan thomasjpfan merged commit 5980455 into scikit-learn:main Mar 11, 2021
@Erotemic
Copy link
Contributor Author

Thanks everyone!

@glemaitre glemaitre mentioned this pull request Apr 22, 2021
12 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants