Skip to content

Conversation

siftikha
Copy link
Contributor

@siftikha siftikha commented Nov 1, 2017

Fixes #10051

Added an optional argument for_partial_fit=False to _validate_params which bypasses warnings about max_iter and tol

added for_partial_fit=True to _validate_params calls in partial_fit methods for both sgdclassifer and sgdregressor

I appreciate that this could have been done through the use of set_max_iter=False but it seemed clearer to me to have a dedicated flag.

first contribution to this project, so I apologize if I've done something horrifically wrong

@jnothman
Copy link
Member

jnothman commented Nov 1, 2017

You have flake8 errors

@siftikha
Copy link
Contributor Author

siftikha commented Nov 1, 2017

flake8 issues have been fixed

@@ -538,7 +539,7 @@ def partial_fit(self, X, y, classes=None, sample_weight=None):
-------
self : returns an instance of self.
"""
self._validate_params()
self._validate_params(for_partial_fit=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think using set_max_iter=False should suffice

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I will make that change.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This hasn't been addressed and it seems cleaner not to add another parameter if it's possible.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was attempted, but was not trivial. We could have a different value for the parameter, but I figure it's hardly worth the bother for a deprecation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was attempted, but was not trivial. We could have a different value for the parameter, but I figure it's hardly worth the bother for a deprecation.

@@ -1211,6 +1211,9 @@ def init(max_iter=None, tol=None, n_iter=None):
assert_no_warnings(init, None, 1e-3, None)
assert_no_warnings(init, 100, 1e-3, None)

# Test that for_partial_fit will not throw warnings for max_iter or tol
assert_no_warnings(init, None, None, None, True)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not test if partial_fit itself raises a warning. Testing the helper is a bit strange, even if it seems that's already what we're doing...

@siftikha
Copy link
Contributor Author

siftikha commented Nov 2, 2017

@jnothman the set_max_iter approach appears to have caused other issues. It appears to me that making it work with partial_fit would require non-trivial modifications to other aspects of the library. So I've gone back to the for_partial_fit approach.

@siftikha siftikha force-pushed the partial_max_iter_tol branch from e255cb7 to 9b624ea Compare November 2, 2017 02:34
@jnothman
Copy link
Member

jnothman commented Nov 2, 2017

What other issues?

@siftikha
Copy link
Contributor Author

siftikha commented Nov 2, 2017

https://travis-ci.org/scikit-learn/scikit-learn/jobs/296053432 This is the result of just using the set_max_iter=False flag with everything else unchanged. The use case for that flag seems to be different than what we are aiming for here.

The issue seems to be that set_max_iter=False also keeps self._tol from being set which causes issues when partial_fit is run. You could conceivably make sure that self._tol is set even with set_max_iter=False but this seems like it would likely break the other cases that rely on set_max_iter=False

@jnothman
Copy link
Member

jnothman commented Nov 2, 2017 via email

@qinhanmin2014 qinhanmin2014 mentioned this pull request Nov 6, 2017
@siftikha
Copy link
Contributor Author

siftikha commented Nov 6, 2017

If there is a sufficiently compelling reason to stick to using one parameter for both these cases, I'm happy to do that but I think such a parameter would need to be renamed from set_max_iter to some more general parameter describing the validation required.

@jnothman
Copy link
Member

jnothman commented Nov 6, 2017 via email

@siftikha
Copy link
Contributor Author

siftikha commented Nov 9, 2017

Do I need to add [MRG] to the title for this to be merged?

@siftikha siftikha changed the title fix for erroneous max_iter and tol warnings for SGDClassifier when using partial_fit [MRG] fix for erroneous max_iter and tol warnings for SGDClassifier when using partial_fit Nov 9, 2017
@siftikha
Copy link
Contributor Author

@jnothman anything else you need to get this merged?

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. We usually require approval from two core devs

@jnothman jnothman changed the title [MRG] fix for erroneous max_iter and tol warnings for SGDClassifier when using partial_fit [MRG+1] fix for erroneous max_iter and tol warnings for SGDClassifier when using partial_fit Nov 13, 2017
@siftikha
Copy link
Contributor Author

@amueller Figured I'd ping you, since you made the initial issue

@amueller
Copy link
Member

lgtm.

@amueller amueller merged commit f485a9e into scikit-learn:master Nov 15, 2017
@jnothman
Copy link
Member

Thanks for contributing, @siftikha!

@amueller
Copy link
Member

Hm this undid #10050 :-/

jwjohnson314 pushed a commit to jwjohnson314/scikit-learn that referenced this pull request Dec 18, 2017
… when using partial_fit (scikit-learn#10053)

* partial fit warnings disabled

* partial fit warnings disabled for regressor

* style improved

* tests added

* pycodestyle passing

* rejiggered format

* fixed style issues
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants