Skip to content

[MRG+2] Early stopping for Gradient Boosting Classifier/Regressor #7071

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 69 commits into from
Aug 9, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
8b1ca1b
Added GradientBoostingClassifierCV with unit tests
vighneshbirodkar Nov 2, 2015
06e6fc9
ENH Allow using Randomized Search
raghavrv Jul 20, 2016
fdb2452
GradientBoostingRegressorCV with unit tests.
raghavrv Jul 20, 2016
1d6a893
ENH Permit disabling early stopping to exhaustively search thro max_i…
raghavrv Jul 29, 2016
93a57e9
TST The predict and predict_proba should be equal for clf
raghavrv Jul 29, 2016
c7e3581
ENH Preserve random state across warm starts?
raghavrv Jul 29, 2016
2e85db8
FIX include the n_stop_rounds in the final estimator count
raghavrv Aug 5, 2016
82acb91
Docfixes
raghavrv Aug 5, 2016
4774f47
Fix random state, precision
raghavrv Aug 7, 2016
bf58074
Comment out tests that are numpy/scipy dependent
raghavrv Aug 7, 2016
c03a5b6
Remove scaffolding
raghavrv Aug 8, 2016
bd29681
Switch to threading.
raghavrv Aug 8, 2016
687c916
Skip all n_estimator comparison tests
raghavrv Aug 8, 2016
bae46ea
Remove Randomized param search
raghavrv Aug 8, 2016
e78e3ed
Reword the example doc
raghavrv Aug 8, 2016
3441d36
Add licence and authors
raghavrv Aug 8, 2016
ec2ec77
[ci skip] Reword dostring
raghavrv Aug 8, 2016
66337df
[ci skip] n_stop_rounds --> n_iter_no_change
raghavrv Aug 8, 2016
aafe6f0
Minor fix
raghavrv Aug 16, 2016
e8a3465
Use _BaseGradientBoostingCV class
raghavrv Aug 23, 2016
68c2165
COSMIT
raghavrv Aug 23, 2016
d942fd9
remove pre_despatch; Add criterion and min_impurity_split
raghavrv Aug 23, 2016
5654791
Use cloned instance of estimator rather than re-initializing them
raghavrv Aug 31, 2016
4bf3819
selfnote: don't be stupid
raghavrv Aug 31, 2016
43fac20
PEP8
raghavrv Sep 1, 2016
a89ea13
stash
raghavrv Sep 1, 2016
163b9b1
Use GridSearchCV like cv_results_
raghavrv Sep 1, 2016
049d50a
Merge branch 'master' into gbcv
raghavrv Jan 18, 2017
17e3d44
Merge branch 'master' into gbcv
raghavrv Jan 23, 2017
5a34758
Merge branch 'master' into gbcv
raghavrv Jul 5, 2017
0202fa6
Remove gradient_boosting_cv.py + tests
raghavrv Jul 5, 2017
ffbd59c
squash
raghavrv Jul 7, 2017
2bdbc2c
Enable early stopping
raghavrv Jul 7, 2017
bd2800f
TST gradient boosting without early stopping
raghavrv Jul 7, 2017
f61fcf8
Rename example and fix it up for the changes
raghavrv Jul 9, 2017
9f7a131
Make example simpler
raghavrv Jul 10, 2017
41927a3
Add n_est=
raghavrv Jul 10, 2017
93cc6d6
Add plot for fit times too
raghavrv Jul 10, 2017
521b3be
Flake8 fix
raghavrv Jul 10, 2017
b599496
Merge branch 'master' into gbcv
raghavrv Jul 10, 2017
3cffae8
Merge branch 'master' into gbcv
raghavrv Jul 11, 2017
1315aba
Flake8
raghavrv Jul 11, 2017
ebedf80
use assert_equal to know n_estimators value in the failing builds
raghavrv Jul 11, 2017
89eb1ef
Flake8
raghavrv Jul 11, 2017
95aa54b
More flake*
raghavrv Jul 11, 2017
fcd0ad5
Use different dataset to see if that passes tests smoothly
raghavrv Jul 11, 2017
f607c68
Flake8 again
raghavrv Jul 11, 2017
9e0b4dd
Update doc
raghavrv Jul 11, 2017
ca424f3
Add comment to clarify the usage of generator for y_val_pred
raghavrv Jul 14, 2017
b8b5e64
Add whatsnew entry
raghavrv Jul 15, 2017
4c180b1
Merge branch 'master' into gbcv
raghavrv Jul 15, 2017
90d6acc
Remove spurious whatsnew entry
raghavrv Jul 15, 2017
4b5be91
DOC add other early stopping params in GradientBoostingRegressorCV
raghavrv Jul 16, 2017
72187a0
Address Joel's comments in example doc
raghavrv Jul 16, 2017
73b2dac
Address Joel's reviews
raghavrv Jul 18, 2017
3849cfd
TST test the validation_fraction and n_iter_no_change params
raghavrv Jul 24, 2017
e33e5b4
Double backticks
raghavrv Jul 24, 2017
9ab41e9
Address Joel's comments
raghavrv Jul 24, 2017
724ac33
Use clone to improve readability
raghavrv Jul 24, 2017
5c7ab8b
flake8
raghavrv Jul 25, 2017
6fc0851
DOC Move whatsnew entry into Classifiers and Regressors section
raghavrv Jul 26, 2017
434b5a4
Merge branch 'master' into gbcv
raghavrv Jul 26, 2017
baf2eb7
Fix indentation
raghavrv Jul 26, 2017
a68fcf6
fix indentation
raghavrv Jul 27, 2017
95695b8
Merge branch 'master' into gbcv
raghavrv Jul 27, 2017
8ef7b32
fix single backtick
raghavrv Jul 27, 2017
0e99698
Merge branch 'master' into gbcv
raghavrv Aug 9, 2017
0aa0a1a
Move the new entry under 0.20
raghavrv Aug 9, 2017
c0dc1d8
Update version information
raghavrv Aug 9, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,26 @@
Release history
===============

Version 0.20 (under development)
================================

Changed models
--------------

Changelog
---------

New features
............

Classifiers and regressors

- :class:`ensemble.GradientBoostingClassifier` and
:class:`ensemble.GradientBoostingRegressor` now support early stopping
via ``n_iter_no_change``, ``validation_fraction`` and ``tol``. :issue:`7071`
by `Raghav RV`_


Version 0.19
============

Expand Down
160 changes: 160 additions & 0 deletions examples/ensemble/plot_gradient_boosting_early_stopping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
"""
===================================
Early stopping of Gradient Boosting
===================================

Gradient boosting is an ensembling technique where several weak learners
(regression trees) are combined to yield a powerful single model, in an
iterative fashion.

Early stopping support in Gradient Boosting enables us to find the least number
of iterations which is sufficient to build a model that generalizes well to
unseen data.

The concept of early stopping is simple. We specify a ``validation_fraction``
which denotes the fraction of the whole dataset that will be kept aside from
training to assess the validation loss of the model. The gradient boosting
model is trained using the training set and evaluated using the validation set.
When each additional stage of regression tree is added, the validation set is
used to score the model. This is continued until the scores of the model in
the last ``n_iter_no_change`` stages do not improve by atleast `tol`. After
that the model is considered to have converged and further addition of stages
is "stopped early".

The number of stages of the final model is available at the attribute
``n_estimators_``.

This example illustrates how the early stopping can used in the
:class:`sklearn.ensemble.GradientBoostingClassifier` model to achieve
almost the same accuracy as compared to a model built without early stopping
using many fewer estimators. This can significantly reduce training time,
memory usage and prediction latency.
"""

# Authors: Vighnesh Birodkar <vighneshbirodkar@nyu.edu>
# Raghav RV <rvraghav93@gmail.com>
# License: BSD 3 clause

import time

import numpy as np
import matplotlib.pyplot as plt

from sklearn import ensemble
from sklearn import datasets
from sklearn.model_selection import train_test_split

print(__doc__)

data_list = [datasets.load_iris(), datasets.load_digits()]
data_list = [(d.data, d.target) for d in data_list]
data_list += [datasets.make_hastie_10_2()]
names = ['Iris Data', 'Digits Data', 'Hastie Data']

n_gb = []
score_gb = []
time_gb = []
n_gbes = []
score_gbes = []
time_gbes = []

n_estimators = 500

for X, y in data_list:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,
random_state=0)

# We specify that if the scores don't improve by atleast 0.01 for the last
# 10 stages, stop fitting additional stages
gbes = ensemble.GradientBoostingClassifier(n_estimators=n_estimators,
validation_fraction=0.2,
n_iter_no_change=5, tol=0.01,
random_state=0)
gb = ensemble.GradientBoostingClassifier(n_estimators=n_estimators,
random_state=0)
start = time.time()
gb.fit(X_train, y_train)
time_gb.append(time.time() - start)

start = time.time()
gbes.fit(X_train, y_train)
time_gbes.append(time.time() - start)

score_gb.append(gb.score(X_test, y_test))
score_gbes.append(gbes.score(X_test, y_test))

n_gb.append(gb.n_estimators_)
n_gbes.append(gbes.n_estimators_)

bar_width = 0.2
n = len(data_list)
index = np.arange(0, n * bar_width, bar_width) * 2.5
index = index[0:n]

#######################################################################
# Compare scores with and without early stopping
# ----------------------------------------------

plt.figure(figsize=(9, 5))

bar1 = plt.bar(index, score_gb, bar_width, label='Without early stopping',
color='crimson')
bar2 = plt.bar(index + bar_width, score_gbes, bar_width,
label='With early stopping', color='coral')

max_y = np.amax(np.maximum(score_gb, score_gbes))

plt.xticks(index + bar_width, names)
plt.yticks(np.arange(0, 1.3, 0.1))


def autolabel(rects, n_estimators):
"""
Attach a text label above each bar displaying n_estimators of each model
"""
for i, rect in enumerate(rects):
plt.text(rect.get_x() + rect.get_width() / 2.,
1.05 * rect.get_height(), 'n_est=%d' % n_estimators[i],
ha='center', va='bottom')


autolabel(bar1, n_gb)
autolabel(bar2, n_gbes)

plt.ylim([0, 1.3])
plt.legend(loc='best')
plt.grid(True)

plt.xlabel('Datasets')
plt.ylabel('Test score')

plt.show()


#######################################################################
# Compare fit times with and without early stopping
# ----------------------------------------------

plt.figure(figsize=(9, 5))

bar1 = plt.bar(index, time_gb, bar_width, label='Without early stopping',
color='crimson')
bar2 = plt.bar(index + bar_width, time_gbes, bar_width,
label='With early stopping', color='coral')

max_y = np.amax(np.maximum(time_gb, time_gbes))

plt.xticks(index + bar_width, names)
plt.yticks(np.linspace(0, 1.3 * max_y, 13))

autolabel(bar1, n_gb)
autolabel(bar2, n_gbes)

plt.ylim([0, 1.3 * max_y])
plt.legend(loc='best')
plt.grid(True)

plt.xlabel('Datasets')
plt.ylabel('Fit Time')

plt.show()
Loading