Skip to content

MAINT cleanup utils.__init__: move _print_elapsed_time into dedicated submodule #28662

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion sklearn/ensemble/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from joblib import effective_n_jobs

from ..base import BaseEstimator, MetaEstimatorMixin, clone, is_classifier, is_regressor
from ..utils import Bunch, _print_elapsed_time, check_random_state
from ..utils import Bunch, check_random_state
from ..utils._tags import _safe_tags
from ..utils._user_interface import _print_elapsed_time
from ..utils.metadata_routing import _routing_enabled
from ..utils.metaestimators import _BaseComposition

Expand Down
3 changes: 2 additions & 1 deletion sklearn/multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@
is_classifier,
)
from .model_selection import cross_val_predict
from .utils import Bunch, _print_elapsed_time, check_random_state
from .utils import Bunch, check_random_state
from .utils._param_validation import HasMethods, StrOptions
from .utils._response import _get_response_values
from .utils._user_interface import _print_elapsed_time
from .utils.metadata_routing import (
MetadataRouter,
MethodMapping,
Expand Down
3 changes: 2 additions & 1 deletion sklearn/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .base import TransformerMixin, _fit_context, clone
from .exceptions import NotFittedError
from .preprocessing import FunctionTransformer
from .utils import Bunch, _print_elapsed_time
from .utils import Bunch
from .utils._estimator_html_repr import _VisualBlock
from .utils._metadata_requests import METHODS
from .utils._param_validation import HasMethods, Hidden
Expand All @@ -27,6 +27,7 @@
_safe_set_output,
)
from .utils._tags import _safe_tags
from .utils._user_interface import _print_elapsed_time
from .utils.metadata_routing import (
MetadataRouter,
MethodMapping,
Expand Down
54 changes: 0 additions & 54 deletions sklearn/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@

import platform
import struct
import timeit
from collections.abc import Sequence
from contextlib import contextmanager

import numpy as np

Expand Down Expand Up @@ -145,55 +143,3 @@ def _to_object_array(sequence):
out = np.empty(len(sequence), dtype=object)
out[:] = sequence
return out


def _message_with_time(source, message, time):
"""Create one line message for logging purposes.

Parameters
----------
source : str
String indicating the source or the reference of the message.

message : str
Short message.

time : int
Time in seconds.
"""
start_message = "[%s] " % source

# adapted from joblib.logger.short_format_time without the Windows -.1s
# adjustment
if time > 60:
time_str = "%4.1fmin" % (time / 60)
else:
time_str = " %5.1fs" % time
end_message = " %s, total=%s" % (message, time_str)
dots_len = 70 - len(start_message) - len(end_message)
return "%s%s%s" % (start_message, dots_len * ".", end_message)


@contextmanager
def _print_elapsed_time(source, message=None):
"""Log elapsed time to stdout when the context is exited.

Parameters
----------
source : str
String indicating the source or the reference of the message.

message : str, default=None
Short message. If None, nothing will be printed.

Returns
-------
context_manager
Prints elapsed time upon exit if verbose.
"""
if message is None:
yield
else:
start = timeit.default_timer()
yield
print(_message_with_time(source, message, timeit.default_timer() - start))
54 changes: 54 additions & 0 deletions sklearn/utils/_user_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import timeit
from contextlib import contextmanager


def _message_with_time(source, message, time):
"""Create one line message for logging purposes.

Parameters
----------
source : str
String indicating the source or the reference of the message.

message : str
Short message.

time : int
Time in seconds.
"""
start_message = "[%s] " % source

# adapted from joblib.logger.short_format_time without the Windows -.1s
# adjustment
if time > 60:
time_str = "%4.1fmin" % (time / 60)
else:
time_str = " %5.1fs" % time
end_message = " %s, total=%s" % (message, time_str)
dots_len = 70 - len(start_message) - len(end_message)
return "%s%s%s" % (start_message, dots_len * ".", end_message)


@contextmanager
def _print_elapsed_time(source, message=None):
"""Log elapsed time to stdout when the context is exited.

Parameters
----------
source : str
String indicating the source or the reference of the message.

message : str, default=None
Short message. If None, nothing will be printed.

Returns
-------
context_manager
Prints elapsed time upon exit if verbose.
"""
if message is None:
yield
else:
start = timeit.default_timer()
yield
print(_message_with_time(source, message, timeit.default_timer() - start))
65 changes: 65 additions & 0 deletions sklearn/utils/tests/test_user_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import string
import timeit

import pytest

from sklearn.utils._user_interface import _message_with_time, _print_elapsed_time


@pytest.mark.parametrize(
["source", "message", "is_long"],
[
("ABC", string.ascii_lowercase, False),
("ABCDEF", string.ascii_lowercase, False),
("ABC", string.ascii_lowercase * 3, True),
("ABC" * 10, string.ascii_lowercase, True),
("ABC", string.ascii_lowercase + "\u1048", False),
],
)
@pytest.mark.parametrize(
["time", "time_str"],
[
(0.2, " 0.2s"),
(20, " 20.0s"),
(2000, "33.3min"),
(20000, "333.3min"),
],
)
def test_message_with_time(source, message, is_long, time, time_str):
out = _message_with_time(source, message, time)
if is_long:
assert len(out) > 70
else:
assert len(out) == 70

assert out.startswith("[" + source + "] ")
out = out[len(source) + 3 :]

assert out.endswith(time_str)
out = out[: -len(time_str)]
assert out.endswith(", total=")
out = out[: -len(", total=")]
assert out.endswith(message)
out = out[: -len(message)]
assert out.endswith(" ")
out = out[:-1]

if is_long:
assert not out
else:
assert list(set(out)) == ["."]


@pytest.mark.parametrize(
["message", "expected"],
[
("hello", _message_with_time("ABC", "hello", 0.1) + "\n"),
("", _message_with_time("ABC", "", 0.1) + "\n"),
(None, ""),
],
)
def test_print_elapsed_time(message, expected, capsys, monkeypatch):
monkeypatch.setattr(timeit, "default_timer", lambda: 0)
with _print_elapsed_time("ABC", message):
monkeypatch.setattr(timeit, "default_timer", lambda: 0.1)
assert capsys.readouterr().out == expected
63 changes: 0 additions & 63 deletions sklearn/utils/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
import string
import timeit
import warnings

import numpy as np
import pytest

from sklearn.utils import (
_message_with_time,
_print_elapsed_time,
_to_object_array,
check_random_state,
column_or_1d,
Expand Down Expand Up @@ -113,65 +109,6 @@ def test_column_or_1d():
column_or_1d(y)


@pytest.mark.parametrize(
["source", "message", "is_long"],
[
("ABC", string.ascii_lowercase, False),
("ABCDEF", string.ascii_lowercase, False),
("ABC", string.ascii_lowercase * 3, True),
("ABC" * 10, string.ascii_lowercase, True),
("ABC", string.ascii_lowercase + "\u1048", False),
],
)
@pytest.mark.parametrize(
["time", "time_str"],
[
(0.2, " 0.2s"),
(20, " 20.0s"),
(2000, "33.3min"),
(20000, "333.3min"),
],
)
def test_message_with_time(source, message, is_long, time, time_str):
out = _message_with_time(source, message, time)
if is_long:
assert len(out) > 70
else:
assert len(out) == 70

assert out.startswith("[" + source + "] ")
out = out[len(source) + 3 :]

assert out.endswith(time_str)
out = out[: -len(time_str)]
assert out.endswith(", total=")
out = out[: -len(", total=")]
assert out.endswith(message)
out = out[: -len(message)]
assert out.endswith(" ")
out = out[:-1]

if is_long:
assert not out
else:
assert list(set(out)) == ["."]


@pytest.mark.parametrize(
["message", "expected"],
[
("hello", _message_with_time("ABC", "hello", 0.1) + "\n"),
("", _message_with_time("ABC", "", 0.1) + "\n"),
(None, ""),
],
)
def test_print_elapsed_time(message, expected, capsys, monkeypatch):
monkeypatch.setattr(timeit, "default_timer", lambda: 0)
with _print_elapsed_time("ABC", message):
monkeypatch.setattr(timeit, "default_timer", lambda: 0.1)
assert capsys.readouterr().out == expected


@pytest.mark.parametrize(
"value, result",
[
Expand Down