Skip to content

Commit ab890a3

Browse files
committed
Make console logger table more compact
1 parent 99606e4 commit ab890a3

File tree

1 file changed

+24
-9
lines changed

1 file changed

+24
-9
lines changed

spacy/training/loggers.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,25 @@
1111
from ..language import Language # noqa: F401
1212

1313

14+
def setup_table(
15+
*, cols: List[str], widths: List[int], max_width: int = 13
16+
) -> Tuple[List[str], List[int], List[str]]:
17+
final_cols = []
18+
final_widths = []
19+
for col, width in zip(cols, widths):
20+
if len(col) > max_width:
21+
col = col[: max_width - 3] + "..." # shorten column if too long
22+
final_cols.append(col.upper())
23+
final_widths.append(max(len(col), width))
24+
return final_cols, final_widths, ["r" for _ in final_widths]
25+
26+
1427
@registry.loggers("spacy.ConsoleLogger.v1")
1528
def console_logger(progress_bar: bool = False):
1629
def setup_printer(
1730
nlp: "Language", stdout: IO = sys.stdout, stderr: IO = sys.stderr
1831
) -> Tuple[Callable[[Optional[Dict[str, Any]]], None], Callable[[], None]]:
32+
write = lambda text: stdout.write(f"{text}\n")
1933
msg = Printer(no_print=True)
2034
# ensure that only trainable components are logged
2135
logged_pipes = [
@@ -26,15 +40,14 @@ def setup_printer(
2640
eval_frequency = nlp.config["training"]["eval_frequency"]
2741
score_weights = nlp.config["training"]["score_weights"]
2842
score_cols = [col for col, value in score_weights.items() if value is not None]
29-
score_widths = [max(len(col), 6) for col in score_cols]
3043
loss_cols = [f"Loss {pipe}" for pipe in logged_pipes]
31-
loss_widths = [max(len(col), 8) for col in loss_cols]
32-
table_header = ["E", "#"] + loss_cols + score_cols + ["Score"]
33-
table_header = [col.upper() for col in table_header]
34-
table_widths = [3, 6] + loss_widths + score_widths + [6]
35-
table_aligns = ["r" for _ in table_widths]
36-
stdout.write(msg.row(table_header, widths=table_widths) + "\n")
37-
stdout.write(msg.row(["-" * width for width in table_widths]) + "\n")
44+
spacing = 2
45+
table_header, table_widths, table_aligns = setup_table(
46+
cols=["E", "#"] + loss_cols + score_cols + ["Score"],
47+
widths=[3, 6] + [8 for _ in loss_cols] + [6 for _ in score_cols] + [6],
48+
)
49+
write(msg.row(table_header, widths=table_widths, spacing=spacing))
50+
write(msg.row(["-" * width for width in table_widths], spacing=spacing))
3851
progress = None
3952

4053
def log_step(info: Optional[Dict[str, Any]]) -> None:
@@ -70,7 +83,9 @@ def log_step(info: Optional[Dict[str, Any]]) -> None:
7083
)
7184
if progress is not None:
7285
progress.close()
73-
stdout.write(msg.row(data, widths=table_widths, aligns=table_aligns) + "\n")
86+
write(
87+
msg.row(data, widths=table_widths, aligns=table_aligns, spacing=spacing)
88+
)
7489
if progress_bar:
7590
# Set disable=None, so that it disables on non-TTY
7691
progress = tqdm.tqdm(

0 commit comments

Comments
 (0)