Skip to content

[MRG] ENH add support for multiclass-multioutput to ClassifierChain #14654

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
8fe643a
fix predict_proba and decision_function
agamemnonc Aug 14, 2019
86b263e
fix tests
agamemnonc Aug 14, 2019
32965fa
Merge branch 'master' into classifier_chain_multiclass
agamemnonc Aug 14, 2019
b4babca
remove multioutput_only tag
agamemnonc Aug 14, 2019
d6b35a7
update whats_new
agamemnonc Aug 14, 2019
cc632f3
Merge branch 'master' into classifier_chain_multiclass
agamemnonc Aug 19, 2019
eb3276a
Merge branch 'master' into classifier_chain_multiclass
agamemnonc Aug 23, 2019
22a856d
merge and resolve conflicts
agamemnonc Nov 15, 2019
dccbf6d
nitpicks
agamemnonc Nov 15, 2019
11f09eb
Merge branch 'master' into classifier_chain_multiclass
agamemnonc Jan 9, 2020
408d595
update docs
agamemnonc Jan 10, 2020
95c4bf8
fix whatsnew
agamemnonc Jan 10, 2020
52a559d
Merge scikit-learn:master
agamemnonc Aug 20, 2020
c271cf4
Merge branch 'master' into classifier_chain_multiclass
agamemnonc Dec 15, 2020
6d8cf2a
Revert "fix whatsnew"
agamemnonc Dec 15, 2020
3b95a3b
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
agamemnonc Jan 30, 2021
0ce1468
merge master
agamemnonc Jan 30, 2021
4750557
checkout whatsnew from master
agamemnonc Jan 30, 2021
0616429
Fix issue
agamemnonc Jan 31, 2021
41e8ffa
Restore changed files
agamemnonc Jan 31, 2021
4d34d08
Update ClassifierChain doc example output
agamemnonc Jan 31, 2021
d3677ad
Nitpick
agamemnonc Jan 31, 2021
bd8493f
Update whatsnew
agamemnonc Jan 31, 2021
d561756
Fix whatsnew formatting issue
agamemnonc Jan 31, 2021
f019810
Update in 0.24 vs. 1.0
agamemnonc Jan 31, 2021
54c3fdd
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
agamemnonc Feb 1, 2021
67f4986
Merge branch 'main' into classifier_chain_multiclass
agamemnonc Feb 1, 2021
e49c043
Move changelog from v0.24.rst to v1.0.rst
agamemnonc Feb 1, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions doc/whats_new/v1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,14 @@ Changelog
class methods and will be removed in 1.2.
:pr:`18543` by `Guillaume Lemaitre`_.

:mod:`sklearn.multioutput`
..........................

- |Fix| :func:`multioutput.ClassifierChain.decision_function` and
:func:`multioutput.ClassifierChain.predict_proba` now both return a list of
``n_outputs`` arrays of shape `(n_samples, n_classes).` :pr:`14654` by
:user:`Agamemnon Krasoulis <agamemnonc>`.

:mod:`sklearn.naive_bayes`
..........................

Expand Down
39 changes: 24 additions & 15 deletions sklearn/multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,9 +638,15 @@ class labels for each estimator in the chain.
[1., 0., 0.],
[0., 1., 0.]])
>>> chain.predict_proba(X_test)
array([[0.8387..., 0.9431..., 0.4576...],
[0.8878..., 0.3684..., 0.2640...],
[0.0321..., 0.9935..., 0.0625...]])
[array([[0.16126878, 0.83873122],
[0.11218344, 0.88781656],
[0.96786386, 0.03213614]]),
array([[0.05685769, 0.94314231],
[0.6315953 , 0.3684047 ],
[0.00640331, 0.99359669]]),
array([[0.5423851 , 0.4576149 ],
[0.73590132, 0.26409868],
[0.93742079, 0.06257921]])]

See Also
--------
Expand Down Expand Up @@ -684,22 +690,25 @@ def predict_proba(self, X):

Returns
-------
Y_prob : array-like of shape (n_samples, n_classes)
Y_prob : list of n_outputs ndarray of shape (n_samples, n_classes)
The class probabilities of the input samples. The order of the
classes for each output corresponds to the respective entry of
the attribute `classes_`.
"""
X = check_array(X, accept_sparse=True)
Y_prob_chain = np.zeros((X.shape[0], len(self.estimators_)))
Y_prob_chain = []
Y_pred_chain = np.zeros((X.shape[0], len(self.estimators_)))
for chain_idx, estimator in enumerate(self.estimators_):
previous_predictions = Y_pred_chain[:, :chain_idx]
if sp.issparse(X):
X_aug = sp.hstack((X, previous_predictions))
else:
X_aug = np.hstack((X, previous_predictions))
Y_prob_chain[:, chain_idx] = estimator.predict_proba(X_aug)[:, 1]
Y_prob_chain.append(estimator.predict_proba(X_aug))
Y_pred_chain[:, chain_idx] = estimator.predict(X_aug)
inv_order = np.empty_like(self.order_)
inv_order[self.order_] = np.arange(len(self.order_))
Y_prob = Y_prob_chain[:, inv_order]
Y_prob = [Y_prob_chain[i] for i in inv_order]

return Y_prob

Expand All @@ -713,30 +722,30 @@ def decision_function(self, X):

Returns
-------
Y_decision : array-like of shape (n_samples, n_classes)
Returns the decision function of the sample for each model
in the chain.
Y_decision : list of n_outputs ndarray of shape (n_samples, n_classes)
Decision function of the input samples for each model
in the chain. The order of the classes for each output corresponds
to the respective entry of the attribute `classes_`.
"""
Y_decision_chain = np.zeros((X.shape[0], len(self.estimators_)))
Y_decision_chain = []
Y_pred_chain = np.zeros((X.shape[0], len(self.estimators_)))
for chain_idx, estimator in enumerate(self.estimators_):
previous_predictions = Y_pred_chain[:, :chain_idx]
if sp.issparse(X):
X_aug = sp.hstack((X, previous_predictions))
else:
X_aug = np.hstack((X, previous_predictions))
Y_decision_chain[:, chain_idx] = estimator.decision_function(X_aug)
Y_decision_chain.append(estimator.decision_function(X_aug))
Y_pred_chain[:, chain_idx] = estimator.predict(X_aug)

inv_order = np.empty_like(self.order_)
inv_order[self.order_] = np.arange(len(self.order_))
Y_decision = Y_decision_chain[:, inv_order]
Y_decision = [Y_decision_chain[i] for i in inv_order]

return Y_decision

def _more_tags(self):
return {'_skip_test': True,
'multioutput_only': True}
return {'_skip_test': True}


class RegressorChain(MetaEstimatorMixin, RegressorMixin, _BaseChain):
Expand Down
6 changes: 4 additions & 2 deletions sklearn/tests/test_multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,9 @@ def test_classifier_chain_fit_and_predict_with_linear_svc():
assert Y_pred.shape == Y.shape

Y_decision = classifier_chain.decision_function(X)
Y_binary = [Y_decision[i] >= 0 for i in range(Y.shape[1])]
Y_binary = np.asarray(Y_binary).T

Y_binary = (Y_decision >= 0)
assert_array_equal(Y_binary, Y_pred)
assert not hasattr(classifier_chain, 'predict_proba')

Expand Down Expand Up @@ -481,7 +482,8 @@ def test_base_chain_fit_and_predict():
list(range(X.shape[1], X.shape[1] + Y.shape[1])))

Y_prob = chains[1].predict_proba(X)
Y_binary = (Y_prob >= .5)
Y_binary = [np.argmax(Y_prob[i], axis=1) for i in range(Y.shape[1])]
Y_binary = np.asarray(Y_binary).T
assert_array_equal(Y_binary, Y_pred)

assert isinstance(chains[1], ClassifierMixin)
Expand Down