Skip to content

ENH multiclass/multinomial newton cholesky for LogisticRegression #28840

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Oct 18, 2024

Conversation

lorentzenchr
Copy link
Member

@lorentzenchr lorentzenchr commented Apr 15, 2024

Reference Issues/PRs

In a way a follow-up of #24767.

What does this implement/fix? Explain your changes.

This extends the "newton-cholesky" solver of LogisticRegression and LogisticRegressionCV to full multinomial loss. In particular, the full hessian is calculated. This way, this solver does not need to resort to OvR for multiclass targets.

Any other comments?

There are 2 tricky parts:

  1. Some index battle as one usually divides the index of coefficients hierarchically into n_features and n_classes. But in the end, the hessian is a 2-dim matrix - and it is!
  2. The multinomial is over-parameterized for any unpenalized coefficient, so at least for the intercept. We therefore choose the last class intercept as reference and set its intercept value to zero.

Copy link

github-actions bot commented Apr 15, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: c690194. Link to the linter CI: here

@lorentzenchr lorentzenchr changed the title ENH multiclass newton cholesky for LogisticRegression ENH multiclass/multinomial newton cholesky for LogisticRegression Apr 15, 2024
@lorentzenchr
Copy link
Member Author

Benchmark

As of 1de85b7

X_train.shape = (10000, 75)
sparse.issparse(X_train)=False
n_classes=12
image

import warnings
from pathlib import Path
import numpy as np
from scipy import sparse
from sklearn._loss import HalfMultinomialLoss
from sklearn.compose import ColumnTransformer
from sklearn.datasets import fetch_openml
from sklearn.linear_model import PoissonRegressor, LogisticRegression
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import FunctionTransformer, OneHotEncoder
from sklearn.preprocessing import StandardScaler, KBinsDiscretizer
from sklearn.exceptions import ConvergenceWarning
from sklearn.linear_model._linear_loss import LinearModelLoss
from sklearn.metrics import log_loss
from sklearn.model_selection import train_test_split
from time import perf_counter
import pandas as pd



def prepare_data():
    df = fetch_openml(data_id=41214, as_frame=True, parser='auto').frame
    df["Frequency"] = df["ClaimNb"] / df["Exposure"]
    log_scale_transformer = make_pipeline(
        FunctionTransformer(np.log, validate=False), StandardScaler()
    )
    linear_model_preprocessor = ColumnTransformer(
        [
            ("passthrough_numeric", "passthrough", ["BonusMalus"]),
            (
                "binned_numeric",
                KBinsDiscretizer(n_bins=10, subsample=None),
                ["VehAge", "DrivAge"],
            ),
            ("log_scaled_numeric", log_scale_transformer, ["Density"]),
            (
                "onehot_categorical",
                OneHotEncoder(),
                ["VehBrand", "VehPower", "VehGas", "Region", "Area"],
            ),
        ],
        remainder="drop",
    )
    y = df["Frequency"]
    w = df["Exposure"]
    X = linear_model_preprocessor.fit_transform(df)
    return X, y, w


X, y_orig, w = prepare_data()

print("binning the target...")
binner = KBinsDiscretizer(
    n_bins=300, encode="ordinal", strategy="quantile", subsample=int(2e5), random_state=0
)
y = binner.fit_transform(y_orig.to_numpy().reshape(-1, 1)).ravel().astype(float)

X = X.toarray()
X_train, X_test, y_train, y_test, w_train, w_test = train_test_split(
    X, y, w, train_size=10_000, test_size=10_000, random_state=0
)
print(f"{X_train.shape = }")
print(f"{sparse.issparse(X_train)=}")
n_classes = len(np.unique(y_train))
print(f"{n_classes=}")
print("y_train.value_counts() :")
print(pd.Series(y_train).value_counts())


results = []
slow_solvers = set()
loss_sw = np.full_like(y_train, fill_value=(1. / y_train.shape[0]))
alpha = 1e-6  # A bit larger than in the LSMR benchmarks to avoid ConvergenceWarnings
for tol in np.logspace(-1, -10, 10):
    for solver in ["lbfgs", "newton-cg", "newton-cholesky"]:
        if solver in slow_solvers:
            # skip slow solvers to keep the benchmark runtime reasonable
            continue
        tic = perf_counter()
        # with warnings.catch_warnings():
        #     warnings.filterwarnings("ignore", category=ConvergenceWarning)
        clf = LogisticRegression(
            C=1/alpha,
            solver=solver,
            tol=tol,
            max_iter=10_000 if solver=="lbfgs" else 1000,
        ).fit(X_train, y_train)
        toc = perf_counter()
        train_time = toc - tic
        n_iter = clf.n_iter_[0]
        if train_time > 200 or n_iter >= clf.max_iter:
            # skip this solver from now on...
            slow_solvers.add(solver)
        # Look inside _GeneralizedLinearRegressor to check the parameters.
        # Or run once with verbose=1 and compare to the reported loss.
        train_loss = LinearModelLoss(
            base_loss=HalfMultinomialLoss(n_classes=n_classes), fit_intercept=clf.fit_intercept
        ).loss(
            coef=np.c_[clf.coef_, clf.intercept_],
            X=X_train,
            y=y_train,
            l2_reg_strength=alpha / X_train.shape[0],
            sample_weight=loss_sw,
        )
        result = {
            "solver": solver,
            "tol": tol,
            "train_loss": train_loss,
            "train_time": train_time,
            "train_score": clf.score(X_train, y_train),
            "test_score": clf.score(X_test, y_test),
            "n_iter": n_iter,
            "converged": n_iter < clf.max_iter,
        }
        print(result)
        results.append(result)


results = pd.DataFrame.from_records(results)
filepath = Path().resolve() / "bench_multinomial_logistic_regression_mtpl_dense_newton_cholesky.csv"
results.to_csv(filepath)


import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt


filepath = Path().resolve() / "bench_multinomial_logistic_regression_mtpl_dense_newton_cholesky.csv"

results = pd.read_csv(filepath)
results["suboptimality"] = results["train_loss"] - results["train_loss"].min() + 1e-16

fig, axes = plt.subplots(ncols=2, figsize=(8*2, 6))
for label, group in results.groupby("solver"):
    group.sort_values("tol").plot(
        x="n_iter", y="suboptimality", loglog=True, marker="o", label=label, ax=axes[0]
    )
axes[0].set_ylabel("suboptimality")
axes[0].set_title("Suboptimality by iterations")

for label, group in results.groupby("solver"):
    group.sort_values("tol").plot(
        x="train_time", y="suboptimality", loglog=True, marker="o", label=label, ax=axes[1]
    )
axes[1].set_ylabel("suboptimality")
axes[1].set_title("Suboptimality by time")
plt.show()

@lorentzenchr lorentzenchr added this to the 1.5 milestone Apr 18, 2024
Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @lorentzenchr, this is a very interesting PR. Here is a first pass of feedback.

nitpick: I think Hessian should always be capitalized in the docstrings and comments.

@lorentzenchr
Copy link
Member Author

nitpick: I think Hessian should always be capitalized in the docstrings and comments.

That's right, but not the standard in our code base. If you wish to correct that, I propose a separate PR.

@ogrisel
Copy link
Member

ogrisel commented Apr 25, 2024

That's right, but not the standard in our code base. If you wish to correct that, I propose a separate PR.

I think we can just make sure that we don't propagate this error in new docstrings / comments and use a follow-up PR to fix existing docstrings/comments that are not logically related to the scope of this PR.

# While a dedicated Cython routine could exploit the symmetry, it is very hard to
# beat BLAS GEMM, even thought the latter cannot exploit the symmetry, unless one
# pays the price of a taking square roots and implements
# sqrtWX = sqrt(W)[: None] * X
Copy link
Member

@ogrisel ogrisel Apr 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that exploiting symmetry is not the only reason why a dedicated sandwich product kernel would make sense.

This line above would trigger and read/write round trip between RAM and CPU of the size of X (when X is too large to fit in CPU cache which is typically the case of interest). When n_samples >> n_features, a dedicated fused sandwich product kernel would only have to:

  • read n_samples * (n_feature + 1) / 2 from RAM;
  • write n_features ** 2 to RAM.

while what you propose would:

  • read n_samples * (n_feature + 1) from RAM, # sqrtWX = sqrt(W)[: None] * X
  • write n_samples * n_feature to RAM, # sqrtWX = sqrt(W)[: None] * X
  • read n_samples * n_feature / 2 from RAM, # np.dot(sqrtWX.T, sqrtWX)
  • write n_features ** 2 to RAM. # np.dot(sqrtWX.T, sqrtWX)

Assuming that this kernel is memory bound, I would expect a ~3x speed-up from the fused kernel over the 2-step numpy code.

The problem is that writing an efficient blocked sandwich product kernel in Cython with OpenMP threading and hardware adapted SIMD vector instructions is far from trivial.

For CPU, https://github.com/Quantco/tabmat already presumably does that.

For GPU, something like https://github.com/openai/triton/ might be able to do it in a vendor agnostic way.

EDIT: some triton developers are working on a CPU backend. It's very preliminary at this point but my be interesting in the medium term as it would allow a single source code base for optimized CPU + GPU kernels written in a high level programming language: https://github.com/triton-lang/triton-cpu/

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can remove some of that comment. I wanted to stress 2 facts:

  • this is the cpu bottleneck
  • Replacing it by some self written BLAS like function is a ludicrous undertaking (even tabmat is only faster when there are categoricals!). GEMM might be the algo where most human time was spent writing and optimizing it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let’s not get off-topic too much.

Copy link
Member

@ogrisel ogrisel Jun 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replacing it by some self written BLAS like function is a ludicrous undertaking (even tabmat is only faster when there are categoricals!). GEMM might be the algo where most human time was spent writing and optimizing it.

tabmat is nearly 4x faster than GEMM on dense numeric inputs according to:

This matches my memory bandwidth analysis above. My suggestion is not to change the implementation as part of this PR or in the near future but rather to improve the comment and converge on a shared understanding of the remaining achievable performance improvements if we move away from GEMM to a dedicated sandwich product fused kernel (as the one implemented in tabmat).

@lorentzenchr
Copy link
Member Author

@agramfort @TomDLT @rth friendly ping in case you find time. IMO, this PR closes a gap in the linear model solvers and enables precision solutions with unprecedented speed (orders of magnitude) for multiclass problems if the hessian fits into memory.

@lorentzenchr lorentzenchr force-pushed the multiclass_newton_cholesky branch from 4c48a40 to a0428e2 Compare October 1, 2024 16:50
Copy link
Member

@agramfort agramfort left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not review in details but being a convex problem if the loss reaches the global minimum as evidenced by the plot and the tests the numerics must be correct. It's a new feature we support here so there is no risk of regression on the specific algorithm. I would just suggest someone carefully checks that this does not lead to any API change.

@lorentzenchr
Copy link
Member Author

I would just suggest someone carefully checks that this does not lead to any API change.

This PR makes the following public API change:
If multi_class = "auto" (effectively the default) and n_classes >= 3 and solver = newton-cholesky then the effectively used multi_class changes from "ovr" to "multiclass".
Please note that multi_class is deprecated and will be removed in 1.7.

@ogrisel
Copy link
Member

ogrisel commented Oct 18, 2024

While reviewing this PR and testing it on some data, I realized that the reported number of iterations was always 0 whenever l-BFGS-b would kick in.

I pushed a quick fix in 1fa8e14.

I made it such that any completed iteration from the newton-cholesky solver would be subtracted from max_iter before calling lbfgs and then report the sum of the two solvers in the end. In practice, this does not seem to change anything because when Hessian conditioning problem always happens during the first iteration in my experiments with low regularized, rank deficient problems that typically trigger the lbfgs fallback mechanism.

Now I realized that this bug was already present with the binary classification problem and I should have opened a separate PR with a proper changelog entry. Let me revert this commit and do that instead.

Note that I find the lbfgs fallback warning quite annoying whenever it is triggered when tuning the regularization level (e.g. using LogisticRegressionCV or RandomizedSearchCV). I have the feeling that this should be a regular verbose print instead, but we can tackle that in a separate PR.

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had another scan at the code. The tests look good, I trust them. I will open a follow-up PR for the bug I found in the LBFGS fallback mechanism.

@ogrisel ogrisel enabled auto-merge (squash) October 18, 2024 10:11
@lorentzenchr
Copy link
Member Author

@ogrisel Thanks for reviewing. I would prefer if you did not push commits on (my) PRs. Please first communicate with me.

@ogrisel ogrisel merged commit c08b433 into scikit-learn:main Oct 18, 2024
29 checks passed
@jjerphan
Copy link
Member

Thank you a lot for this contribution, Christian. 🙌

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Development

Successfully merging this pull request may close these issues.

6 participants