Skip to content

Commit 8edbed1

Browse files
committed
Change optimization test to make it run faster
1 parent eb22059 commit 8edbed1

File tree

1 file changed

+6
-9
lines changed

1 file changed

+6
-9
lines changed

sklearn/metrics/tests/test_regression.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ def test_mean_pinball_loss_on_constant_predictions(
366366
"with support for np.quantile.")
367367

368368
# Check that the pinball loss is minimized by the empirical quantile.
369-
n_samples = 1000
369+
n_samples = 3000
370370
rng = np.random.RandomState(42)
371371
data = getattr(rng, distribution)(size=n_samples)
372372

@@ -400,18 +400,15 @@ def test_mean_pinball_loss_on_constant_predictions(
400400
# Check that we can actually recover the target_quantile by minimizing the
401401
# pinball loss w.r.t. the constant prediction quantile.
402402
def objective_func(x):
403-
if x < 0 or x > 1:
404-
return np.inf
405-
pred = np.quantile(data, x)
406-
constant_pred = np.full(n_samples, fill_value=pred)
403+
constant_pred = np.full(n_samples, fill_value=x)
407404
return mean_pinball_loss(data, constant_pred, alpha=target_quantile)
408405

409-
result = optimize.minimize(objective_func, 0.5, method="Nelder-Mead")
406+
result = optimize.minimize(objective_func, data.mean(),
407+
method="Nelder-Mead")
410408
assert result.success
411-
assert result.fun == pytest.approx(best_pbl)
412-
413409
# The minimum is not unique with limited data, hence the tolerance.
414-
assert result.x == pytest.approx(target_quantile, abs=1e-3)
410+
assert result.x == pytest.approx(best_pred, rel=1e-2)
411+
assert result.fun == pytest.approx(best_pbl)
415412

416413

417414
def test_dummy_quantile_parameter_tuning():

0 commit comments

Comments
 (0)