Skip to content

Issues with the validation of ax in plot_partial_dependence #15757

@NicolasHug

Description

@NicolasHug

The validation of the ax parameter to plot_partial_dependence is weirdly inconsistent.

  • plot_partial_dependence only validates the length the ax if it's a list. But plt.supblots typically return an array.
  • the Display object checks for len, while it should probably check for size. Typically, if I pass in an ax with shape = (2, 2) and only plot 2 PDPs, the display will not complain. However if I ravel this ax to get a shape of (4,), the Display will error.

I think we need a little consistency on the validation.

CC @thomasjpfan

from sklearn.experimental import enable_hist_gradient_boosting
from sklearn.ensemble import HistGradientBoostingRegressor
from sklearn.datasets import make_regression
from sklearn.inspection import plot_partial_dependence
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline


X, y = make_regression()
est = HistGradientBoostingRegressor()
est.fit(X, y)

fig, axes = plt.subplots(nrows=2, ncols=2, squeeze=False)

# uncomment this to get an error
# axes = np.ravel(axes)

plot_partial_dependence(est, X, features=[1, 2], ax=axes)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions