-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG] ENH: multi-output support for BaggingRegressor #8547
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -280,7 +280,7 @@ def _fit(self, X, y, max_samples=None, max_depth=None, sample_weight=None): | |
random_state = check_random_state(self.random_state) | ||
|
||
# Convert data | ||
X, y = check_X_y(X, y, ['csr', 'csc']) | ||
X, y = check_X_y(X, y, ['csr', 'csc'], multi_output=True) | ||
if sample_weight is not None: | ||
sample_weight = check_array(sample_weight, ensure_2d=False) | ||
check_consistent_length(y, sample_weight) | ||
|
@@ -390,8 +390,9 @@ def _set_oob_score(self, X, y): | |
"""Calculate out of bag predictions and score.""" | ||
|
||
def _validate_y(self, y): | ||
# Default implementation | ||
return column_or_1d(y, warn=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So there should still be a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed, this might be right, since the validation is done by the downstream estimator. |
||
# Default implementation. We skip column_or_1d and similar checks | ||
# in order to make the code support multi-output targets. | ||
return y | ||
|
||
def _get_estimators_indices(self): | ||
# Get drawn indices along both sample and feature axes | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,7 @@ | |
|
||
import numpy as np | ||
|
||
from sklearn.base import BaseEstimator | ||
from sklearn.base import BaseEstimator, clone | ||
|
||
from sklearn.utils.testing import assert_array_equal | ||
from sklearn.utils.testing import assert_array_almost_equal | ||
|
@@ -23,7 +23,7 @@ | |
from sklearn.dummy import DummyClassifier, DummyRegressor | ||
from sklearn.model_selection import GridSearchCV, ParameterGrid | ||
from sklearn.ensemble import BaggingClassifier, BaggingRegressor | ||
from sklearn.linear_model import Perceptron, LogisticRegression | ||
from sklearn.linear_model import Perceptron, LogisticRegression, Ridge | ||
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor | ||
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor | ||
from sklearn.svm import SVC, SVR | ||
|
@@ -740,3 +740,23 @@ def test_set_oob_score_label_encoding(): | |
x3 = BaggingClassifier(oob_score=True, | ||
random_state=random_state).fit(X, Y3).oob_score_ | ||
assert_equal([x1, x2], [x3, x3]) | ||
|
||
|
||
def test_multi_output_regressor(): | ||
# Check singleton ensembles. | ||
rng = check_random_state(0) | ||
X_train, X_test, y_train, y_test = train_test_split(boston.data, | ||
boston.target, | ||
random_state=rng) | ||
|
||
reg1 = BaggingRegressor(base_estimator=Ridge(), n_estimators=10, | ||
bootstrap=False, bootstrap_features=False, | ||
random_state=rng) | ||
reg2 = clone(reg1) | ||
reg1.fit(X_train, y_train) | ||
for n_targets in [1, 2]: | ||
y_train_ = np.ndarray((len(y_train), n_targets)) | ||
y_train_.T[:] = y_train.copy() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm finding this hard to understand. |
||
reg2.fit(X_train, y_train_) | ||
assert_array_almost_equal(reg1.predict(X_test), | ||
reg2.predict(X_test)[:, 0], decimal=10) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs to move to the v0.19 section.