Skip to content

ENH Add chain_method to ClassifierChain #27700

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Feb 23, 2024
Merged
6 changes: 6 additions & 0 deletions doc/whats_new/v1.5.rst
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,12 @@ Changelog
- |Enhancement| :term:`CV splitters <CV splitter>` that ignores the group parameter now
raises a warning when groups are passed in to :term:`split`. :pr:`28210` by

:mod:`sklearn.multioutput`
..........................

- |Enhancement| `chain_method` parameter added to `:class:``multioutput.ClassifierChain`.
:pr:`27700` by :user:`Lucy Liu <lucyleeow>`.

:mod:`sklearn.pipeline`
.......................

Expand Down
164 changes: 110 additions & 54 deletions sklearn/multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from .model_selection import cross_val_predict
from .utils import Bunch, _print_elapsed_time, check_random_state
from .utils._param_validation import HasMethods, StrOptions
from .utils._response import _get_response_values
from .utils.metadata_routing import (
MetadataRouter,
MethodMapping,
Expand All @@ -43,7 +44,12 @@
from .utils.metaestimators import available_if
from .utils.multiclass import check_classification_targets
from .utils.parallel import Parallel, delayed
from .utils.validation import _check_method_params, check_is_fitted, has_fit_parameter
from .utils.validation import (
_check_method_params,
_check_response_method,
check_is_fitted,
has_fit_parameter,
)

__all__ = [
"MultiOutputRegressor",
Expand Down Expand Up @@ -650,6 +656,41 @@ def _log_message(self, *, estimator_idx, n_estimators, processing_msg):
return None
return f"({estimator_idx} of {n_estimators}) {processing_msg}"

def _get_predictions(self, X, *, output_method):
"""Get predictions for each model in the chain."""
check_is_fitted(self)
X = self._validate_data(X, accept_sparse=True, reset=False)
Y_output_chain = np.zeros((X.shape[0], len(self.estimators_)))
Y_feature_chain = np.zeros((X.shape[0], len(self.estimators_)))

# `RegressorChain` does not have a `chain_method_` parameter so we
# default to "predict"
chain_method = getattr(self, "chain_method_", "predict")
hstack = sp.hstack if sp.issparse(X) else np.hstack
for chain_idx, estimator in enumerate(self.estimators_):
previous_predictions = Y_feature_chain[:, :chain_idx]
X_aug = hstack((X, previous_predictions))

feature_predictions, _ = _get_response_values(
estimator,
X_aug,
response_method=chain_method,
)
Y_feature_chain[:, chain_idx] = feature_predictions

output_predictions, _ = _get_response_values(
estimator,
X_aug,
response_method=output_method,
)
Y_output_chain[:, chain_idx] = output_predictions

inv_order = np.empty_like(self.order_)
inv_order[self.order_] = np.arange(len(self.order_))
Y_output = Y_output_chain[:, inv_order]

return Y_output

@abstractmethod
def fit(self, X, Y, **fit_params):
"""Fit the model to data matrix X and targets Y.
Expand Down Expand Up @@ -712,6 +753,16 @@ def fit(self, X, Y, **fit_params):
else:
routed_params = Bunch(estimator=Bunch(fit=fit_params))

if hasattr(self, "chain_method"):
chain_method = _check_response_method(
self.base_estimator,
self.chain_method,
).__name__
self.chain_method_ = chain_method
else:
# `RegressorChain` does not have a `chain_method` parameter
chain_method = "predict"

for chain_idx, estimator in enumerate(self.estimators_):
message = self._log_message(
estimator_idx=chain_idx + 1,
Expand All @@ -729,8 +780,15 @@ def fit(self, X, Y, **fit_params):
if self.cv is not None and chain_idx < len(self.estimators_) - 1:
col_idx = X.shape[1] + chain_idx
cv_result = cross_val_predict(
self.base_estimator, X_aug[:, :col_idx], y=y, cv=self.cv
self.base_estimator,
X_aug[:, :col_idx],
y=y,
cv=self.cv,
method=chain_method,
)
# `predict_proba` output is 2D, we use only output for classes[-1]
if cv_result.ndim > 1:
cv_result = cv_result[:, 1]
if sp.issparse(X_aug):
X_aug[:, col_idx] = np.expand_dims(cv_result, 1)
else:
Expand All @@ -751,25 +809,7 @@ def predict(self, X):
Y_pred : array-like of shape (n_samples, n_classes)
The predicted values.
"""
check_is_fitted(self)
X = self._validate_data(X, accept_sparse=True, reset=False)
Y_pred_chain = np.zeros((X.shape[0], len(self.estimators_)))
for chain_idx, estimator in enumerate(self.estimators_):
previous_predictions = Y_pred_chain[:, :chain_idx]
if sp.issparse(X):
if chain_idx == 0:
X_aug = X
else:
X_aug = sp.hstack((X, previous_predictions))
else:
X_aug = np.hstack((X, previous_predictions))
Y_pred_chain[:, chain_idx] = estimator.predict(X_aug)

inv_order = np.empty_like(self.order_)
inv_order[self.order_] = np.arange(len(self.order_))
Y_pred = Y_pred_chain[:, inv_order]

return Y_pred
return self._get_predictions(X, output_method="predict")


class ClassifierChain(MetaEstimatorMixin, ClassifierMixin, _BaseChain):
Expand Down Expand Up @@ -820,6 +860,19 @@ class ClassifierChain(MetaEstimatorMixin, ClassifierMixin, _BaseChain):
- :term:`CV splitter`,
- An iterable yielding (train, test) splits as arrays of indices.

chain_method : {'predict', 'predict_proba', 'predict_log_proba', \
'decision_function'} or list of such str's, default='predict'

Prediction method to be used by estimators in the chain for
the 'prediction' features of previous estimators in the chain.

- if `str`, name of the method;
- if a list of `str`, provides the method names in order of
preference. The method used corresponds to the first method in
the list that is implemented by `base_estimator`.

.. versionadded:: 1.5

random_state : int, RandomState instance or None, optional (default=None)
If ``order='random'``, determines random number generation for the
chain order.
Expand All @@ -846,6 +899,10 @@ class labels for each estimator in the chain.
order_ : list
The order of labels in the classifier chain.

chain_method_ : str
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be provided when using _return_response_method_used in _get_response_values.

Prediction method used by estimators in the chain for the prediction
features.

n_features_in_ : int
Number of features seen during :term:`fit`. Only defined if the
underlying `base_estimator` exposes such an attribute when fit.
Expand Down Expand Up @@ -893,6 +950,36 @@ class labels for each estimator in the chain.
[0.0321..., 0.9935..., 0.0626...]])
"""

_parameter_constraints: dict = {
**_BaseChain._parameter_constraints,
"chain_method": [
list,
tuple,
StrOptions(
{"predict", "predict_proba", "predict_log_proba", "decision_function"}
),
],
}

def __init__(
self,
base_estimator,
*,
order=None,
cv=None,
chain_method="predict",
random_state=None,
verbose=False,
):
super().__init__(
base_estimator,
order=order,
cv=cv,
random_state=random_state,
verbose=verbose,
)
self.chain_method = chain_method

@_fit_context(
# ClassifierChain.base_estimator is not validated yet
prefer_skip_nested_validation=False
Expand Down Expand Up @@ -941,22 +1028,7 @@ def predict_proba(self, X):
Y_prob : array-like of shape (n_samples, n_classes)
The predicted probabilities.
"""
X = self._validate_data(X, accept_sparse=True, reset=False)
Y_prob_chain = np.zeros((X.shape[0], len(self.estimators_)))
Y_pred_chain = np.zeros((X.shape[0], len(self.estimators_)))
for chain_idx, estimator in enumerate(self.estimators_):
previous_predictions = Y_pred_chain[:, :chain_idx]
if sp.issparse(X):
X_aug = sp.hstack((X, previous_predictions))
else:
X_aug = np.hstack((X, previous_predictions))
Y_prob_chain[:, chain_idx] = estimator.predict_proba(X_aug)[:, 1]
Y_pred_chain[:, chain_idx] = estimator.predict(X_aug)
inv_order = np.empty_like(self.order_)
inv_order[self.order_] = np.arange(len(self.order_))
Y_prob = Y_prob_chain[:, inv_order]

return Y_prob
return self._get_predictions(X, output_method="predict_proba")

def predict_log_proba(self, X):
"""Predict logarithm of probability estimates.
Expand Down Expand Up @@ -988,23 +1060,7 @@ def decision_function(self, X):
Returns the decision function of the sample for each model
in the chain.
"""
X = self._validate_data(X, accept_sparse=True, reset=False)
Y_decision_chain = np.zeros((X.shape[0], len(self.estimators_)))
Y_pred_chain = np.zeros((X.shape[0], len(self.estimators_)))
for chain_idx, estimator in enumerate(self.estimators_):
previous_predictions = Y_pred_chain[:, :chain_idx]
if sp.issparse(X):
X_aug = sp.hstack((X, previous_predictions))
else:
X_aug = np.hstack((X, previous_predictions))
Y_decision_chain[:, chain_idx] = estimator.decision_function(X_aug)
Y_pred_chain[:, chain_idx] = estimator.predict(X_aug)

inv_order = np.empty_like(self.order_)
inv_order[self.order_] = np.arange(len(self.order_))
Y_decision = Y_decision_chain[:, inv_order]

return Y_decision
return self._get_predictions(X, output_method="decision_function")

def get_metadata_routing(self):
"""Get metadata routing of this object.
Expand Down
95 changes: 62 additions & 33 deletions sklearn/tests/test_multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,11 +508,14 @@ def generate_multilabel_dataset_with_correlations():
return X, Y_multi


def test_classifier_chain_fit_and_predict_with_linear_svc():
@pytest.mark.parametrize("chain_method", ["predict", "decision_function"])
def test_classifier_chain_fit_and_predict_with_linear_svc(chain_method):
# Fit classifier chain and verify predict performance using LinearSVC
X, Y = generate_multilabel_dataset_with_correlations()
classifier_chain = ClassifierChain(LinearSVC(dual="auto"))
classifier_chain.fit(X, Y)
classifier_chain = ClassifierChain(
LinearSVC(dual="auto"),
chain_method=chain_method,
).fit(X, Y)

Y_pred = classifier_chain.predict(X)
assert Y_pred.shape == Y.shape
Expand All @@ -530,12 +533,10 @@ def test_classifier_chain_fit_and_predict_with_sparse_data(csr_container):
X, Y = generate_multilabel_dataset_with_correlations()
X_sparse = csr_container(X)

classifier_chain = ClassifierChain(LogisticRegression())
classifier_chain.fit(X_sparse, Y)
classifier_chain = ClassifierChain(LogisticRegression()).fit(X_sparse, Y)
Y_pred_sparse = classifier_chain.predict(X_sparse)

classifier_chain = ClassifierChain(LogisticRegression())
classifier_chain.fit(X, Y)
classifier_chain = ClassifierChain(LogisticRegression()).fit(X, Y)
Y_pred_dense = classifier_chain.predict(X)

assert_array_equal(Y_pred_sparse, Y_pred_dense)
Expand Down Expand Up @@ -564,26 +565,41 @@ def test_classifier_chain_vs_independent_models():
)


@pytest.mark.parametrize(
"chain_method",
["predict", "predict_proba", "predict_log_proba", "decision_function"],
)
@pytest.mark.parametrize("response_method", ["predict_proba", "predict_log_proba"])
def test_base_chain_fit_and_predict(response_method):
# Fit base chain and verify predict performance
def test_classifier_chain_fit_and_predict(chain_method, response_method):
# Fit classifier chain and verify predict performance
X, Y = generate_multilabel_dataset_with_correlations()
chains = [RegressorChain(Ridge()), ClassifierChain(LogisticRegression())]
for chain in chains:
chain.fit(X, Y)
Y_pred = chain.predict(X)
assert Y_pred.shape == Y.shape
assert [c.coef_.size for c in chain.estimators_] == list(
range(X.shape[1], X.shape[1] + Y.shape[1])
)
chain = ClassifierChain(LogisticRegression(), chain_method=chain_method)
chain.fit(X, Y)
Y_pred = chain.predict(X)
assert Y_pred.shape == Y.shape
assert [c.coef_.size for c in chain.estimators_] == list(
range(X.shape[1], X.shape[1] + Y.shape[1])
)

Y_prob = getattr(chains[1], response_method)(X)
Y_prob = getattr(chain, response_method)(X)
if response_method == "predict_log_proba":
Y_prob = np.exp(Y_prob)
Y_binary = Y_prob >= 0.5
assert_array_equal(Y_binary, Y_pred)

assert isinstance(chains[1], ClassifierMixin)
assert isinstance(chain, ClassifierMixin)


def test_regressor_chain_fit_and_predict():
# Fit regressor chain and verify Y and estimator coefficients shape
X, Y = generate_multilabel_dataset_with_correlations()
chain = RegressorChain(Ridge())
chain.fit(X, Y)
Y_pred = chain.predict(X)
assert Y_pred.shape == Y.shape
assert [c.coef_.size for c in chain.estimators_] == list(
range(X.shape[1], X.shape[1] + Y.shape[1])
)


@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
Expand Down Expand Up @@ -619,24 +635,37 @@ def test_base_chain_random_order():
assert_array_almost_equal(est1.coef_, est2.coef_)


def test_base_chain_crossval_fit_and_predict():
@pytest.mark.parametrize(
"chain_type, chain_method",
[
("classifier", "predict"),
("classifier", "predict_proba"),
("classifier", "predict_log_proba"),
("classifier", "decision_function"),
("regressor", ""),
],
)
def test_base_chain_crossval_fit_and_predict(chain_type, chain_method):
# Fit chain with cross_val_predict and verify predict
# performance
X, Y = generate_multilabel_dataset_with_correlations()

for chain in [ClassifierChain(LogisticRegression()), RegressorChain(Ridge())]:
chain.fit(X, Y)
chain_cv = clone(chain).set_params(cv=3)
chain_cv.fit(X, Y)
Y_pred_cv = chain_cv.predict(X)
Y_pred = chain.predict(X)

assert Y_pred_cv.shape == Y_pred.shape
assert not np.all(Y_pred == Y_pred_cv)
if isinstance(chain, ClassifierChain):
assert jaccard_score(Y, Y_pred_cv, average="samples") > 0.4
else:
assert mean_squared_error(Y, Y_pred_cv) < 0.25
if chain_type == "classifier":
chain = ClassifierChain(LogisticRegression(), chain_method=chain_method)
else:
chain = RegressorChain(Ridge())
chain.fit(X, Y)
chain_cv = clone(chain).set_params(cv=3)
chain_cv.fit(X, Y)
Y_pred_cv = chain_cv.predict(X)
Y_pred = chain.predict(X)

assert Y_pred_cv.shape == Y_pred.shape
assert not np.all(Y_pred == Y_pred_cv)
if isinstance(chain, ClassifierChain):
assert jaccard_score(Y, Y_pred_cv, average="samples") > 0.4
else:
assert mean_squared_error(Y, Y_pred_cv) < 0.25


@pytest.mark.parametrize(
Expand Down