From b41ba0a5c59714cb53f8ea520c8e3eef016f04a4 Mon Sep 17 00:00:00 2001 From: zeeshan Date: Tue, 14 Mar 2023 22:02:09 +0530 Subject: [PATCH] MAINT Added Parameter Validation for metrics.mean_gamma_deviance --- sklearn/metrics/_regression.py | 7 +++++++ sklearn/tests/test_public_functions.py | 1 + 2 files changed, 8 insertions(+) diff --git a/sklearn/metrics/_regression.py b/sklearn/metrics/_regression.py index 70a3303a4770d..d4337cad59984 100644 --- a/sklearn/metrics/_regression.py +++ b/sklearn/metrics/_regression.py @@ -1186,6 +1186,13 @@ def mean_poisson_deviance(y_true, y_pred, *, sample_weight=None): return mean_tweedie_deviance(y_true, y_pred, sample_weight=sample_weight, power=1) +@validate_params( + { + "y_true": ["array-like"], + "y_pred": ["array-like"], + "sample_weight": ["array-like", None], + } +) def mean_gamma_deviance(y_true, y_pred, *, sample_weight=None): """Mean Gamma deviance regression loss. diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 31aeb37c5e536..6ebab2246b6ae 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -148,6 +148,7 @@ def _check_function_param_validation( "sklearn.metrics.max_error", "sklearn.metrics.mean_absolute_error", "sklearn.metrics.mean_absolute_percentage_error", + "sklearn.metrics.mean_gamma_deviance", "sklearn.metrics.mean_pinball_loss", "sklearn.metrics.mean_squared_error", "sklearn.metrics.mean_tweedie_deviance",