Skip to content
Merged
5 changes: 5 additions & 0 deletions doc/whats_new/v0.20.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ Classifiers and regressors
- Added :class:`multioutput.RegressorChain` for multi-target
regression. :issue:`9257` by :user:`Kumar Ashutosh <thechargedneutron>`.

- :class:`ensemble.BaggingRegressor` and :class:`ensemble.BaggingClassifier` can now
be fit with missing/non-finite values in X and/or multi-output Y to support
wrapping pipelines that perform their own imputation.
:issue:`9707` by :user:`Jimmy Wan <jimmywan>`.

Preprocessing

- Added :class:`preprocessing.CategoricalEncoder`, which allows to encode
Expand Down
42 changes: 29 additions & 13 deletions sklearn/ensemble/bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,22 @@
import itertools
import numbers
import numpy as np
from warnings import warn
from abc import ABCMeta, abstractmethod
from warnings import warn

from .base import BaseEnsemble, _partition_estimators
from ..base import ClassifierMixin, RegressorMixin
from ..externals.joblib import Parallel, delayed
from ..externals.six import with_metaclass
from ..externals.six.moves import zip
from ..metrics import r2_score, accuracy_score
from ..tree import DecisionTreeClassifier, DecisionTreeRegressor
from ..utils import check_random_state, check_X_y, check_array, column_or_1d
from ..utils.random import sample_without_replacement
from ..utils.validation import has_fit_parameter, check_is_fitted
from ..utils import indices_to_mask, check_consistent_length
from ..utils.metaestimators import if_delegate_has_method
from ..utils.multiclass import check_classification_targets

from .base import BaseEnsemble, _partition_estimators
from ..utils.random import sample_without_replacement
from ..utils.validation import has_fit_parameter, check_is_fitted


__all__ = ["BaggingClassifier",
Expand Down Expand Up @@ -277,8 +276,11 @@ def _fit(self, X, y, max_samples=None, max_depth=None, sample_weight=None):
"""
random_state = check_random_state(self.random_state)

# Convert data
X, y = check_X_y(X, y, ['csr', 'csc'])
# Convert data (X is required to be 2d and indexable)
X, y = check_X_y(
X, y, ['csr', 'csc'], dtype=None, force_all_finite=False,
multi_output=True
)
if sample_weight is not None:
sample_weight = check_array(sample_weight, ensure_2d=False)
check_consistent_length(y, sample_weight)
Expand Down Expand Up @@ -388,8 +390,10 @@ def _set_oob_score(self, X, y):
"""Calculate out of bag predictions and score."""

def _validate_y(self, y):
# Default implementation
return column_or_1d(y, warn=True)
if len(y.shape) == 1 or y.shape[1] == 1:
return column_or_1d(y, warn=True)
else:
return y

def _get_estimators_indices(self):
# Get drawn indices along both sample and feature axes
Expand Down Expand Up @@ -667,7 +671,10 @@ def predict_proba(self, X):
"""
check_is_fitted(self, "classes_")
# Check data
X = check_array(X, accept_sparse=['csr', 'csc'])
X = check_array(
X, accept_sparse=['csr', 'csc'], dtype=None,
force_all_finite=False
)

if self.n_features_ != X.shape[1]:
raise ValueError("Number of features of the model must "
Expand Down Expand Up @@ -714,7 +721,10 @@ def predict_log_proba(self, X):
check_is_fitted(self, "classes_")
if hasattr(self.base_estimator_, "predict_log_proba"):
# Check data
X = check_array(X, accept_sparse=['csr', 'csc'])
X = check_array(
X, accept_sparse=['csr', 'csc'], dtype=None,
force_all_finite=False
)

if self.n_features_ != X.shape[1]:
raise ValueError("Number of features of the model must "
Expand Down Expand Up @@ -769,7 +779,10 @@ def decision_function(self, X):
check_is_fitted(self, "classes_")

# Check data
X = check_array(X, accept_sparse=['csr', 'csc'])
X = check_array(
X, accept_sparse=['csr', 'csc'], dtype=None,
force_all_finite=False
)

if self.n_features_ != X.shape[1]:
raise ValueError("Number of features of the model must "
Expand Down Expand Up @@ -947,7 +960,10 @@ def predict(self, X):
"""
check_is_fitted(self, "estimators_features_")
# Check data
X = check_array(X, accept_sparse=['csr', 'csc'])
X = check_array(
X, accept_sparse=['csr', 'csc'], dtype=None,
force_all_finite=False
)

# Parallel loop
n_jobs, n_estimators, starts = _partition_estimators(self.n_estimators,
Expand Down
74 changes: 74 additions & 0 deletions sklearn/ensemble/tests/test_bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_boston, load_iris, make_hastie_10_2
from sklearn.utils import check_random_state
from sklearn.preprocessing import Imputer

from scipy.sparse import csc_matrix, csr_matrix

Expand Down Expand Up @@ -752,3 +753,76 @@ def test_set_oob_score_label_encoding():
x3 = BaggingClassifier(oob_score=True,
random_state=random_state).fit(X, Y3).oob_score_
assert_equal([x1, x2], [x3, x3])


def test_bagging_regressor_with_missing_inputs():
# Check that BaggingRegressor can accept X with missing/infinite data
X = np.array([
[1, 3, 5],
[2, None, 6],
[2, np.nan, 6],
[2, np.inf, 6],
[2, np.NINF, 6],
])
y_values = [
np.array([2, 3, 3, 3, 3]),
np.array([
[2, 1, 9],
[3, 6, 8],
[3, 6, 8],
[3, 6, 8],
[3, 6, 8],
])
]
for y in y_values:
regressor = DecisionTreeRegressor()
pipeline = make_pipeline(
Imputer(),
Imputer(missing_values=np.inf),
Imputer(missing_values=np.NINF),
regressor
)
pipeline.fit(X, y).predict(X)
bagging_regressor = BaggingRegressor(pipeline)
y_hat = bagging_regressor.fit(X, y).predict(X)
assert_equal(y.shape, y_hat.shape)

# Verify that exceptions can be raised by wrapper regressor
regressor = DecisionTreeRegressor()
pipeline = make_pipeline(regressor)
assert_raises(ValueError, pipeline.fit, X, y)
bagging_regressor = BaggingRegressor(pipeline)
assert_raises(ValueError, bagging_regressor.fit, X, y)


def test_bagging_classifier_with_missing_inputs():
# Check that BaggingClassifier can accept X with missing/infinite data
X = np.array([
[1, 3, 5],
[2, None, 6],
[2, np.nan, 6],
[2, np.inf, 6],
[2, np.NINF, 6],
])
y = np.array([3, 6, 6, 6, 6])
classifier = DecisionTreeClassifier()
pipeline = make_pipeline(
Imputer(),
Imputer(missing_values=np.inf),
Imputer(missing_values=np.NINF),
classifier
)
pipeline.fit(X, y).predict(X)
bagging_classifier = BaggingClassifier(pipeline)
bagging_classifier.fit(X, y)
y_hat = bagging_classifier.predict(X)
assert_equal(y.shape, y_hat.shape)
bagging_classifier.predict_log_proba(X)
bagging_classifier.predict_proba(X)

# Verify that exceptions can be raised by wrapper classifier
classifier = DecisionTreeClassifier()
pipeline = make_pipeline(classifier)
assert_raises(ValueError, pipeline.fit, X, y)
bagging_classifier = BaggingClassifier(pipeline)
assert_raises(ValueError, bagging_classifier.fit, X, y)