diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index 21f37f3e28f98..15e8e15e87569 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -294,9 +294,16 @@ def from_estimator( np.linspace(x0_min, x0_max, grid_resolution), np.linspace(x1_min, x1_max, grid_resolution), ) + if hasattr(X, "iloc"): + # we need to preserve the feature names and therefore get an empty dataframe + X_grid = X.iloc[[], :].copy() + X_grid.iloc[:, 0] = xx0.ravel() + X_grid.iloc[:, 1] = xx1.ravel() + else: + X_grid = np.c_[xx0.ravel(), xx1.ravel()] pred_func = _check_boundary_response_method(estimator, response_method) - response = pred_func(np.c_[xx0.ravel(), xx1.ravel()]) + response = pred_func(X_grid) # convert classes predictions into integers if pred_func.__name__ == "predict" and hasattr(estimator, "classes_"): diff --git a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py index 955deb33331d6..786c57571864f 100644 --- a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py +++ b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py @@ -1,3 +1,5 @@ +import warnings + import pytest import numpy as np from numpy.testing import assert_allclose @@ -265,6 +267,11 @@ def test_multioutput_regressor_error(pyplot): DecisionBoundaryDisplay.from_estimator(tree, X) +@pytest.mark.filterwarnings( + # We expect to raise the following warning because the classifier is fit on a + # NumPy array + "ignore:X has feature names, but LogisticRegression was fitted without" +) def test_dataframe_labels_used(pyplot, fitted_clf): """Check that column names are used for pandas.""" pd = pytest.importorskip("pandas") @@ -319,3 +326,20 @@ def test_string_target(pyplot): grid_resolution=5, response_method="predict", ) + + +def test_dataframe_support(): + """Check that passing a dataframe at fit and to the Display does not + raise warnings. + + Non-regression test for: + https://github.com/scikit-learn/scikit-learn/issues/23311 + """ + pd = pytest.importorskip("pandas") + df = pd.DataFrame(X, columns=["col_x", "col_y"]) + estimator = LogisticRegression().fit(df, y) + + with warnings.catch_warnings(): + # no warnings linked to feature names validation should be raised + warnings.simplefilter("error", UserWarning) + DecisionBoundaryDisplay.from_estimator(estimator, df, response_method="predict")