-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG] TST use global_random_seed in sklearn/ensemble/tests/test_gradient_boosting_loss_functions.py #23559
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] TST use global_random_seed in sklearn/ensemble/tests/test_gradient_boosting_loss_functions.py #23559
Conversation
Could you change the following test: @pytest.mark.parametrize("loss", ("log_loss", "exponential"))
def test_classification_synthetic(loss, global_random_seed):
# Test GradientBoostingClassifier on synthetic dataset used by
# Hastie et al. in ESLII - Figure 10.9
X, y = datasets.make_hastie_10_2(n_samples=12000, random_state=global_random_seed)
X_train, X_test = X[:2000], X[2000:]
y_train, y_test = y[:2000], y[2000:]
# Increasing the number of trees should decrease the test error
common_params = {
"max_depth": 1,
"learning_rate": 1.0,
"loss": loss,
"random_state": global_random_seed,
}
gbrt_100_stumps = GradientBoostingClassifier(n_estimators=100, **common_params)
gbrt_100_stumps.fit(X_train, y_train)
gbrt_200_stumps = GradientBoostingClassifier(n_estimators=200, **common_params)
gbrt_200_stumps.fit(X_train, y_train)
assert gbrt_100_stumps.score(X_test, y_test) < gbrt_200_stumps.score(X_test, y_test)
# Decision stumps are better suited for this dataset with a large number of
# estimators.
common_params = {
"n_estimators": 200,
"learning_rate": 1.0,
"loss": loss,
"random_state": global_random_seed,
}
gbrt_stumps = GradientBoostingClassifier(max_depth=1, **common_params)
gbrt_stumps.fit(X_train, y_train)
gbrt_10_nodes = GradientBoostingClassifier(max_leaf_nodes=10, **common_params)
gbrt_10_nodes.fit(X_train, y_train)
assert gbrt_stumps.score(X_test, y_test) > gbrt_10_nodes.score(X_test, y_test) I checked the ESL book and it was not obvious what it was testing. By changing the random state, the test was failing. So checking, the book is advocating that stumps are more suitable than deeper trees with this dataset (p. 363). So I modified the test accordingly and make sure that it passes all tests with different random state. |
You can add the global random seed to |
Another test to modify and relaxing the MSE error: def test_regression_synthetic(global_random_seed):
# Test on synthetic regression datasets used in Leo Breiman,
# `Bagging Predictors?. Machine Learning 24(2): 123-140 (1996).
random_state = check_random_state(global_random_seed)
regression_params = {
"n_estimators": 100,
"max_depth": 4,
"min_samples_split": 2,
"learning_rate": 0.1,
"loss": "squared_error",
}
# Friedman1
X, y = datasets.make_friedman1(n_samples=1200, random_state=random_state, noise=1.0)
X_train, y_train = X[:200], y[:200]
X_test, y_test = X[200:], y[200:]
clf = GradientBoostingRegressor()
clf.fit(X_train, y_train)
mse = mean_squared_error(y_test, clf.predict(X_test))
assert mse < 5.5
# Friedman2
X, y = datasets.make_friedman2(n_samples=1200, random_state=random_state)
X_train, y_train = X[:200], y[:200]
X_test, y_test = X[200:], y[200:]
clf = GradientBoostingRegressor(**regression_params)
clf.fit(X_train, y_train)
mse = mean_squared_error(y_test, clf.predict(X_test))
assert mse < 2500.0
# Friedman3
X, y = datasets.make_friedman3(n_samples=1200, random_state=random_state)
X_train, y_train = X[:200], y[:200]
X_test, y_test = X[200:], y[200:]
clf = GradientBoostingRegressor(**regression_params)
clf.fit(X_train, y_train)
mse = mean_squared_error(y_test, clf.predict(X_test))
assert mse < 0.025 |
You can add it to the test |
sklearn/ensemble/tests/test_gradient_boosting_loss_functions.py
Outdated
Show resolved
Hide resolved
Ups I see that I did not put my comment for the right test file. |
No problem, thank you |
test_multinomial_deviance test_init_raw_predictions_values test_lad_equals_quantiles
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR @haochunchang. LGTM
Reference Issues/PRs
Towards #22827
What does this implement/fix? Explain your changes.
global_random_seed
fixture to 3 tests.test_lad_equals_quantiles
seed to 0-99 to cover all possible seed.Any other comments?