Skip to content

FEA Add Gamma deviance as loss function to HGBT #22409

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Jan 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
7045b71
FEA add gamma loss to HGBT
lorentzenchr Feb 7, 2022
55ad3cd
DOC add whatsnew
lorentzenchr Feb 7, 2022
82ad819
CLN address review comments
lorentzenchr Apr 4, 2022
e67bbe4
Merge branch 'main' into hgbt_gamma
lorentzenchr Apr 4, 2022
df76d92
Merge branch 'main' into hgbt_gamma
lorentzenchr Apr 14, 2022
d8e5037
TST make test_gamma pass by not testing out-of-sample
lorentzenchr Apr 20, 2022
c8f9bfe
TST compare gamma and poisson to LightGBM
lorentzenchr Apr 20, 2022
74caaf7
Merge branch 'main' into hgbt_gamma
lorentzenchr Oct 7, 2022
bb234ee
TST fix test_gamma by comparing to MSE HGBT instead of Poisson HGBT
lorentzenchr Oct 7, 2022
0cc8716
TST fix for test_same_predictions_regression for poisson
lorentzenchr Oct 7, 2022
930572d
Merge branch 'main' into hgbt_gamma
lorentzenchr Dec 15, 2022
5f043a1
CLN address review comments
lorentzenchr Dec 28, 2022
f31d541
Merge branch 'main' into hgbt_gamma
lorentzenchr Dec 28, 2022
7f783ee
Merge branch 'main' into hgbt_gamma
lorentzenchr Dec 28, 2022
e8a1a42
CLN nits
lorentzenchr Dec 28, 2022
aa360c0
CLN better comments
lorentzenchr Jan 11, 2023
1f4a243
Merge branch 'main' into hgbt_gamma
lorentzenchr Jan 11, 2023
0776dec
Merge branch 'main' into hgbt_gamma
jjerphan Jan 12, 2023
3321e3f
TST use pytest.param with skip mark
lorentzenchr Jan 12, 2023
0ad5afa
Merge branch 'hgbt_gamma' of https://github.com/lorentzenchr/scikit-l…
lorentzenchr Jan 12, 2023
0b33f82
Merge remote-tracking branch 'upstream/main' into hgbt_gamma
jjerphan Jan 13, 2023
fcff47b
TST Correct conditional test parametrization mark
jjerphan Jan 13, 2023
74964d0
CI Trigger CI
jjerphan Jan 13, 2023
038bfde
Merge branch 'main' into hgbt_gamma
jjerphan Jan 16, 2023
7b4abb6
DOC add comment for lax comparison with LightGBM
lorentzenchr Jan 30, 2023
e6041f4
CLN tuple needs trailing comma
lorentzenchr Jan 30, 2023
015824c
Merge branch 'hgbt_gamma' of https://github.com/lorentzenchr/scikit-l…
lorentzenchr Jan 30, 2023
6902f6f
Merge branch 'main' into hgbt_gamma
lorentzenchr Jan 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@ Changelog
:mod:`sklearn.ensemble`
.......................

- |Feature| :class:`ensemble.HistGradientBoostingRegressor` now supports
the Gamma deviance loss via `loss="gamma"`.
Using the Gamma deviance as loss function comes in handy for modelling skewed
distributed, strictly positive valued targets.
:pr:`22409` by :user:`Christian Lorentzen <lorentzenchr>`.

- |Feature| Compute a custom out-of-bag score by passing a callable to
:class:`ensemble.RandomForestClassifier`, :class:`ensemble.RandomForestRegressor`,
:class:`ensemble.ExtraTreesClassifier` and :class:`ensemble.ExtraTreesRegressor`.
Expand Down
32 changes: 25 additions & 7 deletions sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
_LOSSES,
BaseLoss,
HalfBinomialLoss,
HalfGammaLoss,
HalfMultinomialLoss,
HalfPoissonLoss,
PinballLoss,
Expand Down Expand Up @@ -43,6 +44,7 @@
_LOSSES.update(
{
"poisson": HalfPoissonLoss,
"gamma": HalfGammaLoss,
"quantile": PinballLoss,
"binary_crossentropy": HalfBinomialLoss,
"categorical_crossentropy": HalfMultinomialLoss,
Expand Down Expand Up @@ -1204,13 +1206,14 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting):

Parameters
----------
loss : {'squared_error', 'absolute_error', 'poisson', 'quantile'}, \
loss : {'squared_error', 'absolute_error', 'gamma', 'poisson', 'quantile'}, \
default='squared_error'
The loss function to use in the boosting process. Note that the
"squared error" and "poisson" losses actually implement
"half least squares loss" and "half poisson deviance" to simplify the
computation of the gradient. Furthermore, "poisson" loss internally
uses a log-link and requires ``y >= 0``.
"squared error", "gamma" and "poisson" losses actually implement
"half least squares loss", "half gamma deviance" and "half poisson
deviance" to simplify the computation of the gradient. Furthermore,
"gamma" and "poisson" losses internally use a log-link, "gamma"
requires ``y > 0`` and "poisson" requires ``y >= 0``.
"quantile" uses the pinball loss.

.. versionchanged:: 0.23
Expand All @@ -1219,6 +1222,9 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting):
.. versionchanged:: 1.1
Added option 'quantile'.

.. versionchanged:: 1.3
Added option 'gamma'.

quantile : float, default=None
If loss is "quantile", this parameter specifies which quantile to be estimated
and must be between 0 and 1.
Expand Down Expand Up @@ -1418,7 +1424,15 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting):
_parameter_constraints: dict = {
**BaseHistGradientBoosting._parameter_constraints,
"loss": [
StrOptions({"squared_error", "absolute_error", "poisson", "quantile"}),
StrOptions(
{
"squared_error",
"absolute_error",
"poisson",
"gamma",
"quantile",
}
),
BaseLoss,
],
"quantile": [Interval(Real, 0, 1, closed="both"), None],
Expand Down Expand Up @@ -1514,7 +1528,11 @@ def _encode_y(self, y):
# Just convert y to the expected dtype
self.n_trees_per_iteration_ = 1
y = y.astype(Y_DTYPE, copy=False)
if self.loss == "poisson":
if self.loss == "gamma":
# Ensure y > 0
if not np.all(y > 0):
raise ValueError("loss='gamma' requires strictly positive y.")
elif self.loss == "poisson":
# Ensure y >= 0 and sum(y) > 0
if not (np.all(y >= 0) and np.sum(y) > 0):
raise ValueError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,17 @@


@pytest.mark.parametrize("seed", range(5))
@pytest.mark.parametrize(
"loss",
[
"squared_error",
"poisson",
pytest.param(
"gamma",
marks=pytest.mark.skip("LightGBM with gamma loss has larger deviation."),
),
],
)
@pytest.mark.parametrize("min_samples_leaf", (1, 20))
@pytest.mark.parametrize(
"n_samples, max_leaf_nodes",
Expand All @@ -19,7 +30,9 @@
(1000, 8),
],
)
def test_same_predictions_regression(seed, min_samples_leaf, n_samples, max_leaf_nodes):
def test_same_predictions_regression(
seed, loss, min_samples_leaf, n_samples, max_leaf_nodes
):
# Make sure sklearn has the same predictions as lightgbm for easy targets.
#
# In particular when the size of the trees are bound and the number of
Expand All @@ -33,7 +46,7 @@ def test_same_predictions_regression(seed, min_samples_leaf, n_samples, max_leaf
# is not exactly the same. To avoid this issue we only compare the
# predictions on the test set when the number of samples is large enough
# and max_leaf_nodes is low enough.
# - To ignore discrepancies caused by small differences the binning
# - To ignore discrepancies caused by small differences in the binning
# strategy, data is pre-binned if n_samples > 255.
# - We don't check the absolute_error loss here. This is because
# LightGBM's computation of the median (used for the initial value of
Expand All @@ -52,6 +65,10 @@ def test_same_predictions_regression(seed, min_samples_leaf, n_samples, max_leaf
n_samples=n_samples, n_features=5, n_informative=5, random_state=0
)

if loss in ("gamma", "poisson"):
# make the target positive
y = np.abs(y) + np.mean(np.abs(y))

if n_samples > 255:
# bin data and convert it to float32 so that the estimator doesn't
# treat it as pre-binned
Expand All @@ -60,6 +77,7 @@ def test_same_predictions_regression(seed, min_samples_leaf, n_samples, max_leaf
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=rng)

est_sklearn = HistGradientBoostingRegressor(
loss=loss,
max_iter=max_iter,
max_bins=max_bins,
learning_rate=1,
Expand All @@ -68,6 +86,7 @@ def test_same_predictions_regression(seed, min_samples_leaf, n_samples, max_leaf
max_leaf_nodes=max_leaf_nodes,
)
est_lightgbm = get_equivalent_estimator(est_sklearn, lib="lightgbm")
est_lightgbm.set_params(min_sum_hessian_in_leaf=0)

est_lightgbm.fit(X_train, y_train)
est_sklearn.fit(X_train, y_train)
Expand All @@ -77,14 +96,24 @@ def test_same_predictions_regression(seed, min_samples_leaf, n_samples, max_leaf

pred_lightgbm = est_lightgbm.predict(X_train)
pred_sklearn = est_sklearn.predict(X_train)
# less than 1% of the predictions are different up to the 3rd decimal
assert np.mean(abs(pred_lightgbm - pred_sklearn) > 1e-3) < 0.011

if max_leaf_nodes < 10 and n_samples >= 1000:
if loss in ("gamma", "poisson"):
# More than 65% of the predictions must be close up to the 2nd decimal.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks a bit lax. Any idea why we can be stricter in the case of the squared error? I wouldn't require to make it stricter as a pre-requisite to merge this PR as the Poisson loss is already in main but lightgbm equivalence was previously just not tested.

But we could at least add a TODO comment to inform the reader that we are not 100% satisfied with the state of things :)

Suggested change
# More than 65% of the predictions must be close up to the 2nd decimal.
# More than 65% of the predictions must be close up to the 2nd decimal.
# TODO: investigate the cause of the remaining discrepancies with lightgbm.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this is because, in the context of the squared error, we only test for the equivalence with the larger dataset with shallow trees (max_leaf_nodes < 10 and n_samples >= 1000): this condition might allow the averaging effect in leafs to mitigate the impact of rounding errors and maybe reduce the likelihood of ties when deciding on which features and thresholds to split on?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, LightGBM's algo deviates more from ours for general losses (=not squared error). For instance in the Poisson case, LightGBM has a poisson_max_delta_step which I set to 1e-12, but I'm not 100% sure of the effective difference this causes.

I'll add a comment as suggested.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And BTW, our losses have a better test suite than the ones in LightGBM:smirk:

# TODO: We are not entirely satisfied with this lax comparison, but the root
# cause is not clear, maybe algorithmic differences. One such example is the
# poisson_max_delta_step parameter of LightGBM which does not exist in HGBT.
assert (
np.mean(np.isclose(pred_lightgbm, pred_sklearn, rtol=1e-2, atol=1e-2))
> 0.65
)
else:
# Less than 1% of the predictions may deviate more than 1e-3 in relative terms.
assert np.mean(np.isclose(pred_lightgbm, pred_sklearn, rtol=1e-3)) > 1 - 0.01

if max_leaf_nodes < 10 and n_samples >= 1000 and loss in ("squared_error",):
pred_lightgbm = est_lightgbm.predict(X_test)
pred_sklearn = est_sklearn.predict(X_test)
# less than 1% of the predictions are different up to the 4th decimal
assert np.mean(abs(pred_lightgbm - pred_sklearn) > 1e-4) < 0.01
# Less than 1% of the predictions may deviate more than 1e-4 in relative terms.
assert np.mean(np.isclose(pred_lightgbm, pred_sklearn, rtol=1e-4)) > 1 - 0.01


@pytest.mark.parametrize("seed", range(5))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from sklearn.base import clone, BaseEstimator, TransformerMixin
from sklearn.base import is_regressor
from sklearn.pipeline import make_pipeline
from sklearn.metrics import mean_poisson_deviance
from sklearn.metrics import mean_gamma_deviance, mean_poisson_deviance
from sklearn.dummy import DummyRegressor
from sklearn.exceptions import NotFittedError
from sklearn.compose import make_column_transformer
Expand Down Expand Up @@ -248,8 +248,64 @@ def test_absolute_error_sample_weight():
gbdt.fit(X, y, sample_weight=sample_weight)


@pytest.mark.parametrize("y", [([1.0, -2.0, 0.0]), ([0.0, 1.0, 2.0])])
def test_gamma_y_positive(y):
# Test that ValueError is raised if any y_i <= 0.
err_msg = r"loss='gamma' requires strictly positive y."
gbdt = HistGradientBoostingRegressor(loss="gamma", random_state=0)
with pytest.raises(ValueError, match=err_msg):
gbdt.fit(np.zeros(shape=(len(y), 1)), y)


def test_gamma():
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, this test fails because the HGBT with poisson loss out-performs all other models on the test set.
I observe the exact same behaviour with (often) similar number with LightGBM:

from lightgbm import LGBMRegressor

lgbm_gamma = LGBMRegressor(objective="gamma")
lgbm_pois = LGBMRegressor(objective="poisson")

I don't know how to construct a better test - for now.

Copy link
Member

@ogrisel ogrisel Feb 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this true only for the small sample sizes that we typically use in our tests our would be this still true for a very large n_samples?

Is it similar to the fact that the median can be a better estimate of the expected value than the sample mean for a finite sample of a long tailed distribution (robustness to outliers)? This was observed empirically when computing metrics on a left out validation set here: https://scikit-learn.org/stable/auto_examples/ensemble/plot_gradient_boosting_quantile.html

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I played with n_sample and n_features without success. This problem that Poisson HGBT always beats Gamma HGBT out-of-sample prevails. I think one reason is that the log-link is not canonical to Gamma deciance, but it is canonical to the Poisson deviance. Therefore, the Poisson HGBT is better calibrated, which here also translates to a better out-of-sample performance. Mabye, the features are not informative enough in this example.

I can try to generate a decision tree and use it to generate a Gamma distributed sample per tree node. That should be easier to fit for tree based models than this GLM based data generation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can try to generate a decision tree and use it to generate a Gamma distributed sample per tree node. That should be easier to fit for tree based models than this GLM based data generation.

That would be interesting to know indeed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to create gamma distributed data with a decision tree structure for the mean value, see details.
This gives the same result: poisson loss beats gamma loss when comparing out-of-sample gamma deviance.

Proposition for solutions:

  1. Remove this test and add (and rely on) a test comparing with LGBT.
  2. Do not compare with poisson loss, just with squared error (which depends a lot on the clipping constant!)

I'm in favor of the first point.

import numpy as np
from sklearn._loss import HalfGammaLoss
from sklearn.dummy import DummyRegressor
from sklearn.datasets import make_low_rank_matrix
from sklearn.ensemble import HistGradientBoostingRegressor
from sklearn.metrics import mean_gamma_deviance
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeRegressor

from lightgbm import LGBMRegressor


### Data Generation
rng = np.random.RandomState(42)
n_train, n_test, n_features = 500, 500, 20
X = make_low_rank_matrix(
    n_samples=n_train + n_test, n_features=n_features, random_state=rng, 
)
# First, we create a normal distributed data
coef = rng.uniform(low=-10, high=20, size=n_features)
y_normal = np.abs(rng.normal(loc=np.exp(X @ coef)))

# Second, we fit a decision tree
dtree = DecisionTreeRegressor().fit(X, y_normal)

# Finally, we generate gamma distributed data according to the tree
# mean(y) = dtree.predict(X)
# var(y) = dispersion * mean(y)**2
dispersion = 0.5
y = rng.gamma(shape=1 / dispersion, scale=dispersion * dtree.predict(X))

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=n_test, random_state=rng
)


### Model fitting
params = {"early_stopping": False, "random_state": 123}
p_lgbt = {"min_child_weight": 0, "random_state": 123}

gbdt_gamma = HistGradientBoostingRegressor(loss="gamma", **params)
gbdt_pois = HistGradientBoostingRegressor(loss="poisson", **params)
gbdt_ls = HistGradientBoostingRegressor(loss="squared_error", **params)
lgbm_gamma = LGBMRegressor(objective="gamma", **p_lgbt)
lgbm_pois = LGBMRegressor(objective="poisson", **p_lgbt)
for model in (gbdt_gamma, gbdt_pois, gbdt_ls, lgbm_gamma, lgbm_pois):
    model.fit(X_train, y_train)
dummy = DummyRegressor(strategy="mean").fit(X_train, y_train)

for m in (dummy, gbdt_gamma, gbdt_ls, gbdt_pois, lgbm_gamma, lgbm_pois):
    message = f"training gamma deviance {m.__class__.__name__}"
    if hasattr(m, "loss"):
        message += " " + m.loss
    elif hasattr(m, "objective"):
        message += " " + m.objective
    print(f"{message: <68}: {mean_gamma_deviance(y_train, np.clip(m.predict(X_train), 1e-12, None))}")

for m in (dummy, gbdt_gamma, gbdt_ls, gbdt_pois, lgbm_gamma, lgbm_pois):
    message = f"test gamma deviance {m.__class__.__name__}"
    if hasattr(m, "loss"):
        message += " " + m.loss
    elif hasattr(m, "objective"):
        message += " " + m.objective
    print(f"{message: <68}: {mean_gamma_deviance(y_test, np.clip(m.predict(X_test), 1e-12, None))}")

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I replaced Poisson with squared error as model to compare against, bb234ee. Now it works. Comments on comparison to Poisson are preserved.

# For a Gamma distributed target, we expect an HGBT trained with the Gamma deviance
# (loss) to give better results than an HGBT with any other loss function, measured
# in out-of-sample Gamma deviance as metric/score.
# Note that squared error could potentially predict negative values which is
# invalid (np.inf) for the Gamma deviance. A Poisson HGBT (having a log link)
# does not have that defect.
# Important note: It seems that a Poisson HGBT almost always has better
# out-of-sample performance than the Gamma HGBT, measured in Gamma deviance.
# LightGBM shows the same behaviour. Hence, we only compare to a squared error
# HGBT, but not to a Poisson deviance HGBT.
rng = np.random.RandomState(42)
n_train, n_test, n_features = 500, 100, 20
X = make_low_rank_matrix(
n_samples=n_train + n_test,
n_features=n_features,
random_state=rng,
)
# We create a log-linear Gamma model. This gives y.min ~ 1e-2, y.max ~ 1e2
coef = rng.uniform(low=-10, high=20, size=n_features)
# Numpy parametrizes gamma(shape=k, scale=theta) with mean = k * theta and
# variance = k * theta^2. We parametrize it instead with mean = exp(X @ coef)
# and variance = dispersion * mean^2 by setting k = 1 / dispersion,
# theta = dispersion * mean.
dispersion = 0.5
y = rng.gamma(shape=1 / dispersion, scale=dispersion * np.exp(X @ coef))
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=n_test, random_state=rng
)
gbdt_gamma = HistGradientBoostingRegressor(loss="gamma", random_state=123)
gbdt_mse = HistGradientBoostingRegressor(loss="squared_error", random_state=123)
dummy = DummyRegressor(strategy="mean")
for model in (gbdt_gamma, gbdt_mse, dummy):
model.fit(X_train, y_train)

for X, y in [(X_train, y_train), (X_test, y_test)]:
loss_gbdt_gamma = mean_gamma_deviance(y, gbdt_gamma.predict(X))
# We restrict the squared error HGBT to predict at least the minimum seen y at
# train time to make it strictly positive.
loss_gbdt_mse = mean_gamma_deviance(
y, np.maximum(np.min(y_train), gbdt_mse.predict(X))
)
loss_dummy = mean_gamma_deviance(y, dummy.predict(X))
assert loss_gbdt_gamma < loss_dummy
assert loss_gbdt_gamma < loss_gbdt_mse


@pytest.mark.parametrize("quantile", [0.2, 0.5, 0.8])
def test_asymmetric_error(quantile):
def test_quantile_asymmetric_error(quantile):
"""Test quantile regression for asymmetric distributed targets."""
n_samples = 10_000
rng = np.random.RandomState(42)
Expand Down
9 changes: 8 additions & 1 deletion sklearn/ensemble/_hist_gradient_boosting/utils.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def get_equivalent_estimator(estimator, lib='lightgbm', n_classes=None):
'squared_error': 'regression_l2',
'absolute_error': 'regression_l1',
'log_loss': 'binary' if n_classes == 2 else 'multiclass',
'gamma': 'gamma',
'poisson': 'poisson',
}

lightgbm_params = {
Expand All @@ -53,13 +55,14 @@ def get_equivalent_estimator(estimator, lib='lightgbm', n_classes=None):
'reg_lambda': sklearn_params['l2_regularization'],
'max_bin': sklearn_params['max_bins'],
'min_data_in_bin': 1,
'min_child_weight': 1e-3,
'min_child_weight': 1e-3, # alias for 'min_sum_hessian_in_leaf'
'min_sum_hessian_in_leaf': 1e-3,
'min_split_gain': 0,
'verbosity': 10 if sklearn_params['verbose'] else -10,
'boost_from_average': True,
'enable_bundle': False, # also makes feature order consistent
'subsample_for_bin': _BinMapper().subsample,
'poisson_max_delta_step': 1e-12,
}

if sklearn_params['loss'] == 'log_loss' and n_classes > 2:
Expand All @@ -76,6 +79,8 @@ def get_equivalent_estimator(estimator, lib='lightgbm', n_classes=None):
'squared_error': 'reg:linear',
'absolute_error': 'LEAST_ABSOLUTE_DEV_NOT_SUPPORTED',
'log_loss': 'reg:logistic' if n_classes == 2 else 'multi:softmax',
'gamma': 'reg:gamma',
'poisson': 'count:poisson',
}

xgboost_params = {
Expand All @@ -100,6 +105,8 @@ def get_equivalent_estimator(estimator, lib='lightgbm', n_classes=None):
# catboost does not support MAE when leaf_estimation_method is Newton
'absolute_error': 'LEAST_ASBOLUTE_DEV_NOT_SUPPORTED',
'log_loss': 'Logloss' if n_classes == 2 else 'MultiClass',
'gamma': None,
'poisson': 'Poisson',
}

catboost_params = {
Expand Down