Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ New features
Enhancements
............

- :class:`multioutput.MultiOutputRegressor` and :class:`multioutput.MultiOutputClassifier`
now support online learning using `partial_fit`.
issue: `8053` by :user:`Peng Yu <yupbank>`.

- :class:`decomposition.PCA`, :class:`decomposition.IncrementalPCA` and
:class:`decomposition.TruncatedSVD` now expose the singular values
from the underlying SVD. They are stored in the attribute
Expand Down
116 changes: 110 additions & 6 deletions sklearn/multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .utils import check_array, check_X_y
from .utils.fixes import parallel_helper
from .utils.validation import check_is_fitted, has_fit_parameter
from .utils.metaestimators import if_delegate_has_method
from .externals.joblib import Parallel, delayed
from .externals import six

Expand All @@ -37,12 +38,86 @@ def _fit_estimator(estimator, X, y, sample_weight=None):
return estimator


def _partial_fit_estimator(estimator, X, y, classes=None, sample_weight=None,
first_time=True):
if first_time:
estimator = clone(estimator)

if sample_weight is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering if

estimator.partial_fit(X, y, classes=classes, sample_weight=sample_weight)

would suffice since partial_fit function of the the base estimator would handle it appropriately. Let me know if it sounds okay !

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the regressor and classifier is sharing same interface, so choice of argument for classes is non-neglect-able, but for sample_weight i might need to check, is it universal for sgd estimator

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fit in BaseSGD only take X, y, so i think i will keep this to make sure all the SGD estimator use this code..

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, thanks for looking into it. I understand that classes argument is not applicable for regressors. But in case someone inadvertently passes classes argument with a Regressor, the estimator would throw an error right ? Perhaps we can have separate partial_fit function in the sub classes MultiOutputRegressor and MultiOutputClassifier ? Just my 2c.

Also BaseSGD has an abstract definition for fit, which is implemented in the subclasses BaseSGDClassifier and BaseSGDRegressor right ?

Please do let me know what you think ? Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, sorry I missed the partial_fit function you have added to MultiOutputRegressor. And also since it is a helper function, I get the reason for separate code paths based on classes. Sorry for the noise..

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

classes is required for classifiers' partial_fit, but I think we have no requirement that sample_weight be supported.

if classes is not None:
estimator.partial_fit(X, y, classes=classes,
sample_weight=sample_weight)
else:
estimator.partial_fit(X, y, sample_weight=sample_weight)
else:
if classes is not None:
estimator.partial_fit(X, y, classes=classes)
else:
estimator.partial_fit(X, y)
return estimator


class MultiOutputEstimator(six.with_metaclass(ABCMeta, BaseEstimator)):

def __init__(self, estimator, n_jobs=1):
self.estimator = estimator
self.n_jobs = n_jobs

@if_delegate_has_method('estimator')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, while you are at it, can you please add the @if_delegate_has_method for the fit and predict functions too ? Thanks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure thing

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we do not need if_delegate_has_method for fit. It is required for every estimator. I think in this context we are dealing with predictors and we do not need if_delegate_has_method for predict.

We also do not need the scope of the PR to be expanded unnecessarily. Throwing in a decorator means that decoration needs testing for instance.

Sorry for the unnecessary work, @yupbank

def partial_fit(self, X, y, classes=None, sample_weight=None):
"""Incrementally fit the model to data.
Fit a separate model for each output variable.

Parameters
----------
X : (sparse) array-like, shape (n_samples, n_features)
Data.

y : (sparse) array-like, shape (n_samples, n_outputs)
Multi-output targets.

classes : list of numpy arrays, shape (n_outputs)
Each array is unique classes for one output in str/int
Can be obtained by via
``[np.unique(y[:, i]) for i in range(y.shape[1])]``, where y is the
target matrix of the entire dataset.
This argument is required for the first call to partial_fit
and can be omitted in the subsequent calls.
Note that y doesn't need to contain all labels in `classes`.

sample_weight : array-like, shape = (n_samples) or None
Sample weights. If None, then samples are equally weighted.
Only supported if the underlying regressor supports sample
weights.

Returns
-------
self : object
Returns self.
"""
X, y = check_X_y(X, y,
multi_output=True,
accept_sparse=True)

if y.ndim == 1:
raise ValueError("y must have at least two dimensions for "
"multi-output regression but has only one.")

if (sample_weight is not None and
not has_fit_parameter(self.estimator, 'sample_weight')):
raise ValueError("Underlying estimator does not support"
" sample weights.")

first_time = not hasattr(self, 'estimators_')

self.estimators_ = Parallel(n_jobs=self.n_jobs)(
delayed(_partial_fit_estimator)(
self.estimators_[i] if not first_time else self.estimator,
X, y[:, i],
classes[i] if classes is not None else None,
sample_weight, first_time) for i in range(y.shape[1]))
return self

def fit(self, X, y, sample_weight=None):
""" Fit the model to data.
Fit a separate model for each output variable.
Expand Down Expand Up @@ -76,15 +151,17 @@ def fit(self, X, y, sample_weight=None):

if y.ndim == 1:
raise ValueError("y must have at least two dimensions for "
"multi target regression but has only one.")
"multi-output regression but has only one.")

if (sample_weight is not None and
not has_fit_parameter(self.estimator, 'sample_weight')):
raise ValueError("Underlying regressor does not support"
raise ValueError("Underlying estimator does not support"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catch!

" sample weights.")

self.estimators_ = Parallel(n_jobs=self.n_jobs)(delayed(_fit_estimator)(
self.estimator, X, y[:, i], sample_weight) for i in range(y.shape[1]))
self.estimators_ = Parallel(n_jobs=self.n_jobs)(
delayed(_fit_estimator)(
self.estimator, X, y[:, i], sample_weight)
for i in range(y.shape[1]))
return self

def predict(self, X):
Expand All @@ -108,8 +185,9 @@ def predict(self, X):

X = check_array(X, accept_sparse=True)

y = Parallel(n_jobs=self.n_jobs)(delayed(parallel_helper)(e, 'predict', X)
for e in self.estimators_)
y = Parallel(n_jobs=self.n_jobs)(
delayed(parallel_helper)(e, 'predict', X)
for e in self.estimators_)

return np.asarray(y).T

Expand All @@ -133,9 +211,35 @@ class MultiOutputRegressor(MultiOutputEstimator, RegressorMixin):
using `n_jobs>1` can result in slower performance due
to the overhead of spawning processes.
"""

def __init__(self, estimator, n_jobs=1):
super(MultiOutputRegressor, self).__init__(estimator, n_jobs)

def partial_fit(self, X, y, sample_weight=None):
"""Incrementally fit the model to data.
Fit a separate model for each output variable.

Parameters
----------
X : (sparse) array-like, shape (n_samples, n_features)
Data.

y : (sparse) array-like, shape (n_samples, n_outputs)
Multi-output targets.

sample_weight : array-like, shape = (n_samples) or None
Sample weights. If None, then samples are equally weighted.
Only supported if the underlying regressor supports sample
weights.

Returns
-------
self : object
Returns self.
"""
super(MultiOutputRegressor, self).partial_fit(
X, y, sample_weight=sample_weight)

def score(self, X, y, sample_weight=None):
"""Returns the coefficient of determination R^2 of the prediction.

Expand Down
137 changes: 126 additions & 11 deletions sklearn/tests/test_multioutput.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
from __future__ import division
import numpy as np
import scipy.sparse as sp
from sklearn.utils import shuffle
from sklearn.utils.testing import assert_almost_equal
from sklearn.utils.testing import assert_raises
from sklearn.utils.testing import assert_false
from sklearn.utils.testing import assert_raises_regex
from sklearn.utils.testing import assert_array_equal
from sklearn.utils.testing import assert_equal
from sklearn.utils.testing import assert_not_equal
from sklearn.utils.testing import assert_array_almost_equal
from sklearn.exceptions import NotFittedError
from sklearn import datasets
from sklearn.base import clone
from sklearn.ensemble import GradientBoostingRegressor, RandomForestClassifier
from sklearn.linear_model import Lasso, LogisticRegression
from sklearn.linear_model import Lasso
from sklearn.linear_model import SGDClassifier
from sklearn.linear_model import SGDRegressor
from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC
from sklearn.multiclass import OneVsRestClassifier
from sklearn.multioutput import MultiOutputRegressor, MultiOutputClassifier
Expand All @@ -25,7 +32,7 @@ def test_multi_target_regression():
for n in range(3):
rgr = GradientBoostingRegressor(random_state=0)
rgr.fit(X_train, y_train[:, n])
references[:,n] = rgr.predict(X_test)
references[:, n] = rgr.predict(X_test)

rgr = MultiOutputRegressor(GradientBoostingRegressor(random_state=0))
rgr.fit(X_train, y_train)
Expand All @@ -34,20 +41,40 @@ def test_multi_target_regression():
assert_almost_equal(references, y_pred)


def test_multi_target_regression_partial_fit():
X, y = datasets.make_regression(n_targets=3)
X_train, y_train = X[:50], y[:50]
X_test, y_test = X[50:], y[50:]

references = np.zeros_like(y_test)
half_index = 25
for n in range(3):
sgr = SGDRegressor(random_state=0)
sgr.partial_fit(X_train[:half_index], y_train[:half_index, n])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You also need to test sample_weight

sgr.partial_fit(X_train[half_index:], y_train[half_index:, n])
references[:, n] = sgr.predict(X_test)

sgr = MultiOutputRegressor(SGDRegressor(random_state=0))

sgr.partial_fit(X_train[:half_index], y_train[:half_index])
sgr.partial_fit(X_train[half_index:], y_train[half_index:])

y_pred = sgr.predict(X_test)
assert_almost_equal(references, y_pred)


def test_multi_target_regression_one_target():
# Test multi target regression raises
X, y = datasets.make_regression(n_targets=1)
X_train, y_train = X[:50], y[:50]
X_test, y_test = X[50:], y[50:]

rgr = MultiOutputRegressor(GradientBoostingRegressor(random_state=0))
assert_raises(ValueError, rgr.fit, X_train, y_train)
assert_raises(ValueError, rgr.fit, X, y)


def test_multi_target_sparse_regression():
X, y = datasets.make_regression(n_targets=3)
X_train, y_train = X[:50], y[:50]
X_test, y_test = X[50:], y[50:]
X_test = X[50:]

for sparse in [sp.csr_matrix, sp.csc_matrix, sp.coo_matrix, sp.dok_matrix,
sp.lil_matrix]:
Expand All @@ -57,11 +84,12 @@ def test_multi_target_sparse_regression():
rgr.fit(X_train, y_train)
rgr_sparse.fit(sparse(X_train), y_train)

assert_almost_equal(rgr.predict(X_test), rgr_sparse.predict(sparse(X_test)))
assert_almost_equal(rgr.predict(X_test),
rgr_sparse.predict(sparse(X_test)))


def test_multi_target_sample_weights_api():
X = [[1,2,3], [4,5,6]]
X = [[1, 2, 3], [4, 5, 6]]
y = [[3.141, 2.718], [2.718, 3.141]]
w = [0.8, 0.6]

Expand All @@ -74,23 +102,40 @@ def test_multi_target_sample_weights_api():
rgr.fit(X, y, w)


def test_multi_target_sample_weight_partial_fit():
# weighted regressor
X = [[1, 2, 3], [4, 5, 6]]
y = [[3.141, 2.718], [2.718, 3.141]]
w = [2., 1.]
rgr_w = MultiOutputRegressor(SGDRegressor(random_state=0))
rgr_w.partial_fit(X, y, w)

# weighted with different weights
w = [2., 2.]
rgr = MultiOutputRegressor(SGDRegressor(random_state=0))
rgr.partial_fit(X, y, w)

assert_not_equal(rgr.predict(X)[0][0], rgr_w.predict(X)[0][0])


def test_multi_target_sample_weights():
# weighted regressor
Xw = [[1,2,3], [4,5,6]]
Xw = [[1, 2, 3], [4, 5, 6]]
yw = [[3.141, 2.718], [2.718, 3.141]]
w = [2., 1.]
rgr_w = MultiOutputRegressor(GradientBoostingRegressor(random_state=0))
rgr_w.fit(Xw, yw, w)

# unweighted, but with repeated samples
X = [[1,2,3], [1,2,3], [4,5,6]]
X = [[1, 2, 3], [1, 2, 3], [4, 5, 6]]
y = [[3.141, 2.718], [3.141, 2.718], [2.718, 3.141]]
rgr = MultiOutputRegressor(GradientBoostingRegressor(random_state=0))
rgr.fit(X, y)

X_test = [[1.5,2.5,3.5], [3.5,4.5,5.5]]
X_test = [[1.5, 2.5, 3.5], [3.5, 4.5, 5.5]]
assert_almost_equal(rgr.predict(X_test), rgr_w.predict(X_test))


# Import the data
iris = datasets.load_iris()
# create a multiple targets by randomized shuffling and concatenating y.
Expand All @@ -102,6 +147,57 @@ def test_multi_target_sample_weights():
n_samples, n_features = X.shape
n_outputs = y.shape[1]
n_classes = len(np.unique(y1))
classes = list(map(np.unique, (y1, y2, y3)))


def test_multi_output_classification_partial_fit_parallelism():
sgd_linear_clf = SGDClassifier(loss='log', random_state=1)
mor = MultiOutputClassifier(sgd_linear_clf, n_jobs=-1)
mor.partial_fit(X, y, classes)
est1 = mor.estimators_[0]
mor.partial_fit(X, y)
est2 = mor.estimators_[0]
# parallelism requires this to be the case for a sane implementation
assert_false(est1 is est2)


def test_multi_output_classification_partial_fit():
# test if multi_target initializes correctly with base estimator and fit
# assert predictions work as expected for predict

sgd_linear_clf = SGDClassifier(loss='log', random_state=1)
multi_target_linear = MultiOutputClassifier(sgd_linear_clf)

# train the multi_target_linear and also get the predictions.
half_index = X.shape[0] // 2
multi_target_linear.partial_fit(
X[:half_index], y[:half_index], classes=classes)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should also check check:

  • sample weight
  • passing classes=None raises appropriate error


first_predictions = multi_target_linear.predict(X)
assert_equal((n_samples, n_outputs), first_predictions.shape)

multi_target_linear.partial_fit(X[half_index:], y[half_index:])
second_predictions = multi_target_linear.predict(X)
assert_equal((n_samples, n_outputs), second_predictions.shape)

# train the linear classification with each column and assert that
# predictions are equal after first partial_fit and second partial_fit
for i in range(3):
# create a clone with the same state
sgd_linear_clf = clone(sgd_linear_clf)
sgd_linear_clf.partial_fit(
X[:half_index], y[:half_index, i], classes=classes[i])
assert_array_equal(sgd_linear_clf.predict(X), first_predictions[:, i])
sgd_linear_clf.partial_fit(X[half_index:], y[half_index:, i])
assert_array_equal(sgd_linear_clf.predict(X), second_predictions[:, i])


def test_mutli_output_classifiation_partial_fit_no_first_classes_exception():
sgd_linear_clf = SGDClassifier(loss='log', random_state=1)
multi_target_linear = MultiOutputClassifier(sgd_linear_clf)
assert_raises_regex(ValueError, "classes must be passed on the first call "
"to partial_fit.",
multi_target_linear.partial_fit, X, y)


def test_multi_output_classification():
Expand Down Expand Up @@ -209,6 +305,25 @@ def test_multi_output_classification_sample_weights():
assert_almost_equal(clf.predict(X_test), clf_w.predict(X_test))


def test_multi_output_classification_partial_fit_sample_weights():
# weighted classifier
Xw = [[1, 2, 3], [4, 5, 6], [1.5, 2.5, 3.5]]
yw = [[3, 2], [2, 3], [3, 2]]
w = np.asarray([2., 1., 1.])
sgd_linear_clf = SGDClassifier(random_state=1)
clf_w = MultiOutputClassifier(sgd_linear_clf)
clf_w.fit(Xw, yw, w)

# unweighted, but with repeated samples
X = [[1, 2, 3], [1, 2, 3], [4, 5, 6], [1.5, 2.5, 3.5]]
y = [[3, 2], [3, 2], [2, 3], [3, 2]]
sgd_linear_clf = SGDClassifier(random_state=1)
clf = MultiOutputClassifier(sgd_linear_clf)
clf.fit(X, y)
X_test = [[1.5, 2.5, 3.5]]
assert_array_almost_equal(clf.predict(X_test), clf_w.predict(X_test))


def test_multi_output_exceptions():
# NotFittedError when fit is not done but score, predict and
# and predict_proba are called
Expand Down