From 96b3793c4063e7a310f7b4432dc35faa2d829e2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Est=C3=A8ve?= Date: Tue, 17 Jul 2018 18:44:36 +0200 Subject: [PATCH 1/3] Fix caveat in ignore_warnings. When the warning class is passed as a positional argument, the function is not called. In a test settings this can mean that the test is not run. --- sklearn/utils/testing.py | 12 +++++++++++- sklearn/utils/tests/test_testing.py | 17 +++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/sklearn/utils/testing.py b/sklearn/utils/testing.py index c67a314e2fc58..c90afa30369ca 100644 --- a/sklearn/utils/testing.py +++ b/sklearn/utils/testing.py @@ -275,6 +275,8 @@ def ignore_warnings(obj=None, category=Warning): Parameters ---------- + obj : callable or None + callable where you want to ignore the warnings. category : warning class, defaults to Warning. The category to filter. If Warning, all categories will be muted. @@ -290,7 +292,15 @@ def ignore_warnings(obj=None, category=Warning): >>> ignore_warnings(nasty_warn)() 42 """ - if callable(obj): + if isinstance(obj, type) and issubclass(obj, Warning): + # Avoid common pitfall of passing category as the first positional + # argument which result in the test not being run + raise ValueError( + "'obj' should be a callable where you want to ignore warnings. " + "You passed a warning class instead: 'obj={obj}'. If you want " + "to pass a warning class to ignore_warnings, you should use " + "'category={obj}'".format(obj=obj)) + elif callable(obj): return _IgnoreWarnings(category=category)(obj) else: return _IgnoreWarnings(category=category) diff --git a/sklearn/utils/tests/test_testing.py b/sklearn/utils/tests/test_testing.py index eb9512f177ed3..404161c7ede15 100644 --- a/sklearn/utils/tests/test_testing.py +++ b/sklearn/utils/tests/test_testing.py @@ -8,6 +8,8 @@ from scipy import sparse +import pytest + from sklearn.utils.deprecation import deprecated from sklearn.utils.metaestimators import if_delegate_has_method from sklearn.utils.testing import ( @@ -210,6 +212,21 @@ def context_manager_no_user_multiple_warning(): assert_warns(UserWarning, context_manager_no_deprecation_multiple_warning) assert_warns(DeprecationWarning, context_manager_no_user_multiple_warning) + # Check that passing warning class as first positional argument + warning_class = UserWarning + match = "'obj' should be a callable.+you should use 'category={}'".format( + warning_class) + + with pytest.raises(ValueError, match=match): + silence_warnings_func = ignore_warnings(warning_class)( + _warning_function) + silence_warnings_func() + + with pytest.raises(ValueError, match=match): + @ignore_warnings(warning_class) + def test(): + pass + class TestWarns(unittest.TestCase): def test_warn(self): From bc98f587fd6b667e42fe9843b9ee1054d2dc2ad4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Est=C3=A8ve?= Date: Tue, 17 Jul 2018 17:07:15 +0200 Subject: [PATCH 2/3] TST warning class should be passed as the category parameter. The first argument is the test callable. Doing it like this cause the test function to not be called --- sklearn/linear_model/tests/test_sgd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/linear_model/tests/test_sgd.py b/sklearn/linear_model/tests/test_sgd.py index ee1c473707187..15e29ce4d41df 100644 --- a/sklearn/linear_model/tests/test_sgd.py +++ b/sklearn/linear_model/tests/test_sgd.py @@ -319,7 +319,7 @@ def test_validation_set_not_used_for_training(self): assert_array_equal(clf1.coef_, clf2.coef_) - @ignore_warnings(ConvergenceWarning) + @ignore_warnings(category=ConvergenceWarning) def test_n_iter_no_change(self): # test that n_iter_ increases monotonically with n_iter_no_change for early_stopping in [True, False]: From c36f29dfed5d7c0e9f819f7da5aa2e9f3695f3a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Est=C3=A8ve?= Date: Tue, 17 Jul 2018 22:08:20 +0200 Subject: [PATCH 3/3] Use the warning class name in message --- sklearn/utils/testing.py | 8 +++++--- sklearn/utils/tests/test_testing.py | 3 +-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/sklearn/utils/testing.py b/sklearn/utils/testing.py index c90afa30369ca..bfae5d4662b1c 100644 --- a/sklearn/utils/testing.py +++ b/sklearn/utils/testing.py @@ -295,11 +295,13 @@ def ignore_warnings(obj=None, category=Warning): if isinstance(obj, type) and issubclass(obj, Warning): # Avoid common pitfall of passing category as the first positional # argument which result in the test not being run + warning_name = obj.__name__ raise ValueError( "'obj' should be a callable where you want to ignore warnings. " - "You passed a warning class instead: 'obj={obj}'. If you want " - "to pass a warning class to ignore_warnings, you should use " - "'category={obj}'".format(obj=obj)) + "You passed a warning class instead: 'obj={warning_name}'. " + "If you want to pass a warning class to ignore_warnings, " + "you should use 'category={warning_name}'".format( + warning_name=warning_name)) elif callable(obj): return _IgnoreWarnings(category=category)(obj) else: diff --git a/sklearn/utils/tests/test_testing.py b/sklearn/utils/tests/test_testing.py index 404161c7ede15..729b5ef81c684 100644 --- a/sklearn/utils/tests/test_testing.py +++ b/sklearn/utils/tests/test_testing.py @@ -214,8 +214,7 @@ def context_manager_no_user_multiple_warning(): # Check that passing warning class as first positional argument warning_class = UserWarning - match = "'obj' should be a callable.+you should use 'category={}'".format( - warning_class) + match = "'obj' should be a callable.+you should use 'category=UserWarning'" with pytest.raises(ValueError, match=match): silence_warnings_func = ignore_warnings(warning_class)(