Skip to content

Commit a631424

Browse files
NicolasHugadrinjalali
authored andcommitted
FIX fix bug when warm starting with early stopping in Hist GBDT (#15624)
1 parent 1236388 commit a631424

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py

+2
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,8 @@ def fit(self, X, y):
270270

271271
# Compute raw predictions
272272
raw_predictions = self._raw_predict(X_binned_train)
273+
if self.do_early_stopping_ and self._use_validation_data:
274+
raw_predictions_val = self._raw_predict(X_binned_val)
273275

274276
if self.do_early_stopping_ and self.scoring != 'loss':
275277
# Compute the subsample set

sklearn/ensemble/_hist_gradient_boosting/tests/test_warm_start.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,15 @@ def test_warm_start_max_depth(GradientBoosting, X, y):
9393
(HistGradientBoostingClassifier, X_classification, y_classification),
9494
(HistGradientBoostingRegressor, X_regression, y_regression)
9595
])
96-
def test_warm_start_early_stopping(GradientBoosting, X, y):
96+
@pytest.mark.parametrize('scoring', (None, 'loss'))
97+
def test_warm_start_early_stopping(GradientBoosting, X, y, scoring):
9798
# Make sure that early stopping occurs after a small number of iterations
9899
# when fitting a second time with warm starting.
99100

100101
n_iter_no_change = 5
101102
gb = GradientBoosting(
102103
n_iter_no_change=n_iter_no_change, max_iter=10000,
103-
random_state=42, warm_start=True, tol=1e-3
104+
random_state=42, warm_start=True, tol=1e-3, scoring=scoring,
104105
)
105106
gb.fit(X, y)
106107
n_iter_first_fit = gb.n_iter_

0 commit comments

Comments
 (0)