diff --git a/doc/modules/ensemble.rst b/doc/modules/ensemble.rst index 329215406c39c..21610228b9b37 100644 --- a/doc/modules/ensemble.rst +++ b/doc/modules/ensemble.rst @@ -944,7 +944,7 @@ controls the number of iterations of the boosting process:: 0.8965 Available losses for regression are 'squared_error', -'least_absolute_deviation', which is less sensitive to outliers, and +'absolute_error', which is less sensitive to outliers, and 'poisson', which is well suited to model counts and frequencies. For classification, 'binary_crossentropy' is used for binary classification and 'categorical_crossentropy' is used for multiclass classification. By default diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index d26c5dd0c347d..8ad8a295d72e0 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -76,6 +76,35 @@ Changelog - For :class:`tree.ExtraTreeRegressor`, `criterion="mse"` is deprecated, use `"squared_error"` instead which is now the default. +- |API| The option for using the absolute error via ``loss`` and + ``criterion`` parameters was made more consistent. The preferred way is by + setting the value to `"absolute_error"`. Old option names are still valid, + produce the same models, but are deprecated and will be removed in version + 1.2. + :pr:`19733` by :user:`Christian Lorentzen `. + + - For :class:`ensemble.ExtraTreesRegressor`, `criterion="mae"` is deprecated, + use `"absolute_error"` instead. + + - For :class:`ensemble.GradientBoostingRegressor`, `loss="lad"` is deprecated, + use `"absolute_error"` instead. + + - For :class:`ensemble.RandomForestRegressor`, `criterion="mae"` is deprecated, + use `"absolute_error"` instead. + + - For :class:`ensemble.HistGradientBoostingRegressor`, + `loss="least_absolute_deviation"` is deprecated, use `"absolute_error"` + instead. + + - For :class:`linear_model.RANSACRegressor`, `loss="absolute_loss"` is + deprecated, use `"absolute_error"` instead which is now the default. + + - For :class:`tree.DecisionTreeRegressor`, `criterion="mae"` is deprecated, + use `"absolute_error"` instead. + + - For :class:`tree.ExtraTreeRegressor`, `criterion="mae"` is deprecated, + use `"absolute_error"` instead. + :mod:`sklearn.base` ................... diff --git a/sklearn/ensemble/_base.py b/sklearn/ensemble/_base.py index 095d801de166d..c58a0c7dbe9c7 100644 --- a/sklearn/ensemble/_base.py +++ b/sklearn/ensemble/_base.py @@ -153,13 +153,13 @@ def _make_estimator(self, append=True, random_state=None): for p in self.estimator_params}) # TODO: Remove in v1.2 - # criterion "mse" would cause warnings in every call to + # criterion "mse" and "mae" would cause warnings in every call to # DecisionTreeRegressor.fit(..) - if ( - isinstance(estimator, (DecisionTreeRegressor, ExtraTreeRegressor)) - and getattr(estimator, "criterion", None) == "mse" - ): - estimator.set_params(criterion="squared_error") + if isinstance(estimator, (DecisionTreeRegressor, ExtraTreeRegressor)): + if getattr(estimator, "criterion", None) == "mse": + estimator.set_params(criterion="squared_error") + elif getattr(estimator, "criterion", None) == "mae": + estimator.set_params(criterion="absolute_error") if random_state is not None: _set_random_states(estimator, random_state) diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index 140c1c93e8eef..8eef1f3429227 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -346,16 +346,21 @@ def fit(self, X, y, sample_weight=None): # Check parameters self._validate_estimator() # TODO: Remove in v1.2 - if ( - isinstance(self, (RandomForestRegressor, ExtraTreesRegressor)) - and self.criterion == "mse" - ): - warn( - "Criterion 'mse' was deprecated in v1.0 and will be " - "removed in version 1.2. Use `criterion='squared_error'` " - "which is equivalent.", - FutureWarning - ) + if isinstance(self, (RandomForestRegressor, ExtraTreesRegressor)): + if self.criterion == "mse": + warn( + "Criterion 'mse' was deprecated in v1.0 and will be " + "removed in version 1.2. Use `criterion='squared_error'` " + "which is equivalent.", + FutureWarning + ) + elif self.criterion == "mae": + warn( + "Criterion 'mae' was deprecated in v1.0 and will be " + "removed in version 1.2. Use `criterion='absolute_error'` " + "which is equivalent.", + FutureWarning + ) if not self.bootstrap and self.oob_score: raise ValueError("Out of bag estimation only available" @@ -1321,11 +1326,12 @@ class RandomForestRegressor(ForestRegressor): The default value of ``n_estimators`` changed from 10 to 100 in 0.22. - criterion : {"squared_error", "mse", "mae"}, default="squared_error" + criterion : {"squared_error", "mse", "absolute_error", "mae"}, \ + default="squared_error" The function to measure the quality of a split. Supported criteria are "squared_error" for the mean squared error, which is equal to - variance reduction as feature selection criterion, and "mae" for the - mean absolute error. + variance reduction as feature selection criterion, and "absolute_error" + for the mean absolute error. .. versionadded:: 0.18 Mean Absolute Error (MAE) criterion. @@ -1334,6 +1340,10 @@ class RandomForestRegressor(ForestRegressor): Criterion "mse" was deprecated in v1.0 and will be removed in version 1.2. Use `criterion="squared_error"` which is equivalent. + .. deprecated:: 1.0 + Criterion "mae" was deprecated in v1.0 and will be removed in + version 1.2. Use `criterion="absolute_error"` which is equivalent. + max_depth : int, default=None The maximum depth of the tree. If None, then nodes are expanded until all leaves are pure or until all leaves contain less than @@ -1936,10 +1946,11 @@ class ExtraTreesRegressor(ForestRegressor): The default value of ``n_estimators`` changed from 10 to 100 in 0.22. - criterion : {"squared_error", "mse", "mae"}, default="squared_error" + criterion : {"squared_error", "mse", "absolute_error", "mae"}, \ + default="squared_error" The function to measure the quality of a split. Supported criteria - are "squared_error" and "mse" for the mean squared error, which is - equal to variance reduction as feature selection criterion, and "mae" + are "squared_error" for the mean squared error, which is equal to + variance reduction as feature selection criterion, and "absolute_error" for the mean absolute error. .. versionadded:: 0.18 @@ -1949,6 +1960,10 @@ class ExtraTreesRegressor(ForestRegressor): Criterion "mse" was deprecated in v1.0 and will be removed in version 1.2. Use `criterion="squared_error"` which is equivalent. + .. deprecated:: 1.0 + Criterion "mae" was deprecated in v1.0 and will be removed in + version 1.2. Use `criterion="absolute_error"` which is equivalent. + max_depth : int, default=None The maximum depth of the tree. If None, then nodes are expanded until all leaves are pure or until all leaves contain less than diff --git a/sklearn/ensemble/_gb.py b/sklearn/ensemble/_gb.py index 4984575bce8c3..527bbcb559b5f 100644 --- a/sklearn/ensemble/_gb.py +++ b/sklearn/ensemble/_gb.py @@ -238,11 +238,17 @@ def _check_params(self): or self.loss not in _gb_losses.LOSS_FUNCTIONS): raise ValueError("Loss '{0:s}' not supported. ".format(self.loss)) + # TODO: Remove in v1.2 if self.loss == "ls": warnings.warn("The loss 'ls' was deprecated in v1.0 and " "will be removed in version 1.2. Use 'squared_error'" " which is equivalent.", FutureWarning) + elif self.loss == "lad": + warnings.warn("The loss 'lad' was deprecated in v1.0 and " + "will be removed in version 1.2. Use " + "'absolute_error' which is equivalent.", + FutureWarning) if self.loss == 'deviance': loss_class = (_gb_losses.MultinomialDeviance @@ -403,7 +409,7 @@ def fit(self, X, y, sample_weight=None, monitor=None): ------- self : object """ - if self.criterion == 'mae': + if self.criterion in ('absolute_error', 'mae'): # TODO: This should raise an error from 1.1 self._warn_mae_for_criterion() @@ -1340,19 +1346,22 @@ class GradientBoostingRegressor(RegressorMixin, BaseGradientBoosting): Parameters ---------- - loss : {'squared_error', 'ls', 'lad', 'huber', 'quantile'}, \ - default='squared_error' + loss : {'squared_error', 'ls', 'absolute_error', 'lad', 'huber', \ + 'quantile'}, default='squared_error' Loss function to be optimized. 'squared_error' refers to the squared - error for regression. - 'lad' (least absolute deviation) is a highly robust - loss function solely based on order information of the input - variables. 'huber' is a combination of the two. 'quantile' - allows quantile regression (use `alpha` to specify the quantile). + error for regression. 'absolute_error' refers to the absolute error of + regression and is a robust loss function. 'huber' is a + combination of the two. 'quantile' allows quantile regression (use + `alpha` to specify the quantile). .. deprecated:: 1.0 The loss 'ls' was deprecated in v1.0 and will be removed in version 1.2. Use `loss='squared_error'` which is equivalent. + .. deprecated:: 1.0 + The loss 'lad' was deprecated in v1.0 and will be removed in + version 1.2. Use `loss='absolute_error'` which is equivalent. + learning_rate : float, default=0.1 Learning rate shrinks the contribution of each tree by `learning_rate`. There is a trade-off between learning_rate and n_estimators. @@ -1383,7 +1392,7 @@ class GradientBoostingRegressor(RegressorMixin, BaseGradientBoosting): .. deprecated:: 0.24 `criterion='mae'` is deprecated and will be removed in version 1.1 (renaming of 0.26). The correct way of minimizing the absolute - error is to use `loss='lad'` instead. + error is to use `loss='absolute_error'` instead. .. deprecated:: 1.0 Criterion 'mse' was deprecated in v1.0 and will be removed in @@ -1644,7 +1653,8 @@ class GradientBoostingRegressor(RegressorMixin, BaseGradientBoosting): """ # TODO: remove "ls" in verion 1.2 - _SUPPORTED_LOSS = ("squared_error", 'ls', 'lad', 'huber', 'quantile') + _SUPPORTED_LOSS = ("squared_error", 'ls', "absolute_error", 'lad', 'huber', + 'quantile') @_deprecate_positional_args def __init__(self, *, loss="squared_error", learning_rate=0.1, @@ -1681,7 +1691,7 @@ def _warn_mae_for_criterion(self): warnings.warn("criterion='mae' was deprecated in version 0.24 and " "will be removed in version 1.1 (renaming of 0.26). The " "correct way of minimizing the absolute error is to use " - " loss='lad' instead.", FutureWarning) + " loss='absolute_error' instead.", FutureWarning) def predict(self, X): """Predict regression target for X. diff --git a/sklearn/ensemble/_gb_losses.py b/sklearn/ensemble/_gb_losses.py index f33c7086b596b..67a3b1b364f47 100644 --- a/sklearn/ensemble/_gb_losses.py +++ b/sklearn/ensemble/_gb_losses.py @@ -856,10 +856,11 @@ def get_init_raw_predictions(self, X, estimator): return raw_predictions.reshape(-1, 1).astype(np.float64) -# TODO: Remove entry 'ls' in version 1.2. +# TODO: Remove entry 'ls' and 'lad' in version 1.2. LOSS_FUNCTIONS = { "squared_error": LeastSquaresError, 'ls': LeastSquaresError, + "absolute_error": LeastAbsoluteError, 'lad': LeastAbsoluteError, 'huber': HuberLossFunction, 'quantile': QuantileLossFunction, diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index d3b62a5df784a..6d5de978add9b 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -893,8 +893,8 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting): Parameters ---------- - loss : {'squared_error', 'least_squares', 'least_absolute_deviation', \ - 'poisson'}, default='squared_error' + loss : {'squared_error', 'least_squares', 'absolute_error', \ + 'least_absolute_deviation', 'poisson'}, default='squared_error' The loss function to use in the boosting process. Note that the "least squares" and "poisson" losses actually implement "half least squares loss" and "half poisson deviance" to simplify the @@ -908,6 +908,11 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting): The loss 'least_squares' was deprecated in v1.0 and will be removed in version 1.2. Use `loss='squared_error'` which is equivalent. + .. deprecated:: 1.0 + The loss 'least_absolute_deviation' was deprecated in v1.0 and will + be removed in version 1.2. Use `loss='absolute_error'` which is + equivalent. + learning_rate : float, default=0.1 The learning rate, also known as *shrinkage*. This is used as a multiplicative factor for the leaves values. Use ``1`` for no @@ -1037,7 +1042,7 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting): 0.92... """ - _VALID_LOSSES = ('squared_error', 'least_squares', + _VALID_LOSSES = ('squared_error', 'least_squares', 'absolute_error', 'least_absolute_deviation', 'poisson') @_deprecate_positional_args @@ -1113,6 +1118,7 @@ def _encode_y(self, y): return y def _get_loss(self, sample_weight): + # TODO: Remove in v1.2 if self.loss == "least_squares": warnings.warn( "The loss 'least_squares' was deprecated in v1.0 and will be " @@ -1120,6 +1126,13 @@ def _get_loss(self, sample_weight): "equivalent.", FutureWarning) return _LOSSES["squared_error"](sample_weight=sample_weight) + elif self.loss == "least_absolute_deviation": + warnings.warn( + "The loss 'least_absolute_deviation' was deprecated in v1.0 " + " and will be removed in version 1.2. Use 'absolute_error' " + "which is equivalent.", + FutureWarning) + return _LOSSES["absolute_error"](sample_weight=sample_weight) return _LOSSES[self.loss](sample_weight=sample_weight) diff --git a/sklearn/ensemble/_hist_gradient_boosting/loss.py b/sklearn/ensemble/_hist_gradient_boosting/loss.py index c336bd347e4cf..036f075bdabd8 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/loss.py +++ b/sklearn/ensemble/_hist_gradient_boosting/loss.py @@ -420,7 +420,7 @@ def predict_proba(self, raw_predictions): _LOSSES = { 'squared_error': LeastSquares, - 'least_absolute_deviation': LeastAbsoluteDeviation, + 'absolute_error': LeastAbsoluteDeviation, 'binary_crossentropy': BinaryCrossEntropy, 'categorical_crossentropy': CategoricalCrossEntropy, 'poisson': Poisson, diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_compare_lightgbm.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_compare_lightgbm.py index f34dffab2671c..ac58f39422687 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_compare_lightgbm.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_compare_lightgbm.py @@ -34,7 +34,7 @@ def test_same_predictions_regression(seed, min_samples_leaf, n_samples, # and max_leaf_nodes is low enough. # - To ignore discrepancies caused by small differences the binning # strategy, data is pre-binned if n_samples > 255. - # - We don't check the least_absolute_deviation loss here. This is because + # - We don't check the absolute_error loss here. This is because # LightGBM's computation of the median (used for the initial value of # raw_prediction) is a bit off (they'll e.g. return midpoints when there # is no need to.). Since these tests only run 1 iteration, the diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py index b2322f29f85d1..213d46cf58f04 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py @@ -192,26 +192,26 @@ def test_should_stop(scores, n_iter_no_change, tol, stopping): assert gbdt._should_stop(scores) == stopping -def test_least_absolute_deviation(): +def test_absolute_error(): # For coverage only. X, y = make_regression(n_samples=500, random_state=0) - gbdt = HistGradientBoostingRegressor(loss='least_absolute_deviation', + gbdt = HistGradientBoostingRegressor(loss='absolute_error', random_state=0) gbdt.fit(X, y) assert gbdt.score(X, y) > .9 -def test_least_absolute_deviation_sample_weight(): +def test_absolute_error_sample_weight(): # non regression test for issue #19400 # make sure no error is thrown during fit of - # HistGradientBoostingRegressor with least_absolute_deviation loss function + # HistGradientBoostingRegressor with absolute_error loss function # and passing sample_weight rng = np.random.RandomState(0) n_samples = 100 X = rng.uniform(-1, 1, size=(n_samples, 2)) y = rng.uniform(-1, 1, size=n_samples) sample_weight = rng.uniform(0, 1, size=n_samples) - gbdt = HistGradientBoostingRegressor(loss='least_absolute_deviation') + gbdt = HistGradientBoostingRegressor(loss='absolute_error') gbdt.fit(X, y, sample_weight=sample_weight) @@ -650,8 +650,7 @@ def test_sample_weight_effect(problem, duplication): est_dup._raw_predict(X_dup)) -@pytest.mark.parametrize('loss_name', ('squared_error', - 'least_absolute_deviation')) +@pytest.mark.parametrize('loss_name', ('squared_error', 'absolute_error')) def test_sum_hessians_are_sample_weight(loss_name): # For losses with constant hessians, the sum_hessians field of the # histograms must be equal to the sum of the sample weight of samples at @@ -993,14 +992,18 @@ def test_uint8_predict(Est): # TODO: Remove in v1.2 -def test_loss_least_squares_deprecated(): +@pytest.mark.parametrize("old_loss, new_loss", [ + ("least_squares", "squared_error"), + ("least_absolute_deviation", "absolute_error"), +]) +def test_loss_deprecated(old_loss, new_loss): X, y = make_regression(n_samples=50, random_state=0) - est1 = HistGradientBoostingRegressor(loss="least_squares", random_state=0) + est1 = HistGradientBoostingRegressor(loss=old_loss, random_state=0) with pytest.warns(FutureWarning, - match="The loss 'least_squares' was deprecated"): + match=f"The loss '{old_loss}' was deprecated"): est1.fit(X, y) - est2 = HistGradientBoostingRegressor(loss="squared_error", random_state=0) + est2 = HistGradientBoostingRegressor(loss=new_loss, random_state=0) est2.fit(X, y) assert_allclose(est1.predict(X), est2.predict(X)) diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_loss.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_loss.py index ce7b4acedbae5..345e72c642668 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_loss.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_loss.py @@ -103,7 +103,7 @@ def fprime2(x: np.ndarray) -> np.ndarray: @pytest.mark.parametrize('loss, n_classes, prediction_dim', [ ("squared_error", 0, 1), - ('least_absolute_deviation', 0, 1), + ("absolute_error", 0, 1), ('binary_crossentropy', 2, 1), ('categorical_crossentropy', 3, 3), ('poisson', 0, 1), @@ -118,7 +118,7 @@ def test_numerical_gradients(loss, n_classes, prediction_dim, seed=0): rng = np.random.RandomState(seed) n_samples = 100 - if loss in ("squared_error", 'least_absolute_deviation'): + if loss in ("squared_error", "absolute_error"): y_true = rng.normal(size=n_samples).astype(Y_DTYPE) elif loss in ('poisson'): y_true = rng.poisson(size=n_samples).astype(Y_DTYPE) @@ -172,10 +172,10 @@ def test_baseline_least_squares(): baseline_prediction) -def test_baseline_least_absolute_deviation(): +def test_baseline_absolute_error(): rng = np.random.RandomState(0) - loss = _LOSSES['least_absolute_deviation'](sample_weight=None) + loss = _LOSSES["absolute_error"](sample_weight=None) y_train = rng.normal(size=100) baseline_prediction = loss.get_baseline_prediction(y_train, None, 1) assert baseline_prediction.shape == tuple() # scalar @@ -256,7 +256,7 @@ def test_baseline_categorical_crossentropy(): @pytest.mark.parametrize('loss, problem', [ ("squared_error", 'regression'), - ('least_absolute_deviation', 'regression'), + ("absolute_error", 'regression'), ('binary_crossentropy', 'classification'), ('categorical_crossentropy', 'classification'), ('poisson', 'poisson_regression'), diff --git a/sklearn/ensemble/_hist_gradient_boosting/utils.pyx b/sklearn/ensemble/_hist_gradient_boosting/utils.pyx index d1168acf94835..3b323b3e298b8 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/utils.pyx +++ b/sklearn/ensemble/_hist_gradient_boosting/utils.pyx @@ -43,7 +43,7 @@ def get_equivalent_estimator(estimator, lib='lightgbm'): lightgbm_loss_mapping = { 'squared_error': 'regression_l2', - 'least_absolute_deviation': 'regression_l1', + 'absolute_error': 'regression_l1', 'binary_crossentropy': 'binary', 'categorical_crossentropy': 'multiclass' } @@ -76,7 +76,7 @@ def get_equivalent_estimator(estimator, lib='lightgbm'): # XGB xgboost_loss_mapping = { 'squared_error': 'reg:linear', - 'least_absolute_deviation': 'LEAST_ABSOLUTE_DEV_NOT_SUPPORTED', + 'absolute_error': 'LEAST_ABSOLUTE_DEV_NOT_SUPPORTED', 'binary_crossentropy': 'reg:logistic', 'categorical_crossentropy': 'multi:softmax' } @@ -101,7 +101,7 @@ def get_equivalent_estimator(estimator, lib='lightgbm'): catboost_loss_mapping = { 'squared_error': 'RMSE', # catboost does not support MAE when leaf_estimation_method is Newton - 'least_absolute_deviation': 'LEAST_ASBOLUTE_DEV_NOT_SUPPORTED', + 'absolute_error': 'LEAST_ASBOLUTE_DEV_NOT_SUPPORTED', 'binary_crossentropy': 'Logloss', 'categorical_crossentropy': 'MultiClass' } diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index b6c1fea0e2f29..c74a1ca0c603e 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -176,7 +176,9 @@ def check_regression_criterion(name, criterion): @pytest.mark.parametrize('name', FOREST_REGRESSORS) -@pytest.mark.parametrize('criterion', ("squared_error", "mae", "friedman_mse")) +@pytest.mark.parametrize('criterion', ( + "squared_error", "absolute_error", "friedman_mse" +)) def test_regression(name, criterion): check_regression_criterion(name, criterion) @@ -261,10 +263,14 @@ def check_importances(name, criterion, dtype, tolerance): itertools.chain(product(FOREST_CLASSIFIERS, ["gini", "entropy"]), product(FOREST_REGRESSORS, - ["squared_error", "friedman_mse", "mae"]))) + [ + "squared_error", + "friedman_mse", + "absolute_error" + ]))) def test_importances(dtype, name, criterion): tolerance = 0.01 - if name in FOREST_REGRESSORS and criterion == "mae": + if name in FOREST_REGRESSORS and criterion == "absolute_error": tolerance = 0.05 check_importances(name, criterion, dtype, tolerance) @@ -1498,14 +1504,18 @@ def test_n_features_deprecation(Estimator): # TODO: Remove in v1.2 -def test_mse_deprecated(): - est1 = RandomForestRegressor(criterion="mse", random_state=0) +@pytest.mark.parametrize("old_criterion, new_criterion", [ + ("mse", "squared_error"), + ("mae", "absolute_error"), +]) +def test_criterion_deprecated(old_criterion, new_criterion): + est1 = RandomForestRegressor(criterion=old_criterion, random_state=0) with pytest.warns(FutureWarning, - match="Criterion 'mse' was deprecated"): + match=f"Criterion '{old_criterion}' was deprecated"): est1.fit(X, y) - est2 = RandomForestRegressor(criterion="squared_error", random_state=0) + est2 = RandomForestRegressor(criterion=new_criterion, random_state=0) est2.fit(X, y) assert_allclose(est1.predict(X), est2.predict(X)) diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index 166d6bdfc5c11..30c0cdc0cc8fd 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -133,7 +133,7 @@ def test_gbdt_loss_alpha_error(params, err_msg): @pytest.mark.parametrize( "GradientBoosting, loss", [(GradientBoostingClassifier, "ls"), - (GradientBoostingClassifier, "lad"), + (GradientBoostingClassifier, "absolute_error"), (GradientBoostingClassifier, "quantile"), (GradientBoostingClassifier, "huber"), (GradientBoostingRegressor, "deviance"), @@ -171,7 +171,7 @@ def test_classification_synthetic(loss): assert error_rate < 0.08 -@pytest.mark.parametrize('loss', ('squared_error', 'lad', 'huber')) +@pytest.mark.parametrize('loss', ('squared_error', 'absolute_error', 'huber')) @pytest.mark.parametrize('subsample', (1.0, 0.5)) def test_regression_dataset(loss, subsample): # Check consistency on regression dataset with least squares @@ -508,7 +508,7 @@ def test_degenerate_targets(): def test_quantile_loss(): - # Check if quantile loss with alpha=0.5 equals lad. + # Check if quantile loss with alpha=0.5 equals absolute_error. clf_quantile = GradientBoostingRegressor(n_estimators=100, loss='quantile', max_depth=4, alpha=0.5, random_state=7) @@ -516,12 +516,12 @@ def test_quantile_loss(): clf_quantile.fit(X_reg, y_reg) y_quantile = clf_quantile.predict(X_reg) - clf_lad = GradientBoostingRegressor(n_estimators=100, loss='lad', - max_depth=4, random_state=7) + clf_ae = GradientBoostingRegressor(n_estimators=100, loss='absolute_error', + max_depth=4, random_state=7) - clf_lad.fit(X_reg, y_reg) - y_lad = clf_lad.predict(X_reg) - assert_array_almost_equal(y_quantile, y_lad, decimal=4) + clf_ae.fit(X_reg, y_reg) + y_ae = clf_ae.predict(X_reg) + assert_array_almost_equal(y_quantile, y_ae, decimal=4) def test_symbol_labels(): @@ -1067,7 +1067,7 @@ def test_non_uniform_weights_toy_edge_case_reg(): y = [0, 0, 1, 0] # ignore the first 2 training samples by setting their weight to 0 sample_weight = [0, 0, 1, 1] - for loss in ('huber', 'squared_error', 'lad', 'quantile'): + for loss in ('huber', 'squared_error', 'absolute_error', 'quantile'): gb = GradientBoostingRegressor(learning_rate=1.0, n_estimators=2, loss=loss) gb.fit(X, y, sample_weight=sample_weight) @@ -1390,13 +1390,17 @@ def test_criterion_mse_deprecated(Estimator): # TODO: Remove in v1.2 -def test_loss_ls_deprecated(): - est1 = GradientBoostingRegressor(loss="ls", random_state=0) +@pytest.mark.parametrize("old_loss, new_loss", [ + ("ls", "squared_error"), + ("lad", "absolute_error"), +]) +def test_loss_deprecated(old_loss, new_loss): + est1 = GradientBoostingRegressor(loss=old_loss, random_state=0) with pytest.warns(FutureWarning, - match="The loss 'ls' was deprecated"): + match=f"The loss '{old_loss}' was deprecated"): est1.fit(X, y) - est2 = GradientBoostingRegressor(loss="squared_error", random_state=0) + est2 = GradientBoostingRegressor(loss=new_loss, random_state=0) est2.fit(X, y) assert_allclose(est1.predict(X), est2.predict(X)) diff --git a/sklearn/linear_model/_ransac.py b/sklearn/linear_model/_ransac.py index 2fc8143f432c8..3cde1f1235ec8 100644 --- a/sklearn/linear_model/_ransac.py +++ b/sklearn/linear_model/_ransac.py @@ -137,9 +137,9 @@ class RANSACRegressor(MetaEstimatorMixin, RegressorMixin, as 0.99 (the default) and e is the current fraction of inliers w.r.t. the total number of samples. - loss : string, callable, default='absolute_loss' - String inputs, 'absolute_loss' and 'squared_error' are supported which - find the absolute loss and squared error per sample respectively. + loss : string, callable, default='absolute_error' + String inputs, 'absolute_error' and 'squared_error' are supported which + find the absolute error and squared error per sample respectively. If ``loss`` is a callable, then it should be a function that takes two arrays as inputs, the true and predicted value and returns a 1-D @@ -155,6 +155,10 @@ class RANSACRegressor(MetaEstimatorMixin, RegressorMixin, The loss 'squared_loss' was deprecated in v1.0 and will be removed in version 1.2. Use `loss='squared_error'` which is equivalent. + .. deprecated:: 1.0 + The loss 'absolute_loss' was deprecated in v1.0 and will be removed + in version 1.2. Use `loss='absolute_error'` which is equivalent. + random_state : int, RandomState instance, default=None The generator used to initialize the centers. Pass an int for reproducible output across multiple function calls. @@ -212,7 +216,7 @@ def __init__(self, base_estimator=None, *, min_samples=None, residual_threshold=None, is_data_valid=None, is_model_valid=None, max_trials=100, max_skips=np.inf, stop_n_inliers=np.inf, stop_score=np.inf, - stop_probability=0.99, loss='absolute_loss', + stop_probability=0.99, loss='absolute_error', random_state=None): self.base_estimator = base_estimator @@ -293,7 +297,15 @@ def fit(self, X, y, sample_weight=None): else: residual_threshold = self.residual_threshold - if self.loss == "absolute_loss": + # TODO: Remove absolute_loss in v1.2. + if self.loss in ("absolute_error", "absolute_loss"): + if self.loss == "absolute_loss": + warnings.warn( + "The loss 'absolute_loss' was deprecated in v1.0 and will " + "be removed in version 1.2. Use `loss='absolute_error'` " + "which is equivalent.", + FutureWarning + ) if y.ndim == 1: loss_function = lambda y_true, y_pred: np.abs(y_true - y_pred) else: @@ -319,7 +331,7 @@ def fit(self, X, y, sample_weight=None): else: raise ValueError( - "loss should be 'absolute_loss', 'squared_error' or a " + "loss should be 'absolute_error', 'squared_error' or a " "callable. Got %s. " % self.loss) random_state = check_random_state(self.random_state) diff --git a/sklearn/linear_model/tests/test_ransac.py b/sklearn/linear_model/tests/test_ransac.py index 857696bf387d5..071a67efcf28f 100644 --- a/sklearn/linear_model/tests/test_ransac.py +++ b/sklearn/linear_model/tests/test_ransac.py @@ -539,13 +539,17 @@ def test_ransac_final_model_fit_sample_weight(): # TODO: Remove in v1.2 -def test_loss_squared_loss_deprecated(): - est1 = RANSACRegressor(loss="squared_loss", random_state=0) +@pytest.mark.parametrize("old_loss, new_loss", [ + ("absolute_loss", "squared_error"), + ("squared_loss", "absolute_error"), +]) +def test_loss_deprecated(old_loss, new_loss): + est1 = RANSACRegressor(loss=old_loss, random_state=0) with pytest.warns(FutureWarning, - match="The loss 'squared_loss' was deprecated"): + match=f"The loss '{old_loss}' was deprecated"): est1.fit(X, y) - est2 = RANSACRegressor(loss="squared_error", random_state=0) + est2 = RANSACRegressor(loss=new_loss, random_state=0) est2.fit(X, y) assert_allclose(est1.predict(X), est2.predict(X)) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 420292881f7db..de5aebfa8a6e3 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -62,10 +62,11 @@ CRITERIA_CLF = {"gini": _criterion.Gini, "entropy": _criterion.Entropy} -# TODO: Remove "mse" in version 1.2. +# TODO: Remove "mse" and "mae" in version 1.2. CRITERIA_REG = {"squared_error": _criterion.MSE, "mse": _criterion.MSE, "friedman_mse": _criterion.FriedmanMSE, + "absolute_error": _criterion.MAE, "mae": _criterion.MAE, "poisson": _criterion.Poisson} @@ -360,6 +361,13 @@ def fit(self, X, y, sample_weight=None, check_input=True, "which is equivalent.", FutureWarning ) + elif self.criterion == "mae": + warnings.warn( + "Criterion 'mae' was deprecated in v1.0 and will be " + "removed in version 1.2. Use `criterion='absolute_error'` " + "which is equivalent.", + FutureWarning + ) else: # Make a deepcopy in case the criterion has mutable attributes that # might be shared and modified concurrently during parallel fitting @@ -1001,16 +1009,16 @@ class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree): Parameters ---------- - criterion : {"squared_error", "mse", "friedman_mse", "mae", "poisson"}, \ - default="squared_error" + criterion : {"squared_error", "mse", "friedman_mse", "absolute_error", \ + "mae", "poisson"}, default="squared_error" The function to measure the quality of a split. Supported criteria are "squared_error" for the mean squared error, which is equal to variance reduction as feature selection criterion and minimizes the L2 loss using the mean of each terminal node, "friedman_mse", which uses mean squared error with Friedman's improvement score for potential - splits, "mae" for the mean absolute error, which minimizes the L1 loss - using the median of each terminal node, and "poisson" which uses - reduction in Poisson deviance to find splits. + splits, "absolute_error" for the mean absolute error, which minimizes + the L1 loss using the median of each terminal node, and "poisson" which + uses reduction in Poisson deviance to find splits. .. versionadded:: 0.18 Mean Absolute Error (MAE) criterion. @@ -1022,6 +1030,10 @@ class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree): Criterion "mse" was deprecated in v1.0 and will be removed in version 1.2. Use `criterion="squared_error"` which is equivalent. + .. deprecated:: 1.0 + Criterion "mae" was deprecated in v1.0 and will be removed in + version 1.2. Use `criterion="absolute_error"` which is equivalent. + splitter : {"best", "random"}, default="best" The strategy used to choose the split at each node. Supported strategies are "best" to choose the best split and "random" to choose @@ -1577,6 +1589,10 @@ class ExtraTreeRegressor(DecisionTreeRegressor): Criterion "mse" was deprecated in v1.0 and will be removed in version 1.2. Use `criterion="squared_error"` which is equivalent. + .. deprecated:: 1.0 + Criterion "mae" was deprecated in v1.0 and will be removed in + version 1.2. Use `criterion="absolute_error"` which is equivalent. + splitter : {"random", "best"}, default="random" The strategy used to choose the split at each node. Supported strategies are "best" to choose the best split and "random" to choose diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 2a1da1e2bfce0..a6e30a9941756 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -51,7 +51,7 @@ from sklearn.utils import compute_sample_weight CLF_CRITERIONS = ("gini", "entropy") -REG_CRITERIONS = ("squared_error", "mae", "friedman_mse", "poisson") +REG_CRITERIONS = ("squared_error", "absolute_error", "friedman_mse", "poisson") CLF_TREES = { "DecisionTreeClassifier": DecisionTreeClassifier, @@ -294,7 +294,7 @@ def test_diabetes_overfit(name, Tree, criterion): @pytest.mark.parametrize( "criterion, max_depth, metric, max_loss", [("squared_error", 15, mean_squared_error, 60), - ("mae", 20, mean_squared_error, 60), + ("absolute_error", 20, mean_squared_error, 60), ("friedman_mse", 15, mean_squared_error, 60), ("poisson", 15, mean_poisson_deviance, 30)] ) @@ -1772,7 +1772,7 @@ def test_mae(): = 0.75 ------ """ - dt_mae = DecisionTreeRegressor(random_state=0, criterion="mae", + dt_mae = DecisionTreeRegressor(random_state=0, criterion="absolute_error", max_leaf_nodes=2) # Test MAE where sample weights are non-uniform (as illustrated above): @@ -2121,12 +2121,16 @@ def test_X_idx_sorted_deprecated(TreeEstimator): # TODO: Remove in v1.2 @pytest.mark.parametrize("Tree", REG_TREES.values()) -def test_mse_deprecated(Tree): - tree = Tree(criterion="mse") +@pytest.mark.parametrize("old_criterion, new_criterion", [ + ("mse", "squared_error"), + ("mae", "absolute_error"), +]) +def test_criterion_deprecated(Tree, old_criterion, new_criterion): + tree = Tree(criterion=old_criterion) with pytest.warns(FutureWarning, - match="Criterion 'mse' was deprecated"): + match=f"Criterion '{old_criterion}' was deprecated"): tree.fit(X, y) - tree_sqer = Tree(criterion="squared_error").fit(X, y) - assert_allclose(tree.predict(X), tree_sqer.predict(X)) + tree_new = Tree(criterion=new_criterion).fit(X, y) + assert_allclose(tree.predict(X), tree_new.predict(X))