diff --git a/sklearn/ensemble/_base.py b/sklearn/ensemble/_base.py index 8410be81c6cbc..5483206de51d5 100644 --- a/sklearn/ensemble/_base.py +++ b/sklearn/ensemble/_base.py @@ -10,8 +10,9 @@ from joblib import effective_n_jobs from ..base import BaseEstimator, MetaEstimatorMixin, clone, is_classifier, is_regressor -from ..utils import Bunch, _print_elapsed_time, check_random_state +from ..utils import Bunch, check_random_state from ..utils._tags import _safe_tags +from ..utils._user_interface import _print_elapsed_time from ..utils.metadata_routing import _routing_enabled from ..utils.metaestimators import _BaseComposition diff --git a/sklearn/multioutput.py b/sklearn/multioutput.py index f3b73628a66f5..e0da38357e792 100644 --- a/sklearn/multioutput.py +++ b/sklearn/multioutput.py @@ -31,9 +31,10 @@ is_classifier, ) from .model_selection import cross_val_predict -from .utils import Bunch, _print_elapsed_time, check_random_state +from .utils import Bunch, check_random_state from .utils._param_validation import HasMethods, StrOptions from .utils._response import _get_response_values +from .utils._user_interface import _print_elapsed_time from .utils.metadata_routing import ( MetadataRouter, MethodMapping, diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 83c62be61b21a..4ee0622c699b7 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -18,7 +18,7 @@ from .base import TransformerMixin, _fit_context, clone from .exceptions import NotFittedError from .preprocessing import FunctionTransformer -from .utils import Bunch, _print_elapsed_time +from .utils import Bunch from .utils._estimator_html_repr import _VisualBlock from .utils._metadata_requests import METHODS from .utils._param_validation import HasMethods, Hidden @@ -27,6 +27,7 @@ _safe_set_output, ) from .utils._tags import _safe_tags +from .utils._user_interface import _print_elapsed_time from .utils.metadata_routing import ( MetadataRouter, MethodMapping, diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 5131f7e7ed6e6..db5021570451d 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -4,9 +4,7 @@ import platform import struct -import timeit from collections.abc import Sequence -from contextlib import contextmanager import numpy as np @@ -145,55 +143,3 @@ def _to_object_array(sequence): out = np.empty(len(sequence), dtype=object) out[:] = sequence return out - - -def _message_with_time(source, message, time): - """Create one line message for logging purposes. - - Parameters - ---------- - source : str - String indicating the source or the reference of the message. - - message : str - Short message. - - time : int - Time in seconds. - """ - start_message = "[%s] " % source - - # adapted from joblib.logger.short_format_time without the Windows -.1s - # adjustment - if time > 60: - time_str = "%4.1fmin" % (time / 60) - else: - time_str = " %5.1fs" % time - end_message = " %s, total=%s" % (message, time_str) - dots_len = 70 - len(start_message) - len(end_message) - return "%s%s%s" % (start_message, dots_len * ".", end_message) - - -@contextmanager -def _print_elapsed_time(source, message=None): - """Log elapsed time to stdout when the context is exited. - - Parameters - ---------- - source : str - String indicating the source or the reference of the message. - - message : str, default=None - Short message. If None, nothing will be printed. - - Returns - ------- - context_manager - Prints elapsed time upon exit if verbose. - """ - if message is None: - yield - else: - start = timeit.default_timer() - yield - print(_message_with_time(source, message, timeit.default_timer() - start)) diff --git a/sklearn/utils/_user_interface.py b/sklearn/utils/_user_interface.py new file mode 100644 index 0000000000000..09e6f2b7bf849 --- /dev/null +++ b/sklearn/utils/_user_interface.py @@ -0,0 +1,54 @@ +import timeit +from contextlib import contextmanager + + +def _message_with_time(source, message, time): + """Create one line message for logging purposes. + + Parameters + ---------- + source : str + String indicating the source or the reference of the message. + + message : str + Short message. + + time : int + Time in seconds. + """ + start_message = "[%s] " % source + + # adapted from joblib.logger.short_format_time without the Windows -.1s + # adjustment + if time > 60: + time_str = "%4.1fmin" % (time / 60) + else: + time_str = " %5.1fs" % time + end_message = " %s, total=%s" % (message, time_str) + dots_len = 70 - len(start_message) - len(end_message) + return "%s%s%s" % (start_message, dots_len * ".", end_message) + + +@contextmanager +def _print_elapsed_time(source, message=None): + """Log elapsed time to stdout when the context is exited. + + Parameters + ---------- + source : str + String indicating the source or the reference of the message. + + message : str, default=None + Short message. If None, nothing will be printed. + + Returns + ------- + context_manager + Prints elapsed time upon exit if verbose. + """ + if message is None: + yield + else: + start = timeit.default_timer() + yield + print(_message_with_time(source, message, timeit.default_timer() - start)) diff --git a/sklearn/utils/tests/test_user_interface.py b/sklearn/utils/tests/test_user_interface.py new file mode 100644 index 0000000000000..9aa9d41ba9aef --- /dev/null +++ b/sklearn/utils/tests/test_user_interface.py @@ -0,0 +1,65 @@ +import string +import timeit + +import pytest + +from sklearn.utils._user_interface import _message_with_time, _print_elapsed_time + + +@pytest.mark.parametrize( + ["source", "message", "is_long"], + [ + ("ABC", string.ascii_lowercase, False), + ("ABCDEF", string.ascii_lowercase, False), + ("ABC", string.ascii_lowercase * 3, True), + ("ABC" * 10, string.ascii_lowercase, True), + ("ABC", string.ascii_lowercase + "\u1048", False), + ], +) +@pytest.mark.parametrize( + ["time", "time_str"], + [ + (0.2, " 0.2s"), + (20, " 20.0s"), + (2000, "33.3min"), + (20000, "333.3min"), + ], +) +def test_message_with_time(source, message, is_long, time, time_str): + out = _message_with_time(source, message, time) + if is_long: + assert len(out) > 70 + else: + assert len(out) == 70 + + assert out.startswith("[" + source + "] ") + out = out[len(source) + 3 :] + + assert out.endswith(time_str) + out = out[: -len(time_str)] + assert out.endswith(", total=") + out = out[: -len(", total=")] + assert out.endswith(message) + out = out[: -len(message)] + assert out.endswith(" ") + out = out[:-1] + + if is_long: + assert not out + else: + assert list(set(out)) == ["."] + + +@pytest.mark.parametrize( + ["message", "expected"], + [ + ("hello", _message_with_time("ABC", "hello", 0.1) + "\n"), + ("", _message_with_time("ABC", "", 0.1) + "\n"), + (None, ""), + ], +) +def test_print_elapsed_time(message, expected, capsys, monkeypatch): + monkeypatch.setattr(timeit, "default_timer", lambda: 0) + with _print_elapsed_time("ABC", message): + monkeypatch.setattr(timeit, "default_timer", lambda: 0.1) + assert capsys.readouterr().out == expected diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index ccc3738e8d733..c2e2d01ee39a5 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -1,13 +1,9 @@ -import string -import timeit import warnings import numpy as np import pytest from sklearn.utils import ( - _message_with_time, - _print_elapsed_time, _to_object_array, check_random_state, column_or_1d, @@ -113,65 +109,6 @@ def test_column_or_1d(): column_or_1d(y) -@pytest.mark.parametrize( - ["source", "message", "is_long"], - [ - ("ABC", string.ascii_lowercase, False), - ("ABCDEF", string.ascii_lowercase, False), - ("ABC", string.ascii_lowercase * 3, True), - ("ABC" * 10, string.ascii_lowercase, True), - ("ABC", string.ascii_lowercase + "\u1048", False), - ], -) -@pytest.mark.parametrize( - ["time", "time_str"], - [ - (0.2, " 0.2s"), - (20, " 20.0s"), - (2000, "33.3min"), - (20000, "333.3min"), - ], -) -def test_message_with_time(source, message, is_long, time, time_str): - out = _message_with_time(source, message, time) - if is_long: - assert len(out) > 70 - else: - assert len(out) == 70 - - assert out.startswith("[" + source + "] ") - out = out[len(source) + 3 :] - - assert out.endswith(time_str) - out = out[: -len(time_str)] - assert out.endswith(", total=") - out = out[: -len(", total=")] - assert out.endswith(message) - out = out[: -len(message)] - assert out.endswith(" ") - out = out[:-1] - - if is_long: - assert not out - else: - assert list(set(out)) == ["."] - - -@pytest.mark.parametrize( - ["message", "expected"], - [ - ("hello", _message_with_time("ABC", "hello", 0.1) + "\n"), - ("", _message_with_time("ABC", "", 0.1) + "\n"), - (None, ""), - ], -) -def test_print_elapsed_time(message, expected, capsys, monkeypatch): - monkeypatch.setattr(timeit, "default_timer", lambda: 0) - with _print_elapsed_time("ABC", message): - monkeypatch.setattr(timeit, "default_timer", lambda: 0.1) - assert capsys.readouterr().out == expected - - @pytest.mark.parametrize( "value, result", [