Skip to content

ENH replace Cython loss functions in SGD part 2 #28029

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

lorentzenchr
Copy link
Member

Reference Issues/PRs

Follow-up of #27999 (which needs to be merged first). Partly addresses #15123.

What does this implement/fix? Explain your changes.

This PR replaces the Cython loss functions of SGD and SAGA with the ones from _loss (SquaredLoss, Huber, LogLoss) and inherits from _loss._loss.CyLossFunction for the remaining ones (Hinge, ..., and Multinomial).

Also, the loss functions form sklearn.linear_model.__init__ are removed.

Any other comments?

Only merge after release 1.5, i.e. this PR is to be released with v1.6.

Copy link

github-actions bot commented Dec 27, 2023

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 0dd76c3. Link to the linter CI: here

@lorentzenchr
Copy link
Member Author

asv benchmarks

asv compare b8d783d2 68f35c3e                   

All benchmarks:

| Change | Before [b8d783d2] <main> | After [68f35c3e] <replace_sgd_with_common_loss_part_2~1> | Ratio | Benchmark (Parameter)                                                            |
|--------|--------------------------|----------------------------------------------------------|-------|----------------------------------------------------------------------------------|
|        | 109M                     | 109M                                                     |  1    | linear_model.LogisticRegressionBenchmark.peakmem_fit('dense', 'lbfgs', 1)        |
|        | 90.1M                    | 90.1M                                                    |  1    | linear_model.LogisticRegressionBenchmark.peakmem_fit('dense', 'saga', 1)         |
|        | 385M                     | 384M                                                     |  1    | linear_model.LogisticRegressionBenchmark.peakmem_fit('sparse', 'lbfgs', 1)       |
|        | 110M                     | 110M                                                     |  1    | linear_model.LogisticRegressionBenchmark.peakmem_fit('sparse', 'saga', 1)        |
|        | 108M                     | 108M                                                     |  1    | linear_model.LogisticRegressionBenchmark.peakmem_predict('dense', 'lbfgs', 1)    |
|        | 94.3M                    | 94.6M                                                    |  1    | linear_model.LogisticRegressionBenchmark.peakmem_predict('dense', 'saga', 1)     |
|        | 108M                     | 109M                                                     |  1    | linear_model.LogisticRegressionBenchmark.peakmem_predict('sparse', 'lbfgs', 1)   |
|        | 96.5M                    | 96.7M                                                    |  1    | linear_model.LogisticRegressionBenchmark.peakmem_predict('sparse', 'saga', 1)    |
| +      | 15.2±2ms                 | 17.7±2ms                                                 |  1.16 | linear_model.LogisticRegressionBenchmark.time_fit('dense', 'lbfgs', 1)           |
|        | 1.97±0.04s               | 1.93±0.07s                                               |  0.98 | linear_model.LogisticRegressionBenchmark.time_fit('dense', 'saga', 1)            |
|        | 963±20ms                 | 963±20ms                                                 |  1    | linear_model.LogisticRegressionBenchmark.time_fit('sparse', 'lbfgs', 1)          |
| +      | 1.87±0.01s               | 2.12±0.06s                                               |  1.14 | linear_model.LogisticRegressionBenchmark.time_fit('sparse', 'saga', 1)           |
|        | 3.00±0.5ms               | 3.05±0.5ms                                               |  1.02 | linear_model.LogisticRegressionBenchmark.time_predict('dense', 'lbfgs', 1)       |
|        | 1.42±0.01ms              | 1.44±0.02ms                                              |  1.01 | linear_model.LogisticRegressionBenchmark.time_predict('dense', 'saga', 1)        |
|        | 6.24±0.04ms              | 6.31±0.09ms                                              |  1.01 | linear_model.LogisticRegressionBenchmark.time_predict('sparse', 'lbfgs', 1)      |
|        | 4.81±0.02ms              | 4.95±0.08ms                                              |  1.03 | linear_model.LogisticRegressionBenchmark.time_predict('sparse', 'saga', 1)       |
|        | 0.17488127353035546      | 0.17488127353035546                                      |  1    | linear_model.LogisticRegressionBenchmark.track_test_score('dense', 'lbfgs', 1)   |
|        | 0.7789925817802057       | 0.7789925817802057                                       |  1    | linear_model.LogisticRegressionBenchmark.track_test_score('dense', 'saga', 1)    |
|        | 0.06538461538461539      | 0.06538461538461539                                      |  1    | linear_model.LogisticRegressionBenchmark.track_test_score('sparse', 'lbfgs', 1)  |
|        | 0.5765140080078162       | 0.5765140080078162                                       |  1    | linear_model.LogisticRegressionBenchmark.track_test_score('sparse', 'saga', 1)   |
|        | 0.17920161231776         | 0.17920161231776                                         |  1    | linear_model.LogisticRegressionBenchmark.track_train_score('dense', 'lbfgs', 1)  |
|        | 0.7998934724948512       | 0.7998934724948512                                       |  1    | linear_model.LogisticRegressionBenchmark.track_train_score('dense', 'saga', 1)   |
|        | 0.0681998556998557       | 0.0681998556998557                                       |  1    | linear_model.LogisticRegressionBenchmark.track_train_score('sparse', 'lbfgs', 1) |
|        | 0.6908414295256007       | 0.6908414295256007                                       |  1    | linear_model.LogisticRegressionBenchmark.track_train_score('sparse', 'saga', 1)  |
|        | 165M                     | 165M                                                     |  1    | linear_model.SGDRegressorBenchmark.peakmem_fit('dense')                          |
|        | 93.8M                    | 93.7M                                                    |  1    | linear_model.SGDRegressorBenchmark.peakmem_fit('sparse')                         |
|        | 166M                     | 166M                                                     |  1    | linear_model.SGDRegressorBenchmark.peakmem_predict('dense')                      |
|        | 93.7M                    | 93.7M                                                    |  1    | linear_model.SGDRegressorBenchmark.peakmem_predict('sparse')                     |
|        | 4.50±0.02s               | 4.40±0.02s                                               |  0.98 | linear_model.SGDRegressorBenchmark.time_fit('dense')                             |
|        | 3.46±0.02s               | 3.51±0.01s                                               |  1.01 | linear_model.SGDRegressorBenchmark.time_fit('sparse')                            |
|        | 7.77±0.4ms               | 8.08±0.1ms                                               |  1.04 | linear_model.SGDRegressorBenchmark.time_predict('dense')                         |
|        | 1.68±0.03ms              | 1.75±0.04ms                                              |  1.04 | linear_model.SGDRegressorBenchmark.time_predict('sparse')                        |
|        | 0.9636293915890342       | 0.9636293915890342                                       |  1    | linear_model.SGDRegressorBenchmark.track_test_score('dense')                     |
|        | 0.961311884809733        | 0.961311884809733                                        |  1    | linear_model.SGDRegressorBenchmark.track_test_score('sparse')                    |
|        | 0.9641785427112692       | 0.9641785427112692                                       |  1    | linear_model.SGDRegressorBenchmark.track_train_score('dense')                    |
|        | 0.9621441831314548       | 0.9621441831314548                                       |  1    | linear_model.SGDRegressorBenchmark.track_train_score('sparse')                   |

@lorentzenchr
Copy link
Member Author

The failing tests seem to be caused by #28046.

@ogrisel
Copy link
Member

ogrisel commented Jan 10, 2024

There are conflict to resolve before being able to confirm if the merge of #28046 makes the tests pass.

@OmarManzoor
Copy link
Contributor

OmarManzoor commented Jul 18, 2024

@lorentzenchr would it be okay if I work on these PRs to take them forward?

Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you @lorentzenchr
The tests are passing and the benchmarks show no regression.

@OmarManzoor OmarManzoor added the Waiting for Second Reviewer First reviewer is done, need a second one! label Jul 19, 2024
@@ -1,26 +0,0 @@
# SPDX-License-Identifier: BSD-3-Clause
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's assume no other external project was using those functions.

@jjerphan jjerphan removed the Waiting for Second Reviewer First reviewer is done, need a second one! label Jul 23, 2024
@OmarManzoor
Copy link
Contributor

Thanks for the review Julien. I fixed the typo. I'll let you merge after having another look.

@jjerphan jjerphan merged commit 936a391 into scikit-learn:main Jul 24, 2024
30 checks passed
@lorentzenchr lorentzenchr deleted the replace_sgd_with_common_loss_part_2 branch August 4, 2024 10:02
MarcBresson pushed a commit to MarcBresson/scikit-learn that referenced this pull request Sep 2, 2024
Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Development

Successfully merging this pull request may close these issues.

4 participants