Skip to content

ENH speedup coordinate descent by avoiding calls to axpy in innermost loop #31956

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

lorentzenchr
Copy link
Member

@lorentzenchr lorentzenchr commented Aug 16, 2025

Reference Issues/PRs

Similar to #31880.
Continues and fixes #15931.

What does this implement/fix? Explain your changes.

This PR avoids calls to _axpy in the innermost loop of all coordinate descent solvers (Lasso and Enet), except enet_coordinate_descent_gram which was done in #31880.

Any other comments?

Ironically, this improvement also reduces code size 😄

For reviewers: better merge #31957 first.

Copy link

github-actions bot commented Aug 16, 2025

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 6b4b72c. Link to the linter CI: here

@lorentzenchr lorentzenchr mentioned this pull request Aug 16, 2025
2 tasks
@lorentzenchr
Copy link
Member Author

Benchmarking

Benchmarking code from #17021.
Total time:

                                   time
n_samples n_features n_tasks           
100       300        2         0.045615
                     10        0.157946
                     20        0.293871
                     50        0.746838
          1000       2         0.132510
                     10        0.446886
                     20        0.891311
                     50        2.198235
          4000       2         0.460885
                     10        1.470541
                     20        2.614343
                     50        6.619404
500       300        2         0.003803
                     10        0.014251
                     20        0.025303
                     50        0.056270
          1000       2         0.084915
                     10        0.850941
                     20        1.947459
                     50        7.439827
          4000       2         2.597164
                     10        6.517381
                     20       10.983630
                     50       27.935731
                                   time
n_samples n_features n_tasks           
100       300        2         0.051384
                     10        0.200898
                     20        0.394395
                     50        1.013719
          1000       2         0.143790
                     10        0.555918
                     20        1.072670
                     50        2.938247
          4000       2         0.474804
                     10        1.558145
                     20        2.952639
                     50        7.527227
500       300        2         0.004047
                     10        0.017960
                     20        0.030633
                     50        0.078154
          1000       2         0.097164
                     10        1.097646
                     20        3.198540
                     50       10.687953
          4000       2         2.673210
                     10        7.666225
                     20       14.532951
                     50       40.475398

Time Ratios

                              time (old) / time (new)
n_samples n_features n_tasks                         
100       300        2                       1.126472
                     10                      1.271942
                     20                      1.342068
                     50                      1.357348
          1000       2                       1.085124
                     10                      1.243982
                     20                      1.203474
                     50                      1.336639
          4000       2                       1.030201
                     10                      1.059572
                     20                      1.129400
                     50                      1.137146
500       300        2                       1.064197
                     10                      1.260285
                     20                      1.210649
                     50                      1.388911
          1000       2                       1.144253
                     10                      1.289920
                     20                      1.642417
                     50                      1.436586
          4000       2                       1.029280
                     10                      1.176274
                     20                      1.323146
                     50                      1.448876

Code

file mtl_bench.py

"""
Benchmark of MultiTaskLasso
"""
import gc
from itertools import product
from time import time
import numpy as np
import pandas as pd

from sklearn.datasets import make_regression
from sklearn.linear_model import MultiTaskLasso


def compute_bench(alpha, n_samples, n_features, n_tasks):
    results = []
    n_bench = len(n_samples) * len(n_features) * len(n_tasks)

    for it, (ns, nf, nt) in enumerate(product(n_samples, n_features, n_tasks)):
        print('==================')
        print('Iteration %s of %s' % (it, n_bench))
        print('==================')
        n_informative = nf // 10
        X, Y, coef_ = make_regression(n_samples=ns, n_features=nf,
                                      n_informative=n_informative,
                                      n_targets=nt,
                                      noise=0.1, coef=True)

        X /= np.sqrt(np.sum(X ** 2, axis=0))  # Normalize data

        gc.collect()
        clf = MultiTaskLasso(alpha=alpha, fit_intercept=False)
        tstart = time()
        clf.fit(X, Y)
        results.append(
            dict(n_samples=ns, n_features=nf, n_tasks=nt, time=time() - tstart)
        )

    return pd.DataFrame(results)


def compare_results():
    results_new = pd.read_csv('mlt_new.csv').set_index(['n_samples', 'n_features', 'n_tasks'])
    results_old = pd.read_csv('mlt_old.csv').set_index(['n_samples', 'n_features', 'n_tasks'])
    results_ratio = (results_old / results_new)
    results_ratio.columns = ['time (old) / time (new)']
    print(results_new)
    print(results_old)
    print(results_ratio)


if __name__ == '__main__':
    import matplotlib.pyplot as plt

    alpha = 0.01  # regularization parameter

    list_n_features = [300, 1000, 4000]
    list_n_samples = [100, 500]
    list_n_tasks = [2, 10, 20, 50]
    results = compute_bench(alpha, list_n_samples,
                            list_n_features, list_n_tasks)

    # results.to_csv('mlt_old.csv', index=False)
    results.to_csv('mlt_new.csv', index=False)

    compare_results()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant