From a7f466495c356bbd65b91cd383f3b78fc47ec631 Mon Sep 17 00:00:00 2001 From: aymeric basset Date: Mon, 13 Mar 2023 23:20:03 +0100 Subject: [PATCH] add parameter validation to metrics.fbeta_score --- sklearn/metrics/_classification.py | 18 ++++++++++++++++++ sklearn/tests/test_public_functions.py | 1 + 2 files changed, 19 insertions(+) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 88b8af7944ecc..9aba3f57a54e6 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -1227,6 +1227,24 @@ def f1_score( ) +@validate_params( + { + "y_true": ["array-like", "sparse matrix"], + "y_pred": ["array-like", "sparse matrix"], + "beta": [Interval(Real, 0.0, None, closed="both")], + "labels": ["array-like", None], + "pos_label": [Real, str, "boolean", None], + "average": [ + StrOptions({"micro", "macro", "samples", "weighted", "binary"}), + None, + ], + "sample_weight": ["array-like", None], + "zero_division": [ + Options(Real, {0, 1}), + StrOptions({"warn"}), + ], + } +) def fbeta_score( y_true, y_pred, diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 31aeb37c5e536..913950625ba72 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -137,6 +137,7 @@ def _check_function_param_validation( "sklearn.metrics.dcg_score", "sklearn.metrics.det_curve", "sklearn.metrics.f1_score", + "sklearn.metrics.fbeta_score", "sklearn.metrics.get_scorer", "sklearn.metrics.hamming_loss", "sklearn.metrics.jaccard_score",