From 8ee276f3603c3f838fed1628c75987d8f75b1294 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9ophile=20Baranger?= Date: Fri, 10 Mar 2023 22:36:58 +0100 Subject: [PATCH] add parameter validation to metrics.recall_score --- sklearn/metrics/_classification.py | 17 +++++++++++++++++ sklearn/tests/test_public_functions.py | 1 + 2 files changed, 18 insertions(+) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 68659d251cef7..88b8af7944ecc 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -2070,6 +2070,23 @@ def precision_score( return p +@validate_params( + { + "y_true": ["array-like", "sparse matrix"], + "y_pred": ["array-like", "sparse matrix"], + "labels": ["array-like", None], + "pos_label": [Real, str, "boolean", None], + "average": [ + StrOptions({"micro", "macro", "samples", "weighted", "binary"}), + None, + ], + "sample_weight": ["array-like", None], + "zero_division": [ + Options(Real, {0, 1}), + StrOptions({"warn"}), + ], + } +) def recall_score( y_true, y_pred, diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 7c04e5ce44319..f6d400d805c73 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -157,6 +157,7 @@ def _check_function_param_validation( "sklearn.metrics.precision_recall_fscore_support", "sklearn.metrics.precision_score", "sklearn.metrics.r2_score", + "sklearn.metrics.recall_score", "sklearn.metrics.roc_curve", "sklearn.metrics.zero_one_loss", "sklearn.model_selection.train_test_split",