Skip to content

Commit aeabdee

Browse files
committed
Modification of the ignore_warning function and _IgnoreWarning class.
1 parent a6bcb0b commit aeabdee

File tree

3 files changed

+147
-56
lines changed

3 files changed

+147
-56
lines changed

doc/whats_new.rst

+11
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ Enhancements
131131
- Add option to show ``indicator features`` in the output of Imputer.
132132
By `Mani Teja`_.
133133

134+
- The :func: `ignore_warnings` now accept a category argument to ignore only
135+
the warnings of a specified type. By `Thierry Guillemot`_.
136+
134137
Bug fixes
135138
.........
136139

@@ -198,6 +201,12 @@ API changes summary
198201
- Access to public attributes ``.X_`` and ``.y_`` has been deprecated in
199202
:class:`isotonic.IsotonicRegression`. By `Jonathan Arfa`_.
200203

204+
- The old :class:`GMM` is deprecated in favor of the new
205+
:class:`GaussianMixture`. The new class compute the Gaussian mixture
206+
faster than before and some of computationnal problems have been solved.
207+
By `Thierry Guillemot`_.
208+
209+
201210

202211
.. _changes_0_17_1:
203212

@@ -4148,3 +4157,5 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson.
41484157
.. _JPFrancoia: https://github.com/JPFrancoia
41494158

41504159
.. _Mani Teja: https://github.com/maniteja123
4160+
4161+
.. _Thierry Guillemot: https://github.com/tguillemot

sklearn/utils/testing.py

+42-55
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# Arnaud Joly
99
# Denis Engemann
1010
# Giorgio Patrini
11+
# Thierry Guillemot
1112
# License: BSD 3 clause
1213
import os
1314
import inspect
@@ -93,8 +94,7 @@ def assert_not_in(x, container):
9394
# for Python 2
9495
def assert_raises_regex(expected_exception, expected_regexp,
9596
callable_obj=None, *args, **kwargs):
96-
"""Helper function to check for message patterns in exceptions"""
97-
97+
"""Helper function to check for message patterns in exceptions."""
9898
not_raised = False
9999
try:
100100
callable_obj(*args, **kwargs)
@@ -165,7 +165,6 @@ def assert_warns(warning_class, func, *args, **kw):
165165
result : the return value of `func`
166166
167167
"""
168-
169168
# very important to avoid uncontrolled state propagation
170169
clean_warning_registry()
171170
with warnings.catch_warnings(record=True) as w:
@@ -282,17 +281,17 @@ def assert_no_warnings(func, *args, **kw):
282281
return result
283282

284283

285-
def ignore_warnings(obj=None, category=None):
286-
""" Context manager and decorator to ignore warnings
284+
def ignore_warnings(obj=None, category=Warning):
285+
"""Context manager and decorator to ignore warnings.
287286
288287
Note. Using this (in both variants) will clear all warnings
289288
from all python modules loaded. In case you need to test
290289
cross-module-warning-logging this is not your tool of choice.
291290
292291
Parameters
293292
----------
294-
category : warning class, defaults to None.
295-
The category to filter. If None, all categories will be muted.
293+
category : warning class, defaults to Warning.
294+
The category to filter. If Warning, all categories will be muted.
296295
297296
Examples
298297
--------
@@ -305,47 +304,44 @@ def ignore_warnings(obj=None, category=None):
305304
306305
>>> ignore_warnings(nasty_warn)()
307306
42
308-
309307
"""
310-
311-
def _ignore_warnings(fn):
312-
"""Decorator to catch and hide warnings without visual nesting"""
313-
@wraps(fn)
314-
def wrapper(*args, **kwargs):
315-
# very important to avoid uncontrolled state propagation
316-
clean_warning_registry()
317-
with warnings.catch_warnings():
318-
if category is None:
319-
warnings.simplefilter("ignore")
320-
else:
321-
warnings.simplefilter("ignore", category)
322-
return fn(*args, **kwargs)
323-
324-
return wrapper
325-
326308
if callable(obj):
327-
return _ignore_warnings(obj)
328-
elif category is None:
329-
return _IgnoreWarnings()
309+
return _IgnoreWarnings(category=category)(obj)
330310
else:
331-
return _ignore_warnings
311+
return _IgnoreWarnings(category=category)
332312

333313

334314
class _IgnoreWarnings(object):
315+
"""Improved and simplified Python warnings context manager and decorator.
335316
336-
"""Improved and simplified Python warnings context manager
337-
317+
This class allows to ignore the warnings raise by a function.
338318
Copied from Python 2.7.5 and modified as required.
339-
"""
340319
341-
def __init__(self):
342-
"""
320+
Parameters
321+
----------
322+
category : tuple of warning class, defaut to Warning
323+
The category to filter. By default, all the categories will be muted.
343324
344-
"""
325+
"""
326+
327+
def __init__(self, category):
345328
self._record = True
346329
self._module = sys.modules['warnings']
347330
self._entered = False
348331
self.log = []
332+
self.category = category
333+
334+
def __call__(self, fn):
335+
"""Decorator to catch and hide warnings without visual nesting."""
336+
@wraps(fn)
337+
def wrapper(*args, **kwargs):
338+
# very important to avoid uncontrolled state propagation
339+
clean_warning_registry()
340+
with warnings.catch_warnings():
341+
warnings.simplefilter("ignore", self.category)
342+
return fn(*args, **kwargs)
343+
344+
return wrapper
349345

350346
def __repr__(self):
351347
args = []
@@ -358,22 +354,13 @@ def __repr__(self):
358354

359355
def __enter__(self):
360356
clean_warning_registry() # be safe and not propagate state + chaos
361-
warnings.simplefilter('always')
357+
warnings.simplefilter("ignore", self.category)
362358
if self._entered:
363359
raise RuntimeError("Cannot enter %r twice" % self)
364360
self._entered = True
365361
self._filters = self._module.filters
366362
self._module.filters = self._filters[:]
367363
self._showwarning = self._module.showwarning
368-
if self._record:
369-
self.log = []
370-
371-
def showwarning(*args, **kwargs):
372-
self.log.append(warnings.WarningMessage(*args, **kwargs))
373-
self._module.showwarning = showwarning
374-
return self.log
375-
else:
376-
return None
377364

378365
def __exit__(self, *exc_info):
379366
if not self._entered:
@@ -412,7 +399,7 @@ def _assert_allclose(actual, desired, rtol=1e-7, atol=0,
412399

413400

414401
def assert_raise_message(exceptions, message, function, *args, **kwargs):
415-
"""Helper function to test error messages in exceptions
402+
"""Helper function to test error messages in exceptions.
416403
417404
Parameters
418405
----------
@@ -621,8 +608,8 @@ def is_abstract(c):
621608
all_classes = set(all_classes)
622609

623610
estimators = [c for c in all_classes
624-
if (issubclass(c[1], BaseEstimator)
625-
and c[0] != 'BaseEstimator')]
611+
if (issubclass(c[1], BaseEstimator) and
612+
c[0] != 'BaseEstimator')]
626613
# get rid of abstract base classes
627614
estimators = [c for c in estimators if not is_abstract(c[1])]
628615

@@ -652,7 +639,8 @@ def is_abstract(c):
652639
estimators = filtered_estimators
653640
if type_filter:
654641
raise ValueError("Parameter type_filter must be 'classifier', "
655-
"'regressor', 'transformer', 'cluster' or None, got"
642+
"'regressor', 'transformer', 'cluster' or "
643+
"None, got"
656644
" %s." % repr(type_filter))
657645

658646
# drop duplicates, sort for reproducibility
@@ -667,7 +655,6 @@ def set_random_state(estimator, random_state=0):
667655
Classes for whom random_state is deprecated are ignored. Currently DBSCAN
668656
is one such class.
669657
"""
670-
671658
if isinstance(estimator, DBSCAN):
672659
return
673660

@@ -676,8 +663,7 @@ def set_random_state(estimator, random_state=0):
676663

677664

678665
def if_matplotlib(func):
679-
"""Test decorator that skips test if matplotlib not installed. """
680-
666+
"""Test decorator that skips test if matplotlib not installed."""
681667
@wraps(func)
682668
def run_test(*args, **kwargs):
683669
try:
@@ -728,7 +714,7 @@ def func(*args, **kwargs):
728714

729715

730716
def if_safe_multiprocessing_with_blas(func):
731-
"""Decorator for tests involving both BLAS calls and multiprocessing
717+
"""Decorator for tests involving both BLAS calls and multiprocessing.
732718
733719
Under POSIX (e.g. Linux or OSX), using multiprocessing in conjunction with
734720
some implementation of BLAS (or other libraries that manage an internal
@@ -745,7 +731,6 @@ def if_safe_multiprocessing_with_blas(func):
745731
for multiprocessing to avoid this issue. However it can cause pickling
746732
errors on interactively defined functions. It therefore not enabled by
747733
default.
748-
749734
"""
750735
@wraps(func)
751736
def run_test(*args, **kwargs):
@@ -757,7 +742,7 @@ def run_test(*args, **kwargs):
757742

758743

759744
def clean_warning_registry():
760-
"""Safe way to reset warnings """
745+
"""Safe way to reset warnings."""
761746
warnings.resetwarnings()
762747
reg = "__warningregistry__"
763748
for mod_name, mod in list(sys.modules.items()):
@@ -780,7 +765,9 @@ def check_skip_travis():
780765

781766
def _delete_folder(folder_path, warn=False):
782767
"""Utility function to cleanup a temporary folder if still existing.
783-
Copy from joblib.pool (for independence)"""
768+
769+
Copy from joblib.pool (for independence).
770+
"""
784771
try:
785772
if os.path.exists(folder_path):
786773
# This can fail under windows,

sklearn/utils/tests/test_testing.py

+94-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
assert_no_warnings,
1414
assert_equal,
1515
set_random_state,
16-
assert_raise_message)
16+
assert_raise_message,
17+
ignore_warnings)
1718

1819
from sklearn.tree import DecisionTreeClassifier
1920
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
@@ -96,6 +97,98 @@ def _no_raise():
9697
"test", _no_raise)
9798

9899

100+
def test_ignore_warning():
101+
# This check that ignore_warning decorateur and context manager are working
102+
# as expected
103+
def _warning_function():
104+
warnings.warn("deprecation warning", DeprecationWarning)
105+
106+
def _multiple_warning_function():
107+
warnings.warn("deprecation warning", DeprecationWarning)
108+
warnings.warn("deprecation warning")
109+
110+
# Check the function directly
111+
assert_no_warnings(ignore_warnings(_warning_function))
112+
assert_no_warnings(ignore_warnings(_warning_function,
113+
category=DeprecationWarning))
114+
assert_warns(DeprecationWarning, ignore_warnings(_warning_function,
115+
category=UserWarning))
116+
assert_warns(UserWarning,
117+
ignore_warnings(_multiple_warning_function,
118+
category=DeprecationWarning))
119+
assert_warns(DeprecationWarning,
120+
ignore_warnings(_multiple_warning_function,
121+
category=UserWarning))
122+
assert_no_warnings(ignore_warnings(_warning_function,
123+
category=(DeprecationWarning,
124+
UserWarning)))
125+
126+
# Check the decorator
127+
@ignore_warnings
128+
def decorator_no_warning():
129+
_warning_function()
130+
_multiple_warning_function()
131+
132+
@ignore_warnings(category=(DeprecationWarning, UserWarning))
133+
def decorator_no_warning_multiple():
134+
_multiple_warning_function()
135+
136+
@ignore_warnings(category=DeprecationWarning)
137+
def decorator_no_deprecation_warning():
138+
_warning_function()
139+
140+
@ignore_warnings(category=UserWarning)
141+
def decorator_no_user_warning():
142+
_warning_function()
143+
144+
@ignore_warnings(category=DeprecationWarning)
145+
def decorator_no_deprecation_multiple_warning():
146+
_multiple_warning_function()
147+
148+
@ignore_warnings(category=UserWarning)
149+
def decorator_no_user_multiple_warning():
150+
_multiple_warning_function()
151+
152+
assert_no_warnings(decorator_no_warning)
153+
assert_no_warnings(decorator_no_warning_multiple)
154+
assert_no_warnings(decorator_no_deprecation_warning)
155+
assert_warns(DeprecationWarning, decorator_no_user_warning)
156+
assert_warns(UserWarning, decorator_no_deprecation_multiple_warning)
157+
assert_warns(DeprecationWarning, decorator_no_user_multiple_warning)
158+
159+
# Check the context manager
160+
def context_manager_no_warning():
161+
with ignore_warnings():
162+
_warning_function()
163+
164+
def context_manager_no_warning_multiple():
165+
with ignore_warnings(category=(DeprecationWarning, UserWarning)):
166+
_multiple_warning_function()
167+
168+
def context_manager_no_deprecation_warning():
169+
with ignore_warnings(category=DeprecationWarning):
170+
_warning_function()
171+
172+
def context_manager_no_user_warning():
173+
with ignore_warnings(category=UserWarning):
174+
_warning_function()
175+
176+
def context_manager_no_deprecation_multiple_warning():
177+
with ignore_warnings(category=DeprecationWarning):
178+
_multiple_warning_function()
179+
180+
def context_manager_no_user_multiple_warning():
181+
with ignore_warnings(category=UserWarning):
182+
_multiple_warning_function()
183+
184+
assert_no_warnings(context_manager_no_warning)
185+
assert_no_warnings(context_manager_no_warning_multiple)
186+
assert_no_warnings(context_manager_no_deprecation_warning)
187+
assert_warns(DeprecationWarning, context_manager_no_user_warning)
188+
assert_warns(UserWarning, context_manager_no_deprecation_multiple_warning)
189+
assert_warns(DeprecationWarning, context_manager_no_user_multiple_warning)
190+
191+
99192
# This class is inspired from numpy 1.7 with an alteration to check
100193
# the reset warning filters after calls to assert_warns.
101194
# This assert_warns behavior is specific to scikit-learn because

0 commit comments

Comments
 (0)