diff --git a/benchmarks/bench_hist_gradient_boosting.py b/benchmarks/bench_hist_gradient_boosting.py index 58477e8894fd1..163e21f98ed0d 100644 --- a/benchmarks/bench_hist_gradient_boosting.py +++ b/benchmarks/bench_hist_gradient_boosting.py @@ -115,12 +115,7 @@ def one_run(n_samples): loss = args.loss if args.problem == "classification": if loss == "default": - # loss='auto' does not work with get_equivalent_estimator() - loss = ( - "binary_crossentropy" - if args.n_classes == 2 - else "categorical_crossentropy" - ) + loss = "log_loss" else: # regression if loss == "default": @@ -159,7 +154,7 @@ def one_run(n_samples): xgb_score_duration = None if args.xgboost: print("Fitting an XGBoost model...") - xgb_est = get_equivalent_estimator(est, lib="xgboost") + xgb_est = get_equivalent_estimator(est, lib="xgboost", n_classes=args.n_classes) tic = time() xgb_est.fit(X_train, y_train, sample_weight=sample_weight_train) @@ -176,7 +171,9 @@ def one_run(n_samples): cat_score_duration = None if args.catboost: print("Fitting a CatBoost model...") - cat_est = get_equivalent_estimator(est, lib="catboost") + cat_est = get_equivalent_estimator( + est, lib="catboost", n_classes=args.n_classes + ) tic = time() cat_est.fit(X_train, y_train, sample_weight=sample_weight_train) diff --git a/benchmarks/bench_hist_gradient_boosting_adult.py b/benchmarks/bench_hist_gradient_boosting_adult.py index 6b85b3819fb0f..0e0ca911e6ed7 100644 --- a/benchmarks/bench_hist_gradient_boosting_adult.py +++ b/benchmarks/bench_hist_gradient_boosting_adult.py @@ -1,6 +1,8 @@ import argparse from time import time +import numpy as np + from sklearn.model_selection import train_test_split from sklearn.datasets import fetch_openml from sklearn.metrics import accuracy_score, roc_auc_score @@ -48,6 +50,7 @@ def predict(est, data_test, target_test): data = fetch_openml(data_id=179, as_frame=False) # adult dataset X, y = data.data, data.target +n_classes = len(np.unique(y)) n_features = X.shape[1] n_categorical_features = len(data.categories) n_numerical_features = n_features - n_categorical_features @@ -61,7 +64,7 @@ def predict(est, data_test, target_test): # already clean is_categorical = [name in data.categories for name in data.feature_names] est = HistGradientBoostingClassifier( - loss="binary_crossentropy", + loss="log_loss", learning_rate=lr, max_iter=n_trees, max_bins=max_bins, @@ -76,7 +79,7 @@ def predict(est, data_test, target_test): predict(est, X_test, y_test) if args.lightgbm: - est = get_equivalent_estimator(est, lib="lightgbm") + est = get_equivalent_estimator(est, lib="lightgbm", n_classes=n_classes) est.set_params(max_cat_to_onehot=1) # dont use OHE categorical_features = [ f_idx for (f_idx, is_cat) in enumerate(is_categorical) if is_cat diff --git a/benchmarks/bench_hist_gradient_boosting_categorical_only.py b/benchmarks/bench_hist_gradient_boosting_categorical_only.py index 5e6c63067f7cd..e8d215170f9c8 100644 --- a/benchmarks/bench_hist_gradient_boosting_categorical_only.py +++ b/benchmarks/bench_hist_gradient_boosting_categorical_only.py @@ -58,7 +58,7 @@ def predict(est, data_test): is_categorical = [True] * n_features est = HistGradientBoostingClassifier( - loss="binary_crossentropy", + loss="log_loss", learning_rate=lr, max_iter=n_trees, max_bins=max_bins, @@ -73,7 +73,7 @@ def predict(est, data_test): predict(est, X) if args.lightgbm: - est = get_equivalent_estimator(est, lib="lightgbm") + est = get_equivalent_estimator(est, lib="lightgbm", n_classes=2) est.set_params(max_cat_to_onehot=1) # dont use OHE categorical_features = list(range(n_features)) fit(est, X, y, "lightgbm", categorical_feature=categorical_features) diff --git a/benchmarks/bench_hist_gradient_boosting_higgsboson.py b/benchmarks/bench_hist_gradient_boosting_higgsboson.py index 8455ef177860c..abe8018adfd83 100644 --- a/benchmarks/bench_hist_gradient_boosting_higgsboson.py +++ b/benchmarks/bench_hist_gradient_boosting_higgsboson.py @@ -80,6 +80,7 @@ def predict(est, data_test, target_test): data_train, data_test, target_train, target_test = train_test_split( data, target, test_size=0.2, random_state=0 ) +n_classes = len(np.unique(target)) if subsample is not None: data_train, target_train = data_train[:subsample], target_train[:subsample] @@ -88,7 +89,7 @@ def predict(est, data_test, target_test): print(f"Training set with {n_samples} records with {n_features} features.") est = HistGradientBoostingClassifier( - loss="binary_crossentropy", + loss="log_loss", learning_rate=lr, max_iter=n_trees, max_bins=max_bins, @@ -101,16 +102,16 @@ def predict(est, data_test, target_test): predict(est, data_test, target_test) if args.lightgbm: - est = get_equivalent_estimator(est, lib="lightgbm") + est = get_equivalent_estimator(est, lib="lightgbm", n_classes=n_classes) fit(est, data_train, target_train, "lightgbm") predict(est, data_test, target_test) if args.xgboost: - est = get_equivalent_estimator(est, lib="xgboost") + est = get_equivalent_estimator(est, lib="xgboost", n_classes=n_classes) fit(est, data_train, target_train, "xgboost") predict(est, data_test, target_test) if args.catboost: - est = get_equivalent_estimator(est, lib="catboost") + est = get_equivalent_estimator(est, lib="catboost", n_classes=n_classes) fit(est, data_train, target_train, "catboost") predict(est, data_test, target_test) diff --git a/benchmarks/bench_hist_gradient_boosting_threading.py b/benchmarks/bench_hist_gradient_boosting_threading.py index dbeb4d7f7d7fa..70787fd2eb479 100644 --- a/benchmarks/bench_hist_gradient_boosting_threading.py +++ b/benchmarks/bench_hist_gradient_boosting_threading.py @@ -118,9 +118,7 @@ def get_estimator_and_data(): if args.problem == "classification": if loss == "default": # loss='auto' does not work with get_equivalent_estimator() - loss = ( - "binary_crossentropy" if args.n_classes == 2 else "categorical_crossentropy" - ) + loss = "log_loss" else: # regression if loss == "default": @@ -191,7 +189,7 @@ def one_run(n_threads, n_samples): xgb_score_duration = None if args.xgboost: print("Fitting an XGBoost model...") - xgb_est = get_equivalent_estimator(est, lib="xgboost") + xgb_est = get_equivalent_estimator(est, lib="xgboost", n_classes=args.n_classes) xgb_est.set_params(nthread=n_threads) tic = time() @@ -209,7 +207,9 @@ def one_run(n_threads, n_samples): cat_score_duration = None if args.catboost: print("Fitting a CatBoost model...") - cat_est = get_equivalent_estimator(est, lib="catboost") + cat_est = get_equivalent_estimator( + est, lib="catboost", n_classes=args.n_classes + ) cat_est.set_params(thread_count=n_threads) tic = time() diff --git a/doc/modules/ensemble.rst b/doc/modules/ensemble.rst index 98041bef73e08..aba8acfc5534b 100644 --- a/doc/modules/ensemble.rst +++ b/doc/modules/ensemble.rst @@ -949,10 +949,11 @@ controls the number of iterations of the boosting process:: Available losses for regression are 'squared_error', 'absolute_error', which is less sensitive to outliers, and 'poisson', which is well suited to model counts and frequencies. For -classification, 'binary_crossentropy' is used for binary classification and -'categorical_crossentropy' is used for multiclass classification. By default -the loss is 'auto' and will select the appropriate loss depending on -:term:`y` passed to :term:`fit`. +classification, 'log_loss' is the only option. For binary classification it uses the +binary log loss, also kown as binomial deviance or binary cross-entropy. For +`n_classes >= 3`, it uses the multi-class log loss function, with multinomial deviance +and categorical cross-entropy as alternative names. The appropriate loss version is +selected based on :term:`y` passed to :term:`fit`. The size of the trees can be controlled through the ``max_leaf_nodes``, ``max_depth``, and ``min_samples_leaf`` parameters. diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 17290dc0d55c6..4a2cd29c70714 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -98,6 +98,11 @@ Changelog default. :pr:`23036` by :user:`Christian Lorentzen `. + - For :class:`ensemble.HistGradientBoostingClassifier`, the `loss` parameter names + "auto", "binary_crossentropy" and "categorical_crossentropy" are deprecated in + favor of the new name "log_loss", which is now the default. + :pr:`23040` by :user:`Christian Lorentzen `. + - For :class:`linear_model.SGDClassifier`, the `loss` parameter name "log" is deprecated in favor of the new name "log_loss". :pr:`23046` by :user:`Christian Lorentzen `. diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index c75bdb01ea789..e36f1beb86dd8 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -37,7 +37,8 @@ _LOSSES = _LOSSES.copy() -# TODO: Remove least_squares and least_absolute_deviation in v1.2 +# TODO(1.2): Remove "least_squares" and "least_absolute_deviation" +# TODO(1.3): Remove "binary_crossentropy" and "categorical_crossentropy" _LOSSES.update( { "least_squares": HalfSquaredError, @@ -1299,6 +1300,7 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting): 0.92... """ + # TODO(1.2): remove "least_absolute_deviation" _VALID_LOSSES = ( "squared_error", "least_squares", @@ -1455,13 +1457,25 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting): Parameters ---------- - loss : {'auto', 'binary_crossentropy', 'categorical_crossentropy'}, \ - default='auto' - The loss function to use in the boosting process. 'binary_crossentropy' - (also known as logistic loss) is used for binary classification and - generalizes to 'categorical_crossentropy' for multiclass - classification. 'auto' will automatically choose either loss depending - on the nature of the problem. + loss : {'log_loss', 'auto', 'binary_crossentropy', 'categorical_crossentropy'}, \ + default='log_loss' + The loss function to use in the boosting process. + + For binary classification problems, 'log_loss' is also known as logistic loss, + binomial deviance or binary crossentropy. Internally, the model fits one tree + per boosting iteration and uses the logistic sigmoid function (expit) as + inverse link function to compute the predicted positive class probability. + + For multiclass classification problems, 'log_loss' is also known as multinomial + deviance or categorical crossentropy. Internally, the model fits one tree per + boosting iteration and per class and uses the softmax function as inverse link + function to compute the predicted probabilities of the classes. + + .. deprecated:: 1.1 + The loss arguments 'auto', 'binary_crossentropy' and + 'categorical_crossentropy' were deprecated in v1.1 and will be removed in + version 1.3. Use `loss='log_loss'` which is equivalent. + learning_rate : float, default=0.1 The learning rate, also known as *shrinkage*. This is used as a multiplicative factor for the leaves values. Use ``1`` for no @@ -1617,11 +1631,17 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting): 1.0 """ - _VALID_LOSSES = ("binary_crossentropy", "categorical_crossentropy", "auto") + # TODO(1.3): Remove "binary_crossentropy", "categorical_crossentropy", "auto" + _VALID_LOSSES = ( + "log_loss", + "binary_crossentropy", + "categorical_crossentropy", + "auto", + ) def __init__( self, - loss="auto", + loss="log_loss", *, learning_rate=0.1, max_iter=100, @@ -1798,33 +1818,37 @@ def _encode_y(self, y): return encoded_y def _get_loss(self, sample_weight): - if self.loss == "auto": + # TODO(1.3): Remove "auto", "binary_crossentropy", "categorical_crossentropy" + if self.loss in ("auto", "binary_crossentropy", "categorical_crossentropy"): + warnings.warn( + f"The loss '{self.loss}' was deprecated in v1.1 and will be removed in " + "version 1.3. Use 'log_loss' which is equivalent.", + FutureWarning, + ) + + if self.loss in ("log_loss", "auto"): if self.n_trees_per_iteration_ == 1: - return _LOSSES["binary_crossentropy"](sample_weight=sample_weight) + return HalfBinomialLoss(sample_weight=sample_weight) else: - return _LOSSES["categorical_crossentropy"]( - sample_weight=sample_weight, - n_classes=self.n_trees_per_iteration_, + return HalfMultinomialLoss( + sample_weight=sample_weight, n_classes=self.n_trees_per_iteration_ ) - if self.loss == "categorical_crossentropy": if self.n_trees_per_iteration_ == 1: raise ValueError( - "loss='categorical_crossentropy' is not suitable for " - "a binary classification problem. Please use " - "loss='auto' or loss='binary_crossentropy' instead." + f"loss='{self.loss}' is not suitable for a binary classification " + "problem. Please use loss='log_loss' instead." ) else: - return _LOSSES[self.loss]( + return HalfMultinomialLoss( sample_weight=sample_weight, n_classes=self.n_trees_per_iteration_ ) - else: + if self.loss == "binary_crossentropy": if self.n_trees_per_iteration_ > 1: raise ValueError( - "loss='binary_crossentropy' is not defined for multiclass" - " classification with n_classes=" - f"{self.n_trees_per_iteration_}, use loss=" - "'categorical_crossentropy' instead." + f"loss='{self.loss}' is not defined for multiclass " + f"classification with n_classes={self.n_trees_per_iteration_}, " + "use loss='log_loss' instead." ) else: - return _LOSSES[self.loss](sample_weight=sample_weight) + return HalfBinomialLoss(sample_weight=sample_weight) diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_compare_lightgbm.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_compare_lightgbm.py index 95c0d5f53d640..f5c373ed84558 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_compare_lightgbm.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_compare_lightgbm.py @@ -124,7 +124,7 @@ def test_same_predictions_classification( X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=rng) est_sklearn = HistGradientBoostingClassifier( - loss="binary_crossentropy", + loss="log_loss", max_iter=max_iter, max_bins=max_bins, learning_rate=1, @@ -132,7 +132,9 @@ def test_same_predictions_classification( min_samples_leaf=min_samples_leaf, max_leaf_nodes=max_leaf_nodes, ) - est_lightgbm = get_equivalent_estimator(est_sklearn, lib="lightgbm") + est_lightgbm = get_equivalent_estimator( + est_sklearn, lib="lightgbm", n_classes=n_classes + ) est_lightgbm.fit(X_train, y_train) est_sklearn.fit(X_train, y_train) @@ -198,7 +200,7 @@ def test_same_predictions_multiclass_classification( X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=rng) est_sklearn = HistGradientBoostingClassifier( - loss="categorical_crossentropy", + loss="log_loss", max_iter=max_iter, max_bins=max_bins, learning_rate=lr, diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py index 642904138a2d9..efa1ac1a4d762 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py @@ -6,8 +6,6 @@ from sklearn._loss.loss import ( AbsoluteError, HalfBinomialLoss, - HalfMultinomialLoss, - HalfPoissonLoss, HalfSquaredError, PinballLoss, ) @@ -34,16 +32,6 @@ n_threads = _openmp_effective_n_threads() -_LOSSES = { - "squared_error": HalfSquaredError, - "absolute_error": AbsoluteError, - "poisson": HalfPoissonLoss, - "quantile": PinballLoss, - "binary_crossentropy": HalfBinomialLoss, - "categorical_crossentropy": HalfMultinomialLoss, -} - - X_classification, y_classification = make_classification(random_state=0) X_regression, y_regression = make_regression(random_state=0) X_multi_classification, y_multi_classification = make_classification( @@ -92,12 +80,14 @@ def test_init_parameters_validation(GradientBoosting, X, y, params, err_msg): GradientBoosting(**params).fit(X, y) +# TODO(1.3): remove +@pytest.mark.filterwarnings("ignore::FutureWarning") def test_invalid_classification_loss(): binary_clf = HistGradientBoostingClassifier(loss="binary_crossentropy") err_msg = ( "loss='binary_crossentropy' is not defined for multiclass " "classification with n_classes=3, use " - "loss='categorical_crossentropy' instead" + "loss='log_loss' instead" ) with pytest.raises(ValueError, match=err_msg): binary_clf.fit(np.zeros(shape=(3, 2)), np.arange(3)) @@ -430,7 +420,7 @@ def test_missing_values_resilience( make_classification(random_state=0, n_classes=2), make_classification(random_state=0, n_classes=3, n_informative=3), ], - ids=["binary_crossentropy", "categorical_crossentropy"], + ids=["binary_log_loss", "multiclass_log_loss"], ) def test_zero_division_hessians(data): # non regression test for issue #14018 @@ -621,6 +611,8 @@ def test_infinite_values_missing_values(): assert stump_clf.fit(X, y_isnan).score(X, y_isnan) == 1 +# TODO(1.3): remove +@pytest.mark.filterwarnings("ignore::FutureWarning") def test_crossentropy_binary_problem(): # categorical_crossentropy should only be used if there are more than two # classes present. PR #14869 @@ -665,7 +657,7 @@ def test_zero_sample_weights_classification(): y = [0, 0, 1, 0] # ignore the first 2 training samples by setting their weight to 0 sample_weight = [0, 0, 1, 1] - gb = HistGradientBoostingClassifier(loss="binary_crossentropy", min_samples_leaf=1) + gb = HistGradientBoostingClassifier(loss="log_loss", min_samples_leaf=1) gb.fit(X, y, sample_weight=sample_weight) assert_array_equal(gb.predict([[1, 0]]), [1]) @@ -673,9 +665,7 @@ def test_zero_sample_weights_classification(): y = [0, 0, 1, 0, 2] # ignore the first 2 training samples by setting their weight to 0 sample_weight = [0, 0, 1, 1, 1] - gb = HistGradientBoostingClassifier( - loss="categorical_crossentropy", min_samples_leaf=1 - ) + gb = HistGradientBoostingClassifier(loss="log_loss", min_samples_leaf=1) gb.fit(X, y, sample_weight=sample_weight) assert_array_equal(gb.predict([[1, 0]]), [1]) @@ -737,8 +727,8 @@ def test_sample_weight_effect(problem, duplication): assert np.allclose(est_sw._raw_predict(X_dup), est_dup._raw_predict(X_dup)) -@pytest.mark.parametrize("loss_name", ("squared_error", "absolute_error")) -def test_sum_hessians_are_sample_weight(loss_name): +@pytest.mark.parametrize("Loss", (HalfSquaredError, AbsoluteError)) +def test_sum_hessians_are_sample_weight(Loss): # For losses with constant hessians, the sum_hessians field of the # histograms must be equal to the sum of the sample weight of samples at # the corresponding bin. @@ -753,7 +743,7 @@ def test_sum_hessians_are_sample_weight(loss_name): # While sample weights are supposed to be positive, this still works. sample_weight = rng.normal(size=n_samples) - loss = _LOSSES[loss_name](sample_weight=sample_weight) + loss = Loss(sample_weight=sample_weight) gradients, hessians = loss.init_gradient_and_hessian( n_samples=n_samples, dtype=G_H_DTYPE ) @@ -1125,22 +1115,31 @@ def test_uint8_predict(Est): est.predict(X) -# TODO: Remove in v1.2 @pytest.mark.parametrize( - "old_loss, new_loss", + "old_loss, new_loss, Estimator", [ - ("least_squares", "squared_error"), - ("least_absolute_deviation", "absolute_error"), + # TODO(1.2): Remove + ("least_squares", "squared_error", HistGradientBoostingRegressor), + ("least_absolute_deviation", "absolute_error", HistGradientBoostingRegressor), + # TODO(1.3): Remove + ("auto", "log_loss", HistGradientBoostingClassifier), + ("binary_crossentropy", "log_loss", HistGradientBoostingClassifier), + ("categorical_crossentropy", "log_loss", HistGradientBoostingClassifier), ], ) -def test_loss_deprecated(old_loss, new_loss): - X, y = make_regression(n_samples=50, random_state=0) - est1 = HistGradientBoostingRegressor(loss=old_loss, random_state=0) +def test_loss_deprecated(old_loss, new_loss, Estimator): + if old_loss == "categorical_crossentropy": + X, y = X_multi_classification[:10], y_multi_classification[:10] + assert len(np.unique(y)) > 2 + else: + X, y = X_classification[:10], y_classification[:10] + + est1 = Estimator(loss=old_loss, random_state=0) with pytest.warns(FutureWarning, match=f"The loss '{old_loss}' was deprecated"): est1.fit(X, y) - est2 = HistGradientBoostingRegressor(loss=new_loss, random_state=0) + est2 = Estimator(loss=new_loss, random_state=0) est2.fit(X, y) assert_allclose(est1.predict(X), est2.predict(X)) diff --git a/sklearn/ensemble/_hist_gradient_boosting/utils.pyx b/sklearn/ensemble/_hist_gradient_boosting/utils.pyx index b2de7614fe499..d2123ecc61510 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/utils.pyx +++ b/sklearn/ensemble/_hist_gradient_boosting/utils.pyx @@ -40,8 +40,7 @@ def get_equivalent_estimator(estimator, lib='lightgbm', n_classes=None): lightgbm_loss_mapping = { 'squared_error': 'regression_l2', 'absolute_error': 'regression_l1', - 'binary_crossentropy': 'binary', - 'categorical_crossentropy': 'multiclass' + 'log_loss': 'binary' if n_classes == 2 else 'multiclass', } lightgbm_params = { @@ -63,7 +62,7 @@ def get_equivalent_estimator(estimator, lib='lightgbm', n_classes=None): 'subsample_for_bin': _BinMapper().subsample, } - if sklearn_params['loss'] == 'categorical_crossentropy': + if sklearn_params['loss'] == 'log_loss' and n_classes > 2: # LightGBM multiplies hessians by 2 in multiclass loss. lightgbm_params['min_sum_hessian_in_leaf'] *= 2 # LightGBM 3.0 introduced a different scaling of the hessian for the multiclass case. @@ -76,8 +75,7 @@ def get_equivalent_estimator(estimator, lib='lightgbm', n_classes=None): xgboost_loss_mapping = { 'squared_error': 'reg:linear', 'absolute_error': 'LEAST_ABSOLUTE_DEV_NOT_SUPPORTED', - 'binary_crossentropy': 'reg:logistic', - 'categorical_crossentropy': 'multi:softmax' + 'log_loss': 'reg:logistic' if n_classes == 2 else 'multi:softmax', } xgboost_params = { @@ -101,8 +99,7 @@ def get_equivalent_estimator(estimator, lib='lightgbm', n_classes=None): 'squared_error': 'RMSE', # catboost does not support MAE when leaf_estimation_method is Newton 'absolute_error': 'LEAST_ASBOLUTE_DEV_NOT_SUPPORTED', - 'binary_crossentropy': 'Logloss', - 'categorical_crossentropy': 'MultiClass' + 'log_loss': 'Logloss' if n_classes == 2 else 'MultiClass', } catboost_params = {