Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 4 additions & 10 deletions sklearn/ensemble/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,16 +161,16 @@ def _validate_estimator(self, default=None):
)

if self.estimator is not None:
self._estimator = self.estimator
self.estimator_ = self.estimator
elif self.base_estimator not in [None, "deprecated"]:
warnings.warn(
"`base_estimator` was renamed to `estimator` in version 1.2 and "
"will be removed in 1.4.",
FutureWarning,
)
self._estimator = self.base_estimator
self.estimator_ = self.base_estimator
else:
self._estimator = default
self.estimator_ = default

# TODO(1.4): remove
# mypy error: Decorated property not supported
Expand All @@ -181,13 +181,7 @@ def _validate_estimator(self, default=None):
@property
def base_estimator_(self):
"""Estimator used to grow the ensemble."""
return self._estimator

# TODO(1.4): remove
@property
def estimator_(self):
"""Estimator used to grow the ensemble."""
return self._estimator
return self.estimator_

def _make_estimator(self, append=True, random_state=None):
"""Make and configure a copy of the `estimator_` attribute.
Expand Down
24 changes: 24 additions & 0 deletions sklearn/ensemble/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
from sklearn.ensemble import BaggingClassifier
from sklearn.ensemble._base import _set_random_states
from sklearn.linear_model import Perceptron
from sklearn.linear_model import Ridge, LogisticRegression
from collections import OrderedDict
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.pipeline import Pipeline
from sklearn.feature_selection import SelectFromModel
from sklearn import ensemble


def test_base():
Expand Down Expand Up @@ -117,3 +119,25 @@ def test_validate_estimator_value_error():
err_msg = "Both `estimator` and `base_estimator` were set. Only set `estimator`."
with pytest.raises(ValueError, match=err_msg):
model.fit(X, y)


# TODO(1.4): remove
@pytest.mark.parametrize(
"model",
[
ensemble.GradientBoostingClassifier(),
ensemble.GradientBoostingRegressor(),
ensemble.HistGradientBoostingClassifier(),
ensemble.HistGradientBoostingRegressor(),
ensemble.VotingClassifier(
[("a", LogisticRegression()), ("b", LogisticRegression())]
),
ensemble.VotingRegressor([("a", Ridge()), ("b", Ridge())]),
],
)
def test_estimator_attribute_error(model):
X = [[1], [2]]
y = [0, 1]
model.fit(X, y)

assert not hasattr(model, "estimator_")