Skip to content

[MRG] Added FutureWarning in sgd models for tol parameter #12399

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 24, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions sklearn/decomposition/tests/test_kernel_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ def test_kernel_pca_invalid_kernel():


@pytest.mark.filterwarnings('ignore: The default of the `iid`') # 0.22
# 0.23. warning about tol not having its correct default value.
@pytest.mark.filterwarnings('ignore:max_iter and tol parameters have been')
def test_gridsearch_pipeline():
# Test if we can do a grid-search to find parameters to separate
# circles with a perceptron model.
Expand All @@ -189,6 +191,8 @@ def test_gridsearch_pipeline():


@pytest.mark.filterwarnings('ignore: The default of the `iid`') # 0.22
# 0.23. warning about tol not having its correct default value.
@pytest.mark.filterwarnings('ignore:max_iter and tol parameters have been')
def test_gridsearch_pipeline_precomputed():
# Test if we can do a grid-search to find parameters to separate
# circles with a perceptron model using a precomputed kernel.
Expand All @@ -204,6 +208,8 @@ def test_gridsearch_pipeline_precomputed():
assert_equal(grid_search.best_score_, 1)


# 0.23. warning about tol not having its correct default value.
@pytest.mark.filterwarnings('ignore:max_iter and tol parameters have been')
def test_nested_circles():
# Test the linear separability of the first 2D KPCA transform
X, y = make_circles(n_samples=400, factor=.3, noise=.05,
Expand Down
8 changes: 8 additions & 0 deletions sklearn/feature_selection/tests/test_from_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
rng = np.random.RandomState(0)


# 0.23. warning about tol not having its correct default value.
@pytest.mark.filterwarnings('ignore:max_iter and tol parameters have been')
def test_invalid_input():
clf = SGDClassifier(alpha=0.1, max_iter=10, shuffle=True,
random_state=None, tol=None)
Expand Down Expand Up @@ -243,6 +245,8 @@ def test_2d_coef():


@pytest.mark.filterwarnings('ignore:The default value of n_estimators')
# 0.23. warning about tol not having its correct default value.
@pytest.mark.filterwarnings('ignore:max_iter and tol parameters have been')
def test_partial_fit():
est = PassiveAggressiveClassifier(random_state=0, shuffle=False,
max_iter=5, tol=None)
Expand Down Expand Up @@ -273,6 +277,8 @@ def test_calling_fit_reinitializes():
assert_equal(transformer.estimator_.C, 100)


# 0.23. warning about tol not having its correct default value.
@pytest.mark.filterwarnings('ignore:max_iter and tol parameters have been')
def test_prefit():
# Test all possible combinations of the prefit parameter.

Expand Down Expand Up @@ -311,6 +317,8 @@ def test_threshold_string():
assert_array_almost_equal(X_transform, data[:, mask])


# 0.23. warning about tol not having its correct default value.
@pytest.mark.filterwarnings('ignore:max_iter and tol parameters have been')
def test_threshold_without_refitting():
# Test that the threshold can be set without refitting the model.
clf = SGDClassifier(alpha=0.1, max_iter=10, shuffle=True,
Expand Down
12 changes: 6 additions & 6 deletions sklearn/kernel_approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,14 @@ class RBFSampler(BaseEstimator, TransformerMixin):
>>> y = [0, 0, 1, 1]
>>> rbf_feature = RBFSampler(gamma=1, random_state=1)
>>> X_features = rbf_feature.fit_transform(X)
>>> clf = SGDClassifier(max_iter=5)
>>> clf = SGDClassifier(max_iter=5, tol=1e-3)
>>> clf.fit(X_features, y)
... # doctest: +NORMALIZE_WHITESPACE
SGDClassifier(alpha=0.0001, average=False, class_weight=None,
early_stopping=False, epsilon=0.1, eta0=0.0, fit_intercept=True,
l1_ratio=0.15, learning_rate='optimal', loss='hinge', max_iter=5,
n_iter=None, n_iter_no_change=5, n_jobs=None, penalty='l2',
power_t=0.5, random_state=None, shuffle=True, tol=None,
power_t=0.5, random_state=None, shuffle=True, tol=0.001,
validation_fraction=0.1, verbose=0, warm_start=False)
>>> clf.score(X_features, y)
1.0
Expand Down Expand Up @@ -162,13 +162,13 @@ class SkewedChi2Sampler(BaseEstimator, TransformerMixin):
... n_components=10,
... random_state=0)
>>> X_features = chi2_feature.fit_transform(X, y)
>>> clf = SGDClassifier(max_iter=10)
>>> clf = SGDClassifier(max_iter=10, tol=1e-3)
>>> clf.fit(X_features, y)
SGDClassifier(alpha=0.0001, average=False, class_weight=None,
early_stopping=False, epsilon=0.1, eta0=0.0, fit_intercept=True,
l1_ratio=0.15, learning_rate='optimal', loss='hinge', max_iter=10,
n_iter=None, n_iter_no_change=5, n_jobs=None, penalty='l2',
power_t=0.5, random_state=None, shuffle=True, tol=None,
power_t=0.5, random_state=None, shuffle=True, tol=0.001,
validation_fraction=0.1, verbose=0, warm_start=False)
>>> clf.score(X_features, y)
1.0
Expand Down Expand Up @@ -282,13 +282,13 @@ class AdditiveChi2Sampler(BaseEstimator, TransformerMixin):
>>> X, y = load_digits(return_X_y=True)
>>> chi2sampler = AdditiveChi2Sampler(sample_steps=2)
>>> X_transformed = chi2sampler.fit_transform(X, y)
>>> clf = SGDClassifier(max_iter=5, random_state=0)
>>> clf = SGDClassifier(max_iter=5, random_state=0, tol=1e-3)
>>> clf.fit(X_transformed, y)
SGDClassifier(alpha=0.0001, average=False, class_weight=None,
early_stopping=False, epsilon=0.1, eta0=0.0, fit_intercept=True,
l1_ratio=0.15, learning_rate='optimal', loss='hinge', max_iter=5,
n_iter=None, n_iter_no_change=5, n_jobs=None, penalty='l2',
power_t=0.5, random_state=0, shuffle=True, tol=None,
power_t=0.5, random_state=0, shuffle=True, tol=0.001,
validation_fraction=0.1, verbose=0, warm_start=False)
>>> clf.score(X_transformed, y) # doctest: +ELLIPSIS
0.9543...
Expand Down
14 changes: 8 additions & 6 deletions sklearn/linear_model/passive_aggressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,17 +139,18 @@ class PassiveAggressiveClassifier(BaseSGDClassifier):
>>> from sklearn.datasets import make_classification

>>> X, y = make_classification(n_features=4, random_state=0)
>>> clf = PassiveAggressiveClassifier(max_iter=1000, random_state=0)
>>> clf = PassiveAggressiveClassifier(max_iter=1000, random_state=0,
... tol=1e-3)
>>> clf.fit(X, y)
PassiveAggressiveClassifier(C=1.0, average=False, class_weight=None,
early_stopping=False, fit_intercept=True, loss='hinge',
max_iter=1000, n_iter=None, n_iter_no_change=5, n_jobs=None,
random_state=0, shuffle=True, tol=None,
random_state=0, shuffle=True, tol=0.001,
validation_fraction=0.1, verbose=0, warm_start=False)
>>> print(clf.coef_)
[[0.29509834 0.33711843 0.56127352 0.60105546]]
[[-0.6543424 1.54603022 1.35361642 0.22199435]]
>>> print(clf.intercept_)
[2.54153383]
[0.63310933]
>>> print(clf.predict([[0, 0, 0, 0]]))
[1]

Expand Down Expand Up @@ -377,12 +378,13 @@ class PassiveAggressiveRegressor(BaseSGDRegressor):
>>> from sklearn.datasets import make_regression

>>> X, y = make_regression(n_features=4, random_state=0)
>>> regr = PassiveAggressiveRegressor(max_iter=100, random_state=0)
>>> regr = PassiveAggressiveRegressor(max_iter=100, random_state=0,
... tol=1e-3)
>>> regr.fit(X, y)
PassiveAggressiveRegressor(C=1.0, average=False, early_stopping=False,
epsilon=0.1, fit_intercept=True, loss='epsilon_insensitive',
max_iter=100, n_iter=None, n_iter_no_change=5,
random_state=0, shuffle=True, tol=None,
random_state=0, shuffle=True, tol=0.001,
validation_fraction=0.1, verbose=0, warm_start=False)
>>> print(regr.coef_)
[20.48736655 34.18818427 67.59122734 87.94731329]
Expand Down
22 changes: 18 additions & 4 deletions sklearn/linear_model/stochastic_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,20 @@ def _validate_params(self, set_max_iter=True, for_partial_fit=False):
# Before 0.19, default was n_iter=5
max_iter = 5
else:
if self.tol is None:
# max_iter was set, but tol wasn't. The docs / warning do not
# specify this case. In 0.20 tol would stay being None which
# is equivalent to -inf, but it will be changed to 1e-3 in
# 0.21. We warn users that the behaviour (and potentially
# their results) will change.
warnings.warn(
"max_iter and tol parameters have been added in %s in "
"0.19. If max_iter is set but tol is left unset, the "
"default value for tol in 0.19 and 0.20 will be None "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the semantics of None here are unclear. Is None equivalent to 0?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"is 0"?

"(which is equivalent to -infinity, so it has no effect) "
"but will change in 0.21 to 1e-3. Specify tol to "
"silence this warning." % type(self).__name__,
FutureWarning)
max_iter = self.max_iter if self.max_iter is not None else 1000
self._max_iter = max_iter

Expand Down Expand Up @@ -937,14 +951,14 @@ class SGDClassifier(BaseSGDClassifier):
>>> from sklearn import linear_model
>>> X = np.array([[-1, -1], [-2, -1], [1, 1], [2, 1]])
>>> Y = np.array([1, 1, 2, 2])
>>> clf = linear_model.SGDClassifier(max_iter=1000)
>>> clf = linear_model.SGDClassifier(max_iter=1000, tol=1e-3)
>>> clf.fit(X, Y)
... #doctest: +NORMALIZE_WHITESPACE
SGDClassifier(alpha=0.0001, average=False, class_weight=None,
early_stopping=False, epsilon=0.1, eta0=0.0, fit_intercept=True,
l1_ratio=0.15, learning_rate='optimal', loss='hinge', max_iter=1000,
n_iter=None, n_iter_no_change=5, n_jobs=None, penalty='l2',
power_t=0.5, random_state=None, shuffle=True, tol=None,
power_t=0.5, random_state=None, shuffle=True, tol=0.001,
validation_fraction=0.1, verbose=0, warm_start=False)

>>> print(clf.predict([[-0.8, -1]]))
Expand Down Expand Up @@ -1540,14 +1554,14 @@ class SGDRegressor(BaseSGDRegressor):
>>> np.random.seed(0)
>>> y = np.random.randn(n_samples)
>>> X = np.random.randn(n_samples, n_features)
>>> clf = linear_model.SGDRegressor(max_iter=1000)
>>> clf = linear_model.SGDRegressor(max_iter=1000, tol=1e-3)
>>> clf.fit(X, y)
... #doctest: +NORMALIZE_WHITESPACE
SGDRegressor(alpha=0.0001, average=False, early_stopping=False,
epsilon=0.1, eta0=0.01, fit_intercept=True, l1_ratio=0.15,
learning_rate='invscaling', loss='squared_loss', max_iter=1000,
n_iter=None, n_iter_no_change=5, penalty='l2', power_t=0.25,
random_state=None, shuffle=True, tol=None, validation_fraction=0.1,
random_state=None, shuffle=True, tol=0.001, validation_fraction=0.1,
verbose=0, warm_start=False)

See also
Expand Down
3 changes: 3 additions & 0 deletions sklearn/linear_model/tests/test_huber.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
from scipy import optimize, sparse
import pytest

from sklearn.utils.testing import assert_almost_equal
from sklearn.utils.testing import assert_array_equal
Expand Down Expand Up @@ -140,6 +141,8 @@ def test_huber_scaling_invariant():
assert_array_equal(n_outliers_mask_3, n_outliers_mask_1)


# 0.23. warning about tol not having its correct default value.
@pytest.mark.filterwarnings('ignore:max_iter and tol parameters have been')
def test_huber_and_sgd_same_results():
# Test they should converge to same coefficients for same parameters

Expand Down
24 changes: 24 additions & 0 deletions sklearn/linear_model/tests/test_passive_aggressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def project(self, X):
return np.dot(X, self.w) + self.b


# 0.23. warning about tol not having its correct default value.
@pytest.mark.filterwarnings('ignore:max_iter and tol parameters have been')
def test_classifier_accuracy():
for data in (X, X_csr):
for fit_intercept in (True, False):
Expand All @@ -86,6 +88,8 @@ def test_classifier_accuracy():
assert_true(hasattr(clf, 'standard_coef_'))


# 0.23. warning about tol not having its correct default value.
@pytest.mark.filterwarnings('ignore:max_iter and tol parameters have been')
def test_classifier_partial_fit():
classes = np.unique(y)
for data in (X, X_csr):
Expand All @@ -104,6 +108,8 @@ def test_classifier_partial_fit():
assert_true(hasattr(clf, 'standard_coef_'))


# 0.23. warning about tol not having its correct default value.
@pytest.mark.filterwarnings('ignore:max_iter and tol parameters have been')
def test_classifier_refit():
# Classifier can be retrained on different labels and features.
clf = PassiveAggressiveClassifier(max_iter=5).fit(X, y)
Expand All @@ -113,6 +119,8 @@ def test_classifier_refit():
assert_array_equal(clf.classes_, iris.target_names)


# 0.23. warning about tol not having its correct default value.
@pytest.mark.filterwarnings('ignore:max_iter and tol parameters have been')
@pytest.mark.parametrize('loss', ("hinge", "squared_hinge"))
def test_classifier_correctness(loss):
y_bin = y.copy()
Expand All @@ -137,6 +145,8 @@ def test_classifier_undefined_methods():
assert_raises(AttributeError, lambda x: getattr(clf, x), meth)


# 0.23. warning about tol not having its correct default value.
@pytest.mark.filterwarnings('ignore:max_iter and tol parameters have been')
def test_class_weights():
# Test class weights.
X2 = np.array([[-1.0, -1.0], [-1.0, 0], [-.8, -1.0],
Expand All @@ -159,12 +169,16 @@ def test_class_weights():
assert_array_equal(clf.predict([[0.2, -1.0]]), np.array([-1]))


# 0.23. warning about tol not having its correct default value.
@pytest.mark.filterwarnings('ignore:max_iter and tol parameters have been')
def test_partial_fit_weight_class_balanced():
# partial_fit with class_weight='balanced' not supported
clf = PassiveAggressiveClassifier(class_weight="balanced", max_iter=100)
assert_raises(ValueError, clf.partial_fit, X, y, classes=np.unique(y))


# 0.23. warning about tol not having its correct default value.
@pytest.mark.filterwarnings('ignore:max_iter and tol parameters have been')
def test_equal_class_weight():
X2 = [[1, 0], [1, 0], [0, 1], [0, 1]]
y2 = [0, 0, 1, 1]
Expand All @@ -186,6 +200,8 @@ def test_equal_class_weight():
assert_almost_equal(clf.coef_, clf_balanced.coef_, decimal=2)


# 0.23. warning about tol not having its correct default value.
@pytest.mark.filterwarnings('ignore:max_iter and tol parameters have been')
def test_wrong_class_weight_label():
# ValueError due to wrong class_weight label.
X2 = np.array([[-1.0, -1.0], [-1.0, 0], [-.8, -1.0],
Expand All @@ -196,6 +212,8 @@ def test_wrong_class_weight_label():
assert_raises(ValueError, clf.fit, X2, y2)


# 0.23. warning about tol not having its correct default value.
@pytest.mark.filterwarnings('ignore:max_iter and tol parameters have been')
def test_wrong_class_weight_format():
# ValueError due to wrong class_weight argument type.
X2 = np.array([[-1.0, -1.0], [-1.0, 0], [-.8, -1.0],
Expand All @@ -209,6 +227,8 @@ def test_wrong_class_weight_format():
assert_raises(ValueError, clf.fit, X2, y2)


# 0.23. warning about tol not having its correct default value.
@pytest.mark.filterwarnings('ignore:max_iter and tol parameters have been')
def test_regressor_mse():
y_bin = y.copy()
y_bin[y != 1] = -1
Expand All @@ -229,6 +249,8 @@ def test_regressor_mse():
assert_true(hasattr(reg, 'standard_coef_'))


# 0.23. warning about tol not having its correct default value.
@pytest.mark.filterwarnings('ignore:max_iter and tol parameters have been')
def test_regressor_partial_fit():
y_bin = y.copy()
y_bin[y != 1] = -1
Expand All @@ -249,6 +271,8 @@ def test_regressor_partial_fit():
assert_true(hasattr(reg, 'standard_coef_'))


# 0.23. warning about tol not having its correct default value.
@pytest.mark.filterwarnings('ignore:max_iter and tol parameters have been')
@pytest.mark.parametrize(
'loss',
("epsilon_insensitive", "squared_epsilon_insensitive"))
Expand Down
5 changes: 5 additions & 0 deletions sklearn/linear_model/tests/test_perceptron.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import scipy.sparse as sp
import pytest

from sklearn.utils.testing import assert_array_almost_equal
from sklearn.utils.testing import assert_greater
Expand Down Expand Up @@ -43,6 +44,8 @@ def predict(self, X):
return np.sign(self.project(X))


# 0.23. warning about tol not having its correct default value.
@pytest.mark.filterwarnings('ignore:max_iter and tol parameters have been')
def test_perceptron_accuracy():
for data in (X, X_csr):
clf = Perceptron(max_iter=100, tol=None, shuffle=False)
Expand All @@ -51,6 +54,8 @@ def test_perceptron_accuracy():
assert_greater(score, 0.7)


# 0.23. warning about tol not having its correct default value.
@pytest.mark.filterwarnings('ignore:max_iter and tol parameters have been')
def test_perceptron_correctness():
y_bin = y.copy()
y_bin[y != 1] = -1
Expand Down
12 changes: 10 additions & 2 deletions sklearn/linear_model/tests/test_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@
from sklearn.model_selection import RandomizedSearchCV


# 0.23. warning about tol not having its correct default value.
pytestmark = pytest.mark.filterwarnings(
"ignore:max_iter and tol parameters have been")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was so many warnings in this file that I globally ignored them, for the whole module.



class SparseSGDClassifier(SGDClassifier):

def fit(self, X, y, *args, **kw):
Expand Down Expand Up @@ -1330,8 +1335,11 @@ def init(max_iter=None, tol=None, n_iter=None, for_partial_fit=False):
msg_deprecation = "n_iter parameter is deprecated"
assert_warns_message(DeprecationWarning, msg_deprecation, init, 6, 0, 5)

# When n_iter=None, and at least one of tol and max_iter is specified
assert_no_warnings(init, 100, None, None)
# When n_iter=None and max_iter is specified but tol=None
msg_changed = "If max_iter is set but tol is left unset"
assert_warns_message(FutureWarning, msg_changed, init, 100, None, None)

# When n_iter=None and tol is specified
assert_no_warnings(init, None, 1e-3, None)
assert_no_warnings(init, 100, 1e-3, None)

Expand Down
4 changes: 4 additions & 0 deletions sklearn/model_selection/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,6 +1100,8 @@ def test_learning_curve_incremental_learning_unsupervised():
np.linspace(0.1, 1.0, 10))


# 0.23. warning about tol not having its correct default value.
@pytest.mark.filterwarnings('ignore:max_iter and tol parameters have been')
def test_learning_curve_batch_and_incremental_learning_are_equal():
X, y = make_classification(n_samples=30, n_features=1, n_informative=1,
n_redundant=0, n_classes=2,
Expand Down Expand Up @@ -1167,6 +1169,8 @@ def test_learning_curve_with_boolean_indices():
np.linspace(0.1, 1.0, 10))


# 0.23. warning about tol not having its correct default value.
@pytest.mark.filterwarnings('ignore:max_iter and tol parameters have been')
def test_learning_curve_with_shuffle():
# Following test case was designed this way to verify the code
# changes made in pull request: #7506.
Expand Down
Loading