-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG+1] FIX Make sure GridSearchCV and RandomizedSearchCV are pickle-able #7594
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG+1] FIX Make sure GridSearchCV and RandomizedSearchCV are pickle-able #7594
Conversation
[The flake8 test will fail because of the utils.fixes import] |
If you really care you can do something like this to silence the warning:
Or we can just merge like this obviously. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Besides the minor comment, LGTM
|
||
if np_version < (1, 12, 0): | ||
class MaskedArray(np.ma.MaskedArray): | ||
# Before numpy 1.12, np.ma.MaskedArray object is not pickle-able |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would replace pickle-able by picklable which is an acceptable adjective I think (it is definitely already mentioned a few times in the source code).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
Ah... Cool. I recall seeing this somewhere! |
Thanks for the review @lesteve!! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you check whether the pickle tests in estimator_checks.py also load?
|
||
random_search = RandomizedSearchCV(clf, {'foo_param': [1, 2, 3]}, | ||
refit=True, n_iter=3) | ||
random_search.fit(X, y) | ||
pickle.dumps(random_search) # smoke test | ||
pickle.loads(pickle.dumps(random_search)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe call predict afterwards?
@@ -940,12 +940,12 @@ def test_pickle(): | |||
clf = MockClassifier() | |||
grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}, refit=True) | |||
grid_search.fit(X, y) | |||
pickle.dumps(grid_search) # smoke test | |||
pickle.loads(pickle.dumps(grid_search)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe call predict afterwards?
5596417
to
82e4d0a
Compare
@amueller Thanks for the review. Have rebased and addressed your comment... The meta tests don't seem to test pickle :/ I'll confirm tomorrow... |
LGTM |
…able (scikit-learn#7594) * FIX Subclass a new MaskedArray which allows pickling even when dype=object * TST unpickling too * FIX Use MaskedArray from utils.fixes rather than from numpy * FIX imports * Don't assign a variable * FIX np --> numpy * Use tostring instead of tobytes for old numpy * COSMIT pickle-able --> picklable * use #noqa comment to turn off flake8 * TST/ENH Check if the pickled est's predict matches with the original one's # Conflicts: # sklearn/utils/tests/test_fixes.py
needs a whatsnew :-/ |
…able (scikit-learn#7594) * FIX Subclass a new MaskedArray which allows pickling even when dype=object * TST unpickling too * FIX Use MaskedArray from utils.fixes rather than from numpy * FIX imports * Don't assign a variable * FIX np --> numpy * Use tostring instead of tobytes for old numpy * COSMIT pickle-able --> picklable * use #noqa comment to turn off flake8 * TST/ENH Check if the pickled est's predict matches with the original one's
…able (scikit-learn#7594) * FIX Subclass a new MaskedArray which allows pickling even when dype=object * TST unpickling too * FIX Use MaskedArray from utils.fixes rather than from numpy * FIX imports * Don't assign a variable * FIX np --> numpy * Use tostring instead of tobytes for old numpy * COSMIT pickle-able --> picklable * use #noqa comment to turn off flake8 * TST/ENH Check if the pickled est's predict matches with the original one's
Fixes #7562
np.ma.MaskedArray
and overrides the__getstate__
to make obj dtypedMaskedArray
s pickle-able.utils.fixes.MaskedArray
insidegs.cv_results_
...This is based off of numpy/numpy#8122
Please review @jnothman @amueller @GaelVaroquaux @davechallis