Skip to content

Commit f485a9e

Browse files
siftikhaamueller
authored andcommitted
[MRG+1] fix for erroneous max_iter and tol warnings for SGDClassifier when using partial_fit (#10053)
* partial fit warnings disabled * partial fit warnings disabled for regressor * style improved * tests added * pycodestyle passing * rejiggered format * fixed style issues
1 parent 653de6c commit f485a9e

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed

sklearn/linear_model/stochastic_gradient.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def set_params(self, *args, **kwargs):
8282
def fit(self, X, y):
8383
"""Fit model."""
8484

85-
def _validate_params(self, set_max_iter=True):
85+
def _validate_params(self, set_max_iter=True, for_partial_fit=False):
8686
"""Validate input params. """
8787
if not isinstance(self.shuffle, bool):
8888
raise ValueError("shuffle must be either True or False")
@@ -120,14 +120,15 @@ def _validate_params(self, set_max_iter=True):
120120
self._tol = None
121121

122122
elif self.tol is None and self.max_iter is None:
123-
warnings.warn(
124-
"max_iter and tol parameters have been added in %s in 0.19. If"
125-
" both are left unset, they default to max_iter=5 and tol=None"
126-
". If tol is not None, max_iter defaults to max_iter=1000. "
127-
"From 0.21, default max_iter will be 1000, "
128-
"and default tol will be 1e-3." % type(self).__name__,
129-
FutureWarning)
130-
# Before 0.19, default was n_iter=5
123+
if not for_partial_fit:
124+
warnings.warn(
125+
"max_iter and tol parameters have been "
126+
"added in %s in 0.19. If both are left unset, "
127+
"they default to max_iter=5 and tol=None. "
128+
"If tol is not None, max_iter defaults to max_iter=1000. "
129+
"From 0.21, default max_iter will be 1000, and"
130+
" default tol will be 1e-3." % type(self), FutureWarning)
131+
# Before 0.19, default was n_iter=5
131132
max_iter = 5
132133
else:
133134
max_iter = self.max_iter if self.max_iter is not None else 1000
@@ -539,7 +540,7 @@ def partial_fit(self, X, y, classes=None, sample_weight=None):
539540
-------
540541
self : returns an instance of self.
541542
"""
542-
self._validate_params()
543+
self._validate_params(for_partial_fit=True)
543544
if self.class_weight in ['balanced']:
544545
raise ValueError("class_weight '{0}' is not supported for "
545546
"partial_fit. In order to use 'balanced' weights,"
@@ -984,7 +985,7 @@ def partial_fit(self, X, y, sample_weight=None):
984985
-------
985986
self : returns an instance of self.
986987
"""
987-
self._validate_params()
988+
self._validate_params(for_partial_fit=True)
988989
return self._partial_fit(X, y, self.alpha, C=1.0,
989990
loss=self.loss,
990991
learning_rate=self.learning_rate, max_iter=1,

sklearn/linear_model/tests/test_sgd.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1194,9 +1194,9 @@ def test_tol_parameter():
11941194
def test_future_and_deprecation_warnings():
11951195
# Test that warnings are raised. Will be removed in 0.21
11961196

1197-
def init(max_iter=None, tol=None, n_iter=None):
1197+
def init(max_iter=None, tol=None, n_iter=None, for_partial_fit=False):
11981198
sgd = SGDClassifier(max_iter=max_iter, tol=tol, n_iter=n_iter)
1199-
sgd._validate_params()
1199+
sgd._validate_params(for_partial_fit=for_partial_fit)
12001200

12011201
# When all default values are used
12021202
msg_future = "max_iter and tol parameters have been added in "
@@ -1211,6 +1211,9 @@ def init(max_iter=None, tol=None, n_iter=None):
12111211
assert_no_warnings(init, None, 1e-3, None)
12121212
assert_no_warnings(init, 100, 1e-3, None)
12131213

1214+
# Test that for_partial_fit will not throw warnings for max_iter or tol
1215+
assert_no_warnings(init, None, None, None, True)
1216+
12141217

12151218
@ignore_warnings(category=(DeprecationWarning, FutureWarning))
12161219
def test_tol_and_max_iter_default_values():

0 commit comments

Comments
 (0)