Skip to content

GridSearchCV with SDGClassifier as estimator throw error when calling predict_proba #7155

Closed
@yl565

Description

@yl565

Description

GridSearchCV with SDGClassifier as estimator throw error when calling predict_proba

Steps/Code to Reproduce

import numpy as np
from sklearn import grid_search
from sklearn.linear_model import SGDClassifier
from sklearn.pipeline import make_pipeline
from numpy.random import random, random_integers

X = random([300,1000])
y = random_integers(0, 1, [300, ])
param_grid = {
    'loss': ['log'],
    'penalty': ['elasticnet'],
    'alpha': [10 ** x for x in range(-6, 1)],
    'l1_ratio': [0, 0.05, 0.1, 0.2, 0.5, 0.8, 0.9, 0.95, 1],
}
clf = SGDClassifier(random_state=0, class_weight='balanced')

clf_grid = grid_search.GridSearchCV(estimator=clf, param_grid=param_grid,
                                    n_jobs=-1, scoring='roc_auc')

clf_grid.fit(X=X, y=y)
clf_grid.predict_proba(X)

Error message:

  File "<ipython-input-1-c4ceb6ddb8c1>", line 1, in <module>
    runfile('/home/yichuanliu/Dropbox/Python/Cognoma/test.py', wdir='/home/yichuanliu/Dropbox/Python/Cognoma')

  File "/home/yichuanliu/Programs/anaconda3/lib/python3.5/site-packages/spyderlib/widgets/externalshell/sitecustomize.py", line 699, in runfile
    execfile(filename, namespace)

  File "/home/yichuanliu/Programs/anaconda3/lib/python3.5/site-packages/spyderlib/widgets/externalshell/sitecustomize.py", line 88, in execfile
    exec(compile(open(filename, 'rb').read(), filename, 'exec'), namespace)

  File "/home/yichuanliu/Dropbox/Python/Cognoma/test.py", line 28, in <module>
    clf_grid.predict_proba(X)

  File "/home/yichuanliu/Programs/anaconda3/lib/python3.5/site-packages/sklearn/utils/metaestimators.py", line 35, in __get__
    self.get_attribute(obj)

  File "/home/yichuanliu/Programs/anaconda3/lib/python3.5/site-packages/sklearn/linear_model/stochastic_gradient.py", line 756, in predict_proba
    self._check_proba()

  File "/home/yichuanliu/Programs/anaconda3/lib/python3.5/site-packages/sklearn/linear_model/stochastic_gradient.py", line 721, in _check_proba
    " loss=%r" % self.loss)

AttributeError: probability estimates are not available for loss='hinge'

This is weird because the best_estimator_ clearly has loss 'log':

clf_grid.best_estimator_
Out[4]: 
SGDClassifier(alpha=0.1, average=False, class_weight='balanced', epsilon=0.1,
       eta0=0.0, fit_intercept=True, l1_ratio=0.1, learning_rate='optimal',
       loss='log', n_iter=5, n_jobs=1, penalty='elasticnet', power_t=0.5,
       random_state=0, shuffle=True, verbose=0, warm_start=False)

If loss is not specified in param_grid everything works as expected:

param_grid = {
    'alpha': [10 ** x for x in range(-6, 1)],
    'l1_ratio': [0, 0.05, 0.1, 0.2, 0.5, 0.8, 0.9, 0.95, 1],
}
clf = SGDClassifier(random_state=0, class_weight='balanced',
                    loss='log', penalty='elasticnet')

clf_grid = grid_search.GridSearchCV(estimator=clf, param_grid=param_grid,
                                    n_jobs=-1, scoring='roc_auc')
clf_grid.fit(X=X, y=y)
clf_grid.predict_proba(X)

The problem goes away if removing the property decorator of class SGDClassifier:
Original:

class SGDClassifier(BaseSGDClassifier, _LearntSelectorMixin):
...
    @property
    def predict_proba(self):
        """Probability estimates.
        ...
        """
        self._check_proba()
        return self._predict_proba

New:

class SGDClassifier(BaseSGDClassifier, _LearntSelectorMixin):
...
    def predict_proba(self, X):
        """Probability estimates.
        ...
        """
        self._check_proba()
        return self._predict_proba(X)

Versions

Linux-3.19.0-64-generic-x86_64-with-debian-jessie-sid
Python 3.5.2 |Anaconda 2.4.1 (64-bit)| (default, Jul 2 2016, 17:53:06)
[GCC 4.4.7 20120313 (Red Hat 4.4.7-1)]
NumPy 1.11.0
SciPy 0.17.0
Scikit-Learn 0.18.dev0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions