Skip to content

Commit e3acfa4

Browse files
committed
ENH: multi-output support for BaggingRegressor
1 parent b185c4e commit e3acfa4

File tree

4 files changed

+32
-6
lines changed

4 files changed

+32
-6
lines changed

doc/whats_new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,9 @@ Trees and ensembles
823823
:class:`ensemble.VotingClassifier` to fit underlying estimators in parallel.
824824
:issue:`5805` by :user:`Ibraim Ganiev <olologin>`.
825825

826+
- :class:`ensemble.BaggingRegressor` now supports multi-output targets.
827+
By :user:`Elvis Dohmatob <dohmatob>`.
828+
826829
Linear, kernelized and related models
827830

828831
- In :class:`linear_model.LogisticRegression`, the SAG solver is now

sklearn/ensemble/bagging.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def _fit(self, X, y, max_samples=None, max_depth=None, sample_weight=None):
280280
random_state = check_random_state(self.random_state)
281281

282282
# Convert data
283-
X, y = check_X_y(X, y, ['csr', 'csc'])
283+
X, y = check_X_y(X, y, ['csr', 'csc'], multi_output=True)
284284
if sample_weight is not None:
285285
sample_weight = check_array(sample_weight, ensure_2d=False)
286286
check_consistent_length(y, sample_weight)
@@ -390,8 +390,9 @@ def _set_oob_score(self, X, y):
390390
"""Calculate out of bag predictions and score."""
391391

392392
def _validate_y(self, y):
393-
# Default implementation
394-
return column_or_1d(y, warn=True)
393+
# Default implementation. We skip column_or_1d and similar checks
394+
# in order to make the code support multi-output targets.
395+
return y
395396

396397
def _get_estimators_indices(self):
397398
# Get drawn indices along both sample and feature axes

sklearn/ensemble/tests/test_bagging.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import numpy as np
99

10-
from sklearn.base import BaseEstimator
10+
from sklearn.base import BaseEstimator, clone
1111

1212
from sklearn.utils.testing import assert_array_equal
1313
from sklearn.utils.testing import assert_array_almost_equal
@@ -23,7 +23,7 @@
2323
from sklearn.dummy import DummyClassifier, DummyRegressor
2424
from sklearn.model_selection import GridSearchCV, ParameterGrid
2525
from sklearn.ensemble import BaggingClassifier, BaggingRegressor
26-
from sklearn.linear_model import Perceptron, LogisticRegression
26+
from sklearn.linear_model import Perceptron, LogisticRegression, Ridge
2727
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
2828
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
2929
from sklearn.svm import SVC, SVR
@@ -740,3 +740,24 @@ def test_set_oob_score_label_encoding():
740740
x3 = BaggingClassifier(oob_score=True,
741741
random_state=random_state).fit(X, Y3).oob_score_
742742
assert_equal([x1, x2], [x3, x3])
743+
744+
745+
def test_multi_output_regressor():
746+
# Check singleton ensembles.
747+
rng = check_random_state(0)
748+
X_train, X_test, y_train, y_test = train_test_split(boston.data,
749+
boston.target,
750+
random_state=rng)
751+
752+
reg1 = BaggingRegressor(base_estimator=Ridge(), n_estimators=10,
753+
bootstrap=False, bootstrap_features=False,
754+
random_state=rng)
755+
assert_false(reg1.multi_output)
756+
reg2 = clone(reg1)
757+
reg1.fit(X_train, y_train)
758+
for n_targets in [1, 2]:
759+
y_train_ = np.ndarray((len(y_train), n_targets))
760+
y_train_.T[:] = y_train.copy()
761+
reg2.fit(X_train, y_train_)
762+
assert_array_almost_equal(reg1.predict(X_test),
763+
reg2.predict(X_test)[:, 0], decimal=10)

sklearn/utils/estimator_checks.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@
6464
'MultiTaskElasticNetCV', 'MultiTaskLasso', 'MultiTaskLassoCV',
6565
'OrthogonalMatchingPursuit', 'PLSCanonical', 'PLSRegression',
6666
'RANSACRegressor', 'RadiusNeighborsRegressor',
67-
'RandomForestRegressor', 'Ridge', 'RidgeCV']
67+
'RandomForestRegressor', 'Ridge', 'RidgeCV',
68+
"BaggingRegressor"]
6869

6970

7071
def _yield_non_meta_checks(name, Estimator):

0 commit comments

Comments
 (0)