Skip to content

Commit 4408dfb

Browse files
authored
FIX Removes warning in HGBT when fitting on dataframes (#22908)
1 parent 67eb4e5 commit 4408dfb

File tree

3 files changed

+29
-3
lines changed

3 files changed

+29
-3
lines changed

doc/whats_new/v1.1.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,11 @@ Changelog
391391
for instance using cgroups quota in a docker container. :pr:`22566` by
392392
:user:`Jérémie du Boisberranger <jeremiedbb>`.
393393

394+
- |Fix| :class:`ensemble.HistGradientBoostingClassifier` and
395+
:class:`ensemble.HistGradientBoostingRegressor` no longer warns when
396+
fitting on a pandas DataFrame with a non-default `scoring` parameter and
397+
early_stopping enabled. :pr:`22908` by `Thomas Fan`_.
398+
394399
:mod:`sklearn.feature_extraction`
395400
.................................
396401

sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from ...model_selection import train_test_split
3131
from ...preprocessing import LabelEncoder
3232
from ._gradient_boosting import _update_raw_predictions
33-
from .common import Y_DTYPE, X_DTYPE, X_BINNED_DTYPE, G_H_DTYPE
33+
from .common import Y_DTYPE, X_DTYPE, G_H_DTYPE
3434

3535
from .binning import _BinMapper
3636
from .grower import TreeGrower
@@ -936,8 +936,10 @@ def _raw_predict(self, X, n_threads=None):
936936
The raw predicted values.
937937
"""
938938
is_binned = getattr(self, "_in_fit", False)
939-
dtype = X_BINNED_DTYPE if is_binned else X_DTYPE
940-
X = self._validate_data(X, dtype=dtype, force_all_finite=False, reset=False)
939+
if not is_binned:
940+
X = self._validate_data(
941+
X, dtype=X_DTYPE, force_all_finite=False, reset=False
942+
)
941943
check_is_fitted(self)
942944
if X.shape[1] != self._n_features:
943945
raise ValueError(

sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import warnings
2+
13
import numpy as np
24
import pytest
35
from numpy.testing import assert_allclose, assert_array_equal
@@ -1141,3 +1143,20 @@ def test_loss_deprecated(old_loss, new_loss):
11411143
est2 = HistGradientBoostingRegressor(loss=new_loss, random_state=0)
11421144
est2.fit(X, y)
11431145
assert_allclose(est1.predict(X), est2.predict(X))
1146+
1147+
1148+
def test_no_user_warning_with_scoring():
1149+
"""Check that no UserWarning is raised when scoring is set.
1150+
1151+
Non-regression test for #22907.
1152+
"""
1153+
pd = pytest.importorskip("pandas")
1154+
X, y = make_regression(n_samples=50, random_state=0)
1155+
X_df = pd.DataFrame(X, columns=[f"col{i}" for i in range(X.shape[1])])
1156+
1157+
est = HistGradientBoostingRegressor(
1158+
random_state=0, scoring="neg_mean_absolute_error", early_stopping=True
1159+
)
1160+
with warnings.catch_warnings():
1161+
warnings.simplefilter("error", UserWarning)
1162+
est.fit(X_df, y)

0 commit comments

Comments
 (0)