Skip to content

Commit ff9344f

Browse files
lorentzenchrogriseljjerphan
authored
FEA add (single) Cholesky Newton solver to GLMs (#24637)
* FEA add NewtonSolver, CholeskyNewtonSolver and QRCholeskyNewtonSolver * ENH better singular hessian special solve * CLN fix some typos found by reviewer * TST assert ConvergenceWarning is raised * MNT add BaseCholeskyNewtonSolver * WIP colinear design in GLMs * FIX _solve_singular * FIX false unpacking in * TST add tests for unpenalized GLMs * TST fix solutions of glm_dataset * ENH add SVDFallbackSolver * CLN remove SVDFallbackSolver * ENH use gradient step for singular hessians * ENH print iteration number in warnings * TST improve test_linalg_warning_with_newton_solver * CLN LinAlgWarning fron scipy.linalg * ENH more robust hessian * ENH increase maxls for lbfgs to make it more robust * ENH add hessian_warning for too many negative hessian values * CLN some warning messages * ENH add lbfgs_step * ENH use lbfgs_step for hessian_warning * TST make them pass * TST tweek rtol for lbfgs * TST add rigoros test for GLMs * TST improve test_warm_start * ENH improve lbfgs options for better convergence * CLN fix test_warm_start * TST fix assert singular values in datasets * CLN address most review comments * ENH enable more vebosity levels for lbfgs * DOC add whatsnew * CLN remove xfail and clean a bit * CLN docstring about minimum norm * More informative repr for the glm_dataset fixture cases * Forgot to run black * CLN remove unnecessary filterwarnings * CLN address review comments * Trigger [all random seeds] on the following tests: test_glm_regression test_glm_regression_hstacked_X test_glm_regression_vstacked_X test_glm_regression_unpenalized test_glm_regression_unpenalized_hstacked_X test_glm_regression_unpenalized_vstacked_X test_warm_start * CLN add comment for lbfgs ftol=64 * machine precision * CLN XXX code comment * Trigger [all random seeds] on the following tests: test_glm_regression test_glm_regression_hstacked_X test_glm_regression_vstacked_X test_glm_regression_unpenalized test_glm_regression_unpenalized_hstacked_X test_glm_regression_unpenalized_vstacked_X test_warm_start * CLN link issue and remove code snippet in comment * Trigger [all random seeds] on the following tests: test_glm_regression test_glm_regression_hstacked_X test_glm_regression_vstacked_X test_glm_regression_unpenalized test_glm_regression_unpenalized_hstacked_X test_glm_regression_unpenalized_vstacked_X test_warm_start * CLN add catch_warnings * Trigger [all random seeds] on the following tests: test_glm_regression test_glm_regression_hstacked_X test_glm_regression_vstacked_X test_glm_regression_unpenalized test_glm_regression_unpenalized_hstacked_X test_glm_regression_unpenalized_vstacked_X test_warm_start * Trigger [all random seeds] on the following tests: test_glm_regression test_glm_regression_hstacked_X test_glm_regression_vstacked_X test_glm_regression_unpenalized test_glm_regression_unpenalized_hstacked_X test_glm_regression_unpenalized_vstacked_X test_warm_start * [all random seeds] test_glm_regression test_glm_regression_hstacked_X test_glm_regression_vstacked_X test_glm_regression_unpenalized test_glm_regression_unpenalized_hstacked_X test_glm_regression_unpenalized_vstacked_X test_warm_start * Trigger with -Werror [all random seeds] test_glm_regression test_glm_regression_hstacked_X test_glm_regression_vstacked_X test_glm_regression_unpenalized test_glm_regression_unpenalized_hstacked_X test_glm_regression_unpenalized_vstacked_X test_warm_start * ENH increase maxls to 50 * [all random seeds] test_glm_regression test_glm_regression_hstacked_X test_glm_regression_vstacked_X test_glm_regression_unpenalized test_glm_regression_unpenalized_hstacked_X test_glm_regression_unpenalized_vstacked_X test_warm_start * Revert "Trigger with -Werror [all random seeds]" This reverts commit 99f4cf9. * TST add catch_warnings to filterwarnings * TST adapt tests for newton solvers * CLN cleaner gradient step with gradient_times_newton * DOC add whatsnew * ENH always use lbfgs as fallback * TST adapt rtol * TST fix test_linalg_warning_with_newton_solver * CLN address some review comments * Improve tests related to convergence warning on collinear data * overfit -> fit * Typo in comment * Apply suggestions from code review * ENH fallback_lbfgs_solve - Do not use lbfgs steps, fall back complete to lbfgs * ENH adapt rtol * Improve test_linalg_warning_with_newton_solver * Better comments * Fixed Hessian casing and improved warning messages * [all random seeds] test_linalg_warning_with_newton_solver * Ignore ConvergenceWarnings for now if convergence is good * CLN remove counting of warnings * ENH fall back to lbfgs if line search did not converge * DOC better comment on performance bottleneck * Update GLM related examples to use the new solver * CLN address reviewer comments * EXA improve some wordings * CLN do not pop "solver in parameter constraints * CLN fix typos * DOC fix docstring * CLN remove solver newton-qr-cholesky * DOC update PR number in whatsnew * CLN address review comments * CLN remove unnecessary catch_warnings * CLN address some review comments * DOC more precise whatsnew * CLN use init_zero_coef * CLN use and test init_zero_coef * CLN address some review comments * CLN mark NewtonSolver as private by leading underscore * CLN exact comments for inner_solve * TST add test_newton_solver_verbosity * TST extend test_newton_solver_verbosity * TST logic in test_glm_regression_unpenalized * TST use count_nonzero * CLN remove super rare line search checks * MNT move Newton solver to new file _newton_solver.py Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>
1 parent 2335a8e commit ff9344f

File tree

8 files changed

+1321
-119
lines changed

8 files changed

+1321
-119
lines changed

doc/whats_new/v1.2.rst

+10
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,16 @@ Changelog
353353
:mod:`sklearn.linear_model`
354354
...........................
355355

356+
- |Enhancement| :class:`linear_model.GammaRegressor`,
357+
:class:`linear_model.PoissonRegressor` and :class:`linear_model.TweedieRegressor` got
358+
a new solver `solver="newton-cholesky"`. This is a 2nd order (Newton) optimisation
359+
routine that uses a Cholesky decomposition of the hessian matrix.
360+
When `n_samples >> n_features`, the `"newton-cholesky"` solver has been observed to
361+
converge both faster and to a higher precision solution than the `"lbfgs"` solver on
362+
problems with one-hot encoded categorical variables with some rare categorical
363+
levels.
364+
:pr:`24637` by :user:`Christian Lorentzen <lorentzenchr>`.
365+
356366
- |Enhancement| :class:`linear_model.GammaRegressor`,
357367
:class:`linear_model.PoissonRegressor` and :class:`linear_model.TweedieRegressor`
358368
can reach higher precision with the lbfgs solver, in particular when `tol` is set

examples/linear_model/plot_poisson_regression_non_normal_loss.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,11 @@
110110
linear_model_preprocessor = ColumnTransformer(
111111
[
112112
("passthrough_numeric", "passthrough", ["BonusMalus"]),
113-
("binned_numeric", KBinsDiscretizer(n_bins=10), ["VehAge", "DrivAge"]),
113+
(
114+
"binned_numeric",
115+
KBinsDiscretizer(n_bins=10, subsample=int(2e5), random_state=0),
116+
["VehAge", "DrivAge"],
117+
),
114118
("log_scaled_numeric", log_scale_transformer, ["Density"]),
115119
(
116120
"onehot_categorical",
@@ -247,7 +251,7 @@ def score_estimator(estimator, df_test):
247251
poisson_glm = Pipeline(
248252
[
249253
("preprocessor", linear_model_preprocessor),
250-
("regressor", PoissonRegressor(alpha=1e-12, max_iter=300)),
254+
("regressor", PoissonRegressor(alpha=1e-12, solver="newton-cholesky")),
251255
]
252256
)
253257
poisson_glm.fit(

examples/linear_model/plot_tweedie_regression_insurance_claims.py

+76-16
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,12 @@
5656
from sklearn.metrics import mean_squared_error
5757

5858

59-
def load_mtpl2(n_samples=100000):
59+
def load_mtpl2(n_samples=None):
6060
"""Fetch the French Motor Third-Party Liability Claims dataset.
6161
6262
Parameters
6363
----------
64-
n_samples: int, default=100000
64+
n_samples: int, default=None
6565
number of samples to select (for faster run time). Full dataset has
6666
678013 samples.
6767
"""
@@ -215,7 +215,7 @@ def score_estimator(
215215
from sklearn.compose import ColumnTransformer
216216

217217

218-
df = load_mtpl2(n_samples=60000)
218+
df = load_mtpl2()
219219

220220
# Note: filter out claims with zero amount, as the severity model
221221
# requires strictly positive target values.
@@ -233,7 +233,11 @@ def score_estimator(
233233

234234
column_trans = ColumnTransformer(
235235
[
236-
("binned_numeric", KBinsDiscretizer(n_bins=10), ["VehAge", "DrivAge"]),
236+
(
237+
"binned_numeric",
238+
KBinsDiscretizer(n_bins=10, subsample=int(2e5), random_state=0),
239+
["VehAge", "DrivAge"],
240+
),
237241
(
238242
"onehot_categorical",
239243
OneHotEncoder(),
@@ -276,10 +280,26 @@ def score_estimator(
276280

277281
df_train, df_test, X_train, X_test = train_test_split(df, X, random_state=0)
278282

283+
# %%
284+
#
285+
# Let us keep in mind that despite the seemingly large number of data points in
286+
# this dataset, the number of evaluation points where the claim amount is
287+
# non-zero is quite small:
288+
len(df_test)
289+
290+
# %%
291+
len(df_test[df_test["ClaimAmount"] > 0])
292+
293+
# %%
294+
#
295+
# As a consequence, we expect a significant variability in our
296+
# evaluation upon random resampling of the train test split.
297+
#
279298
# The parameters of the model are estimated by minimizing the Poisson deviance
280-
# on the training set via a quasi-Newton solver: l-BFGS. Some of the features
281-
# are collinear, we use a weak penalization to avoid numerical issues.
282-
glm_freq = PoissonRegressor(alpha=1e-3, max_iter=400)
299+
# on the training set via a Newton solver. Some of the features are collinear
300+
# (e.g. because we did not drop any categorical level in the `OneHotEncoder`),
301+
# we use a weak L2 penalization to avoid numerical issues.
302+
glm_freq = PoissonRegressor(alpha=1e-4, solver="newton-cholesky")
283303
glm_freq.fit(X_train, df_train["Frequency"], sample_weight=df_train["Exposure"])
284304

285305
scores = score_estimator(
@@ -295,6 +315,12 @@ def score_estimator(
295315
print(scores)
296316

297317
# %%
318+
#
319+
# Note that the score measured on the test set is surprisingly better than on
320+
# the training set. This might be specific to this random train-test split.
321+
# Proper cross-validation could help us to assess the sampling variability of
322+
# these results.
323+
#
298324
# We can visually compare observed and predicted values, aggregated by the
299325
# drivers age (``DrivAge``), vehicle age (``VehAge``) and the insurance
300326
# bonus/malus (``BonusMalus``).
@@ -374,7 +400,7 @@ def score_estimator(
374400
mask_train = df_train["ClaimAmount"] > 0
375401
mask_test = df_test["ClaimAmount"] > 0
376402

377-
glm_sev = GammaRegressor(alpha=10.0, max_iter=10000)
403+
glm_sev = GammaRegressor(alpha=10.0, solver="newton-cholesky")
378404

379405
glm_sev.fit(
380406
X_train[mask_train.values],
@@ -395,13 +421,44 @@ def score_estimator(
395421
print(scores)
396422

397423
# %%
398-
# Here, the scores for the test data call for caution as they are
399-
# significantly worse than for the training data indicating an overfit despite
400-
# the strong regularization.
401424
#
402-
# Note that the resulting model is the average claim amount per claim. As
403-
# such, it is conditional on having at least one claim, and cannot be used to
404-
# predict the average claim amount per policy in general.
425+
# Those values of the metrics are not necessarily easy to interpret. It can be
426+
# insightful to compare them with a model that does not use any input
427+
# features and always predicts a constant value, i.e. the average claim
428+
# amount, in the same setting:
429+
430+
from sklearn.dummy import DummyRegressor
431+
432+
dummy_sev = DummyRegressor(strategy="mean")
433+
dummy_sev.fit(
434+
X_train[mask_train.values],
435+
df_train.loc[mask_train, "AvgClaimAmount"],
436+
sample_weight=df_train.loc[mask_train, "ClaimNb"],
437+
)
438+
439+
scores = score_estimator(
440+
dummy_sev,
441+
X_train[mask_train.values],
442+
X_test[mask_test.values],
443+
df_train[mask_train],
444+
df_test[mask_test],
445+
target="AvgClaimAmount",
446+
weights="ClaimNb",
447+
)
448+
print("Evaluation of a mean predictor on target AvgClaimAmount")
449+
print(scores)
450+
451+
# %%
452+
#
453+
# We conclude that the claim amount is very challenging to predict. Still, the
454+
# :class:`~sklearn.linear.GammaRegressor` is able to leverage some information
455+
# from the input features to slighly improve upon the mean baseline in terms
456+
# of D².
457+
#
458+
# Note that the resulting model is the average claim amount per claim. As such,
459+
# it is conditional on having at least one claim, and cannot be used to predict
460+
# the average claim amount per policy. For this, it needs to be combined with
461+
# a claims frequency model.
405462

406463
print(
407464
"Mean AvgClaim Amount per policy: %.2f "
@@ -415,7 +472,10 @@ def score_estimator(
415472
"Predicted Mean AvgClaim Amount | NbClaim > 0: %.2f"
416473
% glm_sev.predict(X_train).mean()
417474
)
418-
475+
print(
476+
"Predicted Mean AvgClaim Amount (dummy) | NbClaim > 0: %.2f"
477+
% dummy_sev.predict(X_train).mean()
478+
)
419479

420480
# %%
421481
# We can visually compare observed and predicted values, aggregated for
@@ -481,7 +541,7 @@ def score_estimator(
481541
from sklearn.linear_model import TweedieRegressor
482542

483543

484-
glm_pure_premium = TweedieRegressor(power=1.9, alpha=0.1, max_iter=10000)
544+
glm_pure_premium = TweedieRegressor(power=1.9, alpha=0.1, solver="newton-cholesky")
485545
glm_pure_premium.fit(
486546
X_train, df_train["PurePremium"], sample_weight=df_train["Exposure"]
487547
)

0 commit comments

Comments
 (0)