diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 88b8af7944ecc..8beb29467cd2c 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -1227,6 +1227,25 @@ def f1_score( ) +@validate_params( + { + "y_true": ["array-like", "sparse matrix"], + "y_pred": ["array-like", "sparse matrix"], + "beta": [Interval(Real, 0.0, None, closed="both")], + "labels": ["array-like", None], + "pos_label": [Real, str, "boolean", None], + "average": [ + StrOptions({"micro", "macro", "samples", "weighted", "binary"}), + None, + ], + "warn_for": [list, tuple, set], + "sample_weight": ["array-like", None], + "zero_division": [ + Options(Real, {0, 1}), + StrOptions({"warn"}), + ], + } +) def fbeta_score( y_true, y_pred, @@ -2743,9 +2762,11 @@ def log_loss( else: # TODO: Remove user defined eps in 1.5 warnings.warn( - "Setting the eps parameter is deprecated and will " - "be removed in 1.5. Instead eps will always have" - "a default value of `np.finfo(y_pred.dtype).eps`.", + ( + "Setting the eps parameter is deprecated and will " + "be removed in 1.5. Instead eps will always have" + "a default value of `np.finfo(y_pred.dtype).eps`." + ), FutureWarning, ) @@ -2812,8 +2833,10 @@ def log_loss( y_pred_sum = y_pred.sum(axis=1) if not np.isclose(y_pred_sum, 1, rtol=1e-15, atol=5 * eps).all(): warnings.warn( - "The y_pred values do not sum to one. Starting from 1.5 this" - "will result in an error.", + ( + "The y_pred values do not sum to one. Starting from 1.5 this" + "will result in an error." + ), UserWarning, ) y_pred = y_pred / y_pred_sum[:, np.newaxis] diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 31aeb37c5e536..913950625ba72 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -137,6 +137,7 @@ def _check_function_param_validation( "sklearn.metrics.dcg_score", "sklearn.metrics.det_curve", "sklearn.metrics.f1_score", + "sklearn.metrics.fbeta_score", "sklearn.metrics.get_scorer", "sklearn.metrics.hamming_loss", "sklearn.metrics.jaccard_score",