@@ -350,13 +350,39 @@ def test_mean_absolute_percentage_error():
350
350
assert mean_absolute_percentage_error (y_true , y_pred ) == pytest .approx (0.2 )
351
351
352
352
353
- def test_pinball_loss ():
354
- data = np .linspace (0 , 1 , 100 )
355
- th = 0.2
356
- for constant_pred in range (0 , 11 ):
357
- alpha = constant_pred / 10
358
- pbl = pinball_loss (data , alpha * np .ones (100 ), alpha = th )
359
- err = ((alpha - data [data < alpha ]).sum () * (1 - th ) +
360
- (data [data >= alpha ] - alpha ).sum () * th )
361
- err /= data .shape [0 ]
362
- assert_almost_equal (err , pbl )
353
+ @pytest .mark .parametrize ("distribution" ,
354
+ ["normal" , "lognormal" , "exponential" , "uniform" ])
355
+ @pytest .mark .parametrize ("target_quantile" , [0.05 , 0.5 , 0.75 ])
356
+ def test_pinball_loss_on_constant_predictions (
357
+ distribution ,
358
+ target_quantile
359
+ ):
360
+ # Check that the pinball loss is minimized
361
+ n_samples = 100
362
+ rng = np .random .RandomState (42 )
363
+ data = getattr (rng , distribution )(size = n_samples )
364
+
365
+ # Compute the best possible pinball loss for any constant predictor:
366
+ best_pred = np .quantile (data , target_quantile )
367
+ best_pred = np .full (n_samples , fill_value = best_pred )
368
+ best_pbl = pinball_loss (data , best_pred , alpha = target_quantile )
369
+
370
+ candidate_predictions = np .quantile (data , np .linspace (0 , 1 , 100 ))
371
+ for pred in candidate_predictions :
372
+ # Compute the pinball loss of a constant predictor:
373
+ constant_pred = np .full (n_samples , fill_value = pred )
374
+ pbl = pinball_loss (data , constant_pred , alpha = target_quantile )
375
+
376
+ # Check that the loss of this constant predictor is greater or equal
377
+ # than the loss of using the optimal quantile (up to machine
378
+ # precision):
379
+ assert pbl >= best_pbl - np .finfo (best_pbl .dtype ).eps
380
+
381
+ # Check that the value of the pinball loss matches the analytical
382
+ # formula.
383
+ expected_pbl = (
384
+ (pred - data [data < pred ]).sum () * (1 - target_quantile ) +
385
+ (data [data >= pred ] - pred ).sum () * target_quantile
386
+ )
387
+ expected_pbl /= n_samples
388
+ assert_almost_equal (expected_pbl , pbl )
0 commit comments