diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index ac488ac3af772..b51ef3189ae56 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -885,6 +885,13 @@ def jaccard_score( return np.average(jaccard, weights=weights) +@validate_params( + { + "y_true": ["array-like"], + "y_pred": ["array-like"], + "sample_weight": ["array-like", None], + } +) def matthews_corrcoef(y_true, y_pred, *, sample_weight=None): """Compute the Matthews correlation coefficient (MCC). @@ -905,10 +912,10 @@ def matthews_corrcoef(y_true, y_pred, *, sample_weight=None): Parameters ---------- - y_true : array, shape = [n_samples] + y_true : array-like of shape (n_samples,) Ground truth (correct) target values. - y_pred : array, shape = [n_samples] + y_pred : array-like of shape (n_samples,) Estimated targets as returned by a classifier. sample_weight : array-like of shape (n_samples,), default=None diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index aa705c1599582..c2f0f18a7825e 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -127,6 +127,7 @@ def _check_function_param_validation( "sklearn.metrics.hamming_loss", "sklearn.metrics.jaccard_score", "sklearn.metrics.log_loss", + "sklearn.metrics.matthews_corrcoef", "sklearn.metrics.max_error", "sklearn.metrics.mean_absolute_error", "sklearn.metrics.mean_absolute_percentage_error",