From 71470186611d70bc8702b94ac56ae7ac23f00440 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 3 Nov 2023 22:40:07 +0100 Subject: [PATCH 1/2] ENH add predict_log_proba to ClassifierChain --- doc/whats_new/v1.4.rst | 6 ++++++ sklearn/multioutput.py | 15 +++++++++++++++ sklearn/tests/test_multioutput.py | 7 +++++-- 3 files changed, 26 insertions(+), 2 deletions(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index 69549e6527b69..0b59dea95541c 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -417,6 +417,12 @@ Changelog object in the parameter grid if it's an estimator. :pr:`26786` by `Adrin Jalali`_. +:mod:`sklearn.multioutput` +.......................... + +- |Enhancement| Add method `predict_log_proba` to :class:`multioutput.ClassifierChain`. + :pr:`xxxx` by :user:`Guillaume Lemaitre `. + :mod:`sklearn.neighbors` ........................ diff --git a/sklearn/multioutput.py b/sklearn/multioutput.py index b7859006ac215..1e1c2b646ca82 100644 --- a/sklearn/multioutput.py +++ b/sklearn/multioutput.py @@ -956,6 +956,21 @@ def predict_proba(self, X): return Y_prob + def predict_log_proba(self, X): + """Predict logarithm of probability estimates. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + The input data. + + Returns + ------- + Y_log_prob : array-like of shape (n_samples, n_classes) + The predicted logarithm of the probabilities. + """ + return np.log(self.predict_proba(X)) + @_available_if_base_estimator_has("decision_function") def decision_function(self, X): """Evaluate the decision_function of the models in the chain. diff --git a/sklearn/tests/test_multioutput.py b/sklearn/tests/test_multioutput.py index 493d0fc7dc8b5..9d5accac21040 100644 --- a/sklearn/tests/test_multioutput.py +++ b/sklearn/tests/test_multioutput.py @@ -548,7 +548,8 @@ def test_classifier_chain_vs_independent_models(): ) -def test_base_chain_fit_and_predict(): +@pytest.mark.parametrize("response_method", ["predict_proba", "predict_log_proba"]) +def test_base_chain_fit_and_predict(response_method): # Fit base chain and verify predict performance X, Y = generate_multilabel_dataset_with_correlations() chains = [RegressorChain(Ridge()), ClassifierChain(LogisticRegression())] @@ -560,7 +561,9 @@ def test_base_chain_fit_and_predict(): range(X.shape[1], X.shape[1] + Y.shape[1]) ) - Y_prob = chains[1].predict_proba(X) + Y_prob = getattr(chains[1], response_method)(X) + if response_method == "predict_log_proba": + Y_prob = np.exp(Y_prob) Y_binary = Y_prob >= 0.5 assert_array_equal(Y_binary, Y_pred) From c2eb71c275aaff3bf9f6bb70859b6d3f8bf363d0 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 3 Nov 2023 22:42:00 +0100 Subject: [PATCH 2/2] update pr number --- doc/whats_new/v1.4.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index 0b59dea95541c..17c78043290ec 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -421,7 +421,7 @@ Changelog .......................... - |Enhancement| Add method `predict_log_proba` to :class:`multioutput.ClassifierChain`. - :pr:`xxxx` by :user:`Guillaume Lemaitre `. + :pr:`27720` by :user:`Guillaume Lemaitre `. :mod:`sklearn.neighbors` ........................