Skip to content

ENH Loss module LogisticRegression #21808

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 57 commits into from
Feb 14, 2022

Conversation

lorentzenchr
Copy link
Member

Reference Issues/PRs

Follow-up of #20567.

What does this implement/fix? Explain your changes.

This PR replaces the losses of LogisticRegression by the new common loss module of #20567.

Any other comments?

It implements LinearLoss in linear_model._linear_loss.py as helper-wrapper around loss functions such that the raw_prediction = X @ coef is taken into account. Ideally, this class can be used for the GLMs like PoissonRegressor, too.

This is similar to #19089, but replaces losses only for LogisticRegression.

@lorentzenchr lorentzenchr changed the title [MRG] ENH Loss module LogisticRegression [WIP] ENH Loss module LogisticRegression Nov 28, 2021
@lorentzenchr
Copy link
Member Author

lorentzenchr commented Dec 1, 2021

Benchmark of Fit Time of LogisticRegression lbfgs and newton-cg

Edit: Updated 17.12.2021

N=n_samples, n_features=50
Error bars from 10 runs each.
Full code: https://github.com/lorentzenchr/notebooks/blob/master/bench_loss_module_logistic_and_hgbt.ipynb

Hardware: Intel Core i7-8559U, 8th generation, 16 GB RAM
Software: Python 3.7.9, numpy 1.21.4, scipy 1.7.3
This PR (based on dea9bf0) and master (d09e1d7) both compiled the same way.

Results

logistic_benchmark

n_threads=6 set via OMP_NUM_THREADS=6 (hardware maximum is 8).

Summary

Within error bounds, this PR only improves fit times for n_threads=1.
Note that improved fit times is not the (primary) goal of this PR.

Binary Logistic Regression (left plots)

Using multiple cores adds a runtime penalty that starts to be beneficial for larger sample sizes.

Multiclass/Multinomial Logistic Regression: (right plots)

Same as binary. Note that memory consumption should be lower as this PR uses label encoded target (aka ordinal encoding) while master uses binarized label encoding (aka one-hot encoding).

@lorentzenchr lorentzenchr changed the title [WIP] ENH Loss module LogisticRegression [MRG] ENH Loss module LogisticRegression Dec 1, 2021
@lorentzenchr
Copy link
Member Author

@agramfort @rth You might be interested.

Copy link
Member

@agramfort agramfort left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lorentzenchr can you point me to the lines that do the magic here?

Copy link
Member Author

@lorentzenchr lorentzenchr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments on important code segments.

Comment on lines 1210 to 1211
if solver in ["lbfgs", "newton-cg"] and len(classes_) == 1:
n_threads = _openmp_effective_n_threads()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, this should also depend on n_jobs, e.g. n_threads = min(n_jobs, _openmp_effective_n_threads()).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We decided to restrict n_jobs to control joblib-level parallelism in other estimators (KMeans, HGBDT). So for consistency I would stick to that policy for now. To control the number of OpenMP (and possibly also OpenBLAS threads) users can rely on threadpoolctl which is a dependency of scikit-learn.

However, I suspect that the iterated calls to OpenMP cython code and OpenBLAS code can lead to performance slow downs when OpenBLAS using its own native threading layer instead of the same OpenMP runtime as scikit-learn. This can be confirmed by comparing the performance of this PR in 2 environments:

Copy link
Member

@ogrisel ogrisel Dec 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Profiling on a MNIST shaped problem shows that this does not seem a problem in practice. Maybe this is because the 2 OpenBLAS calls dominate the Cython calls that are comparatively negligible.

I only tried on macOS though. Maybe the OpenBLAS/OpenMP story is different on Linux.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the conclusion? No code change needed?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only tried on macOS though. Maybe the OpenBLAS/OpenMP story is different on Linux.

@ogrisel: I can experiment with that, do you have your script at hand? If not, no problem.

Copy link
Member

@ogrisel ogrisel Dec 16, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No code change needed?

I don't think code change is required now to change the way Cython/prange/OpenMP threading is used in this PR but it might be important to keep this potential problem in mind if we observe slowe than expected performance on data with different shapes (e.g. smaller data size maybe where the interacting threadpool problem might still be present or in other locations in the scikit-learn code base whenever with have the following code pattern:

while not converged:
     step 1: Cython prange loop that uses an OpenMP threads
     step 2: numpy matrix multiply / scipy.linalg BLAS open that uses OpenBLAS threads

The long term fix for this kind of problems would ideally happen at the community level in upstream packages (scipy/scipy#15129) but if ever we need to workaround it for a specific in scikit-learn before that happens, we can always decide to use threadpoolctl (or just passing num_threads=1 to prange) to disable threading either for step 1 or step 2 when we detect that OpenBLAS is packaged in a way that it relies on its native threadpool instead of sharing the OpenMP threadpool.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OpenMathLib/OpenBLAS#3187 relates to the issue @ogrisel mentions regarding the following code pattern.

Copy link
Member

@ogrisel ogrisel Dec 16, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, looking at this again, and independently of the OpenBLAS / OpenMP threadpool interaction hypothetical problem, there might be problematic case when n_jobs > 1 because we could get over-subscription between joblib parallelism and openmp parallelism. Maybe we should always hardcode n_threads=1 to stay on the same side for this PR.

We might want to add a TODO comment to say that we should refactor this all to avoid using joblib parallelism entirely when doing binary and multinomial multiclass classification and only use joblib only for for the one-vs-rest multiclass case. This is probably beyond the scope of this PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in dea9bf0 (sorry for the stupid commit message).

Comment on lines +308 to +312
# Here we may safely assume HalfMultinomialLoss aka categorical
# cross-entropy.
# HalfMultinomialLoss computes only the diagonal part of the hessian, i.e.
# diagonal in the classes. Here, we want the matrix-vector product of the
# full hessian. Therefore, we call gradient_proba.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really special for multinomial case / n_classes >= 3. The hessian returned by HalfMultinomialLoss is only the diagonal part, diagonal in classes. Therefore, we call the method gradient_proba which exists only for HalfMultinomialLoss for exactly this reason: to define hessp such that it computes products of a vector with the full hessian for which gradient and predicted probabilties are enough.

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is a first round of feedback. I did some profiling on macOS with MNIST

Here is the flamegraph generated by py-spy:

image

The two numpy matrix multiplications that involve X clearly dominate the duration of one iteration (as seen in the line numbers of the flamegraph above: _linear_loss.py line 91 and 182).

The performance is very similar to that of main. The Cython code is almost invisible compared to the numpy matrix multiplications.

Here is the benchmark script I used:

from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
from joblib import Memory
from sklearn.preprocessing import scale
from time import perf_counter


# Cache in memory to avoid profiling data generation
m = Memory(location="/tmp/joblib")
make_classification = m.cache(make_classification)

# MNIST-sized dataset
X, y = make_classification(
    n_samples=60_000, n_features=784, n_informative=500, n_classes=10,
    random_state=0
)
X = scale(X)

print("Fitting logistic regression...")
tic = perf_counter()
clf = LogisticRegression(max_iter=300).fit(X, y)
toc = perf_counter()
print(f"Fitting took {toc - tic:.3f} seconds ({clf.n_iter_} iterations)")

You can use py-spy or viztracer to profile it to interactively explore the reports.


The intercept term is at the end of the coef array:
if loss.is_multiclass:
coef[n_features::n_dof] = coef[(n_dof-1)::n_dof]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
coef[n_features::n_dof] = coef[(n_dof-1)::n_dof]
intercept = coef[n_features::n_dof] = coef[(n_dof-1)::n_dof]
intercept.shape = (n_classes,)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This notation implies that all intercept values are contiguously packed at the end of the flat coef array. However in the code we have

            w = coef.reshape(self._loss.n_classes, -1)
            if self.fit_intercept:
                intercept = w[:, -1]
                w = w[:, :-1]

which means that if coef is contiguuous, neither w is not contiguous nor intercept are. Furthermore the intercept values are not packaged at the end of coef anymore

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to evaluate the performance impact of using non contiguous coef on a MNIST shaped problem and at least on my machine it does not have a significant impact:

In [1]: import numpy as np

In [2]: n_samples, n_features, n_classes = 60000, 784, 10

In [3]: X = np.random.randn(n_samples, n_features)

In [5]: flat_coef = np.random.normal(size=(n_classes * (n_features + 1)))

In [6]: coef = flat_coef.reshape(n_classes, -1)[:, :-1]

In [7]: intercept = flat_coef.reshape(n_classes, -1)[:, -1]

In [8]: coef.flags
Out[8]: 
  C_CONTIGUOUS : False
  F_CONTIGUOUS : False
  OWNDATA : False
  WRITEABLE : True
  ALIGNED : True
  WRITEBACKIFCOPY : False
  UPDATEIFCOPY : False

In [9]: intercept.flags
Out[9]: 
  C_CONTIGUOUS : False
  F_CONTIGUOUS : False
  OWNDATA : False
  WRITEABLE : True
  ALIGNED : True
  WRITEBACKIFCOPY : False
  UPDATEIFCOPY : False

In [10]: coef_contig = np.ascontiguousarray(coef)

In [11]: intercept_contig = np.ascontiguousarray(intercept)

In [12]: %timeit X @ coef.T + intercept
54.8 ms ± 2.1 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

In [13]: %timeit X @ coef_contig.T + intercept_contig
55.5 ms ± 2.8 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

Still, I would feel more comfortable we'd shaped the problem to use contiguous arrays everywhere.

Copy link
Member Author

@lorentzenchr lorentzenchr Dec 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right! D*m*d.

The scipy solvers only understand 1-d arrays x to solve for, i.e. minimize(func, x, ...). The x in logistic regression is initialized as

w0 = np.zeros(
            (classes.size, n_features + int(fit_intercept)), order="F", dtype=X.dtype
        )

Then this is ravelled for lbfgs and newton-cg:

w0 = w0.ravel()

w0 is now (C- and F-) contiguous. Then in _w_intercept_raw

w = coef.reshape(self._loss.n_classes, -1)

This is surprisingly C-contiguous.

Here is a full example on its own:

import numpy as np

n_classes, n_features = 3,5
w0 = np.zeros((n_classes, n_features), order="F")
w1 = w0.ravel()
w1.flags
#  C_CONTIGUOUS : True
#  F_CONTIGUOUS : True

w2 = w1.reshape(n_classes, -1)
w2.flags
#  C_CONTIGUOUS : True
#  F_CONTIGUOUS : False

w3 = w2[:, :-1]
w3.flags
# C_CONTIGUOUS : False
# F_CONTIGUOUS : False

Possible solution
Should we use np.ravel(..., order="F") and np.reshape(..., order='F')?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

w0 = np.zeros(
(classes.size, n_features + int(fit_intercept)), order="F", dtype=X.dtype
)

Oh I missed the order="F" here! So my analysis was flawed, I think.

Maybe we should rename w0 to packed_coef when it is rectangular with the fortran layout as is the case above and then use the view flat_coef = np.ravel(packed_coef) for solvers that need to 1d view representation.

Then we can have a centralized documentation somewhere to explain the memory layout and contiguity assumptions for packed_coef and flat_coef and how to extract the contiguous views for the coef and intercept sub-components throughout this code.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

39ffe9d is my attempt as solving this. It is a larger change:

  • coef.shape is either (n_dof) or (n_classes, n_dof) or ravelled (n_classes * n_dof,)
    In the multiclass case, this enables to handle both, ravelled and unravelled coefeficients at the same time.
  • In all steps involving ravel or reshape, the order is set explicitly, mostly "F".
  • The benchmark script above gives 1.821 seconds in this PR vs 2.070 seconds on master.

Comment on lines 1210 to 1211
if solver in ["lbfgs", "newton-cg"] and len(classes_) == 1:
n_threads = _openmp_effective_n_threads()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We decided to restrict n_jobs to control joblib-level parallelism in other estimators (KMeans, HGBDT). So for consistency I would stick to that policy for now. To control the number of OpenMP (and possibly also OpenBLAS threads) users can rely on threadpoolctl which is a dependency of scikit-learn.

However, I suspect that the iterated calls to OpenMP cython code and OpenBLAS code can lead to performance slow downs when OpenBLAS using its own native threading layer instead of the same OpenMP runtime as scikit-learn. This can be confirmed by comparing the performance of this PR in 2 environments:

Comment on lines 1210 to 1211
if solver in ["lbfgs", "newton-cg"] and len(classes_) == 1:
n_threads = _openmp_effective_n_threads()
Copy link
Member

@ogrisel ogrisel Dec 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Profiling on a MNIST shaped problem shows that this does not seem a problem in practice. Maybe this is because the 2 OpenBLAS calls dominate the Cython calls that are comparatively negligible.

I only tried on macOS though. Maybe the OpenBLAS/OpenMP story is different on Linux.

@lorentzenchr
Copy link
Member Author

@rth @TomDLT might be interested in this PR.

Copy link
Member

@agramfort agramfort left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

otherwise LGTM

Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
@lorentzenchr
Copy link
Member Author

@agramfort Can you read minds from far distances? After about 1.5 years in the making of losses, this week, I got a bit frustrated - for no particular reason. Now, there are 2 approvals, a first linear model to profit from it and a path for further 2nd order solvers 🥳

@lorentzenchr lorentzenchr changed the title [MRG] ENH Loss module LogisticRegression ENH Loss module LogisticRegression Feb 9, 2022
@agramfort
Copy link
Member

yes-im-so-excited

@agramfort agramfort merged commit d8d5637 into scikit-learn:main Feb 14, 2022
@agramfort
Copy link
Member

Thx @lorentzenchr !

@lorentzenchr lorentzenchr deleted the loss_module_logistic branch February 19, 2022 20:33
thomasjpfan pushed a commit to thomasjpfan/scikit-learn that referenced this pull request Mar 1, 2022
* ENH replace loss in linear logistic regression

* MNT remove logistic regression's own loss functions

* CLN remove comment

* DOC add whatsnew

* DOC more precise whatsnew

* CLN restore improvements of scikit-learn#19571

* ENH improve fit time by separating mat-vec in multiclass

* DOC update whatsnew

* not only a bit ;-)

* DOC note memory benefit for multiclass case

* trigger CI

* trigger CI

* CLN rename variable to hess_prod

* DOC address reviewer comments

* CLN remove C/F for 1d arrays

* CLN rename to gradient_per_sample

* CLN rename alpha to l2_reg_strength

* ENH respect F-contiguity

* TST fix sag tests

* CLN rename to LinearModelLoss

* CLN improve comments according to review

* CLN liblinear comment

* TST add / move test to test_linear_loss.py

* CLN comment placement

* trigger CI

* CLN add comment about contiguity of raw_prediction

* DEBUG debian-32

* DEBUG test only linear_model module

* Revert "DEBUG test only linear_model module"

This reverts commit 9d6e698.

* DEBUG test -k LogisticRegression

* Revert "DEBUG test -k LogisticRegression"

This reverts commit c203167.

* Revert "DEBUG debian-32"

This reverts commit ef0b98f.

* DEBUG set n_jobs=1

* Revert "DEBUG set n_jobs=1"

This reverts commit c7f6f72.

* CLN always use n_threads=1

* CLN address review

* ENH avoid array copy

* CLN simplify L2 norm

* CLN rename w to weights

* CLN rename to hessian_sum and hx_sum

* CLN address review

* CLN rename to init arg and attribute to base_loss

* CLN apply review suggestion

Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>

* CLN base_loss instead of _loss attribute

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants