|
8 | 8 | from sklearn.utils._testing import assert_almost_equal
|
9 | 9 | from sklearn.utils._testing import assert_array_equal
|
10 | 10 | from sklearn.utils._testing import assert_array_almost_equal
|
| 11 | +from sklearn.dummy import DummyRegressor |
| 12 | +from sklearn.model_selection import GridSearchCV |
11 | 13 |
|
12 | 14 | from sklearn.metrics import explained_variance_score
|
13 | 15 | from sklearn.metrics import mean_absolute_error
|
|
19 | 21 | from sklearn.metrics import mean_pinball_loss
|
20 | 22 | from sklearn.metrics import r2_score
|
21 | 23 | from sklearn.metrics import mean_tweedie_deviance
|
| 24 | +from sklearn.metrics import make_scorer |
22 | 25 |
|
23 | 26 | from sklearn.metrics._regression import _check_reg_targets
|
24 | 27 |
|
@@ -409,3 +412,30 @@ def objective_func(x):
|
409 | 412 |
|
410 | 413 | # The minimum is not unique with limited data, hence the tolerance.
|
411 | 414 | assert result.x == pytest.approx(target_quantile, abs=1e-3)
|
| 415 | + |
| 416 | + |
| 417 | +def test_dummy_quantile_parameter_tuning(): |
| 418 | + # Integration test to check that it is possible to use the pinball loss to |
| 419 | + # tune the hyperparameter of a quantile regressor. This is conceptually |
| 420 | + # similar to the previous test but using the scikit-learn estimator and |
| 421 | + # scoring API instead. |
| 422 | + n_samples = 1000 |
| 423 | + rng = np.random.RandomState(0) |
| 424 | + X = rng.normal(size=(n_samples, 5)) # Ignored |
| 425 | + y = rng.exponential(size=n_samples) |
| 426 | + |
| 427 | + all_quantiles = [0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95] |
| 428 | + for alpha in all_quantiles: |
| 429 | + neg_mean_pinball_loss = make_scorer( |
| 430 | + mean_pinball_loss, |
| 431 | + alpha=alpha, |
| 432 | + greater_is_better=False, |
| 433 | + ) |
| 434 | + regressor = DummyRegressor(strategy="quantile", quantile=0.25) |
| 435 | + grid_search = GridSearchCV( |
| 436 | + regressor, |
| 437 | + param_grid=dict(quantile=all_quantiles), |
| 438 | + scoring=neg_mean_pinball_loss, |
| 439 | + ).fit(X, y) |
| 440 | + |
| 441 | + assert grid_search.best_params_["quantile"] == pytest.approx(alpha) |
0 commit comments