Description
Description
In lines 300-310 of ensemble/partial_dependence, an error message is incorrectly specified:
try:
for fxs in features:
l = []
# explicit loop so "i" is bound for exception below
for i in fxs:
l.append(feature_names[i])
names.append(l)
except IndexError:
raise ValueError('features[i] must be in [0, n_features) '
'but was %d' % i)
Note that the index error can only be thrown by feature_names[i]
, not features[i]
as the error claims. Initially, I expected the feature_names
array to correspond to thefeatures
array (in particular, that they'd be the same length). Since I was only interested in partial dependence plots for 10 features (out of 60 total), both of the arrays I passed had length 10. I ran into the above error because feature_names
did not have length 60. The wording of the error caused me to look at features[i]
, which did not help resolve the issue.
The error should clearly indicate that the index error is caused by feature_names[i]
to save debugging time.
Steps/Code to Reproduce
You can run the following code to see the error:
import numpy as np
from sklearn.ensemble.partial_dependence import plot_partial_dependence
from sklearn.ensemble import GradientBoostingClassifier
X = np.arange(0,24,1).reshape((8,3))
y = np.ones(8)
model = GradientBoostingClassifier()
y[0] = 0
model.fit(X,y)
features = [1,2]
feature_names = ['IQ', 'Athletic Ability']
pdp_fig = plot_partial_dependence(model, X, features = features, feature_names = feature_names)
Again, note that the problem is caused by the length of the feature_names array, not the features array.
Expected Results
I'd expect the error to correctly indicate which parameter is causing the index error.
Actual Results
Instead, the error reads features[i] must be in [0, n_features) but was 2
. In this case n_features == X.shape[1]
which is 3. The error claims that 2 is not in [0,3)
, making users question their sanity.
Versions
Darwin-15.5.0-x86_64-i386-64bit
('Python', '2.7.12 |Anaconda custom (x86_64)| (default, Jul 2 2016, 17:43:17) \n[GCC 4.2.1 (Based on Apple Inc. build 5658) (LLVM build 2336.11.00)]')
('NumPy', '1.11.1')
('SciPy', '0.17.1')
('Scikit-Learn', '0.17.1')