Skip to content

Commit eb22059

Browse files
committed
Add integration test
1 parent 728d632 commit eb22059

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

sklearn/metrics/tests/test_regression.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from sklearn.utils._testing import assert_almost_equal
99
from sklearn.utils._testing import assert_array_equal
1010
from sklearn.utils._testing import assert_array_almost_equal
11+
from sklearn.dummy import DummyRegressor
12+
from sklearn.model_selection import GridSearchCV
1113

1214
from sklearn.metrics import explained_variance_score
1315
from sklearn.metrics import mean_absolute_error
@@ -19,6 +21,7 @@
1921
from sklearn.metrics import mean_pinball_loss
2022
from sklearn.metrics import r2_score
2123
from sklearn.metrics import mean_tweedie_deviance
24+
from sklearn.metrics import make_scorer
2225

2326
from sklearn.metrics._regression import _check_reg_targets
2427

@@ -409,3 +412,30 @@ def objective_func(x):
409412

410413
# The minimum is not unique with limited data, hence the tolerance.
411414
assert result.x == pytest.approx(target_quantile, abs=1e-3)
415+
416+
417+
def test_dummy_quantile_parameter_tuning():
418+
# Integration test to check that it is possible to use the pinball loss to
419+
# tune the hyperparameter of a quantile regressor. This is conceptually
420+
# similar to the previous test but using the scikit-learn estimator and
421+
# scoring API instead.
422+
n_samples = 1000
423+
rng = np.random.RandomState(0)
424+
X = rng.normal(size=(n_samples, 5)) # Ignored
425+
y = rng.exponential(size=n_samples)
426+
427+
all_quantiles = [0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95]
428+
for alpha in all_quantiles:
429+
neg_mean_pinball_loss = make_scorer(
430+
mean_pinball_loss,
431+
alpha=alpha,
432+
greater_is_better=False,
433+
)
434+
regressor = DummyRegressor(strategy="quantile", quantile=0.25)
435+
grid_search = GridSearchCV(
436+
regressor,
437+
param_grid=dict(quantile=all_quantiles),
438+
scoring=neg_mean_pinball_loss,
439+
).fit(X, y)
440+
441+
assert grid_search.best_params_["quantile"] == pytest.approx(alpha)

0 commit comments

Comments
 (0)