Skip to content

[MRG+1] FIX Make sure GridSearchCV and RandomizedSearchCV are pickle-able #7594

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Oct 10, 2016

Conversation

raghavrv
Copy link
Member

@raghavrv raghavrv commented Oct 6, 2016

Fixes #7562

  • Subclasses the np.ma.MaskedArray and overrides the __getstate__ to make obj dtyped MaskedArrays pickle-able.
  • Uses this fixed utils.fixes.MaskedArray inside gs.cv_results_...

This is based off of numpy/numpy#8122

Please review @jnothman @amueller @GaelVaroquaux @davechallis

@raghavrv raghavrv added this to the 0.18.1 milestone Oct 6, 2016
@raghavrv raghavrv changed the title [MRG] FIX Make sure GridSearchCV and RandomizedSearchCV are picke-able [MRG] FIX Make sure GridSearchCV and RandomizedSearchCV are pickle-able Oct 6, 2016
@raghavrv
Copy link
Member Author

raghavrv commented Oct 6, 2016

[The flake8 test will fail because of the utils.fixes import]

@lesteve
Copy link
Member

lesteve commented Oct 7, 2016

[The flake8 test will fail because of the utils.fixes import]

If you really care you can do something like this to silence the warning:

from numpy.ma import MaskedArray  # noqa

Or we can just merge like this obviously.

Copy link
Member

@lesteve lesteve left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Besides the minor comment, LGTM


if np_version < (1, 12, 0):
class MaskedArray(np.ma.MaskedArray):
# Before numpy 1.12, np.ma.MaskedArray object is not pickle-able
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would replace pickle-able by picklable which is an acceptable adjective I think (it is definitely already mentioned a few times in the source code).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

@raghavrv
Copy link
Member Author

raghavrv commented Oct 7, 2016

from numpy.ma import MaskedArray # noqa

Ah... Cool. I recall seeing this somewhere!

@raghavrv
Copy link
Member Author

raghavrv commented Oct 7, 2016

Besides the minor comment, LGTM

Thanks for the review @lesteve!!

Copy link
Member

@amueller amueller left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check whether the pickle tests in estimator_checks.py also load?


random_search = RandomizedSearchCV(clf, {'foo_param': [1, 2, 3]},
refit=True, n_iter=3)
random_search.fit(X, y)
pickle.dumps(random_search) # smoke test
pickle.loads(pickle.dumps(random_search))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe call predict afterwards?

@@ -940,12 +940,12 @@ def test_pickle():
clf = MockClassifier()
grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}, refit=True)
grid_search.fit(X, y)
pickle.dumps(grid_search) # smoke test
pickle.loads(pickle.dumps(grid_search))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe call predict afterwards?

@raghavrv raghavrv force-pushed the pickleable_masked_array branch from 5596417 to 82e4d0a Compare October 10, 2016 05:14
@raghavrv
Copy link
Member Author

@amueller Thanks for the review. Have rebased and addressed your comment... The meta tests don't seem to test pickle :/ I'll confirm tomorrow...

@raghavrv raghavrv changed the title [MRG] FIX Make sure GridSearchCV and RandomizedSearchCV are pickle-able [MRG+1] FIX Make sure GridSearchCV and RandomizedSearchCV are pickle-able Oct 10, 2016
@amueller
Copy link
Member

LGTM

@amueller amueller merged commit 868a58b into scikit-learn:master Oct 10, 2016
@raghavrv raghavrv deleted the pickleable_masked_array branch October 10, 2016 19:36
amueller added a commit to amueller/scikit-learn that referenced this pull request Oct 14, 2016
…able (scikit-learn#7594)

* FIX Subclass a new MaskedArray which allows pickling even when dype=object

* TST unpickling too

* FIX Use MaskedArray from utils.fixes rather than from numpy

* FIX imports

* Don't assign a variable

* FIX np --> numpy

* Use tostring instead of tobytes for old numpy

* COSMIT pickle-able --> picklable

* use #noqa comment to turn off flake8

* TST/ENH Check if the pickled est's predict matches with the original one's

# Conflicts:
#	sklearn/utils/tests/test_fixes.py
@amueller
Copy link
Member

needs a whatsnew :-/

Sundrique pushed a commit to Sundrique/scikit-learn that referenced this pull request Jun 14, 2017
…able (scikit-learn#7594)

* FIX Subclass a new MaskedArray which allows pickling even when dype=object

* TST unpickling too

* FIX Use MaskedArray from utils.fixes rather than from numpy

* FIX imports

* Don't assign a variable

* FIX np --> numpy

* Use tostring instead of tobytes for old numpy

* COSMIT pickle-able --> picklable

* use #noqa comment to turn off flake8

* TST/ENH Check if the pickled est's predict matches with the original one's
paulha pushed a commit to paulha/scikit-learn that referenced this pull request Aug 19, 2017
…able (scikit-learn#7594)

* FIX Subclass a new MaskedArray which allows pickling even when dype=object

* TST unpickling too

* FIX Use MaskedArray from utils.fixes rather than from numpy

* FIX imports

* Don't assign a variable

* FIX np --> numpy

* Use tostring instead of tobytes for old numpy

* COSMIT pickle-able --> picklable

* use #noqa comment to turn off flake8

* TST/ENH Check if the pickled est's predict matches with the original one's
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Error unpickling RandomizedSearchCV objects in 0.18 due to masked arrays
3 participants