Skip to content

ENH add criterion log_loss as alternative to entropy in trees and forests #23047

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 14, 2022
Merged
16 changes: 8 additions & 8 deletions doc/modules/grid_search.rst
Original file line number Diff line number Diff line change
Expand Up @@ -507,27 +507,27 @@ additional information related to the successive halving process.

Here is an example with some of the columns of a (truncated) dataframe:

==== ====== =============== ================= =======================================================================================
==== ====== =============== ================= ========================================================================================
.. iter n_resources mean_test_score params
==== ====== =============== ================= =======================================================================================
0 0 125 0.983667 {'criterion': 'entropy', 'max_depth': None, 'max_features': 9, 'min_samples_split': 5}
==== ====== =============== ================= ========================================================================================
0 0 125 0.983667 {'criterion': 'log_loss', 'max_depth': None, 'max_features': 9, 'min_samples_split': 5}
1 0 125 0.983667 {'criterion': 'gini', 'max_depth': None, 'max_features': 8, 'min_samples_split': 7}
2 0 125 0.983667 {'criterion': 'gini', 'max_depth': None, 'max_features': 10, 'min_samples_split': 10}
3 0 125 0.983667 {'criterion': 'entropy', 'max_depth': None, 'max_features': 6, 'min_samples_split': 6}
3 0 125 0.983667 {'criterion': 'log_loss', 'max_depth': None, 'max_features': 6, 'min_samples_split': 6}
... ... ... ... ...
15 2 500 0.951958 {'criterion': 'entropy', 'max_depth': None, 'max_features': 9, 'min_samples_split': 10}
15 2 500 0.951958 {'criterion': 'log_loss', 'max_depth': None, 'max_features': 9, 'min_samples_split': 10}
16 2 500 0.947958 {'criterion': 'gini', 'max_depth': None, 'max_features': 10, 'min_samples_split': 10}
17 2 500 0.951958 {'criterion': 'gini', 'max_depth': None, 'max_features': 10, 'min_samples_split': 4}
18 3 1000 0.961009 {'criterion': 'entropy', 'max_depth': None, 'max_features': 9, 'min_samples_split': 10}
18 3 1000 0.961009 {'criterion': 'log_loss', 'max_depth': None, 'max_features': 9, 'min_samples_split': 10}
19 3 1000 0.955989 {'criterion': 'gini', 'max_depth': None, 'max_features': 10, 'min_samples_split': 4}
==== ====== =============== ================= =======================================================================================
==== ====== =============== ================= ========================================================================================

Each row corresponds to a given parameter combination (a candidate) and a given
iteration. The iteration is given by the ``iter`` column. The ``n_resources``
column tells you how many resources were used.

In the example above, the best parameter combination is ``{'criterion':
'entropy', 'max_depth': None, 'max_features': 9, 'min_samples_split': 10}``
'log_loss', 'max_depth': None, 'max_features': 9, 'min_samples_split': 10}``
since it has reached the last iteration (3) with the highest score:
0.96.

Expand Down
2 changes: 1 addition & 1 deletion doc/modules/permutation_importance.rst
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ Relation to impurity-based importance in trees
Tree-based models provide an alternative measure of :ref:`feature importances
based on the mean decrease in impurity <random_forest_feature_importance>`
(MDI). Impurity is quantified by the splitting criterion of the decision trees
(Gini, Entropy or Mean Squared Error). However, this method can give high
(Gini, Log Loss or Mean Squared Error). However, this method can give high
importance to features that may not be predictive on unseen data when the model
is overfitting. Permutation-based feature importance, on the other hand, avoids
this issue, since it can be computed on unseen data.
Expand Down
5 changes: 3 additions & 2 deletions doc/modules/tree.rst
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,8 @@ In general, the run time cost to construct a balanced binary tree is
to generate balanced trees, they will not always be balanced. Assuming that the
subtrees remain approximately balanced, the cost at each node consists of
searching through :math:`O(n_{features})` to find the feature that offers the
largest reduction in entropy. This has a cost of
largest reduction in the impurity criterion, e.g. log loss (which is equivalent to an
information gain). This has a cost of
:math:`O(n_{features}n_{samples}\log(n_{samples}))` at each node, leading to a
total cost over the entire trees (by summing the cost at each node) of
:math:`O(n_{features}n_{samples}^{2}\log(n_{samples}))`.
Expand Down Expand Up @@ -494,7 +495,7 @@ Gini:

H(Q_m) = \sum_k p_{mk} (1 - p_{mk})

Entropy:
Log Loss or Entropy:

.. math::

Expand Down
8 changes: 8 additions & 0 deletions doc/whats_new/v1.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,10 @@ Changelog
in version 1.3.
:pr:`23079` by :user:`Christian Lorentzen <lorentzenchr>`.

- |Enhancement| :class:`RandomForestClassifier` and :class:`ExtraTreesClassifier` have
the new `criterion="log_loss"`, which is equivalent to `criterion="entropy"`.
:pr:`23047` by :user:`Christian Lorentzen <lorentzenchr>`.

:mod:`sklearn.feature_extraction`
.................................

Expand Down Expand Up @@ -1007,6 +1011,10 @@ Changelog
for :class:`tree.DecisionTreeClassifier` and :class:`DecisionTreeRegressor`.
:pr:`22476` by :user: `Zhehao Liu <MaxwellLZH>`.

- |Enhancement| :class:`DecisionTreeClassifier` and :class:`ExtraTreeClassifier` have
the new `criterion="log_loss"`, which is equivalent to `criterion="entropy"`.
:pr:`23047` by :user:`Christian Lorentzen <lorentzenchr>`.

:mod:`sklearn.utils`
....................

Expand Down
18 changes: 11 additions & 7 deletions sklearn/ensemble/_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def fit(self, X, y, sample_weight=None):

# Check parameters
self._validate_estimator()
# TODO: Remove in v1.2
# TODO(1.2): Remove "mse" and "mae"
if isinstance(self, (RandomForestRegressor, ExtraTreesRegressor)):
if self.criterion == "mse":
warn(
Expand All @@ -411,6 +411,7 @@ def fit(self, X, y, sample_weight=None):
FutureWarning,
)

# TODO(1.3): Remove "auto"
if self.max_features == "auto":
warn(
"`max_features='auto'` has been deprecated in 1.1 "
Expand All @@ -420,8 +421,8 @@ def fit(self, X, y, sample_weight=None):
"RandomForestRegressors and ExtraTreesRegressors.",
FutureWarning,
)

elif isinstance(self, (RandomForestClassifier, ExtraTreesClassifier)):
# TODO(1.3): Remove "auto"
if self.max_features == "auto":
warn(
"`max_features='auto'` has been deprecated in 1.1 "
Expand Down Expand Up @@ -1109,10 +1110,11 @@ class RandomForestClassifier(ForestClassifier):
The default value of ``n_estimators`` changed from 10 to 100
in 0.22.

criterion : {"gini", "entropy"}, default="gini"
criterion : {"gini", "entropy", "log_loss"}, default="gini"
The function to measure the quality of a split. Supported criteria are
"gini" for the Gini impurity and "entropy" for the information gain.
Note: this parameter is tree-specific.
"gini" for the Gini impurity and "log_loss" and "entropy" both for the
Shannon information gain, see :ref:`tree_mathematical_formulation`.
Note: This parameter is tree-specific.

max_depth : int, default=None
The maximum depth of the tree. If None, then nodes are expanded until
Expand Down Expand Up @@ -1780,9 +1782,11 @@ class ExtraTreesClassifier(ForestClassifier):
The default value of ``n_estimators`` changed from 10 to 100
in 0.22.

criterion : {"gini", "entropy"}, default="gini"
criterion : {"gini", "entropy", "log_loss"}, default="gini"
The function to measure the quality of a split. Supported criteria are
"gini" for the Gini impurity and "entropy" for the information gain.
"gini" for the Gini impurity and "log_loss" and "entropy" both for the
Shannon information gain, see :ref:`tree_mathematical_formulation`.
Note: This parameter is tree-specific.

max_depth : int, default=None
The maximum depth of the tree. If None, then nodes are expanded until
Expand Down
20 changes: 10 additions & 10 deletions sklearn/ensemble/tests/test_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def check_iris_criterion(name, criterion):


@pytest.mark.parametrize("name", FOREST_CLASSIFIERS)
@pytest.mark.parametrize("criterion", ("gini", "entropy"))
@pytest.mark.parametrize("criterion", ("gini", "log_loss"))
def test_iris(name, criterion):
check_iris_criterion(name, criterion)

Expand Down Expand Up @@ -347,7 +347,7 @@ def check_importances(name, criterion, dtype, tolerance):
@pytest.mark.parametrize(
"name, criterion",
itertools.chain(
product(FOREST_CLASSIFIERS, ["gini", "entropy"]),
product(FOREST_CLASSIFIERS, ["gini", "log_loss"]),
product(FOREST_REGRESSORS, ["squared_error", "friedman_mse", "absolute_error"]),
),
)
Expand Down Expand Up @@ -451,7 +451,7 @@ def mdi_importance(X_m, X, y):

# Estimate importances with totally randomized trees
clf = ExtraTreesClassifier(
n_estimators=500, max_features=1, criterion="entropy", random_state=0
n_estimators=500, max_features=1, criterion="log_loss", random_state=0
).fit(X, y)

importances = (
Expand Down Expand Up @@ -1807,23 +1807,23 @@ def test_max_features_deprecation(Estimator):
est.fit(X, y)


# TODO: Remove in v1.2
@pytest.mark.parametrize(
"old_criterion, new_criterion",
"old_criterion, new_criterion, Estimator",
[
("mse", "squared_error"),
("mae", "absolute_error"),
# TODO(1.2): Remove "mse" and "mae"
("mse", "squared_error", RandomForestRegressor),
("mae", "absolute_error", RandomForestRegressor),
],
)
def test_criterion_deprecated(old_criterion, new_criterion):
est1 = RandomForestRegressor(criterion=old_criterion, random_state=0)
def test_criterion_deprecated(old_criterion, new_criterion, Estimator):
est1 = Estimator(criterion=old_criterion, random_state=0)

with pytest.warns(
FutureWarning, match=f"Criterion '{old_criterion}' was deprecated"
):
est1.fit(X, y)

est2 = RandomForestRegressor(criterion=new_criterion, random_state=0)
est2 = Estimator(criterion=new_criterion, random_state=0)
est2.fit(X, y)
assert_allclose(est1.predict(X), est2.predict(X))

Expand Down
20 changes: 13 additions & 7 deletions sklearn/tree/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,12 @@
DTYPE = _tree.DTYPE
DOUBLE = _tree.DOUBLE

CRITERIA_CLF = {"gini": _criterion.Gini, "entropy": _criterion.Entropy}
# TODO: Remove "mse" and "mae" in version 1.2.
CRITERIA_CLF = {
"gini": _criterion.Gini,
"log_loss": _criterion.Entropy,
"entropy": _criterion.Entropy,
}
# TODO(1.2): Remove "mse" and "mae".
CRITERIA_REG = {
"squared_error": _criterion.MSE,
"mse": _criterion.MSE,
Expand Down Expand Up @@ -388,7 +392,7 @@ def fit(self, X, y, sample_weight=None, check_input=True):
)
else:
criterion = CRITERIA_REG[self.criterion](self.n_outputs_, n_samples)
# TODO: Remove in v1.2
# TODO(1.2): Remove "mse" and "mae"
if self.criterion == "mse":
warnings.warn(
"Criterion 'mse' was deprecated in v1.0 and will be "
Expand Down Expand Up @@ -674,9 +678,10 @@ class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree):

Parameters
----------
criterion : {"gini", "entropy"}, default="gini"
criterion : {"gini", "entropy", "log_loss"}, default="gini"
The function to measure the quality of a split. Supported criteria are
"gini" for the Gini impurity and "entropy" for the information gain.
"gini" for the Gini impurity and "log_loss" and "entropy" both for the
Shannon information gain, see :ref:`tree_mathematical_formulation`.

splitter : {"best", "random"}, default="best"
The strategy used to choose the split at each node. Supported
Expand Down Expand Up @@ -1394,9 +1399,10 @@ class ExtraTreeClassifier(DecisionTreeClassifier):

Parameters
----------
criterion : {"gini", "entropy"}, default="gini"
criterion : {"gini", "entropy", "log_loss"}, default="gini"
The function to measure the quality of a split. Supported criteria are
"gini" for the Gini impurity and "entropy" for the information gain.
"gini" for the Gini impurity and "log_loss" and "entropy" both for the
Shannon information gain, see :ref:`tree_mathematical_formulation`.

splitter : {"random", "best"}, default="random"
The strategy used to choose the split at each node. Supported
Expand Down
39 changes: 32 additions & 7 deletions sklearn/tree/tests/test_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
from sklearn.utils import compute_sample_weight


CLF_CRITERIONS = ("gini", "entropy")
CLF_CRITERIONS = ("gini", "log_loss")
REG_CRITERIONS = ("squared_error", "absolute_error", "friedman_mse", "poisson")

CLF_TREES = {
Expand Down Expand Up @@ -2181,16 +2181,41 @@ def test_decision_tree_regressor_sample_weight_consistency(criterion):
assert_allclose(tree1.predict(X), tree2.predict(X))


# TODO: Remove in v1.2
@pytest.mark.parametrize("Tree", REG_TREES.values())
@pytest.mark.parametrize("Tree", [DecisionTreeClassifier, ExtraTreeClassifier])
@pytest.mark.parametrize("n_classes", [2, 4])
def test_criterion_entropy_same_as_log_loss(Tree, n_classes):
"""Test that criterion=entropy gives same as log_loss."""
n_samples, n_features = 50, 5
X, y = datasets.make_classification(
n_classes=n_classes,
n_samples=n_samples,
n_features=n_features,
n_informative=n_features,
n_redundant=0,
random_state=42,
)
tree_log_loss = Tree(criterion="log_loss", random_state=43).fit(X, y)
tree_entropy = Tree(criterion="entropy", random_state=43).fit(X, y)

assert_tree_equal(
tree_log_loss.tree_,
tree_entropy.tree_,
f"{Tree!r} with criterion 'entropy' and 'log_loss' gave different trees.",
)
assert_allclose(tree_log_loss.predict(X), tree_entropy.predict(X))


@pytest.mark.parametrize(
"old_criterion, new_criterion",
"old_criterion, new_criterion, Tree",
[
("mse", "squared_error"),
("mae", "absolute_error"),
# TODO(1.2): Remove "mse" and "mae"
("mse", "squared_error", DecisionTreeRegressor),
("mse", "squared_error", ExtraTreeRegressor),
("mae", "absolute_error", DecisionTreeRegressor),
("mae", "absolute_error", ExtraTreeRegressor),
],
)
def test_criterion_deprecated(Tree, old_criterion, new_criterion):
def test_criterion_deprecated(old_criterion, new_criterion, Tree):
tree = Tree(criterion=old_criterion)

with pytest.warns(
Expand Down