From d2982eb93d11b54684ba80e3fd66d7c3d7aa65fd Mon Sep 17 00:00:00 2001 From: Bharat Raghunathan Date: Thu, 16 Mar 2023 21:48:39 +0530 Subject: [PATCH 1/3] MAINT Parameters validation for metrics.hinge_loss --- sklearn/metrics/_classification.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 187863e44515f..83dc88aa76134 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -2864,6 +2864,14 @@ def log_loss( return _weighted_sum(loss, sample_weight, normalize) +@validate_params( + { + "y_true": ["array-like"], + "pred_decision": ["array-like"], + "labels": ["array-like", None], + "sample_weight": ["array-like", None], + } +) def hinge_loss(y_true, pred_decision, *, labels=None, sample_weight=None): """Average hinge loss (non-regularized). From dac05c0fc6984b3af034ee332837f299690663cc Mon Sep 17 00:00:00 2001 From: Bharat Raghunathan Date: Thu, 16 Mar 2023 21:54:45 +0530 Subject: [PATCH 2/3] MAINT Add to list of public test functions --- sklearn/tests/test_public_functions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 2fa93fdfb6adf..fb18013d092a1 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -163,6 +163,7 @@ def _check_function_param_validation( "sklearn.metrics.fbeta_score", "sklearn.metrics.get_scorer", "sklearn.metrics.hamming_loss", + "sklearn.metrics.hinge_loss", "sklearn.metrics.jaccard_score", "sklearn.metrics.label_ranking_average_precision_score", "sklearn.metrics.label_ranking_loss", From 20b295597e2797856c56a058949571153719453f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= <34657725+jeremiedbb@users.noreply.github.com> Date: Fri, 17 Mar 2023 10:44:28 +0100 Subject: [PATCH 3/3] Update _classification.py --- sklearn/metrics/_classification.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 83dc88aa76134..50425bb082d39 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -2891,11 +2891,11 @@ def hinge_loss(y_true, pred_decision, *, labels=None, sample_weight=None): Parameters ---------- - y_true : array of shape (n_samples,) + y_true : array-like of shape (n_samples,) True target, consisting of integers of two values. The positive label must be greater than the negative label. - pred_decision : array of shape (n_samples,) or (n_samples, n_classes) + pred_decision : array-like of shape (n_samples,) or (n_samples, n_classes) Predicted decisions, as output by decision_function (floats). labels : array-like, default=None