-
-
Notifications
You must be signed in to change notification settings - Fork 26.2k
Closed
Description
The validation of the ax
parameter to plot_partial_dependence
is weirdly inconsistent.
plot_partial_dependence
only validates the length theax
if it's a list. Butplt.supblots
typically return an array.- the Display object checks for
len
, while it should probably check forsize
. Typically, if I pass in anax
with shape = (2, 2) and only plot 2 PDPs, the display will not complain. However if I ravel this ax to get a shape of (4,), the Display will error.
I think we need a little consistency on the validation.
CC @thomasjpfan
from sklearn.experimental import enable_hist_gradient_boosting
from sklearn.ensemble import HistGradientBoostingRegressor
from sklearn.datasets import make_regression
from sklearn.inspection import plot_partial_dependence
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
X, y = make_regression()
est = HistGradientBoostingRegressor()
est.fit(X, y)
fig, axes = plt.subplots(nrows=2, ncols=2, squeeze=False)
# uncomment this to get an error
# axes = np.ravel(axes)
plot_partial_dependence(est, X, features=[1, 2], ax=axes)
Metadata
Metadata
Assignees
Labels
No labels