Skip to content

Commit f343f3f

Browse files
committed
More comprehensive test for the pinball loss with constant predictions
1 parent 95c5709 commit f343f3f

File tree

1 file changed

+36
-10
lines changed

1 file changed

+36
-10
lines changed

sklearn/metrics/tests/test_regression.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -350,13 +350,39 @@ def test_mean_absolute_percentage_error():
350350
assert mean_absolute_percentage_error(y_true, y_pred) == pytest.approx(0.2)
351351

352352

353-
def test_pinball_loss():
354-
data = np.linspace(0, 1, 100)
355-
th = 0.2
356-
for constant_pred in range(0, 11):
357-
alpha = constant_pred / 10
358-
pbl = pinball_loss(data, alpha * np.ones(100), alpha=th)
359-
err = ((alpha - data[data < alpha]).sum() * (1 - th) +
360-
(data[data >= alpha] - alpha).sum() * th)
361-
err /= data.shape[0]
362-
assert_almost_equal(err, pbl)
353+
@pytest.mark.parametrize("distribution",
354+
["normal", "lognormal", "exponential", "uniform"])
355+
@pytest.mark.parametrize("target_quantile", [0.05, 0.5, 0.75])
356+
def test_pinball_loss_on_constant_predictions(
357+
distribution,
358+
target_quantile
359+
):
360+
# Check that the pinball loss is minimized
361+
n_samples = 100
362+
rng = np.random.RandomState(42)
363+
data = getattr(rng, distribution)(size=n_samples)
364+
365+
# Compute the best possible pinball loss for any constant predictor:
366+
best_pred = np.quantile(data, target_quantile)
367+
best_pred = np.full(n_samples, fill_value=best_pred)
368+
best_pbl = pinball_loss(data, best_pred, alpha=target_quantile)
369+
370+
candidate_predictions = np.quantile(data, np.linspace(0, 1, 100))
371+
for pred in candidate_predictions:
372+
# Compute the pinball loss of a constant predictor:
373+
constant_pred = np.full(n_samples, fill_value=pred)
374+
pbl = pinball_loss(data, constant_pred, alpha=target_quantile)
375+
376+
# Check that the loss of this constant predictor is greater or equal
377+
# than the loss of using the optimal quantile (up to machine
378+
# precision):
379+
assert pbl >= best_pbl - np.finfo(best_pbl.dtype).eps
380+
381+
# Check that the value of the pinball loss matches the analytical
382+
# formula.
383+
expected_pbl = (
384+
(pred - data[data < pred]).sum() * (1 - target_quantile) +
385+
(data[data >= pred] - pred).sum() * target_quantile
386+
)
387+
expected_pbl /= n_samples
388+
assert_almost_equal(expected_pbl, pbl)

0 commit comments

Comments
 (0)