-
-
Notifications
You must be signed in to change notification settings - Fork 26.1k
ENH speedup coordinate descent by avoiding calls to axpy in innermost loop #31956
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
lorentzenchr
wants to merge
4
commits into
scikit-learn:main
Choose a base branch
from
lorentzenchr:cd_remove_one_axpy
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+81
−105
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
BenchmarkingBenchmarking code from #17021.
Time Ratios
Code file mtl_bench.py """
Benchmark of MultiTaskLasso
"""
import gc
from itertools import product
from time import time
import numpy as np
import pandas as pd
from sklearn.datasets import make_regression
from sklearn.linear_model import MultiTaskLasso
def compute_bench(alpha, n_samples, n_features, n_tasks):
results = []
n_bench = len(n_samples) * len(n_features) * len(n_tasks)
for it, (ns, nf, nt) in enumerate(product(n_samples, n_features, n_tasks)):
print('==================')
print('Iteration %s of %s' % (it, n_bench))
print('==================')
n_informative = nf // 10
X, Y, coef_ = make_regression(n_samples=ns, n_features=nf,
n_informative=n_informative,
n_targets=nt,
noise=0.1, coef=True)
X /= np.sqrt(np.sum(X ** 2, axis=0)) # Normalize data
gc.collect()
clf = MultiTaskLasso(alpha=alpha, fit_intercept=False)
tstart = time()
clf.fit(X, Y)
results.append(
dict(n_samples=ns, n_features=nf, n_tasks=nt, time=time() - tstart)
)
return pd.DataFrame(results)
def compare_results():
results_new = pd.read_csv('mlt_new.csv').set_index(['n_samples', 'n_features', 'n_tasks'])
results_old = pd.read_csv('mlt_old.csv').set_index(['n_samples', 'n_features', 'n_tasks'])
results_ratio = (results_old / results_new)
results_ratio.columns = ['time (old) / time (new)']
print(results_new)
print(results_old)
print(results_ratio)
if __name__ == '__main__':
import matplotlib.pyplot as plt
alpha = 0.01 # regularization parameter
list_n_features = [300, 1000, 4000]
list_n_samples = [100, 500]
list_n_tasks = [2, 10, 20, 50]
results = compute_bench(alpha, list_n_samples,
list_n_features, list_n_tasks)
# results.to_csv('mlt_old.csv', index=False)
results.to_csv('mlt_new.csv', index=False)
compare_results() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Reference Issues/PRs
Similar to #31880.
Continues and fixes #15931.
What does this implement/fix? Explain your changes.
This PR avoids calls to
_axpy
in the innermost loop of all coordinate descent solvers (Lasso and Enet), exceptenet_coordinate_descent_gram
which was done in #31880.Any other comments?
Ironically, this improvement also reduces code size 😄
For reviewers: better merge #31957 first.