From d357fc8a69bad27fb8dd916416e91934f4ece9ce Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Sun, 3 Apr 2022 21:10:03 +0200 Subject: [PATCH 01/12] DEP auto, binary_crossentropy, categorical_crossentropy in HGBT --- .../gradient_boosting.py | 93 +++++++++++++------ .../tests/test_gradient_boosting.py | 48 ++++++---- 2 files changed, 96 insertions(+), 45 deletions(-) diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index c75bdb01ea789..44dd71625073e 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -37,13 +37,16 @@ _LOSSES = _LOSSES.copy() -# TODO: Remove least_squares and least_absolute_deviation in v1.2 +# TODO(1.2): Remove "least_squares" and "least_absolute_deviation" +# TODO(1.3): Remove "binary_crossentropy" and "categorical_crossentropy" _LOSSES.update( { "least_squares": HalfSquaredError, "least_absolute_deviation": AbsoluteError, "poisson": HalfPoissonLoss, "quantile": PinballLoss, + "binary_log_loss": HalfBinomialLoss, + "multiclass_log_loss": HalfMultinomialLoss, "binary_crossentropy": HalfBinomialLoss, "categorical_crossentropy": HalfMultinomialLoss, } @@ -1299,6 +1302,7 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting): 0.92... """ + # TODO(1.2): remove "least_absolute_deviation" _VALID_LOSSES = ( "squared_error", "least_squares", @@ -1455,13 +1459,28 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting): Parameters ---------- - loss : {'auto', 'binary_crossentropy', 'categorical_crossentropy'}, \ - default='auto' - The loss function to use in the boosting process. 'binary_crossentropy' - (also known as logistic loss) is used for binary classification and - generalizes to 'categorical_crossentropy' for multiclass - classification. 'auto' will automatically choose either loss depending - on the nature of the problem. + loss : {'log_loss', 'binary_log_loss', 'multiclass_log_loss'}, \ + default='log_loss' + The loss function to use in the boosting process. 'binary_log_loss' (also known + as logistic loss, binomial deviance or binary crossentropy) is used for + (probabilistic) binary classification and generalizes to 'multiclass_log_loss' + (aka multinomial deviance or categorical crossentropy) for multiclass + classification. 'log_loss' will automatically choose either loss depending on + the nature of the problem. + + .. deprecated:: 1.1 + The loss 'auto' was deprecated in v1.1 and will be removed + in version 1.3. Use `loss='log_loss'` which is equivalent. + + .. deprecated:: 1.1 + The loss 'binary_crossentropy' was deprecated in v1.1 and will be removed + in version 1.3. Use `loss='binary_log_loss'` which is equivalent. + + .. deprecated:: 1.1 + The loss 'categorical_crossentropy' was deprecated in v1.1 and will be + removed in version 1.3. Use `loss='multiclass_log_loss'` which is + equivalent. + learning_rate : float, default=0.1 The learning rate, also known as *shrinkage*. This is used as a multiplicative factor for the leaves values. Use ``1`` for no @@ -1617,11 +1636,19 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting): 1.0 """ - _VALID_LOSSES = ("binary_crossentropy", "categorical_crossentropy", "auto") + # TODO(1.3): Remove "binary_crossentropy", "categorical_crossentropy", "auto" + _VALID_LOSSES = ( + "log_loss", + "binary_log_loss", + "multiclass_log_loss", + "binary_crossentropy", + "categorical_crossentropy", + "auto", + ) def __init__( self, - loss="auto", + loss="log_loss", *, learning_rate=0.1, max_iter=100, @@ -1798,33 +1825,45 @@ def _encode_y(self, y): return encoded_y def _get_loss(self, sample_weight): - if self.loss == "auto": + # TODO(1.3): Remove "auto", "binary_crossentropy", "categorical_crossentropy" + loss_replace = { + "auto": "log_loss", + "binary_crossentropy": "binary_log_loss", + "categorical_crossentropy": "multiclass_log_loss", + } + if self.loss in loss_replace.keys(): + warnings.warn( + f"The loss '{self.loss}' was deprecated in v1.1 and will be removed in " + f"version 1.3. Use '{loss_replace[self.loss]}' which is equivalent.", + FutureWarning, + ) + + if self.loss in ("log_loss", "auto"): if self.n_trees_per_iteration_ == 1: - return _LOSSES["binary_crossentropy"](sample_weight=sample_weight) + return HalfBinomialLoss(sample_weight=sample_weight) else: - return _LOSSES["categorical_crossentropy"]( - sample_weight=sample_weight, - n_classes=self.n_trees_per_iteration_, + return HalfMultinomialLoss( + sample_weight=sample_weight, n_classes=self.n_trees_per_iteration_ ) - - if self.loss == "categorical_crossentropy": + elif self.loss in ("multiclass_log_loss", "categorical_crossentropy"): if self.n_trees_per_iteration_ == 1: raise ValueError( - "loss='categorical_crossentropy' is not suitable for " - "a binary classification problem. Please use " - "loss='auto' or loss='binary_crossentropy' instead." + f"loss='{self.loss}' is not suitable for a binary classification " + "problem. Please use loss='log_loss' instead." ) else: - return _LOSSES[self.loss]( + return HalfMultinomialLoss( sample_weight=sample_weight, n_classes=self.n_trees_per_iteration_ ) - else: + elif self.loss in ("binary_log_loss", "binary_crossentropy"): if self.n_trees_per_iteration_ > 1: raise ValueError( - "loss='binary_crossentropy' is not defined for multiclass" - " classification with n_classes=" - f"{self.n_trees_per_iteration_}, use loss=" - "'categorical_crossentropy' instead." + f"loss='{self.loss}' is not defined for multiclass " + f"classification with n_classes={self.n_trees_per_iteration_}, " + "use loss='multiclass_log_loss' instead." ) else: - return _LOSSES[self.loss](sample_weight=sample_weight) + return HalfBinomialLoss(sample_weight=sample_weight) + else: + # should be an instance of BaseLoss + return self.loss diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py index 642904138a2d9..be0e60aee1505 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py @@ -39,8 +39,8 @@ "absolute_error": AbsoluteError, "poisson": HalfPoissonLoss, "quantile": PinballLoss, - "binary_crossentropy": HalfBinomialLoss, - "categorical_crossentropy": HalfMultinomialLoss, + "binary_log_loss": HalfBinomialLoss, + "multiclass_los_loss": HalfMultinomialLoss, } @@ -93,11 +93,11 @@ def test_init_parameters_validation(GradientBoosting, X, y, params, err_msg): def test_invalid_classification_loss(): - binary_clf = HistGradientBoostingClassifier(loss="binary_crossentropy") + binary_clf = HistGradientBoostingClassifier(loss="binary_log_loss") err_msg = ( - "loss='binary_crossentropy' is not defined for multiclass " + "loss='binary_log_loss' is not defined for multiclass " "classification with n_classes=3, use " - "loss='categorical_crossentropy' instead" + "loss='multiclass_log_loss' instead" ) with pytest.raises(ValueError, match=err_msg): binary_clf.fit(np.zeros(shape=(3, 2)), np.arange(3)) @@ -430,7 +430,7 @@ def test_missing_values_resilience( make_classification(random_state=0, n_classes=2), make_classification(random_state=0, n_classes=3, n_informative=3), ], - ids=["binary_crossentropy", "categorical_crossentropy"], + ids=["binary_log_loss", "multiclass_log_loss"], ) def test_zero_division_hessians(data): # non regression test for issue #14018 @@ -622,13 +622,13 @@ def test_infinite_values_missing_values(): def test_crossentropy_binary_problem(): - # categorical_crossentropy should only be used if there are more than two + # multiclass_log_loss should only be used if there are more than two # classes present. PR #14869 X = [[1], [0]] y = [0, 1] - gbrt = HistGradientBoostingClassifier(loss="categorical_crossentropy") + gbrt = HistGradientBoostingClassifier(loss="multiclass_log_loss") with pytest.raises( - ValueError, match="loss='categorical_crossentropy' is not suitable for" + ValueError, match="loss='multiclass_log_loss' is not suitable for" ): gbrt.fit(X, y) @@ -665,7 +665,7 @@ def test_zero_sample_weights_classification(): y = [0, 0, 1, 0] # ignore the first 2 training samples by setting their weight to 0 sample_weight = [0, 0, 1, 1] - gb = HistGradientBoostingClassifier(loss="binary_crossentropy", min_samples_leaf=1) + gb = HistGradientBoostingClassifier(loss="binary_log_loss", min_samples_leaf=1) gb.fit(X, y, sample_weight=sample_weight) assert_array_equal(gb.predict([[1, 0]]), [1]) @@ -1125,22 +1125,34 @@ def test_uint8_predict(Est): est.predict(X) -# TODO: Remove in v1.2 @pytest.mark.parametrize( - "old_loss, new_loss", + "old_loss, new_loss, Estimator", [ - ("least_squares", "squared_error"), - ("least_absolute_deviation", "absolute_error"), + # TODO(1.2): Remove + ("least_squares", "squared_error", HistGradientBoostingRegressor), + ("least_absolute_deviation", "absolute_error", HistGradientBoostingRegressor), + # TODO(1.3): Remove + ("auto", "log_loss", HistGradientBoostingClassifier), + ("binary_crossentropy", "binary_log_loss", HistGradientBoostingClassifier), + ( + "categorical_crossentropy", + "multiclass_log_loss", + HistGradientBoostingClassifier, + ), ], ) -def test_loss_deprecated(old_loss, new_loss): - X, y = make_regression(n_samples=50, random_state=0) - est1 = HistGradientBoostingRegressor(loss=old_loss, random_state=0) +def test_loss_deprecated(old_loss, new_loss, Estimator): + X, y = X_classification[:10], y_classification[:10] + + if old_loss == "categorical_crossentropy": + y[0] = 3 # make sure it is multiclass + + est1 = Estimator(loss=old_loss, random_state=0) with pytest.warns(FutureWarning, match=f"The loss '{old_loss}' was deprecated"): est1.fit(X, y) - est2 = HistGradientBoostingRegressor(loss=new_loss, random_state=0) + est2 = Estimator(loss=new_loss, random_state=0) est2.fit(X, y) assert_allclose(est1.predict(X), est2.predict(X)) From 9cc49c1e66a2ce4c5ca82cb3cef88c5bda29fbf5 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Wed, 6 Apr 2022 21:30:37 +0200 Subject: [PATCH 02/12] CLN only "log_loss" --- .../gradient_boosting.py | 49 +++++++------------ .../tests/test_gradient_boosting.py | 42 ++++++---------- 2 files changed, 32 insertions(+), 59 deletions(-) diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index 44dd71625073e..201bdb24f6265 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -45,8 +45,6 @@ "least_absolute_deviation": AbsoluteError, "poisson": HalfPoissonLoss, "quantile": PinballLoss, - "binary_log_loss": HalfBinomialLoss, - "multiclass_log_loss": HalfMultinomialLoss, "binary_crossentropy": HalfBinomialLoss, "categorical_crossentropy": HalfMultinomialLoss, } @@ -1459,27 +1457,23 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting): Parameters ---------- - loss : {'log_loss', 'binary_log_loss', 'multiclass_log_loss'}, \ - default='log_loss' - The loss function to use in the boosting process. 'binary_log_loss' (also known - as logistic loss, binomial deviance or binary crossentropy) is used for - (probabilistic) binary classification and generalizes to 'multiclass_log_loss' - (aka multinomial deviance or categorical crossentropy) for multiclass - classification. 'log_loss' will automatically choose either loss depending on - the nature of the problem. + loss : {'log_loss'}, default='log_loss' + The loss function to use in the boosting process. - .. deprecated:: 1.1 - The loss 'auto' was deprecated in v1.1 and will be removed - in version 1.3. Use `loss='log_loss'` which is equivalent. + For binary classification problems, 'log_loss' is also known as logistic loss, + binomial deviance or binary crossentropy. Internally, the model fits one tree + per boosting iteration and uses the logistic sigmoid function (expit) as + inverse link function to compute the predicted positive class probability. - .. deprecated:: 1.1 - The loss 'binary_crossentropy' was deprecated in v1.1 and will be removed - in version 1.3. Use `loss='binary_log_loss'` which is equivalent. + For multiclass classification problems, 'log_loss' is also known as multinomial + deviance or categorical crossentropy. Internally, the model fits one tree per + boosting iteration and per class and uses the softmax function as inverse link + function to compute the predicted probabilities of the classes. .. deprecated:: 1.1 - The loss 'categorical_crossentropy' was deprecated in v1.1 and will be - removed in version 1.3. Use `loss='multiclass_log_loss'` which is - equivalent. + The loss arguments 'auto', 'binary_crossentropy' and + 'categorical_crossentropy' were deprecated in v1.1 and will be removed in + version 1.3. Use `loss='log_loss'` which is equivalent. learning_rate : float, default=0.1 The learning rate, also known as *shrinkage*. This is used as a @@ -1639,8 +1633,6 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting): # TODO(1.3): Remove "binary_crossentropy", "categorical_crossentropy", "auto" _VALID_LOSSES = ( "log_loss", - "binary_log_loss", - "multiclass_log_loss", "binary_crossentropy", "categorical_crossentropy", "auto", @@ -1826,15 +1818,10 @@ def _encode_y(self, y): def _get_loss(self, sample_weight): # TODO(1.3): Remove "auto", "binary_crossentropy", "categorical_crossentropy" - loss_replace = { - "auto": "log_loss", - "binary_crossentropy": "binary_log_loss", - "categorical_crossentropy": "multiclass_log_loss", - } - if self.loss in loss_replace.keys(): + if self.loss in ("auto", "binary_crossentropy", "categorical_crossentropy"): warnings.warn( f"The loss '{self.loss}' was deprecated in v1.1 and will be removed in " - f"version 1.3. Use '{loss_replace[self.loss]}' which is equivalent.", + "version 1.3. Use 'log_loss' which is equivalent.", FutureWarning, ) @@ -1845,7 +1832,7 @@ def _get_loss(self, sample_weight): return HalfMultinomialLoss( sample_weight=sample_weight, n_classes=self.n_trees_per_iteration_ ) - elif self.loss in ("multiclass_log_loss", "categorical_crossentropy"): + elif self.loss in ("categorical_crossentropy"): if self.n_trees_per_iteration_ == 1: raise ValueError( f"loss='{self.loss}' is not suitable for a binary classification " @@ -1855,12 +1842,12 @@ def _get_loss(self, sample_weight): return HalfMultinomialLoss( sample_weight=sample_weight, n_classes=self.n_trees_per_iteration_ ) - elif self.loss in ("binary_log_loss", "binary_crossentropy"): + elif self.loss in ("binary_crossentropy"): if self.n_trees_per_iteration_ > 1: raise ValueError( f"loss='{self.loss}' is not defined for multiclass " f"classification with n_classes={self.n_trees_per_iteration_}, " - "use loss='multiclass_log_loss' instead." + "use loss='log_loss' instead." ) else: return HalfBinomialLoss(sample_weight=sample_weight) diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py index be0e60aee1505..4836707e57c4c 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py @@ -6,8 +6,6 @@ from sklearn._loss.loss import ( AbsoluteError, HalfBinomialLoss, - HalfMultinomialLoss, - HalfPoissonLoss, HalfSquaredError, PinballLoss, ) @@ -34,16 +32,6 @@ n_threads = _openmp_effective_n_threads() -_LOSSES = { - "squared_error": HalfSquaredError, - "absolute_error": AbsoluteError, - "poisson": HalfPoissonLoss, - "quantile": PinballLoss, - "binary_log_loss": HalfBinomialLoss, - "multiclass_los_loss": HalfMultinomialLoss, -} - - X_classification, y_classification = make_classification(random_state=0) X_regression, y_regression = make_regression(random_state=0) X_multi_classification, y_multi_classification = make_classification( @@ -92,12 +80,13 @@ def test_init_parameters_validation(GradientBoosting, X, y, params, err_msg): GradientBoosting(**params).fit(X, y) +# TODO(1.3): remove def test_invalid_classification_loss(): - binary_clf = HistGradientBoostingClassifier(loss="binary_log_loss") + binary_clf = HistGradientBoostingClassifier(loss="binary_crossentropy") err_msg = ( - "loss='binary_log_loss' is not defined for multiclass " + "loss='binary_crossentropy' is not defined for multiclass " "classification with n_classes=3, use " - "loss='multiclass_log_loss' instead" + "loss='log_loss' instead" ) with pytest.raises(ValueError, match=err_msg): binary_clf.fit(np.zeros(shape=(3, 2)), np.arange(3)) @@ -621,14 +610,15 @@ def test_infinite_values_missing_values(): assert stump_clf.fit(X, y_isnan).score(X, y_isnan) == 1 +# TODO(1.3): remove def test_crossentropy_binary_problem(): - # multiclass_log_loss should only be used if there are more than two + # categorical_crossentropy should only be used if there are more than two # classes present. PR #14869 X = [[1], [0]] y = [0, 1] - gbrt = HistGradientBoostingClassifier(loss="multiclass_log_loss") + gbrt = HistGradientBoostingClassifier(loss="categorical_crossentropy") with pytest.raises( - ValueError, match="loss='multiclass_log_loss' is not suitable for" + ValueError, match="loss='categorical_crossentropy' is not suitable for" ): gbrt.fit(X, y) @@ -665,7 +655,7 @@ def test_zero_sample_weights_classification(): y = [0, 0, 1, 0] # ignore the first 2 training samples by setting their weight to 0 sample_weight = [0, 0, 1, 1] - gb = HistGradientBoostingClassifier(loss="binary_log_loss", min_samples_leaf=1) + gb = HistGradientBoostingClassifier(loss="log_loss", min_samples_leaf=1) gb.fit(X, y, sample_weight=sample_weight) assert_array_equal(gb.predict([[1, 0]]), [1]) @@ -737,8 +727,8 @@ def test_sample_weight_effect(problem, duplication): assert np.allclose(est_sw._raw_predict(X_dup), est_dup._raw_predict(X_dup)) -@pytest.mark.parametrize("loss_name", ("squared_error", "absolute_error")) -def test_sum_hessians_are_sample_weight(loss_name): +@pytest.mark.parametrize("Loss", (HalfSquaredError, AbsoluteError)) +def test_sum_hessians_are_sample_weight(Loss): # For losses with constant hessians, the sum_hessians field of the # histograms must be equal to the sum of the sample weight of samples at # the corresponding bin. @@ -753,7 +743,7 @@ def test_sum_hessians_are_sample_weight(loss_name): # While sample weights are supposed to be positive, this still works. sample_weight = rng.normal(size=n_samples) - loss = _LOSSES[loss_name](sample_weight=sample_weight) + loss = Loss(sample_weight=sample_weight) gradients, hessians = loss.init_gradient_and_hessian( n_samples=n_samples, dtype=G_H_DTYPE ) @@ -1133,12 +1123,8 @@ def test_uint8_predict(Est): ("least_absolute_deviation", "absolute_error", HistGradientBoostingRegressor), # TODO(1.3): Remove ("auto", "log_loss", HistGradientBoostingClassifier), - ("binary_crossentropy", "binary_log_loss", HistGradientBoostingClassifier), - ( - "categorical_crossentropy", - "multiclass_log_loss", - HistGradientBoostingClassifier, - ), + ("binary_crossentropy", "log_loss", HistGradientBoostingClassifier), + ("categorical_crossentropy", "log_loss", HistGradientBoostingClassifier), ], ) def test_loss_deprecated(old_loss, new_loss, Estimator): From 658217d4e2decf6c115e3786d4e7587d39408ac9 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Wed, 6 Apr 2022 21:32:56 +0200 Subject: [PATCH 03/12] DOC add whatsnew --- doc/whats_new/v1.1.rst | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index b32d9891e633c..c24f9ad99403e 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -93,11 +93,16 @@ Changelog produce the same models, but are deprecated and will be removed in version 1.3. - - For :class:`ensemble.GradientBoostingClassifier`, The `loss` parameter name + - For :class:`ensemble.GradientBoostingClassifier`, the `loss` parameter name "deviance" is deprecated in favor of the new name "log_loss", which is now the default. :pr:`23036` by :user:`Christian Lorentzen `. + - For :class:`ensemble.HistGradientBoostingClassifier`, the `loss` parameter names + "auto", "binary_crossentropy" and "categorical_crossentropy" are deprecated in + favor of the new name "log_loss", which is now the default. + :pr:`23040` by :user:`Christian Lorentzen `. + - |Efficiency| Low-level routines for reductions on pairwise distances for dense float64 datasets have been refactored. The following functions and estimators now benefit from improved performances in terms of hardware From 96628942c99ac1289704eba9336670ed6c0583a2 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Wed, 6 Apr 2022 21:34:58 +0200 Subject: [PATCH 04/12] TST fix test_loss_deprecated --- .../tests/test_gradient_boosting.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py index 4836707e57c4c..7831c3c37c2e2 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py @@ -1128,10 +1128,11 @@ def test_uint8_predict(Est): ], ) def test_loss_deprecated(old_loss, new_loss, Estimator): - X, y = X_classification[:10], y_classification[:10] - if old_loss == "categorical_crossentropy": - y[0] = 3 # make sure it is multiclass + X, y = X_multi_classification[:10], y_multi_classification[:10] + assert len(np.unique(y)) > 2 + else: + X, y = X_classification[:10], y_classification[:10] est1 = Estimator(loss=old_loss, random_state=0) From adb59af444111b3447c096205c8558bb8797c728 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Wed, 6 Apr 2022 23:36:34 +0200 Subject: [PATCH 05/12] MNT replace *_crossentropy by log_loss --- benchmarks/bench_hist_gradient_boosting.py | 13 +++++-------- benchmarks/bench_hist_gradient_boosting_adult.py | 7 +++++-- ...bench_hist_gradient_boosting_categorical_only.py | 4 ++-- .../bench_hist_gradient_boosting_higgsboson.py | 9 +++++---- .../bench_hist_gradient_boosting_threading.py | 10 +++++----- doc/modules/ensemble.rst | 9 +++++---- .../tests/test_compare_lightgbm.py | 6 ++++-- sklearn/ensemble/_hist_gradient_boosting/utils.pyx | 11 ++++------- 8 files changed, 35 insertions(+), 34 deletions(-) diff --git a/benchmarks/bench_hist_gradient_boosting.py b/benchmarks/bench_hist_gradient_boosting.py index 58477e8894fd1..163e21f98ed0d 100644 --- a/benchmarks/bench_hist_gradient_boosting.py +++ b/benchmarks/bench_hist_gradient_boosting.py @@ -115,12 +115,7 @@ def one_run(n_samples): loss = args.loss if args.problem == "classification": if loss == "default": - # loss='auto' does not work with get_equivalent_estimator() - loss = ( - "binary_crossentropy" - if args.n_classes == 2 - else "categorical_crossentropy" - ) + loss = "log_loss" else: # regression if loss == "default": @@ -159,7 +154,7 @@ def one_run(n_samples): xgb_score_duration = None if args.xgboost: print("Fitting an XGBoost model...") - xgb_est = get_equivalent_estimator(est, lib="xgboost") + xgb_est = get_equivalent_estimator(est, lib="xgboost", n_classes=args.n_classes) tic = time() xgb_est.fit(X_train, y_train, sample_weight=sample_weight_train) @@ -176,7 +171,9 @@ def one_run(n_samples): cat_score_duration = None if args.catboost: print("Fitting a CatBoost model...") - cat_est = get_equivalent_estimator(est, lib="catboost") + cat_est = get_equivalent_estimator( + est, lib="catboost", n_classes=args.n_classes + ) tic = time() cat_est.fit(X_train, y_train, sample_weight=sample_weight_train) diff --git a/benchmarks/bench_hist_gradient_boosting_adult.py b/benchmarks/bench_hist_gradient_boosting_adult.py index 6b85b3819fb0f..0e0ca911e6ed7 100644 --- a/benchmarks/bench_hist_gradient_boosting_adult.py +++ b/benchmarks/bench_hist_gradient_boosting_adult.py @@ -1,6 +1,8 @@ import argparse from time import time +import numpy as np + from sklearn.model_selection import train_test_split from sklearn.datasets import fetch_openml from sklearn.metrics import accuracy_score, roc_auc_score @@ -48,6 +50,7 @@ def predict(est, data_test, target_test): data = fetch_openml(data_id=179, as_frame=False) # adult dataset X, y = data.data, data.target +n_classes = len(np.unique(y)) n_features = X.shape[1] n_categorical_features = len(data.categories) n_numerical_features = n_features - n_categorical_features @@ -61,7 +64,7 @@ def predict(est, data_test, target_test): # already clean is_categorical = [name in data.categories for name in data.feature_names] est = HistGradientBoostingClassifier( - loss="binary_crossentropy", + loss="log_loss", learning_rate=lr, max_iter=n_trees, max_bins=max_bins, @@ -76,7 +79,7 @@ def predict(est, data_test, target_test): predict(est, X_test, y_test) if args.lightgbm: - est = get_equivalent_estimator(est, lib="lightgbm") + est = get_equivalent_estimator(est, lib="lightgbm", n_classes=n_classes) est.set_params(max_cat_to_onehot=1) # dont use OHE categorical_features = [ f_idx for (f_idx, is_cat) in enumerate(is_categorical) if is_cat diff --git a/benchmarks/bench_hist_gradient_boosting_categorical_only.py b/benchmarks/bench_hist_gradient_boosting_categorical_only.py index 5e6c63067f7cd..e8d215170f9c8 100644 --- a/benchmarks/bench_hist_gradient_boosting_categorical_only.py +++ b/benchmarks/bench_hist_gradient_boosting_categorical_only.py @@ -58,7 +58,7 @@ def predict(est, data_test): is_categorical = [True] * n_features est = HistGradientBoostingClassifier( - loss="binary_crossentropy", + loss="log_loss", learning_rate=lr, max_iter=n_trees, max_bins=max_bins, @@ -73,7 +73,7 @@ def predict(est, data_test): predict(est, X) if args.lightgbm: - est = get_equivalent_estimator(est, lib="lightgbm") + est = get_equivalent_estimator(est, lib="lightgbm", n_classes=2) est.set_params(max_cat_to_onehot=1) # dont use OHE categorical_features = list(range(n_features)) fit(est, X, y, "lightgbm", categorical_feature=categorical_features) diff --git a/benchmarks/bench_hist_gradient_boosting_higgsboson.py b/benchmarks/bench_hist_gradient_boosting_higgsboson.py index 8455ef177860c..abe8018adfd83 100644 --- a/benchmarks/bench_hist_gradient_boosting_higgsboson.py +++ b/benchmarks/bench_hist_gradient_boosting_higgsboson.py @@ -80,6 +80,7 @@ def predict(est, data_test, target_test): data_train, data_test, target_train, target_test = train_test_split( data, target, test_size=0.2, random_state=0 ) +n_classes = len(np.unique(target)) if subsample is not None: data_train, target_train = data_train[:subsample], target_train[:subsample] @@ -88,7 +89,7 @@ def predict(est, data_test, target_test): print(f"Training set with {n_samples} records with {n_features} features.") est = HistGradientBoostingClassifier( - loss="binary_crossentropy", + loss="log_loss", learning_rate=lr, max_iter=n_trees, max_bins=max_bins, @@ -101,16 +102,16 @@ def predict(est, data_test, target_test): predict(est, data_test, target_test) if args.lightgbm: - est = get_equivalent_estimator(est, lib="lightgbm") + est = get_equivalent_estimator(est, lib="lightgbm", n_classes=n_classes) fit(est, data_train, target_train, "lightgbm") predict(est, data_test, target_test) if args.xgboost: - est = get_equivalent_estimator(est, lib="xgboost") + est = get_equivalent_estimator(est, lib="xgboost", n_classes=n_classes) fit(est, data_train, target_train, "xgboost") predict(est, data_test, target_test) if args.catboost: - est = get_equivalent_estimator(est, lib="catboost") + est = get_equivalent_estimator(est, lib="catboost", n_classes=n_classes) fit(est, data_train, target_train, "catboost") predict(est, data_test, target_test) diff --git a/benchmarks/bench_hist_gradient_boosting_threading.py b/benchmarks/bench_hist_gradient_boosting_threading.py index dbeb4d7f7d7fa..70787fd2eb479 100644 --- a/benchmarks/bench_hist_gradient_boosting_threading.py +++ b/benchmarks/bench_hist_gradient_boosting_threading.py @@ -118,9 +118,7 @@ def get_estimator_and_data(): if args.problem == "classification": if loss == "default": # loss='auto' does not work with get_equivalent_estimator() - loss = ( - "binary_crossentropy" if args.n_classes == 2 else "categorical_crossentropy" - ) + loss = "log_loss" else: # regression if loss == "default": @@ -191,7 +189,7 @@ def one_run(n_threads, n_samples): xgb_score_duration = None if args.xgboost: print("Fitting an XGBoost model...") - xgb_est = get_equivalent_estimator(est, lib="xgboost") + xgb_est = get_equivalent_estimator(est, lib="xgboost", n_classes=args.n_classes) xgb_est.set_params(nthread=n_threads) tic = time() @@ -209,7 +207,9 @@ def one_run(n_threads, n_samples): cat_score_duration = None if args.catboost: print("Fitting a CatBoost model...") - cat_est = get_equivalent_estimator(est, lib="catboost") + cat_est = get_equivalent_estimator( + est, lib="catboost", n_classes=args.n_classes + ) cat_est.set_params(thread_count=n_threads) tic = time() diff --git a/doc/modules/ensemble.rst b/doc/modules/ensemble.rst index 98041bef73e08..aba8acfc5534b 100644 --- a/doc/modules/ensemble.rst +++ b/doc/modules/ensemble.rst @@ -949,10 +949,11 @@ controls the number of iterations of the boosting process:: Available losses for regression are 'squared_error', 'absolute_error', which is less sensitive to outliers, and 'poisson', which is well suited to model counts and frequencies. For -classification, 'binary_crossentropy' is used for binary classification and -'categorical_crossentropy' is used for multiclass classification. By default -the loss is 'auto' and will select the appropriate loss depending on -:term:`y` passed to :term:`fit`. +classification, 'log_loss' is the only option. For binary classification it uses the +binary log loss, also kown as binomial deviance or binary cross-entropy. For +`n_classes >= 3`, it uses the multi-class log loss function, with multinomial deviance +and categorical cross-entropy as alternative names. The appropriate loss version is +selected based on :term:`y` passed to :term:`fit`. The size of the trees can be controlled through the ``max_leaf_nodes``, ``max_depth``, and ``min_samples_leaf`` parameters. diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_compare_lightgbm.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_compare_lightgbm.py index 95c0d5f53d640..82b18c5ce2b64 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_compare_lightgbm.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_compare_lightgbm.py @@ -132,7 +132,9 @@ def test_same_predictions_classification( min_samples_leaf=min_samples_leaf, max_leaf_nodes=max_leaf_nodes, ) - est_lightgbm = get_equivalent_estimator(est_sklearn, lib="lightgbm") + est_lightgbm = get_equivalent_estimator( + est_sklearn, lib="lightgbm", n_classes=n_classes + ) est_lightgbm.fit(X_train, y_train) est_sklearn.fit(X_train, y_train) @@ -198,7 +200,7 @@ def test_same_predictions_multiclass_classification( X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=rng) est_sklearn = HistGradientBoostingClassifier( - loss="categorical_crossentropy", + loss="log_loss", max_iter=max_iter, max_bins=max_bins, learning_rate=lr, diff --git a/sklearn/ensemble/_hist_gradient_boosting/utils.pyx b/sklearn/ensemble/_hist_gradient_boosting/utils.pyx index b2de7614fe499..d2123ecc61510 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/utils.pyx +++ b/sklearn/ensemble/_hist_gradient_boosting/utils.pyx @@ -40,8 +40,7 @@ def get_equivalent_estimator(estimator, lib='lightgbm', n_classes=None): lightgbm_loss_mapping = { 'squared_error': 'regression_l2', 'absolute_error': 'regression_l1', - 'binary_crossentropy': 'binary', - 'categorical_crossentropy': 'multiclass' + 'log_loss': 'binary' if n_classes == 2 else 'multiclass', } lightgbm_params = { @@ -63,7 +62,7 @@ def get_equivalent_estimator(estimator, lib='lightgbm', n_classes=None): 'subsample_for_bin': _BinMapper().subsample, } - if sklearn_params['loss'] == 'categorical_crossentropy': + if sklearn_params['loss'] == 'log_loss' and n_classes > 2: # LightGBM multiplies hessians by 2 in multiclass loss. lightgbm_params['min_sum_hessian_in_leaf'] *= 2 # LightGBM 3.0 introduced a different scaling of the hessian for the multiclass case. @@ -76,8 +75,7 @@ def get_equivalent_estimator(estimator, lib='lightgbm', n_classes=None): xgboost_loss_mapping = { 'squared_error': 'reg:linear', 'absolute_error': 'LEAST_ABSOLUTE_DEV_NOT_SUPPORTED', - 'binary_crossentropy': 'reg:logistic', - 'categorical_crossentropy': 'multi:softmax' + 'log_loss': 'reg:logistic' if n_classes == 2 else 'multi:softmax', } xgboost_params = { @@ -101,8 +99,7 @@ def get_equivalent_estimator(estimator, lib='lightgbm', n_classes=None): 'squared_error': 'RMSE', # catboost does not support MAE when leaf_estimation_method is Newton 'absolute_error': 'LEAST_ASBOLUTE_DEV_NOT_SUPPORTED', - 'binary_crossentropy': 'Logloss', - 'categorical_crossentropy': 'MultiClass' + 'log_loss': 'Logloss' if n_classes == 2 else 'MultiClass', } catboost_params = { From cec04a47eb56e3f8680dc737af05cfbb62f9c7bd Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 7 Apr 2022 10:50:35 +0200 Subject: [PATCH 06/12] fix tests --- .../_hist_gradient_boosting/tests/test_compare_lightgbm.py | 2 +- .../_hist_gradient_boosting/tests/test_gradient_boosting.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_compare_lightgbm.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_compare_lightgbm.py index 82b18c5ce2b64..f5c373ed84558 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_compare_lightgbm.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_compare_lightgbm.py @@ -124,7 +124,7 @@ def test_same_predictions_classification( X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=rng) est_sklearn = HistGradientBoostingClassifier( - loss="binary_crossentropy", + loss="log_loss", max_iter=max_iter, max_bins=max_bins, learning_rate=1, diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py index 7831c3c37c2e2..6535fcfcaacb0 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py @@ -664,7 +664,7 @@ def test_zero_sample_weights_classification(): # ignore the first 2 training samples by setting their weight to 0 sample_weight = [0, 0, 1, 1, 1] gb = HistGradientBoostingClassifier( - loss="categorical_crossentropy", min_samples_leaf=1 + loss="log_loss", min_samples_leaf=1 ) gb.fit(X, y, sample_weight=sample_weight) assert_array_equal(gb.predict([[1, 0]]), [1]) From 5c142282f90884f42ee0260ad8477fc81467eb8e Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 7 Apr 2022 11:07:12 +0200 Subject: [PATCH 07/12] lint --- .../_hist_gradient_boosting/tests/test_gradient_boosting.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py index 6535fcfcaacb0..e66e961b57dab 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py @@ -663,9 +663,7 @@ def test_zero_sample_weights_classification(): y = [0, 0, 1, 0, 2] # ignore the first 2 training samples by setting their weight to 0 sample_weight = [0, 0, 1, 1, 1] - gb = HistGradientBoostingClassifier( - loss="log_loss", min_samples_leaf=1 - ) + gb = HistGradientBoostingClassifier(loss="log_loss", min_samples_leaf=1) gb.fit(X, y, sample_weight=sample_weight) assert_array_equal(gb.predict([[1, 0]]), [1]) From d91795a7d1480d7999adf529af1ff6339728270f Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 7 Apr 2022 12:57:15 +0200 Subject: [PATCH 08/12] fix tests again --- .../_hist_gradient_boosting/tests/test_gradient_boosting.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py index e66e961b57dab..efa1ac1a4d762 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py @@ -81,6 +81,7 @@ def test_init_parameters_validation(GradientBoosting, X, y, params, err_msg): # TODO(1.3): remove +@pytest.mark.filterwarnings("ignore::FutureWarning") def test_invalid_classification_loss(): binary_clf = HistGradientBoostingClassifier(loss="binary_crossentropy") err_msg = ( @@ -611,6 +612,7 @@ def test_infinite_values_missing_values(): # TODO(1.3): remove +@pytest.mark.filterwarnings("ignore::FutureWarning") def test_crossentropy_binary_problem(): # categorical_crossentropy should only be used if there are more than two # classes present. PR #14869 From 4c9e76c96bfd7b0ba2c300e9a74e6f58ca60d6d7 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 7 Apr 2022 14:47:35 +0200 Subject: [PATCH 09/12] doc --- sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index 201bdb24f6265..1484ac74647ef 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -1457,7 +1457,8 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting): Parameters ---------- - loss : {'log_loss'}, default='log_loss' + loss : {'log_loss', 'auto', 'binary_crossentropy', 'categorical_crossentropy'}, + default='log_loss' The loss function to use in the boosting process. For binary classification problems, 'log_loss' is also known as logistic loss, From 29a80c18b050302ad085c746909b19048ec4a1ea Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 7 Apr 2022 15:01:46 +0200 Subject: [PATCH 10/12] remove unreachable code --- sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index 1484ac74647ef..31e91ee509672 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -1852,6 +1852,3 @@ def _get_loss(self, sample_weight): ) else: return HalfBinomialLoss(sample_weight=sample_weight) - else: - # should be an instance of BaseLoss - return self.loss From 4d610c7d893ca122034e4640aa028e00a830032f Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 7 Apr 2022 15:55:15 +0200 Subject: [PATCH 11/12] fix doc --- sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index 31e91ee509672..859c8f5c49274 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -1457,7 +1457,7 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting): Parameters ---------- - loss : {'log_loss', 'auto', 'binary_crossentropy', 'categorical_crossentropy'}, + loss : {'log_loss', 'auto', 'binary_crossentropy', 'categorical_crossentropy'}, \ default='log_loss' The loss function to use in the boosting process. From 7152e377033ba5fa6ea639afcfe82b6cf4ae1f7e Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 7 Apr 2022 16:49:21 +0200 Subject: [PATCH 12/12] nit --- sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index 859c8f5c49274..e36f1beb86dd8 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -1833,7 +1833,7 @@ def _get_loss(self, sample_weight): return HalfMultinomialLoss( sample_weight=sample_weight, n_classes=self.n_trees_per_iteration_ ) - elif self.loss in ("categorical_crossentropy"): + if self.loss == "categorical_crossentropy": if self.n_trees_per_iteration_ == 1: raise ValueError( f"loss='{self.loss}' is not suitable for a binary classification " @@ -1843,7 +1843,7 @@ def _get_loss(self, sample_weight): return HalfMultinomialLoss( sample_weight=sample_weight, n_classes=self.n_trees_per_iteration_ ) - elif self.loss in ("binary_crossentropy"): + if self.loss == "binary_crossentropy": if self.n_trees_per_iteration_ > 1: raise ValueError( f"loss='{self.loss}' is not defined for multiclass "