diff --git a/doc/whats_new/v1.5.rst b/doc/whats_new/v1.5.rst index 53f0fbd8a74e8..d70c9cc2f1f23 100644 --- a/doc/whats_new/v1.5.rst +++ b/doc/whats_new/v1.5.rst @@ -158,6 +158,12 @@ Changelog - |Enhancement| :term:`CV splitters ` that ignores the group parameter now raises a warning when groups are passed in to :term:`split`. :pr:`28210` by +:mod:`sklearn.multioutput` +.......................... + +- |Enhancement| `chain_method` parameter added to `:class:``multioutput.ClassifierChain`. + :pr:`27700` by :user:`Lucy Liu `. + :mod:`sklearn.pipeline` ....................... diff --git a/sklearn/multioutput.py b/sklearn/multioutput.py index bfb83884399ef..64649007d6f24 100644 --- a/sklearn/multioutput.py +++ b/sklearn/multioutput.py @@ -33,6 +33,7 @@ from .model_selection import cross_val_predict from .utils import Bunch, _print_elapsed_time, check_random_state from .utils._param_validation import HasMethods, StrOptions +from .utils._response import _get_response_values from .utils.metadata_routing import ( MetadataRouter, MethodMapping, @@ -43,7 +44,12 @@ from .utils.metaestimators import available_if from .utils.multiclass import check_classification_targets from .utils.parallel import Parallel, delayed -from .utils.validation import _check_method_params, check_is_fitted, has_fit_parameter +from .utils.validation import ( + _check_method_params, + _check_response_method, + check_is_fitted, + has_fit_parameter, +) __all__ = [ "MultiOutputRegressor", @@ -650,6 +656,41 @@ def _log_message(self, *, estimator_idx, n_estimators, processing_msg): return None return f"({estimator_idx} of {n_estimators}) {processing_msg}" + def _get_predictions(self, X, *, output_method): + """Get predictions for each model in the chain.""" + check_is_fitted(self) + X = self._validate_data(X, accept_sparse=True, reset=False) + Y_output_chain = np.zeros((X.shape[0], len(self.estimators_))) + Y_feature_chain = np.zeros((X.shape[0], len(self.estimators_))) + + # `RegressorChain` does not have a `chain_method_` parameter so we + # default to "predict" + chain_method = getattr(self, "chain_method_", "predict") + hstack = sp.hstack if sp.issparse(X) else np.hstack + for chain_idx, estimator in enumerate(self.estimators_): + previous_predictions = Y_feature_chain[:, :chain_idx] + X_aug = hstack((X, previous_predictions)) + + feature_predictions, _ = _get_response_values( + estimator, + X_aug, + response_method=chain_method, + ) + Y_feature_chain[:, chain_idx] = feature_predictions + + output_predictions, _ = _get_response_values( + estimator, + X_aug, + response_method=output_method, + ) + Y_output_chain[:, chain_idx] = output_predictions + + inv_order = np.empty_like(self.order_) + inv_order[self.order_] = np.arange(len(self.order_)) + Y_output = Y_output_chain[:, inv_order] + + return Y_output + @abstractmethod def fit(self, X, Y, **fit_params): """Fit the model to data matrix X and targets Y. @@ -712,6 +753,16 @@ def fit(self, X, Y, **fit_params): else: routed_params = Bunch(estimator=Bunch(fit=fit_params)) + if hasattr(self, "chain_method"): + chain_method = _check_response_method( + self.base_estimator, + self.chain_method, + ).__name__ + self.chain_method_ = chain_method + else: + # `RegressorChain` does not have a `chain_method` parameter + chain_method = "predict" + for chain_idx, estimator in enumerate(self.estimators_): message = self._log_message( estimator_idx=chain_idx + 1, @@ -729,8 +780,15 @@ def fit(self, X, Y, **fit_params): if self.cv is not None and chain_idx < len(self.estimators_) - 1: col_idx = X.shape[1] + chain_idx cv_result = cross_val_predict( - self.base_estimator, X_aug[:, :col_idx], y=y, cv=self.cv + self.base_estimator, + X_aug[:, :col_idx], + y=y, + cv=self.cv, + method=chain_method, ) + # `predict_proba` output is 2D, we use only output for classes[-1] + if cv_result.ndim > 1: + cv_result = cv_result[:, 1] if sp.issparse(X_aug): X_aug[:, col_idx] = np.expand_dims(cv_result, 1) else: @@ -751,25 +809,7 @@ def predict(self, X): Y_pred : array-like of shape (n_samples, n_classes) The predicted values. """ - check_is_fitted(self) - X = self._validate_data(X, accept_sparse=True, reset=False) - Y_pred_chain = np.zeros((X.shape[0], len(self.estimators_))) - for chain_idx, estimator in enumerate(self.estimators_): - previous_predictions = Y_pred_chain[:, :chain_idx] - if sp.issparse(X): - if chain_idx == 0: - X_aug = X - else: - X_aug = sp.hstack((X, previous_predictions)) - else: - X_aug = np.hstack((X, previous_predictions)) - Y_pred_chain[:, chain_idx] = estimator.predict(X_aug) - - inv_order = np.empty_like(self.order_) - inv_order[self.order_] = np.arange(len(self.order_)) - Y_pred = Y_pred_chain[:, inv_order] - - return Y_pred + return self._get_predictions(X, output_method="predict") class ClassifierChain(MetaEstimatorMixin, ClassifierMixin, _BaseChain): @@ -820,6 +860,19 @@ class ClassifierChain(MetaEstimatorMixin, ClassifierMixin, _BaseChain): - :term:`CV splitter`, - An iterable yielding (train, test) splits as arrays of indices. + chain_method : {'predict', 'predict_proba', 'predict_log_proba', \ + 'decision_function'} or list of such str's, default='predict' + + Prediction method to be used by estimators in the chain for + the 'prediction' features of previous estimators in the chain. + + - if `str`, name of the method; + - if a list of `str`, provides the method names in order of + preference. The method used corresponds to the first method in + the list that is implemented by `base_estimator`. + + .. versionadded:: 1.5 + random_state : int, RandomState instance or None, optional (default=None) If ``order='random'``, determines random number generation for the chain order. @@ -846,6 +899,10 @@ class labels for each estimator in the chain. order_ : list The order of labels in the classifier chain. + chain_method_ : str + Prediction method used by estimators in the chain for the prediction + features. + n_features_in_ : int Number of features seen during :term:`fit`. Only defined if the underlying `base_estimator` exposes such an attribute when fit. @@ -893,6 +950,36 @@ class labels for each estimator in the chain. [0.0321..., 0.9935..., 0.0626...]]) """ + _parameter_constraints: dict = { + **_BaseChain._parameter_constraints, + "chain_method": [ + list, + tuple, + StrOptions( + {"predict", "predict_proba", "predict_log_proba", "decision_function"} + ), + ], + } + + def __init__( + self, + base_estimator, + *, + order=None, + cv=None, + chain_method="predict", + random_state=None, + verbose=False, + ): + super().__init__( + base_estimator, + order=order, + cv=cv, + random_state=random_state, + verbose=verbose, + ) + self.chain_method = chain_method + @_fit_context( # ClassifierChain.base_estimator is not validated yet prefer_skip_nested_validation=False @@ -941,22 +1028,7 @@ def predict_proba(self, X): Y_prob : array-like of shape (n_samples, n_classes) The predicted probabilities. """ - X = self._validate_data(X, accept_sparse=True, reset=False) - Y_prob_chain = np.zeros((X.shape[0], len(self.estimators_))) - Y_pred_chain = np.zeros((X.shape[0], len(self.estimators_))) - for chain_idx, estimator in enumerate(self.estimators_): - previous_predictions = Y_pred_chain[:, :chain_idx] - if sp.issparse(X): - X_aug = sp.hstack((X, previous_predictions)) - else: - X_aug = np.hstack((X, previous_predictions)) - Y_prob_chain[:, chain_idx] = estimator.predict_proba(X_aug)[:, 1] - Y_pred_chain[:, chain_idx] = estimator.predict(X_aug) - inv_order = np.empty_like(self.order_) - inv_order[self.order_] = np.arange(len(self.order_)) - Y_prob = Y_prob_chain[:, inv_order] - - return Y_prob + return self._get_predictions(X, output_method="predict_proba") def predict_log_proba(self, X): """Predict logarithm of probability estimates. @@ -988,23 +1060,7 @@ def decision_function(self, X): Returns the decision function of the sample for each model in the chain. """ - X = self._validate_data(X, accept_sparse=True, reset=False) - Y_decision_chain = np.zeros((X.shape[0], len(self.estimators_))) - Y_pred_chain = np.zeros((X.shape[0], len(self.estimators_))) - for chain_idx, estimator in enumerate(self.estimators_): - previous_predictions = Y_pred_chain[:, :chain_idx] - if sp.issparse(X): - X_aug = sp.hstack((X, previous_predictions)) - else: - X_aug = np.hstack((X, previous_predictions)) - Y_decision_chain[:, chain_idx] = estimator.decision_function(X_aug) - Y_pred_chain[:, chain_idx] = estimator.predict(X_aug) - - inv_order = np.empty_like(self.order_) - inv_order[self.order_] = np.arange(len(self.order_)) - Y_decision = Y_decision_chain[:, inv_order] - - return Y_decision + return self._get_predictions(X, output_method="decision_function") def get_metadata_routing(self): """Get metadata routing of this object. diff --git a/sklearn/tests/test_multioutput.py b/sklearn/tests/test_multioutput.py index c42938229d5a6..6048c7c500cb8 100644 --- a/sklearn/tests/test_multioutput.py +++ b/sklearn/tests/test_multioutput.py @@ -508,11 +508,14 @@ def generate_multilabel_dataset_with_correlations(): return X, Y_multi -def test_classifier_chain_fit_and_predict_with_linear_svc(): +@pytest.mark.parametrize("chain_method", ["predict", "decision_function"]) +def test_classifier_chain_fit_and_predict_with_linear_svc(chain_method): # Fit classifier chain and verify predict performance using LinearSVC X, Y = generate_multilabel_dataset_with_correlations() - classifier_chain = ClassifierChain(LinearSVC(dual="auto")) - classifier_chain.fit(X, Y) + classifier_chain = ClassifierChain( + LinearSVC(dual="auto"), + chain_method=chain_method, + ).fit(X, Y) Y_pred = classifier_chain.predict(X) assert Y_pred.shape == Y.shape @@ -530,12 +533,10 @@ def test_classifier_chain_fit_and_predict_with_sparse_data(csr_container): X, Y = generate_multilabel_dataset_with_correlations() X_sparse = csr_container(X) - classifier_chain = ClassifierChain(LogisticRegression()) - classifier_chain.fit(X_sparse, Y) + classifier_chain = ClassifierChain(LogisticRegression()).fit(X_sparse, Y) Y_pred_sparse = classifier_chain.predict(X_sparse) - classifier_chain = ClassifierChain(LogisticRegression()) - classifier_chain.fit(X, Y) + classifier_chain = ClassifierChain(LogisticRegression()).fit(X, Y) Y_pred_dense = classifier_chain.predict(X) assert_array_equal(Y_pred_sparse, Y_pred_dense) @@ -564,26 +565,41 @@ def test_classifier_chain_vs_independent_models(): ) +@pytest.mark.parametrize( + "chain_method", + ["predict", "predict_proba", "predict_log_proba", "decision_function"], +) @pytest.mark.parametrize("response_method", ["predict_proba", "predict_log_proba"]) -def test_base_chain_fit_and_predict(response_method): - # Fit base chain and verify predict performance +def test_classifier_chain_fit_and_predict(chain_method, response_method): + # Fit classifier chain and verify predict performance X, Y = generate_multilabel_dataset_with_correlations() - chains = [RegressorChain(Ridge()), ClassifierChain(LogisticRegression())] - for chain in chains: - chain.fit(X, Y) - Y_pred = chain.predict(X) - assert Y_pred.shape == Y.shape - assert [c.coef_.size for c in chain.estimators_] == list( - range(X.shape[1], X.shape[1] + Y.shape[1]) - ) + chain = ClassifierChain(LogisticRegression(), chain_method=chain_method) + chain.fit(X, Y) + Y_pred = chain.predict(X) + assert Y_pred.shape == Y.shape + assert [c.coef_.size for c in chain.estimators_] == list( + range(X.shape[1], X.shape[1] + Y.shape[1]) + ) - Y_prob = getattr(chains[1], response_method)(X) + Y_prob = getattr(chain, response_method)(X) if response_method == "predict_log_proba": Y_prob = np.exp(Y_prob) Y_binary = Y_prob >= 0.5 assert_array_equal(Y_binary, Y_pred) - assert isinstance(chains[1], ClassifierMixin) + assert isinstance(chain, ClassifierMixin) + + +def test_regressor_chain_fit_and_predict(): + # Fit regressor chain and verify Y and estimator coefficients shape + X, Y = generate_multilabel_dataset_with_correlations() + chain = RegressorChain(Ridge()) + chain.fit(X, Y) + Y_pred = chain.predict(X) + assert Y_pred.shape == Y.shape + assert [c.coef_.size for c in chain.estimators_] == list( + range(X.shape[1], X.shape[1] + Y.shape[1]) + ) @pytest.mark.parametrize("csr_container", CSR_CONTAINERS) @@ -619,24 +635,37 @@ def test_base_chain_random_order(): assert_array_almost_equal(est1.coef_, est2.coef_) -def test_base_chain_crossval_fit_and_predict(): +@pytest.mark.parametrize( + "chain_type, chain_method", + [ + ("classifier", "predict"), + ("classifier", "predict_proba"), + ("classifier", "predict_log_proba"), + ("classifier", "decision_function"), + ("regressor", ""), + ], +) +def test_base_chain_crossval_fit_and_predict(chain_type, chain_method): # Fit chain with cross_val_predict and verify predict # performance X, Y = generate_multilabel_dataset_with_correlations() - for chain in [ClassifierChain(LogisticRegression()), RegressorChain(Ridge())]: - chain.fit(X, Y) - chain_cv = clone(chain).set_params(cv=3) - chain_cv.fit(X, Y) - Y_pred_cv = chain_cv.predict(X) - Y_pred = chain.predict(X) - - assert Y_pred_cv.shape == Y_pred.shape - assert not np.all(Y_pred == Y_pred_cv) - if isinstance(chain, ClassifierChain): - assert jaccard_score(Y, Y_pred_cv, average="samples") > 0.4 - else: - assert mean_squared_error(Y, Y_pred_cv) < 0.25 + if chain_type == "classifier": + chain = ClassifierChain(LogisticRegression(), chain_method=chain_method) + else: + chain = RegressorChain(Ridge()) + chain.fit(X, Y) + chain_cv = clone(chain).set_params(cv=3) + chain_cv.fit(X, Y) + Y_pred_cv = chain_cv.predict(X) + Y_pred = chain.predict(X) + + assert Y_pred_cv.shape == Y_pred.shape + assert not np.all(Y_pred == Y_pred_cv) + if isinstance(chain, ClassifierChain): + assert jaccard_score(Y, Y_pred_cv, average="samples") > 0.4 + else: + assert mean_squared_error(Y, Y_pred_cv) < 0.25 @pytest.mark.parametrize(