diff --git a/sklearn/metrics/_regression.py b/sklearn/metrics/_regression.py index fc8ccaac82c62..6e8dbff10fbe8 100644 --- a/sklearn/metrics/_regression.py +++ b/sklearn/metrics/_regression.py @@ -216,6 +216,15 @@ def mean_absolute_error( return np.average(output_errors, weights=multioutput) +@validate_params( + { + "y_true": ["array-like"], + "y_pred": ["array-like"], + "sample_weight": ["array-like", None], + "alpha": [Interval(Real, 0, 1, closed="both")], + "multioutput": [StrOptions({"raw_values", "uniform_average"}), "array-like"], + } +) def mean_pinball_loss( y_true, y_pred, *, sample_weight=None, alpha=0.5, multioutput="uniform_average" ): @@ -285,18 +294,13 @@ def mean_pinball_loss( sign = (diff >= 0).astype(diff.dtype) loss = alpha * sign * diff - (1 - alpha) * (1 - sign) * diff output_errors = np.average(loss, weights=sample_weight, axis=0) - if isinstance(multioutput, str): - if multioutput == "raw_values": - return output_errors - elif multioutput == "uniform_average": - # pass None as weights to np.average: uniform mean - multioutput = None - else: - raise ValueError( - "multioutput is expected to be 'raw_values' " - "or 'uniform_average' but we got %r" - " instead." % multioutput - ) + + if isinstance(multioutput, str) and multioutput == "raw_values": + return output_errors + + if isinstance(multioutput, str) and multioutput == "uniform_average": + # pass None as weights to np.average: uniform mean + multioutput = None return np.average(output_errors, weights=multioutput) diff --git a/sklearn/metrics/tests/test_regression.py b/sklearn/metrics/tests/test_regression.py index 241a9ba4f2855..d9223401cec9c 100644 --- a/sklearn/metrics/tests/test_regression.py +++ b/sklearn/metrics/tests/test_regression.py @@ -344,12 +344,6 @@ def test_regression_multioutput_array(): mse = mean_squared_error(y_true, y_pred, multioutput="raw_values") mae = mean_absolute_error(y_true, y_pred, multioutput="raw_values") - err_msg = ( - "multioutput is expected to be 'raw_values' " - "or 'uniform_average' but we got 'variance_weighted' instead." - ) - with pytest.raises(ValueError, match=err_msg): - mean_pinball_loss(y_true, y_pred, multioutput="variance_weighted") pbl = mean_pinball_loss(y_true, y_pred, multioutput="raw_values") mape = mean_absolute_percentage_error(y_true, y_pred, multioutput="raw_values") diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 9b2b56cdb3eb8..0a862c027e6a0 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -124,6 +124,7 @@ def _check_function_param_validation( "sklearn.metrics.hamming_loss", "sklearn.metrics.log_loss", "sklearn.metrics.mean_absolute_error", + "sklearn.metrics.mean_pinball_loss", "sklearn.metrics.mean_squared_error", "sklearn.metrics.mean_tweedie_deviance", "sklearn.metrics.median_absolute_error",