diff --git a/HISTORY.rst b/HISTORY.rst index 8e0b1a588..0dab17d87 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -3,6 +3,21 @@ Release History --------------- +1.1.2 (2024-05-01) +++++++++++++++++++ + +- Fix #1383 Revert lxml<=4.9.2 pin that breaks Python 3.12 install +- Fix #1385 Support use of Part._rels by python-docx-template +- Add support and testing for Python 3.12 + +1.1.1 (2024-04-29) +++++++++++++++++++ + +- Fix #531, #1146 Index error on table with misaligned borders +- Fix #1335 Tolerate invalid float value in bottom-margin +- Fix #1337 Do not require typing-extensions at runtime + + 1.1.0 (2023-11-03) ++++++++++++++++++ diff --git a/Makefile b/Makefile index 0478b2bce..da0d7a4ac 100644 --- a/Makefile +++ b/Makefile @@ -30,7 +30,8 @@ build: $(BUILD) clean: - find . -type f -name \*.pyc -exec rm {} \; + # find . -type f -name \*.pyc -exec rm {} \; + fd -e pyc -I -x rm rm -rf dist *.egg-info .coverage .DS_Store cleandocs: diff --git a/docs/index.rst b/docs/index.rst index cdb8b5455..1b1029787 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -74,6 +74,7 @@ User Guide user/install user/quickstart user/documents + user/tables user/text user/sections user/hdrftr diff --git a/docs/user/tables.rst b/docs/user/tables.rst new file mode 100644 index 000000000..40ef20933 --- /dev/null +++ b/docs/user/tables.rst @@ -0,0 +1,202 @@ +.. _tables: + +Working with Tables +=================== + +Word provides sophisticated capabilities to create tables. As usual, this power comes with +additional conceptual complexity. + +This complexity becomes most apparent when *reading* tables, in particular from documents drawn from +the wild where there is limited or no prior knowledge as to what the tables might contain or how +they might be structured. + +These are some of the important concepts you'll need to understand. + + +Concept: Simple (uniform) tables +-------------------------------- + +:: + + +---+---+---+ + | a | b | c | + +---+---+---+ + | d | e | f | + +---+---+---+ + | g | h | i | + +---+---+---+ + +The basic concept of a table is intuitive enough. You have *rows* and *columns*, and at each (row, +column) position is a different *cell*. It can be described as a *grid* or a *matrix*. Let's call +this concept a *uniform table*. A relational database table and a Pandas dataframe are both examples +of a uniform table. + +The following invariants apply to uniform tables: + +* Each row has the same number of cells, one for each column. +* Each column has the same number of cells, one for each row. + + +Complication 1: Merged Cells +---------------------------- + +:: + + +---+---+---+ +---+---+---+ + | a | b | | | b | c | + +---+---+---+ + a +---+---+ + | c | d | e | | | d | e | + +---+---+---+ +---+---+---+ + | f | g | h | | f | g | h | + +---+---+---+ +---+---+---+ + +While very suitable for data processing, a uniform table lacks expressive power desireable for +tables intended for a human reader. + +Perhaps the most important characteristic a uniform table lacks is *merged cells*. It is very common +to want to group multiple cells into one, for example to form a column-group heading or provide the +same value for a sequence of cells rather than repeat it for each cell. These make a rendered table +more *readable* by reducing the cognitive load on the human reader and make certain relationships +explicit that might easily be missed otherwise. + +Unfortunately, accommodating merged cells breaks both the invariants of a uniform table: + +* Each row can have a different number of cells. +* Each column can have a different number of cells. + +This challenges reading table contents programatically. One might naturally want to read the table +into a uniform matrix data structure like a 3 x 3 "2D array" (list of lists perhaps), but this is +not directly possible when the table is not known to be uniform. + + +Concept: The layout grid +------------------------ + +:: + + + - + - + - + + | | | | + + - + - + - + + | | | | + + - + - + - + + | | | | + + - + - + - + + +In Word, each table has a *layout grid*. + +- The layout grid is *uniform*. There is a layout position for every (layout-row, layout-column) + pair. +- The layout grid itself is not visible. However it is represented and referenced by certain + elements and attributes within the table XML +- Each table cell is located at a layout-grid position; i.e. the top-left corner of each cell is the + top-left corner of a layout-grid cell. +- Each table cell occupies one or more whole layout-grid cells. A merged cell will occupy multiple + layout-grid cells. No table cell can occupy a partial layout-grid cell. +- Another way of saying this is that every vertical boundary (left and right) of a cell aligns with + a layout-grid vertical boundary, likewise for horizontal boundaries. But not all layout-grid + boundaries need be occupied by a cell boundary of the table. + + +Complication 2: Omitted Cells +----------------------------- + +:: + + +---+---+ +---+---+---+ + | a | b | | a | b | c | + +---+---+---+ +---+---+---+ + | c | d | | d | + +---+---+ +---+---+---+ + | e | | e | f | g | + +---+ +---+---+---+ + +Word is unusual in that it allows cells to be omitted from the beginning or end (but not the middle) +of a row. A typical practical example is a table with both a row of column headings and a column of +row headings, but no top-left cell (position 0, 0), such as this XOR truth table. + +:: + + +---+---+ + | T | F | + +---+---+---+ + | T | F | T | + +---+---+---+ + | F | T | F | + +---+---+---+ + +In `python-docx`, omitted cells in a |_Row| object are represented by the ``.grid_cols_before`` and +``.grid_cols_after`` properties. In the example above, for the first row, ``.grid_cols_before`` +would equal ``1`` and ``.grid_cols_after`` would equal ``0``. + +Note that omitted cells are not just "empty" cells. They represent layout-grid positions that are +unoccupied by a cell and they cannot be represented by a |_Cell| object. This distinction becomes +important when trying to produce a uniform representation (e.g. a 2D array) for an arbitrary Word +table. + + +Concept: `python-docx` approximates uniform tables by default +------------------------------------------------------------- + +To accurately represent an arbitrary table would require a complex graph data structure. Navigating +this data structure would be at least as complex as navigating the `python-docx` object graph for a +table. When extracting content from a collection of arbitrary Word files, such as for indexing the +document, it is common to choose a simpler data structure and *approximate* the table in that +structure. + +Reflecting on how a relational table or dataframe represents tabular information, a straightforward +approximation would simply repeat merged-cell values for each layout-grid cell occupied by the +merged cell:: + + + +---+---+---+ +---+---+---+ + | a | b | -> | a | a | b | + +---+---+---+ +---+---+---+ + | | d | e | -> | c | d | e | + + c +---+---+ +---+---+---+ + | | f | g | -> | c | f | g | + +---+---+---+ +---+---+---+ + +This is what ``_Row.cells`` does by default. Conceptually:: + + >>> [tuple(c.text for c in r.cells) for r in table.rows] + [ + (a, a, b), + (c, d, e), + (c, f, g), + ] + +Note this only produces a uniform "matrix" of cells when there are no omitted cells. Dealing with +omitted cells requires a more sophisticated approach when maintaining column integrity is required:: + + # +---+---+ + # | a | b | + # +---+---+---+ + # | c | d | + # +---+---+ + # | e | + # +---+ + + def iter_row_cell_texts(row: _Row) -> Iterator[str]: + for _ in range(row.grid_cols_before): + yield "" + for c in row.cells: + yield c.text + for _ in range(row.grid_cols_after): + yield "" + + >>> [tuple(iter_row_cell_texts(r)) for r in table.rows] + [ + ("", "a", "b"), + ("c", "d", ""), + ("", "e", ""), + ] + + +Complication 3: Tables are Recursive +------------------------------------ + +Further complicating table processing is their recursive nature. In Word, as in HTML, a table cell +can itself include one or more tables. + +These can be detected using ``_Cell.tables`` or ``_Cell.iter_inner_content()``. The latter preserves +the document order of the table with respect to paragraphs also in the cell. diff --git a/features/steps/cell.py b/features/steps/cell.py deleted file mode 100644 index 10896872b..000000000 --- a/features/steps/cell.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Step implementations for table cell-related features.""" - -from behave import given, then, when - -from docx import Document - -from helpers import test_docx - -# given =================================================== - - -@given("a table cell") -def given_a_table_cell(context): - table = Document(test_docx("tbl-2x2-table")).tables[0] - context.cell = table.cell(0, 0) - - -# when ===================================================== - - -@when("I add a 2 x 2 table into the first cell") -def when_I_add_a_2x2_table_into_the_first_cell(context): - context.table_ = context.cell.add_table(2, 2) - - -@when("I assign a string to the cell text attribute") -def when_assign_string_to_cell_text_attribute(context): - cell = context.cell - text = "foobar" - cell.text = text - context.expected_text = text - - -# then ===================================================== - - -@then("cell.tables[0] is a 2 x 2 table") -def then_cell_tables_0_is_a_2x2_table(context): - cell = context.cell - table = cell.tables[0] - assert len(table.rows) == 2 - assert len(table.columns) == 2 - - -@then("the cell contains the string I assigned") -def then_cell_contains_string_assigned(context): - cell, expected_text = context.cell, context.expected_text - text = cell.paragraphs[0].runs[0].text - msg = "expected '%s', got '%s'" % (expected_text, text) - assert text == expected_text, msg diff --git a/features/steps/coreprops.py b/features/steps/coreprops.py index 0f6b6a854..90467fb67 100644 --- a/features/steps/coreprops.py +++ b/features/steps/coreprops.py @@ -1,8 +1,9 @@ """Gherkin step implementations for core properties-related features.""" -from datetime import datetime, timedelta +import datetime as dt from behave import given, then, when +from behave.runner import Context from docx import Document from docx.opc.coreprops import CoreProperties @@ -13,12 +14,12 @@ @given("a document having known core properties") -def given_a_document_having_known_core_properties(context): +def given_a_document_having_known_core_properties(context: Context): context.document = Document(test_docx("doc-coreprops")) @given("a document having no core properties part") -def given_a_document_having_no_core_properties_part(context): +def given_a_document_having_no_core_properties_part(context: Context): context.document = Document(test_docx("doc-no-coreprops")) @@ -26,24 +27,24 @@ def given_a_document_having_no_core_properties_part(context): @when("I access the core properties object") -def when_I_access_the_core_properties_object(context): +def when_I_access_the_core_properties_object(context: Context): context.document.core_properties @when("I assign new values to the properties") -def when_I_assign_new_values_to_the_properties(context): +def when_I_assign_new_values_to_the_properties(context: Context): context.propvals = ( ("author", "Creator"), ("category", "Category"), ("comments", "Description"), ("content_status", "Content Status"), - ("created", datetime(2013, 6, 15, 12, 34, 56)), + ("created", dt.datetime(2013, 6, 15, 12, 34, 56, tzinfo=dt.timezone.utc)), ("identifier", "Identifier"), ("keywords", "key; word; keyword"), ("language", "Language"), ("last_modified_by", "Last Modified By"), - ("last_printed", datetime(2013, 6, 15, 12, 34, 56)), - ("modified", datetime(2013, 6, 15, 12, 34, 56)), + ("last_printed", dt.datetime(2013, 6, 15, 12, 34, 56, tzinfo=dt.timezone.utc)), + ("modified", dt.datetime(2013, 6, 15, 12, 34, 56, tzinfo=dt.timezone.utc)), ("revision", 9), ("subject", "Subject"), ("title", "Title"), @@ -58,39 +59,39 @@ def when_I_assign_new_values_to_the_properties(context): @then("a core properties part with default values is added") -def then_a_core_properties_part_with_default_values_is_added(context): +def then_a_core_properties_part_with_default_values_is_added(context: Context): core_properties = context.document.core_properties assert core_properties.title == "Word Document" assert core_properties.last_modified_by == "python-docx" assert core_properties.revision == 1 # core_properties.modified only stores time with seconds resolution, so # comparison needs to be a little loose (within two seconds) - modified_timedelta = datetime.utcnow() - core_properties.modified - max_expected_timedelta = timedelta(seconds=2) + modified_timedelta = dt.datetime.now(dt.timezone.utc) - core_properties.modified + max_expected_timedelta = dt.timedelta(seconds=2) assert modified_timedelta < max_expected_timedelta @then("I can access the core properties object") -def then_I_can_access_the_core_properties_object(context): +def then_I_can_access_the_core_properties_object(context: Context): document = context.document core_properties = document.core_properties assert isinstance(core_properties, CoreProperties) @then("the core property values match the known values") -def then_the_core_property_values_match_the_known_values(context): +def then_the_core_property_values_match_the_known_values(context: Context): known_propvals = ( ("author", "Steve Canny"), ("category", "Category"), ("comments", "Description"), ("content_status", "Content Status"), - ("created", datetime(2014, 12, 13, 22, 2, 0)), + ("created", dt.datetime(2014, 12, 13, 22, 2, 0, tzinfo=dt.timezone.utc)), ("identifier", "Identifier"), ("keywords", "key; word; keyword"), ("language", "Language"), ("last_modified_by", "Steve Canny"), - ("last_printed", datetime(2014, 12, 13, 22, 2, 42)), - ("modified", datetime(2014, 12, 13, 22, 6, 0)), + ("last_printed", dt.datetime(2014, 12, 13, 22, 2, 42, tzinfo=dt.timezone.utc)), + ("modified", dt.datetime(2014, 12, 13, 22, 6, 0, tzinfo=dt.timezone.utc)), ("revision", 2), ("subject", "Subject"), ("title", "Title"), @@ -106,7 +107,7 @@ def then_the_core_property_values_match_the_known_values(context): @then("the core property values match the new values") -def then_the_core_property_values_match_the_new_values(context): +def then_the_core_property_values_match_the_new_values(context: Context): core_properties = context.document.core_properties for name, expected_value in context.propvals: value = getattr(core_properties, name) diff --git a/features/steps/table.py b/features/steps/table.py index 95f2fab75..38d49ee0a 100644 --- a/features/steps/table.py +++ b/features/steps/table.py @@ -1,6 +1,9 @@ +# pyright: reportPrivateUsage=false + """Step implementations for table-related features.""" from behave import given, then, when +from behave.runner import Context from docx import Document from docx.enum.table import ( @@ -10,7 +13,7 @@ WD_TABLE_DIRECTION, ) from docx.shared import Inches -from docx.table import _Column, _Columns, _Row, _Rows +from docx.table import Table, _Cell, _Column, _Columns, _Row, _Rows from helpers import test_docx @@ -18,12 +21,12 @@ @given("a 2 x 2 table") -def given_a_2x2_table(context): +def given_a_2x2_table(context: Context): context.table_ = Document().add_table(rows=2, cols=2) @given("a 3x3 table having {span_state}") -def given_a_3x3_table_having_span_state(context, span_state): +def given_a_3x3_table_having_span_state(context: Context, span_state: str): table_idx = { "only uniform cells": 0, "a horizontal span": 1, @@ -34,8 +37,15 @@ def given_a_3x3_table_having_span_state(context, span_state): context.table_ = document.tables[table_idx] +@given("a _Cell object spanning {count} layout-grid cells") +def given_a_Cell_object_spanning_count_layout_grid_cells(context: Context, count: str): + document = Document(test_docx("tbl-cell-props")) + table = document.tables[0] + context.cell = _Cell(table._tbl.tr_lst[int(count)].tc_lst[0], table) + + @given("a _Cell object with {state} vertical alignment as cell") -def given_a_Cell_object_with_vertical_alignment_as_cell(context, state): +def given_a_Cell_object_with_vertical_alignment_as_cell(context: Context, state: str): table_idx = { "inherited": 0, "bottom": 1, @@ -48,26 +58,32 @@ def given_a_Cell_object_with_vertical_alignment_as_cell(context, state): @given("a column collection having two columns") -def given_a_column_collection_having_two_columns(context): +def given_a_column_collection_having_two_columns(context: Context): docx_path = test_docx("blk-containing-table") document = Document(docx_path) context.columns = document.tables[0].columns @given("a row collection having two rows") -def given_a_row_collection_having_two_rows(context): +def given_a_row_collection_having_two_rows(context: Context): docx_path = test_docx("blk-containing-table") document = Document(docx_path) context.rows = document.tables[0].rows @given("a table") -def given_a_table(context): +def given_a_table(context: Context): context.table_ = Document().add_table(rows=2, cols=2) +@given("a table cell") +def given_a_table_cell(context: Context): + table = Document(test_docx("tbl-2x2-table")).tables[0] + context.cell = table.cell(0, 0) + + @given("a table cell having a width of {width}") -def given_a_table_cell_having_a_width_of_width(context, width): +def given_a_table_cell_having_a_width_of_width(context: Context, width: str): table_idx = {"no explicit setting": 0, "1 inch": 1, "2 inches": 2}[width] document = Document(test_docx("tbl-props")) table = document.tables[table_idx] @@ -76,7 +92,7 @@ def given_a_table_cell_having_a_width_of_width(context, width): @given("a table column having a width of {width_desc}") -def given_a_table_having_a_width_of_width_desc(context, width_desc): +def given_a_table_having_a_width_of_width_desc(context: Context, width_desc: str): col_idx = { "no explicit setting": 0, "1440": 1, @@ -87,7 +103,7 @@ def given_a_table_having_a_width_of_width_desc(context, width_desc): @given("a table having {alignment} alignment") -def given_a_table_having_alignment_alignment(context, alignment): +def given_a_table_having_alignment_alignment(context: Context, alignment: str): table_idx = { "inherited": 3, "left": 4, @@ -100,7 +116,7 @@ def given_a_table_having_alignment_alignment(context, alignment): @given("a table having an autofit layout of {autofit}") -def given_a_table_having_an_autofit_layout_of_autofit(context, autofit): +def given_a_table_having_an_autofit_layout_of_autofit(context: Context, autofit: str): tbl_idx = { "no explicit setting": 0, "autofit": 1, @@ -111,7 +127,7 @@ def given_a_table_having_an_autofit_layout_of_autofit(context, autofit): @given("a table having {style} style") -def given_a_table_having_style(context, style): +def given_a_table_having_style(context: Context, style: str): table_idx = { "no explicit": 0, "Table Grid": 1, @@ -123,14 +139,14 @@ def given_a_table_having_style(context, style): @given("a table having table direction set {setting}") -def given_a_table_having_table_direction_setting(context, setting): +def given_a_table_having_table_direction_setting(context: Context, setting: str): table_idx = ["to inherit", "right-to-left", "left-to-right"].index(setting) document = Document(test_docx("tbl-on-off-props")) context.table_ = document.tables[table_idx] @given("a table having two columns") -def given_a_table_having_two_columns(context): +def given_a_table_having_two_columns(context: Context): docx_path = test_docx("blk-containing-table") document = Document(docx_path) # context.table is used internally by behave, underscore added @@ -139,14 +155,21 @@ def given_a_table_having_two_columns(context): @given("a table having two rows") -def given_a_table_having_two_rows(context): +def given_a_table_having_two_rows(context: Context): docx_path = test_docx("blk-containing-table") document = Document(docx_path) context.table_ = document.tables[0] +@given("a table row ending with {count} empty grid columns") +def given_a_table_row_ending_with_count_empty_grid_columns(context: Context, count: str): + document = Document(test_docx("tbl-props")) + table = document.tables[8] + context.row = table.rows[int(count)] + + @given("a table row having height of {state}") -def given_a_table_row_having_height_of_state(context, state): +def given_a_table_row_having_height_of_state(context: Context, state: str): table_idx = {"no explicit setting": 0, "2 inches": 2, "3 inches": 3}[state] document = Document(test_docx("tbl-props")) table = document.tables[table_idx] @@ -154,48 +177,66 @@ def given_a_table_row_having_height_of_state(context, state): @given("a table row having height rule {state}") -def given_a_table_row_having_height_rule_state(context, state): - table_idx = {"no explicit setting": 0, "automatic": 1, "at least": 2, "exactly": 3}[ - state - ] +def given_a_table_row_having_height_rule_state(context: Context, state: str): + table_idx = {"no explicit setting": 0, "automatic": 1, "at least": 2, "exactly": 3}[state] document = Document(test_docx("tbl-props")) table = document.tables[table_idx] context.row = table.rows[0] +@given("a table row starting with {count} empty grid columns") +def given_a_table_row_starting_with_count_empty_grid_columns(context: Context, count: str): + document = Document(test_docx("tbl-props")) + table = document.tables[7] + context.row = table.rows[int(count)] + + # when ===================================================== @when("I add a 1.0 inch column to the table") -def when_I_add_a_1_inch_column_to_table(context): +def when_I_add_a_1_inch_column_to_table(context: Context): context.column = context.table_.add_column(Inches(1.0)) +@when("I add a 2 x 2 table into the first cell") +def when_I_add_a_2x2_table_into_the_first_cell(context: Context): + context.table_ = context.cell.add_table(2, 2) + + @when("I add a row to the table") -def when_add_row_to_table(context): +def when_add_row_to_table(context: Context): table = context.table_ context.row = table.add_row() +@when("I assign a string to the cell text attribute") +def when_assign_string_to_cell_text_attribute(context: Context): + cell = context.cell + text = "foobar" + cell.text = text + context.expected_text = text + + @when("I assign {value} to cell.vertical_alignment") -def when_I_assign_value_to_cell_vertical_alignment(context, value): +def when_I_assign_value_to_cell_vertical_alignment(context: Context, value: str): context.cell.vertical_alignment = eval(value) @when("I assign {value} to row.height") -def when_I_assign_value_to_row_height(context, value): +def when_I_assign_value_to_row_height(context: Context, value: str): new_value = None if value == "None" else int(value) context.row.height = new_value @when("I assign {value} to row.height_rule") -def when_I_assign_value_to_row_height_rule(context, value): +def when_I_assign_value_to_row_height_rule(context: Context, value: str): new_value = None if value == "None" else getattr(WD_ROW_HEIGHT_RULE, value) context.row.height_rule = new_value @when("I assign {value_str} to table.alignment") -def when_I_assign_value_to_table_alignment(context, value_str): +def when_I_assign_value_to_table_alignment(context: Context, value_str: str): value = { "None": None, "WD_TABLE_ALIGNMENT.LEFT": WD_TABLE_ALIGNMENT.LEFT, @@ -207,7 +248,7 @@ def when_I_assign_value_to_table_alignment(context, value_str): @when("I assign {value} to table.style") -def when_apply_value_to_table_style(context, value): +def when_apply_value_to_table_style(context: Context, value: str): table, styles = context.table_, context.document.styles if value == "None": new_value = None @@ -219,14 +260,14 @@ def when_apply_value_to_table_style(context, value): @when("I assign {value} to table.table_direction") -def when_assign_value_to_table_table_direction(context, value): +def when_assign_value_to_table_table_direction(context: Context, value: str): new_value = None if value == "None" else getattr(WD_TABLE_DIRECTION, value) context.table_.table_direction = new_value @when("I merge from cell {origin} to cell {other}") -def when_I_merge_from_cell_origin_to_cell_other(context, origin, other): - def cell(table, idx): +def when_I_merge_from_cell_origin_to_cell_other(context: Context, origin: str, other: str): + def cell(table: Table, idx: int): row, col = idx // 3, idx % 3 return table.cell(row, col) @@ -237,19 +278,19 @@ def cell(table, idx): @when("I set the cell width to {width}") -def when_I_set_the_cell_width_to_width(context, width): +def when_I_set_the_cell_width_to_width(context: Context, width: str): new_value = {"1 inch": Inches(1)}[width] context.cell.width = new_value @when("I set the column width to {width_emu}") -def when_I_set_the_column_width_to_width_emu(context, width_emu): +def when_I_set_the_column_width_to_width_emu(context: Context, width_emu: str): new_value = None if width_emu == "None" else int(width_emu) context.column.width = new_value @when("I set the table autofit to {setting}") -def when_I_set_the_table_autofit_to_setting(context, setting): +def when_I_set_the_table_autofit_to_setting(context: Context, setting: str): new_value = {"autofit": True, "fixed": False}[setting] table = context.table_ table.autofit = new_value @@ -258,21 +299,34 @@ def when_I_set_the_table_autofit_to_setting(context, setting): # then ===================================================== +@then("cell.grid_span is {count}") +def then_cell_grid_span_is_count(context: Context, count: str): + expected = int(count) + actual = context.cell.grid_span + assert actual == expected, f"expected {expected}, got {actual}" + + +@then("cell.tables[0] is a 2 x 2 table") +def then_cell_tables_0_is_a_2x2_table(context: Context): + cell = context.cell + table = cell.tables[0] + assert len(table.rows) == 2 + assert len(table.columns) == 2 + + @then("cell.vertical_alignment is {value}") -def then_cell_vertical_alignment_is_value(context, value): +def then_cell_vertical_alignment_is_value(context: Context, value: str): expected_value = { "None": None, "WD_ALIGN_VERTICAL.BOTTOM": WD_ALIGN_VERTICAL.BOTTOM, "WD_ALIGN_VERTICAL.CENTER": WD_ALIGN_VERTICAL.CENTER, }[value] actual_value = context.cell.vertical_alignment - assert actual_value is expected_value, ( - "cell.vertical_alignment is %s" % actual_value - ) + assert actual_value is expected_value, "cell.vertical_alignment is %s" % actual_value @then("I can access a collection column by index") -def then_can_access_collection_column_by_index(context): +def then_can_access_collection_column_by_index(context: Context): columns = context.columns for idx in range(2): column = columns[idx] @@ -280,7 +334,7 @@ def then_can_access_collection_column_by_index(context): @then("I can access a collection row by index") -def then_can_access_collection_row_by_index(context): +def then_can_access_collection_row_by_index(context: Context): rows = context.rows for idx in range(2): row = rows[idx] @@ -288,21 +342,21 @@ def then_can_access_collection_row_by_index(context): @then("I can access the column collection of the table") -def then_can_access_column_collection_of_table(context): +def then_can_access_column_collection_of_table(context: Context): table = context.table_ columns = table.columns assert isinstance(columns, _Columns) @then("I can access the row collection of the table") -def then_can_access_row_collection_of_table(context): +def then_can_access_row_collection_of_table(context: Context): table = context.table_ rows = table.rows assert isinstance(rows, _Rows) @then("I can iterate over the column collection") -def then_can_iterate_over_column_collection(context): +def then_can_iterate_over_column_collection(context: Context): columns = context.columns actual_count = 0 for column in columns: @@ -312,7 +366,7 @@ def then_can_iterate_over_column_collection(context): @then("I can iterate over the row collection") -def then_can_iterate_over_row_collection(context): +def then_can_iterate_over_row_collection(context: Context): rows = context.rows actual_count = 0 for row in rows: @@ -321,8 +375,22 @@ def then_can_iterate_over_row_collection(context): assert actual_count == 2 +@then("row.grid_cols_after is {value}") +def then_row_grid_cols_after_is_value(context: Context, value: str): + expected = int(value) + actual = context.row.grid_cols_after + assert actual == expected, "expected %s, got %s" % (expected, actual) + + +@then("row.grid_cols_before is {value}") +def then_row_grid_cols_before_is_value(context: Context, value: str): + expected = int(value) + actual = context.row.grid_cols_before + assert actual == expected, "expected %s, got %s" % (expected, actual) + + @then("row.height is {value}") -def then_row_height_is_value(context, value): +def then_row_height_is_value(context: Context, value: str): expected_height = None if value == "None" else int(value) actual_height = context.row.height assert actual_height == expected_height, "expected %s, got %s" % ( @@ -332,7 +400,7 @@ def then_row_height_is_value(context, value): @then("row.height_rule is {value}") -def then_row_height_rule_is_value(context, value): +def then_row_height_rule_is_value(context: Context, value: str): expected_rule = None if value == "None" else getattr(WD_ROW_HEIGHT_RULE, value) actual_rule = context.row.height_rule assert actual_rule == expected_rule, "expected %s, got %s" % ( @@ -342,7 +410,7 @@ def then_row_height_rule_is_value(context, value): @then("table.alignment is {value_str}") -def then_table_alignment_is_value(context, value_str): +def then_table_alignment_is_value(context: Context, value_str: str): value = { "None": None, "WD_TABLE_ALIGNMENT.LEFT": WD_TABLE_ALIGNMENT.LEFT, @@ -354,7 +422,7 @@ def then_table_alignment_is_value(context, value_str): @then("table.cell({row}, {col}).text is {expected_text}") -def then_table_cell_row_col_text_is_text(context, row, col, expected_text): +def then_table_cell_row_col_text_is_text(context: Context, row: str, col: str, expected_text: str): table = context.table_ row_idx, col_idx = int(row), int(col) cell_text = table.cell(row_idx, col_idx).text @@ -362,68 +430,76 @@ def then_table_cell_row_col_text_is_text(context, row, col, expected_text): @then("table.style is styles['{style_name}']") -def then_table_style_is_styles_style_name(context, style_name): +def then_table_style_is_styles_style_name(context: Context, style_name: str): table, styles = context.table_, context.document.styles expected_style = styles[style_name] assert table.style == expected_style, "got '%s'" % table.style @then("table.table_direction is {value}") -def then_table_table_direction_is_value(context, value): +def then_table_table_direction_is_value(context: Context, value: str): expected_value = None if value == "None" else getattr(WD_TABLE_DIRECTION, value) actual_value = context.table_.table_direction assert actual_value == expected_value, "got '%s'" % actual_value +@then("the cell contains the string I assigned") +def then_cell_contains_string_assigned(context: Context): + cell, expected_text = context.cell, context.expected_text + text = cell.paragraphs[0].runs[0].text + msg = "expected '%s', got '%s'" % (expected_text, text) + assert text == expected_text, msg + + @then("the column cells text is {expected_text}") -def then_the_column_cells_text_is_expected_text(context, expected_text): +def then_the_column_cells_text_is_expected_text(context: Context, expected_text: str): table = context.table_ cells_text = " ".join(c.text for col in table.columns for c in col.cells) assert cells_text == expected_text, "got %s" % cells_text @then("the length of the column collection is 2") -def then_len_of_column_collection_is_2(context): +def then_len_of_column_collection_is_2(context: Context): columns = context.table_.columns assert len(columns) == 2 @then("the length of the row collection is 2") -def then_len_of_row_collection_is_2(context): +def then_len_of_row_collection_is_2(context: Context): rows = context.table_.rows assert len(rows) == 2 @then("the new column has 2 cells") -def then_new_column_has_2_cells(context): +def then_new_column_has_2_cells(context: Context): assert len(context.column.cells) == 2 @then("the new column is 1.0 inches wide") -def then_new_column_is_1_inches_wide(context): +def then_new_column_is_1_inches_wide(context: Context): assert context.column.width == Inches(1) @then("the new row has 2 cells") -def then_new_row_has_2_cells(context): +def then_new_row_has_2_cells(context: Context): assert len(context.row.cells) == 2 @then("the reported autofit setting is {autofit}") -def then_the_reported_autofit_setting_is_autofit(context, autofit): +def then_the_reported_autofit_setting_is_autofit(context: Context, autofit: str): expected_value = {"autofit": True, "fixed": False}[autofit] table = context.table_ assert table.autofit is expected_value @then("the reported column width is {width_emu}") -def then_the_reported_column_width_is_width_emu(context, width_emu): +def then_the_reported_column_width_is_width_emu(context: Context, width_emu: str): expected_value = None if width_emu == "None" else int(width_emu) assert context.column.width == expected_value, "got %s" % context.column.width @then("the reported width of the cell is {width}") -def then_the_reported_width_of_the_cell_is_width(context, width): +def then_the_reported_width_of_the_cell_is_width(context: Context, width: str): expected_width = {"None": None, "1 inch": Inches(1)}[width] actual_width = context.cell.width assert actual_width == expected_width, "expected %s, got %s" % ( @@ -433,7 +509,7 @@ def then_the_reported_width_of_the_cell_is_width(context, width): @then("the row cells text is {encoded_text}") -def then_the_row_cells_text_is_expected_text(context, encoded_text): +def then_the_row_cells_text_is_expected_text(context: Context, encoded_text: str): expected_text = encoded_text.replace("\\", "\n") table = context.table_ cells_text = " ".join(c.text for row in table.rows for c in row.cells) @@ -441,32 +517,33 @@ def then_the_row_cells_text_is_expected_text(context, encoded_text): @then("the table has {count} columns") -def then_table_has_count_columns(context, count): +def then_table_has_count_columns(context: Context, count: str): column_count = int(count) columns = context.table_.columns assert len(columns) == column_count @then("the table has {count} rows") -def then_table_has_count_rows(context, count): +def then_table_has_count_rows(context: Context, count: str): row_count = int(count) rows = context.table_.rows assert len(rows) == row_count @then("the width of cell {n_str} is {inches_str} inches") -def then_the_width_of_cell_n_is_x_inches(context, n_str, inches_str): - def _cell(table, idx): +def then_the_width_of_cell_n_is_x_inches(context: Context, n_str: str, inches_str: str): + def _cell(table: Table, idx: int): row, col = idx // 3, idx % 3 return table.cell(row, col) idx, inches = int(n_str) - 1, float(inches_str) cell = _cell(context.table_, idx) + assert cell.width is not None assert cell.width == Inches(inches), "got %s" % cell.width.inches @then("the width of each cell is {inches} inches") -def then_the_width_of_each_cell_is_inches(context, inches): +def then_the_width_of_each_cell_is_inches(context: Context, inches: str): table = context.table_ expected_width = Inches(float(inches)) for cell in table._cells: @@ -474,7 +551,7 @@ def then_the_width_of_each_cell_is_inches(context, inches): @then("the width of each column is {inches} inches") -def then_the_width_of_each_column_is_inches(context, inches): +def then_the_width_of_each_column_is_inches(context: Context, inches: str): table = context.table_ expected_width = Inches(float(inches)) for column in table.columns: diff --git a/features/steps/test_files/tbl-cell-props.docx b/features/steps/test_files/tbl-cell-props.docx new file mode 100644 index 000000000..627fb66fc Binary files /dev/null and b/features/steps/test_files/tbl-cell-props.docx differ diff --git a/features/steps/test_files/tbl-props.docx b/features/steps/test_files/tbl-props.docx index 9d2db676e..e5fdd728f 100644 Binary files a/features/steps/test_files/tbl-props.docx and b/features/steps/test_files/tbl-props.docx differ diff --git a/features/cel-add-table.feature b/features/tbl-cell-add-table.feature similarity index 100% rename from features/cel-add-table.feature rename to features/tbl-cell-add-table.feature diff --git a/features/tbl-cell-props.feature b/features/tbl-cell-props.feature index 609d2f442..456ed39a4 100644 --- a/features/tbl-cell-props.feature +++ b/features/tbl-cell-props.feature @@ -4,6 +4,17 @@ Feature: Get and set table cell properties I need a way to get and set the properties of a table cell + Scenario Outline: Get _Cell.grid_span + Given a _Cell object spanning layout-grid cells + Then cell.grid_span is + + Examples: Cell.grid_span value cases + | count | + | 1 | + | 2 | + | 4 | + + Scenario Outline: Get _Cell.vertical_alignment Given a _Cell object with vertical alignment as cell Then cell.vertical_alignment is diff --git a/features/cel-text.feature b/features/tbl-cell-text.feature similarity index 100% rename from features/cel-text.feature rename to features/tbl-cell-text.feature diff --git a/features/tbl-row-props.feature b/features/tbl-row-props.feature index 377f2853e..1b006f204 100644 --- a/features/tbl-row-props.feature +++ b/features/tbl-row-props.feature @@ -4,6 +4,28 @@ Feature: Get and set table row properties I need a way to get and set the properties of a table row + Scenario Outline: Get Row.grid_cols_after + Given a table row ending with empty grid columns + Then row.grid_cols_after is + + Examples: Row.grid_cols_after value cases + | count | + | 0 | + | 1 | + | 2 | + + + Scenario Outline: Get Row.grid_cols_before + Given a table row starting with empty grid columns + Then row.grid_cols_before is + + Examples: Row.grid_cols_before value cases + | count | + | 0 | + | 1 | + | 3 | + + Scenario Outline: Get Row.height_rule Given a table row having height rule Then row.height_rule is diff --git a/pyproject.toml b/pyproject.toml index d35c790c7..91bac83d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ classifiers = [ ] dependencies = [ "lxml>=3.1.0", - "typing_extensions", + "typing_extensions>=4.9.0", ] description = "Create, read, and update Microsoft Word .docx files." dynamic = ["version"] @@ -39,8 +39,20 @@ Homepage = "https://github.com/python-openxml/python-docx" Repository = "https://github.com/python-openxml/python-docx" [tool.black] +line-length = 100 target-version = ["py37", "py38", "py39", "py310", "py311"] +[tool.pyright] +include = ["src/docx", "tests"] +pythonPlatform = "All" +pythonVersion = "3.8" +reportImportCycles = true +reportUnnecessaryCast = true +reportUnnecessaryTypeIgnoreComment = true +stubPath = "./typings" +typeCheckingMode = "strict" +verboseOutput = true + [tool.pytest.ini_options] filterwarnings = [ # -- exit on any warning not explicitly ignored here -- @@ -69,6 +81,10 @@ python_functions = ["it_", "its_", "they_", "and_", "but_"] [tool.ruff] exclude = [] +line-length = 100 +target-version = "py38" + +[tool.ruff.lint] ignore = [ "COM812", # -- over-aggressively insists on trailing commas where not desired -- "PT001", # -- wants @pytest.fixture() instead of @pytest.fixture -- @@ -88,9 +104,8 @@ select = [ "UP032", # -- Use f-string instead of `.format()` call -- "UP034", # -- Avoid extraneous parentheses -- ] -target-version = "py37" -[tool.ruff.isort] +[tool.ruff.lint.isort] known-first-party = ["docx"] known-local-folder = ["helpers"] diff --git a/pyrightconfig.json b/pyrightconfig.json deleted file mode 100644 index 161e49d2b..000000000 --- a/pyrightconfig.json +++ /dev/null @@ -1,21 +0,0 @@ -{ - "exclude": [ - "**/__pycache__", - "**/.*" - ], - "ignore": [ - ], - "include": [ - "src/docx/", - "tests" - ], - "pythonPlatform": "All", - "pythonVersion": "3.7", - "reportImportCycles": true, - "reportUnnecessaryCast": true, - "reportUnnecessaryTypeIgnoreComment": true, - "stubPath": "./typings", - "typeCheckingMode": "strict", - "useLibraryCodeForTypes": true, - "verboseOutput": true -} diff --git a/requirements-dev.txt b/requirements-dev.txt index 45e5f78c3..14d8740e3 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,7 @@ -r requirements-test.txt build +ruff setuptools>=61.0.0 tox twine +types-lxml diff --git a/requirements-docs.txt b/requirements-docs.txt index 11f9d2cd2..90edd8e31 100644 --- a/requirements-docs.txt +++ b/requirements-docs.txt @@ -1,4 +1,5 @@ Sphinx==1.8.6 Jinja2==2.11.3 MarkupSafe==0.23 +alabaster<0.7.14 -e . diff --git a/requirements-test.txt b/requirements-test.txt index 85d9f6ba3..b542c1af7 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -2,4 +2,6 @@ behave>=1.2.3 pyparsing>=2.0.1 pytest>=2.5 +pytest-coverage +pytest-xdist ruff diff --git a/src/docx/__init__.py b/src/docx/__init__.py index b214045d1..205221027 100644 --- a/src/docx/__init__.py +++ b/src/docx/__init__.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: from docx.opc.part import Part -__version__ = "1.1.0" +__version__ = "1.1.2" __all__ = ["Document"] diff --git a/src/docx/api.py b/src/docx/api.py index a17c1dad4..aea876458 100644 --- a/src/docx/api.py +++ b/src/docx/api.py @@ -6,13 +6,17 @@ from __future__ import annotations import os -from typing import IO +from typing import IO, TYPE_CHECKING, cast from docx.opc.constants import CONTENT_TYPE as CT from docx.package import Package +if TYPE_CHECKING: + from docx.document import Document as DocumentObject + from docx.parts.document import DocumentPart -def Document(docx: str | IO[bytes] | None = None): + +def Document(docx: str | IO[bytes] | None = None) -> DocumentObject: """Return a |Document| object loaded from `docx`, where `docx` can be either a path to a ``.docx`` file (a string) or a file-like object. @@ -20,7 +24,7 @@ def Document(docx: str | IO[bytes] | None = None): loaded. """ docx = _default_docx_path() if docx is None else docx - document_part = Package.open(docx).main_document_part + document_part = cast("DocumentPart", Package.open(docx).main_document_part) if document_part.content_type != CT.WML_DOCUMENT_MAIN: tmpl = "file '%s' is not a Word file, content type is '%s'" raise ValueError(tmpl % (docx, document_part.content_type)) diff --git a/src/docx/blkcntnr.py b/src/docx/blkcntnr.py index 1327e6d08..a9969f6f6 100644 --- a/src/docx/blkcntnr.py +++ b/src/docx/blkcntnr.py @@ -18,7 +18,7 @@ from docx.text.paragraph import Paragraph if TYPE_CHECKING: - from docx import types as t + import docx.types as t from docx.oxml.document import CT_Body from docx.oxml.section import CT_HdrFtr from docx.oxml.table import CT_Tc @@ -41,9 +41,7 @@ def __init__(self, element: BlockItemElement, parent: t.ProvidesStoryPart): super(BlockItemContainer, self).__init__(parent) self._element = element - def add_paragraph( - self, text: str = "", style: str | ParagraphStyle | None = None - ) -> Paragraph: + def add_paragraph(self, text: str = "", style: str | ParagraphStyle | None = None) -> Paragraph: """Return paragraph newly added to the end of the content in this container. The paragraph has `text` in a single run if present, and is given paragraph @@ -77,11 +75,7 @@ def iter_inner_content(self) -> Iterator[Paragraph | Table]: from docx.table import Table for element in self._element.inner_content_elements: - yield ( - Paragraph(element, self) - if isinstance(element, CT_P) - else Table(element, self) - ) + yield (Paragraph(element, self) if isinstance(element, CT_P) else Table(element, self)) @property def paragraphs(self): diff --git a/src/docx/document.py b/src/docx/document.py index 4deb8aa8e..8944a0e50 100644 --- a/src/docx/document.py +++ b/src/docx/document.py @@ -14,7 +14,7 @@ from docx.shared import ElementProxy, Emu if TYPE_CHECKING: - from docx import types as t + import docx.types as t from docx.oxml.document import CT_Body, CT_Document from docx.parts.document import DocumentPart from docx.settings import Settings @@ -56,9 +56,7 @@ def add_page_break(self): paragraph.add_run().add_break(WD_BREAK.PAGE) return paragraph - def add_paragraph( - self, text: str = "", style: str | ParagraphStyle | None = None - ) -> Paragraph: + def add_paragraph(self, text: str = "", style: str | ParagraphStyle | None = None) -> Paragraph: """Return paragraph newly added to the end of the document. The paragraph is populated with `text` and having paragraph style `style`. diff --git a/src/docx/drawing/__init__.py b/src/docx/drawing/__init__.py index 03c9c5ab8..f40205747 100644 --- a/src/docx/drawing/__init__.py +++ b/src/docx/drawing/__init__.py @@ -2,10 +2,14 @@ from __future__ import annotations -from docx import types as t +from typing import TYPE_CHECKING + from docx.oxml.drawing import CT_Drawing from docx.shared import Parented +if TYPE_CHECKING: + import docx.types as t + class Drawing(Parented): """Container for a DrawingML object.""" diff --git a/src/docx/enum/__init__.py b/src/docx/enum/__init__.py index bfab52d36..e69de29bb 100644 --- a/src/docx/enum/__init__.py +++ b/src/docx/enum/__init__.py @@ -1,11 +0,0 @@ -"""Enumerations used in python-docx.""" - - -class Enumeration: - @classmethod - def from_xml(cls, xml_val): - return cls._xml_to_idx[xml_val] - - @classmethod - def to_xml(cls, enum_val): - return cls._idx_to_xml[enum_val] diff --git a/src/docx/enum/base.py b/src/docx/enum/base.py index e37e74299..bc96ab6a2 100644 --- a/src/docx/enum/base.py +++ b/src/docx/enum/base.py @@ -4,9 +4,10 @@ import enum import textwrap -from typing import Any, Dict, Type, TypeVar +from typing import TYPE_CHECKING, Any, Dict, Type, TypeVar -from typing_extensions import Self +if TYPE_CHECKING: + from typing_extensions import Self _T = TypeVar("_T", bound="BaseXmlEnum") @@ -69,7 +70,7 @@ def to_xml(cls: Type[_T], value: int | _T | None) -> str | None: """XML value of this enum member, generally an XML attribute value.""" # -- presence of multi-arg `__new__()` method fools type-checker, but getting a # -- member by its value using EnumCls(val) works as usual. - return cls(value).xml_value # pyright: ignore[reportGeneralTypeIssues] + return cls(value).xml_value class DocsPageFormatter: @@ -129,9 +130,7 @@ def _member_defs(self): """A single string containing the aggregated member definitions section of the documentation page.""" members = self._clsdict["__members__"] - member_defs = [ - self._member_def(member) for member in members if member.name is not None - ] + member_defs = [self._member_def(member) for member in members if member.name is not None] return "\n".join(member_defs) @property diff --git a/src/docx/image/image.py b/src/docx/image/image.py index 945432872..0022b5b45 100644 --- a/src/docx/image/image.py +++ b/src/docx/image/image.py @@ -11,8 +11,6 @@ import os from typing import IO, Tuple -from typing_extensions import Self - from docx.image.exceptions import UnrecognizedImageError from docx.shared import Emu, Inches, Length, lazyproperty @@ -28,14 +26,14 @@ def __init__(self, blob: bytes, filename: str, image_header: BaseImageHeader): self._image_header = image_header @classmethod - def from_blob(cls, blob: bytes) -> Self: + def from_blob(cls, blob: bytes) -> Image: """Return a new |Image| subclass instance parsed from the image binary contained in `blob`.""" stream = io.BytesIO(blob) return cls._from_stream(stream, blob) @classmethod - def from_file(cls, image_descriptor): + def from_file(cls, image_descriptor: str | IO[bytes]): """Return a new |Image| subclass instance loaded from the image file identified by `image_descriptor`, a path or file-like object.""" if isinstance(image_descriptor, str): @@ -57,7 +55,7 @@ def blob(self): return self._blob @property - def content_type(self): + def content_type(self) -> str: """MIME content type for this image, e.g. ``'image/jpeg'`` for a JPEG image.""" return self._image_header.content_type @@ -116,7 +114,7 @@ def height(self) -> Inches: return Inches(self.px_height / self.vert_dpi) def scaled_dimensions( - self, width: int | None = None, height: int | None = None + self, width: int | Length | None = None, height: int | Length | None = None ) -> Tuple[Length, Length]: """(cx, cy) pair representing scaled dimensions of this image. @@ -167,12 +165,11 @@ def _from_stream( return cls(blob, filename, image_header) -def _ImageHeaderFactory(stream): - """Return a |BaseImageHeader| subclass instance that knows how to parse the headers - of the image in `stream`.""" +def _ImageHeaderFactory(stream: IO[bytes]): + """A |BaseImageHeader| subclass instance that can parse headers of image in `stream`.""" from docx.image import SIGNATURES - def read_32(stream): + def read_32(stream: IO[bytes]): stream.seek(0) return stream.read(32) @@ -188,32 +185,27 @@ def read_32(stream): class BaseImageHeader: """Base class for image header subclasses like |Jpeg| and |Tiff|.""" - def __init__(self, px_width, px_height, horz_dpi, vert_dpi): + def __init__(self, px_width: int, px_height: int, horz_dpi: int, vert_dpi: int): self._px_width = px_width self._px_height = px_height self._horz_dpi = horz_dpi self._vert_dpi = vert_dpi @property - def content_type(self): + def content_type(self) -> str: """Abstract property definition, must be implemented by all subclasses.""" - msg = ( - "content_type property must be implemented by all subclasses of " - "BaseImageHeader" - ) + msg = "content_type property must be implemented by all subclasses of " "BaseImageHeader" raise NotImplementedError(msg) @property - def default_ext(self): + def default_ext(self) -> str: """Default filename extension for images of this type. An abstract property definition, must be implemented by all subclasses. """ - msg = ( - "default_ext property must be implemented by all subclasses of " - "BaseImageHeader" + raise NotImplementedError( + "default_ext property must be implemented by all subclasses of " "BaseImageHeader" ) - raise NotImplementedError(msg) @property def px_width(self): diff --git a/src/docx/image/tiff.py b/src/docx/image/tiff.py index b84d9f10f..1194929af 100644 --- a/src/docx/image/tiff.py +++ b/src/docx/image/tiff.py @@ -98,11 +98,7 @@ def _dpi(self, resolution_tag): return 72 # resolution unit defaults to inches (2) - resolution_unit = ( - ifd_entries[TIFF_TAG.RESOLUTION_UNIT] - if TIFF_TAG.RESOLUTION_UNIT in ifd_entries - else 2 - ) + resolution_unit = ifd_entries.get(TIFF_TAG.RESOLUTION_UNIT, 2) if resolution_unit == 1: # aspect ratio only return 72 diff --git a/src/docx/opc/coreprops.py b/src/docx/opc/coreprops.py index 2fd9a75c8..c564550d4 100644 --- a/src/docx/opc/coreprops.py +++ b/src/docx/opc/coreprops.py @@ -3,12 +3,21 @@ These are broadly-standardized attributes like author, last-modified, etc. """ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from docx.oxml.coreprops import CT_CoreProperties + +if TYPE_CHECKING: + from docx.oxml.coreprops import CT_CoreProperties + class CoreProperties: """Corresponds to part named ``/docProps/core.xml``, containing the core document properties for this document package.""" - def __init__(self, element): + def __init__(self, element: CT_CoreProperties): self._element = element @property @@ -16,7 +25,7 @@ def author(self): return self._element.author_text @author.setter - def author(self, value): + def author(self, value: str): self._element.author_text = value @property @@ -24,7 +33,7 @@ def category(self): return self._element.category_text @category.setter - def category(self, value): + def category(self, value: str): self._element.category_text = value @property @@ -32,7 +41,7 @@ def comments(self): return self._element.comments_text @comments.setter - def comments(self, value): + def comments(self, value: str): self._element.comments_text = value @property @@ -40,7 +49,7 @@ def content_status(self): return self._element.contentStatus_text @content_status.setter - def content_status(self, value): + def content_status(self, value: str): self._element.contentStatus_text = value @property @@ -56,7 +65,7 @@ def identifier(self): return self._element.identifier_text @identifier.setter - def identifier(self, value): + def identifier(self, value: str): self._element.identifier_text = value @property @@ -64,7 +73,7 @@ def keywords(self): return self._element.keywords_text @keywords.setter - def keywords(self, value): + def keywords(self, value: str): self._element.keywords_text = value @property @@ -72,7 +81,7 @@ def language(self): return self._element.language_text @language.setter - def language(self, value): + def language(self, value: str): self._element.language_text = value @property @@ -80,7 +89,7 @@ def last_modified_by(self): return self._element.lastModifiedBy_text @last_modified_by.setter - def last_modified_by(self, value): + def last_modified_by(self, value: str): self._element.lastModifiedBy_text = value @property @@ -112,7 +121,7 @@ def subject(self): return self._element.subject_text @subject.setter - def subject(self, value): + def subject(self, value: str): self._element.subject_text = value @property @@ -120,7 +129,7 @@ def title(self): return self._element.title_text @title.setter - def title(self, value): + def title(self, value: str): self._element.title_text = value @property @@ -128,5 +137,5 @@ def version(self): return self._element.version_text @version.setter - def version(self, value): + def version(self, value: str): self._element.version_text = value diff --git a/src/docx/opc/oxml.py b/src/docx/opc/oxml.py index 570dcf413..7da72f50d 100644 --- a/src/docx/opc/oxml.py +++ b/src/docx/opc/oxml.py @@ -1,3 +1,5 @@ +# pyright: reportPrivateUsage=false + """Temporary stand-in for main oxml module. This module came across with the PackageReader transplant. Probably much will get @@ -5,6 +7,10 @@ deleted or only hold the package related custom element classes. """ +from __future__ import annotations + +from typing import cast + from lxml import etree from docx.opc.constants import NAMESPACE as NS @@ -27,7 +33,7 @@ # =========================================================================== -def parse_xml(text: str) -> etree._Element: # pyright: ignore[reportPrivateUsage] +def parse_xml(text: str) -> etree._Element: """`etree.fromstring()` replacement that uses oxml parser.""" return etree.fromstring(text, oxml_parser) @@ -44,7 +50,7 @@ def qn(tag): return "{%s}%s" % (uri, tagroot) -def serialize_part_xml(part_elm): +def serialize_part_xml(part_elm: etree._Element): """Serialize `part_elm` etree element to XML suitable for storage as an XML part. That is to say, no insignificant whitespace added for readability, and an @@ -136,7 +142,7 @@ class CT_Relationship(BaseOxmlElement): target part.""" @staticmethod - def new(rId, reltype, target, target_mode=RTM.INTERNAL): + def new(rId: str, reltype: str, target: str, target_mode: str = RTM.INTERNAL): """Return a new ```` element.""" xml = '' % nsmap["pr"] relationship = parse_xml(xml) @@ -176,7 +182,7 @@ def target_mode(self): class CT_Relationships(BaseOxmlElement): """```` element, the root element in a .rels file.""" - def add_rel(self, rId, reltype, target, is_external=False): + def add_rel(self, rId: str, reltype: str, target: str, is_external: bool = False): """Add a child ```` element with attributes set according to parameter values.""" target_mode = RTM.EXTERNAL if is_external else RTM.INTERNAL @@ -184,11 +190,10 @@ def add_rel(self, rId, reltype, target, is_external=False): self.append(relationship) @staticmethod - def new(): + def new() -> CT_Relationships: """Return a new ```` element.""" xml = '' % nsmap["pr"] - relationships = parse_xml(xml) - return relationships + return cast(CT_Relationships, parse_xml(xml)) @property def Relationship_lst(self): diff --git a/src/docx/opc/package.py b/src/docx/opc/package.py index b5bdc0e7c..3b1eef256 100644 --- a/src/docx/opc/package.py +++ b/src/docx/opc/package.py @@ -1,5 +1,9 @@ """Objects that implement reading and writing OPC packages.""" +from __future__ import annotations + +from typing import IO, TYPE_CHECKING, Iterator, cast + from docx.opc.constants import RELATIONSHIP_TYPE as RT from docx.opc.packuri import PACKAGE_URI, PackURI from docx.opc.part import PartFactory @@ -7,7 +11,12 @@ from docx.opc.pkgreader import PackageReader from docx.opc.pkgwriter import PackageWriter from docx.opc.rel import Relationships -from docx.opc.shared import lazyproperty +from docx.shared import lazyproperty + +if TYPE_CHECKING: + from docx.opc.coreprops import CoreProperties + from docx.opc.part import Part + from docx.opc.rel import _Relationship # pyright: ignore[reportPrivateUsage] class OpcPackage: @@ -30,16 +39,18 @@ def after_unmarshal(self): pass @property - def core_properties(self): + def core_properties(self) -> CoreProperties: """|CoreProperties| object providing read/write access to the Dublin Core properties for this document.""" return self._core_properties_part.core_properties - def iter_rels(self): + def iter_rels(self) -> Iterator[_Relationship]: """Generate exactly one reference to each relationship in the package by performing a depth-first traversal of the rels graph.""" - def walk_rels(source, visited=None): + def walk_rels( + source: OpcPackage | Part, visited: list[Part] | None = None + ) -> Iterator[_Relationship]: visited = [] if visited is None else visited for rel in source.rels.values(): yield rel @@ -56,7 +67,7 @@ def walk_rels(source, visited=None): for rel in walk_rels(self): yield rel - def iter_parts(self): + def iter_parts(self) -> Iterator[Part]: """Generate exactly one reference to each of the parts in the package by performing a depth-first traversal of the rels graph.""" @@ -76,7 +87,7 @@ def walk_parts(source, visited=[]): for part in walk_parts(self): yield part - def load_rel(self, reltype, target, rId, is_external=False): + def load_rel(self, reltype: str, target: Part | str, rId: str, is_external: bool = False): """Return newly added |_Relationship| instance of `reltype` between this part and `target` with key `rId`. @@ -96,7 +107,7 @@ def main_document_part(self): """ return self.part_related_by(RT.OFFICE_DOCUMENT) - def next_partname(self, template): + def next_partname(self, template: str) -> PackURI: """Return a |PackURI| instance representing partname matching `template`. The returned part-name has the next available numeric suffix to distinguish it @@ -111,14 +122,14 @@ def next_partname(self, template): return PackURI(candidate_partname) @classmethod - def open(cls, pkg_file): + def open(cls, pkg_file: str | IO[bytes]) -> OpcPackage: """Return an |OpcPackage| instance loaded with the contents of `pkg_file`.""" pkg_reader = PackageReader.from_file(pkg_file) package = cls() Unmarshaller.unmarshal(pkg_reader, package, PartFactory) return package - def part_related_by(self, reltype): + def part_related_by(self, reltype: str) -> Part: """Return part to which this package has a relationship of `reltype`. Raises |KeyError| if no such relationship is found and |ValueError| if more than @@ -127,13 +138,16 @@ def part_related_by(self, reltype): return self.rels.part_with_reltype(reltype) @property - def parts(self): + def parts(self) -> list[Part]: """Return a list containing a reference to each of the parts in this package.""" return list(self.iter_parts()) - def relate_to(self, part, reltype): - """Return rId key of relationship to `part`, from the existing relationship if - there is one, otherwise a newly created one.""" + def relate_to(self, part: Part, reltype: str): + """Return rId key of new or existing relationship to `part`. + + If a relationship of `reltype` to `part` already exists, its rId is returned. Otherwise a + new relationship is created and that rId is returned. + """ rel = self.rels.get_or_add(reltype, part) return rel.rId @@ -143,21 +157,23 @@ def rels(self): relationships for this package.""" return Relationships(PACKAGE_URI.baseURI) - def save(self, pkg_file): - """Save this package to `pkg_file`, where `file` can be either a path to a file - (a string) or a file-like object.""" + def save(self, pkg_file: str | IO[bytes]): + """Save this package to `pkg_file`. + + `pkg_file` can be either a file-path or a file-like object. + """ for part in self.parts: part.before_marshal() PackageWriter.write(pkg_file, self.rels, self.parts) @property - def _core_properties_part(self): + def _core_properties_part(self) -> CorePropertiesPart: """|CorePropertiesPart| object related to this package. Creates a default core properties part if one is not present (not common). """ try: - return self.part_related_by(RT.CORE_PROPERTIES) + return cast(CorePropertiesPart, self.part_related_by(RT.CORE_PROPERTIES)) except KeyError: core_properties_part = CorePropertiesPart.default(self) self.relate_to(core_properties_part, RT.CORE_PROPERTIES) @@ -190,9 +206,7 @@ def _unmarshal_parts(pkg_reader, package, part_factory): """ parts = {} for partname, content_type, reltype, blob in pkg_reader.iter_sparts(): - parts[partname] = part_factory( - partname, content_type, reltype, blob, package - ) + parts[partname] = part_factory(partname, content_type, reltype, blob, package) return parts @staticmethod @@ -202,7 +216,5 @@ def _unmarshal_relationships(pkg_reader, package, parts): in `parts`.""" for source_uri, srel in pkg_reader.iter_srels(): source = package if source_uri == "/" else parts[source_uri] - target = ( - srel.target_ref if srel.is_external else parts[srel.target_partname] - ) + target = srel.target_ref if srel.is_external else parts[srel.target_partname] source.load_rel(srel.reltype, target, srel.rId, srel.is_external) diff --git a/src/docx/opc/packuri.py b/src/docx/opc/packuri.py index fe330d89b..fdbb67ed8 100644 --- a/src/docx/opc/packuri.py +++ b/src/docx/opc/packuri.py @@ -3,6 +3,8 @@ Also some useful known pack URI strings such as PACKAGE_URI. """ +from __future__ import annotations + import posixpath import re @@ -16,22 +18,21 @@ class PackURI(str): _filename_re = re.compile("([a-zA-Z]+)([1-9][0-9]*)?") - def __new__(cls, pack_uri_str): + def __new__(cls, pack_uri_str: str): if pack_uri_str[0] != "/": tmpl = "PackURI must begin with slash, got '%s'" raise ValueError(tmpl % pack_uri_str) return str.__new__(cls, pack_uri_str) @staticmethod - def from_rel_ref(baseURI, relative_ref): - """Return a |PackURI| instance containing the absolute pack URI formed by - translating `relative_ref` onto `baseURI`.""" + def from_rel_ref(baseURI: str, relative_ref: str) -> PackURI: + """The absolute PackURI formed by translating `relative_ref` onto `baseURI`.""" joined_uri = posixpath.join(baseURI, relative_ref) abs_uri = posixpath.abspath(joined_uri) return PackURI(abs_uri) @property - def baseURI(self): + def baseURI(self) -> str: """The base URI of this pack URI, the directory portion, roughly speaking. E.g. ``'/ppt/slides'`` for ``'/ppt/slides/slide1.xml'``. For the package pseudo- @@ -40,9 +41,8 @@ def baseURI(self): return posixpath.split(self)[0] @property - def ext(self): - """The extension portion of this pack URI, e.g. ``'xml'`` for - ``'/word/document.xml'``. + def ext(self) -> str: + """The extension portion of this pack URI, e.g. ``'xml'`` for ``'/word/document.xml'``. Note the period is not included. """ @@ -84,7 +84,7 @@ def membername(self): """ return self[1:] - def relative_ref(self, baseURI): + def relative_ref(self, baseURI: str): """Return string containing relative reference to package item from `baseURI`. E.g. PackURI('/ppt/slideLayouts/slideLayout1.xml') would return diff --git a/src/docx/opc/part.py b/src/docx/opc/part.py index a4ad3e7b2..cbb4ab556 100644 --- a/src/docx/opc/part.py +++ b/src/docx/opc/part.py @@ -1,16 +1,20 @@ +# pyright: reportImportCycles=false + """Open Packaging Convention (OPC) objects related to package parts.""" from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Dict, Type +from typing import TYPE_CHECKING, Callable, Type, cast from docx.opc.oxml import serialize_part_xml from docx.opc.packuri import PackURI from docx.opc.rel import Relationships -from docx.opc.shared import cls_method_fn, lazyproperty +from docx.opc.shared import cls_method_fn from docx.oxml.parser import parse_xml +from docx.shared import lazyproperty if TYPE_CHECKING: + from docx.oxml.xmlchemy import BaseOxmlElement from docx.package import Package @@ -23,7 +27,7 @@ class Part: def __init__( self, - partname: str, + partname: PackURI, content_type: str, blob: bytes | None = None, package: Package | None = None, @@ -55,13 +59,13 @@ def before_marshal(self): pass @property - def blob(self): + def blob(self) -> bytes: """Contents of this package part as a sequence of bytes. May be text or binary. Intended to be overridden by subclasses. Default behavior is to return load blob. """ - return self._blob + return self._blob or b"" @property def content_type(self): @@ -78,12 +82,13 @@ def drop_rel(self, rId: str): del self.rels[rId] @classmethod - def load(cls, partname: str, content_type: str, blob: bytes, package: Package): + def load(cls, partname: PackURI, content_type: str, blob: bytes, package: Package): return cls(partname, content_type, blob, package) - def load_rel(self, reltype, target, rId, is_external=False): - """Return newly added |_Relationship| instance of `reltype` between this part - and `target` with key `rId`. + def load_rel(self, reltype: str, target: Part | str, rId: str, is_external: bool = False): + """Return newly added |_Relationship| instance of `reltype`. + + The new relationship relates the `target` part to this part with key `rId`. Target mode is set to ``RTM.EXTERNAL`` if `is_external` is |True|. Intended for use during load from a serialized package, where the rId is well-known. Other @@ -103,7 +108,7 @@ def partname(self): return self._partname @partname.setter - def partname(self, partname): + def partname(self, partname: str): if not isinstance(partname, PackURI): tmpl = "partname must be instance of PackURI, got '%s'" raise TypeError(tmpl % type(partname).__name__) @@ -118,16 +123,16 @@ def part_related_by(self, reltype: str) -> Part: """ return self.rels.part_with_reltype(reltype) - def relate_to(self, target: Part, reltype: str, is_external: bool = False) -> str: + def relate_to(self, target: Part | str, reltype: str, is_external: bool = False) -> str: """Return rId key of relationship of `reltype` to `target`. The returned `rId` is from an existing relationship if there is one, otherwise a new relationship is created. """ if is_external: - return self.rels.get_or_add_ext_rel(reltype, target) + return self.rels.get_or_add_ext_rel(reltype, cast(str, target)) else: - rel = self.rels.get_or_add(reltype, target) + rel = self.rels.get_or_add(reltype, cast(Part, target)) return rel.rId @property @@ -140,18 +145,21 @@ def related_parts(self): @lazyproperty def rels(self): """|Relationships| instance holding the relationships for this part.""" - return Relationships(self._partname.baseURI) + # -- prevent breakage in `python-docx-template` by retaining legacy `._rels` attribute -- + self._rels = Relationships(self._partname.baseURI) + return self._rels - def target_ref(self, rId): + def target_ref(self, rId: str) -> str: """Return URL contained in target ref of relationship identified by `rId`.""" rel = self.rels[rId] return rel.target_ref - def _rel_ref_count(self, rId): - """Return the count of references in this part's XML to the relationship - identified by `rId`.""" - rIds = self._element.xpath("//@r:id") - return len([_rId for _rId in rIds if _rId == rId]) + def _rel_ref_count(self, rId: str) -> int: + """Return the count of references in this part to the relationship identified by `rId`. + + Only an XML part can contain references, so this is 0 for `Part`. + """ + return 0 class PartFactory: @@ -168,12 +176,12 @@ class PartFactory: """ part_class_selector: Callable[[str, str], Type[Part] | None] | None - part_type_for: Dict[str, Type[Part]] = {} + part_type_for: dict[str, Type[Part]] = {} default_part_type = Part def __new__( cls, - partname: str, + partname: PackURI, content_type: str, reltype: str, blob: bytes, @@ -203,7 +211,9 @@ class XmlPart(Part): reserializing the XML payload and managing relationships to other parts. """ - def __init__(self, partname, content_type, element, package): + def __init__( + self, partname: PackURI, content_type: str, element: BaseOxmlElement, package: Package + ): super(XmlPart, self).__init__(partname, content_type, package=package) self._element = element @@ -217,7 +227,7 @@ def element(self): return self._element @classmethod - def load(cls, partname, content_type, blob, package): + def load(cls, partname: PackURI, content_type: str, blob: bytes, package: Package): element = parse_xml(blob) return cls(partname, content_type, element, package) @@ -229,3 +239,9 @@ def part(self): That chain of delegation ends here for child objects. """ return self + + def _rel_ref_count(self, rId: str) -> int: + """Return the count of references in this part's XML to the relationship + identified by `rId`.""" + rIds = cast("list[str]", self._element.xpath("//@r:id")) + return len([_rId for _rId in rIds if _rId == rId]) diff --git a/src/docx/opc/parts/coreprops.py b/src/docx/opc/parts/coreprops.py index 6e26e1d05..fda011218 100644 --- a/src/docx/opc/parts/coreprops.py +++ b/src/docx/opc/parts/coreprops.py @@ -1,6 +1,9 @@ """Core properties part, corresponds to ``/docProps/core.xml`` part in package.""" -from datetime import datetime +from __future__ import annotations + +import datetime as dt +from typing import TYPE_CHECKING from docx.opc.constants import CONTENT_TYPE as CT from docx.opc.coreprops import CoreProperties @@ -8,13 +11,19 @@ from docx.opc.part import XmlPart from docx.oxml.coreprops import CT_CoreProperties +if TYPE_CHECKING: + from docx.opc.package import OpcPackage + class CorePropertiesPart(XmlPart): - """Corresponds to part named ``/docProps/core.xml``, containing the core document - properties for this document package.""" + """Corresponds to part named ``/docProps/core.xml``. + + The "core" is short for "Dublin Core" and contains document metadata relatively common across + documents of all types, not just DOCX. + """ @classmethod - def default(cls, package): + def default(cls, package: OpcPackage): """Return a new |CorePropertiesPart| object initialized with default values for its base properties.""" core_properties_part = cls._new(package) @@ -22,7 +31,7 @@ def default(cls, package): core_properties.title = "Word Document" core_properties.last_modified_by = "python-docx" core_properties.revision = 1 - core_properties.modified = datetime.utcnow() + core_properties.modified = dt.datetime.now(dt.timezone.utc) return core_properties_part @property @@ -32,7 +41,7 @@ def core_properties(self): return CoreProperties(self.element) @classmethod - def _new(cls, package): + def _new(cls, package: OpcPackage) -> CorePropertiesPart: partname = PackURI("/docProps/core.xml") content_type = CT.OPC_CORE_PROPERTIES coreProperties = CT_CoreProperties.new() diff --git a/src/docx/opc/pkgwriter.py b/src/docx/opc/pkgwriter.py index 75af6ac75..e63516979 100644 --- a/src/docx/opc/pkgwriter.py +++ b/src/docx/opc/pkgwriter.py @@ -4,6 +4,10 @@ OpcPackage.save(). """ +from __future__ import annotations + +from typing import TYPE_CHECKING, Iterable + from docx.opc.constants import CONTENT_TYPE as CT from docx.opc.oxml import CT_Types, serialize_part_xml from docx.opc.packuri import CONTENT_TYPES_URI, PACKAGE_URI @@ -11,6 +15,9 @@ from docx.opc.shared import CaseInsensitiveDict from docx.opc.spec import default_content_types +if TYPE_CHECKING: + from docx.opc.part import Part + class PackageWriter: """Writes a zip-format OPC package to `pkg_file`, where `pkg_file` can be either a @@ -38,13 +45,13 @@ def _write_content_types_stream(phys_writer, parts): phys_writer.write(CONTENT_TYPES_URI, cti.blob) @staticmethod - def _write_parts(phys_writer, parts): + def _write_parts(phys_writer: PhysPkgWriter, parts: Iterable[Part]): """Write the blob of each part in `parts` to the package, along with a rels item for its relationships if and only if it has any.""" for part in parts: phys_writer.write(part.partname, part.blob) - if len(part._rels): - phys_writer.write(part.partname.rels_uri, part._rels.xml) + if len(part.rels): + phys_writer.write(part.partname.rels_uri, part.rels.xml) @staticmethod def _write_pkg_rels(phys_writer, pkg_rels): diff --git a/src/docx/opc/rel.py b/src/docx/opc/rel.py index efac5e06b..47e8860d8 100644 --- a/src/docx/opc/rel.py +++ b/src/docx/opc/rel.py @@ -2,10 +2,13 @@ from __future__ import annotations -from typing import Any, Dict +from typing import TYPE_CHECKING, Any, Dict, cast from docx.opc.oxml import CT_Relationships +if TYPE_CHECKING: + from docx.opc.part import Part + class Relationships(Dict[str, "_Relationship"]): """Collection object for |_Relationship| instances, having list semantics.""" @@ -13,10 +16,10 @@ class Relationships(Dict[str, "_Relationship"]): def __init__(self, baseURI: str): super(Relationships, self).__init__() self._baseURI = baseURI - self._target_parts_by_rId: Dict[str, Any] = {} + self._target_parts_by_rId: dict[str, Any] = {} def add_relationship( - self, reltype: str, target: str | Any, rId: str, is_external: bool = False + self, reltype: str, target: Part | str, rId: str, is_external: bool = False ) -> "_Relationship": """Return a newly added |_Relationship| instance.""" rel = _Relationship(rId, reltype, target, self._baseURI, is_external) @@ -25,7 +28,7 @@ def add_relationship( self._target_parts_by_rId[rId] = target return rel - def get_or_add(self, reltype, target_part): + def get_or_add(self, reltype: str, target_part: Part) -> _Relationship: """Return relationship of `reltype` to `target_part`, newly added if not already present in collection.""" rel = self._get_matching(reltype, target_part) @@ -34,7 +37,7 @@ def get_or_add(self, reltype, target_part): rel = self.add_relationship(reltype, target_part, rId) return rel - def get_or_add_ext_rel(self, reltype, target_ref): + def get_or_add_ext_rel(self, reltype: str, target_ref: str) -> str: """Return rId of external relationship of `reltype` to `target_ref`, newly added if not already present in collection.""" rel = self._get_matching(reltype, target_ref, is_external=True) @@ -43,7 +46,7 @@ def get_or_add_ext_rel(self, reltype, target_ref): rel = self.add_relationship(reltype, target_ref, rId, is_external=True) return rel.rId - def part_with_reltype(self, reltype): + def part_with_reltype(self, reltype: str) -> Part: """Return target part of rel with matching `reltype`, raising |KeyError| if not found and |ValueError| if more than one matching relationship is found.""" rel = self._get_rel_of_type(reltype) @@ -56,7 +59,7 @@ def related_parts(self): return self._target_parts_by_rId @property - def xml(self): + def xml(self) -> str: """Serialize this relationship collection into XML suitable for storage as a .rels file in an OPC package.""" rels_elm = CT_Relationships.new() @@ -64,11 +67,13 @@ def xml(self): rels_elm.add_rel(rel.rId, rel.reltype, rel.target_ref, rel.is_external) return rels_elm.xml - def _get_matching(self, reltype, target, is_external=False): + def _get_matching( + self, reltype: str, target: Part | str, is_external: bool = False + ) -> _Relationship | None: """Return relationship of matching `reltype`, `target`, and `is_external` from collection, or None if not found.""" - def matches(rel, reltype, target, is_external): + def matches(rel: _Relationship, reltype: str, target: Part | str, is_external: bool): if rel.reltype != reltype: return False if rel.is_external != is_external: @@ -83,7 +88,7 @@ def matches(rel, reltype, target, is_external): return rel return None - def _get_rel_of_type(self, reltype): + def _get_rel_of_type(self, reltype: str): """Return single relationship of type `reltype` from the collection. Raises |KeyError| if no matching relationship is found. Raises |ValueError| if @@ -99,7 +104,7 @@ def _get_rel_of_type(self, reltype): return matching[0] @property - def _next_rId(self): + def _next_rId(self) -> str: # pyright: ignore[reportReturnType] """Next available rId in collection, starting from 'rId1' and making use of any gaps in numbering, e.g. 'rId2' for rIds ['rId1', 'rId3'].""" for n in range(1, len(self) + 2): @@ -111,7 +116,9 @@ def _next_rId(self): class _Relationship: """Value object for relationship to part.""" - def __init__(self, rId: str, reltype, target, baseURI, external=False): + def __init__( + self, rId: str, reltype: str, target: Part | str, baseURI: str, external: bool = False + ): super(_Relationship, self).__init__() self._rId = rId self._reltype = reltype @@ -120,29 +127,29 @@ def __init__(self, rId: str, reltype, target, baseURI, external=False): self._is_external = bool(external) @property - def is_external(self): + def is_external(self) -> bool: return self._is_external @property - def reltype(self): + def reltype(self) -> str: return self._reltype @property - def rId(self): + def rId(self) -> str: return self._rId @property - def target_part(self): + def target_part(self) -> Part: if self._is_external: raise ValueError( - "target_part property on _Relationship is undef" - "ined when target mode is External" + "target_part property on _Relationship is undef" "ined when target mode is External" ) - return self._target + return cast("Part", self._target) @property def target_ref(self) -> str: if self._is_external: - return self._target + return cast(str, self._target) else: - return self._target.partname.relative_ref(self._baseURI) + target = cast("Part", self._target) + return target.partname.relative_ref(self._baseURI) diff --git a/src/docx/opc/shared.py b/src/docx/opc/shared.py index 1862f66db..9d4c0a6d3 100644 --- a/src/docx/opc/shared.py +++ b/src/docx/opc/shared.py @@ -1,7 +1,13 @@ """Objects shared by opc modules.""" +from __future__ import annotations -class CaseInsensitiveDict(dict): +from typing import Any, Dict, TypeVar + +_T = TypeVar("_T") + + +class CaseInsensitiveDict(Dict[str, Any]): """Mapping type that behaves like dict except that it matches without respect to the case of the key. @@ -23,23 +29,3 @@ def __setitem__(self, key, value): def cls_method_fn(cls: type, method_name: str): """Return method of `cls` having `method_name`.""" return getattr(cls, method_name) - - -def lazyproperty(f): - """@lazyprop decorator. - - Decorated method will be called only on first access to calculate a cached property - value. After that, the cached value is returned. - """ - cache_attr_name = "_%s" % f.__name__ # like '_foobar' for prop 'foobar' - docstring = f.__doc__ - - def get_prop_value(obj): - try: - return getattr(obj, cache_attr_name) - except AttributeError: - value = f(obj) - setattr(obj, cache_attr_name, value) - return value - - return property(get_prop_value, doc=docstring) diff --git a/src/docx/oxml/__init__.py b/src/docx/oxml/__init__.py index 621ef279a..bf32932f9 100644 --- a/src/docx/oxml/__init__.py +++ b/src/docx/oxml/__init__.py @@ -149,6 +149,7 @@ CT_TblGridCol, CT_TblLayoutType, CT_TblPr, + CT_TblPrEx, CT_TblWidth, CT_Tc, CT_TcPr, @@ -158,12 +159,15 @@ ) register_element_cls("w:bidiVisual", CT_OnOff) +register_element_cls("w:gridAfter", CT_DecimalNumber) +register_element_cls("w:gridBefore", CT_DecimalNumber) register_element_cls("w:gridCol", CT_TblGridCol) register_element_cls("w:gridSpan", CT_DecimalNumber) register_element_cls("w:tbl", CT_Tbl) register_element_cls("w:tblGrid", CT_TblGrid) register_element_cls("w:tblLayout", CT_TblLayoutType) register_element_cls("w:tblPr", CT_TblPr) +register_element_cls("w:tblPrEx", CT_TblPrEx) register_element_cls("w:tblStyle", CT_String) register_element_cls("w:tc", CT_Tc) register_element_cls("w:tcPr", CT_TcPr) diff --git a/src/docx/oxml/coreprops.py b/src/docx/oxml/coreprops.py index 2cafcd960..8ba9ff42e 100644 --- a/src/docx/oxml/coreprops.py +++ b/src/docx/oxml/coreprops.py @@ -1,13 +1,18 @@ """Custom element classes for core properties-related XML elements.""" +from __future__ import annotations + +import datetime as dt import re -from datetime import datetime, timedelta -from typing import Any +from typing import TYPE_CHECKING, Any, Callable from docx.oxml.ns import nsdecls, qn from docx.oxml.parser import parse_xml from docx.oxml.xmlchemy import BaseOxmlElement, ZeroOrOne +if TYPE_CHECKING: + from lxml.etree import _Element as etree_Element # pyright: ignore[reportPrivateUsage] + class CT_CoreProperties(BaseOxmlElement): """`` element, the root element of the Core Properties part. @@ -17,6 +22,8 @@ class CT_CoreProperties(BaseOxmlElement): present in the XML. String elements are limited in length to 255 unicode characters. """ + get_or_add_revision: Callable[[], etree_Element] + category = ZeroOrOne("cp:category", successors=()) contentStatus = ZeroOrOne("cp:contentStatus", successors=()) created = ZeroOrOne("dcterms:created", successors=()) @@ -28,7 +35,9 @@ class CT_CoreProperties(BaseOxmlElement): lastModifiedBy = ZeroOrOne("cp:lastModifiedBy", successors=()) lastPrinted = ZeroOrOne("cp:lastPrinted", successors=()) modified = ZeroOrOne("dcterms:modified", successors=()) - revision = ZeroOrOne("cp:revision", successors=()) + revision: etree_Element | None = ZeroOrOne( # pyright: ignore[reportAssignmentType] + "cp:revision", successors=() + ) subject = ZeroOrOne("dc:subject", successors=()) title = ZeroOrOne("dc:title", successors=()) version = ZeroOrOne("cp:version", successors=()) @@ -80,7 +89,7 @@ def created_datetime(self): return self._datetime_of_element("created") @created_datetime.setter - def created_datetime(self, value): + def created_datetime(self, value: dt.datetime): self._set_element_datetime("created", value) @property @@ -88,7 +97,7 @@ def identifier_text(self): return self._text_of_element("identifier") @identifier_text.setter - def identifier_text(self, value): + def identifier_text(self, value: str): self._set_element_text("identifier", value) @property @@ -96,7 +105,7 @@ def keywords_text(self): return self._text_of_element("keywords") @keywords_text.setter - def keywords_text(self, value): + def keywords_text(self, value: str): self._set_element_text("keywords", value) @property @@ -104,7 +113,7 @@ def language_text(self): return self._text_of_element("language") @language_text.setter - def language_text(self, value): + def language_text(self, value: str): self._set_element_text("language", value) @property @@ -112,7 +121,7 @@ def lastModifiedBy_text(self): return self._text_of_element("lastModifiedBy") @lastModifiedBy_text.setter - def lastModifiedBy_text(self, value): + def lastModifiedBy_text(self, value: str): self._set_element_text("lastModifiedBy", value) @property @@ -120,15 +129,15 @@ def lastPrinted_datetime(self): return self._datetime_of_element("lastPrinted") @lastPrinted_datetime.setter - def lastPrinted_datetime(self, value): + def lastPrinted_datetime(self, value: dt.datetime): self._set_element_datetime("lastPrinted", value) @property - def modified_datetime(self): + def modified_datetime(self) -> dt.datetime | None: return self._datetime_of_element("modified") @modified_datetime.setter - def modified_datetime(self, value): + def modified_datetime(self, value: dt.datetime): self._set_element_datetime("modified", value) @property @@ -137,7 +146,7 @@ def revision_number(self): revision = self.revision if revision is None: return 0 - revision_str = revision.text + revision_str = str(revision.text) try: revision = int(revision_str) except ValueError: @@ -149,9 +158,9 @@ def revision_number(self): return revision @revision_number.setter - def revision_number(self, value): + def revision_number(self, value: int): """Set revision property to string value of integer `value`.""" - if not isinstance(value, int) or value < 1: + if not isinstance(value, int) or value < 1: # pyright: ignore[reportUnnecessaryIsInstance] tmpl = "revision property requires positive int, got '%s'" raise ValueError(tmpl % value) revision = self.get_or_add_revision() @@ -162,7 +171,7 @@ def subject_text(self): return self._text_of_element("subject") @subject_text.setter - def subject_text(self, value): + def subject_text(self, value: str): self._set_element_text("subject", value) @property @@ -170,7 +179,7 @@ def title_text(self): return self._text_of_element("title") @title_text.setter - def title_text(self, value): + def title_text(self, value: str): self._set_element_text("title", value) @property @@ -178,10 +187,10 @@ def version_text(self): return self._text_of_element("version") @version_text.setter - def version_text(self, value): + def version_text(self, value: str): self._set_element_text("version", value) - def _datetime_of_element(self, property_name): + def _datetime_of_element(self, property_name: str) -> dt.datetime | None: element = getattr(self, property_name) if element is None: return None @@ -192,7 +201,7 @@ def _datetime_of_element(self, property_name): # invalid datetime strings are ignored return None - def _get_or_add(self, prop_name): + def _get_or_add(self, prop_name: str) -> BaseOxmlElement: """Return element returned by "get_or_add_" method for `prop_name`.""" get_or_add_method_name = "get_or_add_%s" % prop_name get_or_add_method = getattr(self, get_or_add_method_name) @@ -200,8 +209,8 @@ def _get_or_add(self, prop_name): return element @classmethod - def _offset_dt(cls, dt, offset_str): - """A |datetime| instance offset from `dt` by timezone offset in `offset_str`. + def _offset_dt(cls, dt_: dt.datetime, offset_str: str) -> dt.datetime: + """A |datetime| instance offset from `dt_` by timezone offset in `offset_str`. `offset_str` is like `"-07:00"`. """ @@ -212,13 +221,13 @@ def _offset_dt(cls, dt, offset_str): sign_factor = -1 if sign == "+" else 1 hours = int(hours_str) * sign_factor minutes = int(minutes_str) * sign_factor - td = timedelta(hours=hours, minutes=minutes) - return dt + td + td = dt.timedelta(hours=hours, minutes=minutes) + return dt_ + td _offset_pattern = re.compile(r"([+-])(\d\d):(\d\d)") @classmethod - def _parse_W3CDTF_to_datetime(cls, w3cdtf_str): + def _parse_W3CDTF_to_datetime(cls, w3cdtf_str: str) -> dt.datetime: # valid W3CDTF date cases: # yyyy e.g. "2003" # yyyy-mm e.g. "2003-12" @@ -235,22 +244,22 @@ def _parse_W3CDTF_to_datetime(cls, w3cdtf_str): # "-07:30", so we have to do it ourselves parseable_part = w3cdtf_str[:19] offset_str = w3cdtf_str[19:] - dt = None + dt_ = None for tmpl in templates: try: - dt = datetime.strptime(parseable_part, tmpl) + dt_ = dt.datetime.strptime(parseable_part, tmpl) except ValueError: continue - if dt is None: + if dt_ is None: tmpl = "could not parse W3CDTF datetime string '%s'" raise ValueError(tmpl % w3cdtf_str) if len(offset_str) == 6: - return cls._offset_dt(dt, offset_str) - return dt + dt_ = cls._offset_dt(dt_, offset_str) + return dt_.replace(tzinfo=dt.timezone.utc) - def _set_element_datetime(self, prop_name, value): + def _set_element_datetime(self, prop_name: str, value: dt.datetime): """Set date/time value of child element having `prop_name` to `value`.""" - if not isinstance(value, datetime): + if not isinstance(value, dt.datetime): # pyright: ignore[reportUnnecessaryIsInstance] tmpl = "property requires object, got %s" raise ValueError(tmpl % type(value)) element = self._get_or_add(prop_name) diff --git a/src/docx/oxml/document.py b/src/docx/oxml/document.py index cc27f5aa9..36819ef75 100644 --- a/src/docx/oxml/document.py +++ b/src/docx/oxml/document.py @@ -15,7 +15,7 @@ class CT_Document(BaseOxmlElement): """```` element, the root element of a document.xml file.""" - body = ZeroOrOne("w:body") + body: CT_Body = ZeroOrOne("w:body") # pyright: ignore[reportAssignmentType] @property def sectPr_lst(self) -> List[CT_SectPr]: @@ -44,7 +44,7 @@ class CT_Body(BaseOxmlElement): p = ZeroOrMore("w:p", successors=("w:sectPr",)) tbl = ZeroOrMore("w:tbl", successors=("w:sectPr",)) - sectPr: CT_SectPr | None = ZeroOrOne( # pyright: ignore[reportGeneralTypeIssues] + sectPr: CT_SectPr | None = ZeroOrOne( # pyright: ignore[reportAssignmentType] "w:sectPr", successors=() ) diff --git a/src/docx/oxml/ns.py b/src/docx/oxml/ns.py index 3238864e9..5bed1e6a0 100644 --- a/src/docx/oxml/ns.py +++ b/src/docx/oxml/ns.py @@ -1,8 +1,8 @@ """Namespace-related objects.""" -from typing import Any, Dict +from __future__ import annotations -from typing_extensions import Self +from typing import Any, Dict nsmap = { "a": "http://schemas.openxmlformats.org/drawingml/2006/main", @@ -41,7 +41,7 @@ def clark_name(self) -> str: return "{%s}%s" % (self._ns_uri, self._local_part) @classmethod - def from_clark_name(cls, clark_name: str) -> Self: + def from_clark_name(cls, clark_name: str) -> NamespacePrefixedTag: nsuri, local_name = clark_name[1:].split("}") nstag = "%s:%s" % (pfxmap[nsuri], local_name) return cls(nstag) diff --git a/src/docx/oxml/parser.py b/src/docx/oxml/parser.py index 7e6a0fb49..e16ba30ba 100644 --- a/src/docx/oxml/parser.py +++ b/src/docx/oxml/parser.py @@ -1,3 +1,5 @@ +# pyright: reportImportCycles=false + """XML parser for python-docx.""" from __future__ import annotations @@ -18,7 +20,7 @@ oxml_parser.set_element_class_lookup(element_class_lookup) -def parse_xml(xml: str) -> "BaseOxmlElement": +def parse_xml(xml: str | bytes) -> "BaseOxmlElement": """Root lxml element obtained by parsing XML character string `xml`. The custom parser is used, so custom element classes are produced for elements in @@ -43,7 +45,7 @@ def OxmlElement( nsptag_str: str, attrs: Dict[str, str] | None = None, nsdecls: Dict[str, str] | None = None, -) -> BaseOxmlElement: +) -> BaseOxmlElement | etree._Element: # pyright: ignore[reportPrivateUsage] """Return a 'loose' lxml element having the tag specified by `nsptag_str`. The tag in `nsptag_str` must contain the standard namespace prefix, e.g. `a:tbl`. diff --git a/src/docx/oxml/section.py b/src/docx/oxml/section.py index a4090898a..71072e2df 100644 --- a/src/docx/oxml/section.py +++ b/src/docx/oxml/section.py @@ -51,38 +51,34 @@ def inner_content_elements(self) -> List[CT_P | CT_Tbl]: class CT_HdrFtrRef(BaseOxmlElement): """`w:headerReference` and `w:footerReference` elements.""" - type_: WD_HEADER_FOOTER = ( - RequiredAttribute( # pyright: ignore[reportGeneralTypeIssues] - "w:type", WD_HEADER_FOOTER - ) - ) - rId: str = RequiredAttribute( # pyright: ignore[reportGeneralTypeIssues] - "r:id", XsdString + type_: WD_HEADER_FOOTER = RequiredAttribute( # pyright: ignore[reportAssignmentType] + "w:type", WD_HEADER_FOOTER ) + rId: str = RequiredAttribute("r:id", XsdString) # pyright: ignore[reportAssignmentType] class CT_PageMar(BaseOxmlElement): """```` element, defining page margins.""" - top: Length | None = OptionalAttribute( # pyright: ignore[reportGeneralTypeIssues] + top: Length | None = OptionalAttribute( # pyright: ignore[reportAssignmentType] "w:top", ST_SignedTwipsMeasure ) - right: Length | None = OptionalAttribute( # pyright: ignore + right: Length | None = OptionalAttribute( # pyright: ignore[reportAssignmentType] "w:right", ST_TwipsMeasure ) - bottom: Length | None = OptionalAttribute( # pyright: ignore + bottom: Length | None = OptionalAttribute( # pyright: ignore[reportAssignmentType] "w:bottom", ST_SignedTwipsMeasure ) - left: Length | None = OptionalAttribute( # pyright: ignore[reportGeneralTypeIssues] + left: Length | None = OptionalAttribute( # pyright: ignore[reportAssignmentType] "w:left", ST_TwipsMeasure ) - header: Length | None = OptionalAttribute( # pyright: ignore + header: Length | None = OptionalAttribute( # pyright: ignore[reportAssignmentType] "w:header", ST_TwipsMeasure ) - footer: Length | None = OptionalAttribute( # pyright: ignore + footer: Length | None = OptionalAttribute( # pyright: ignore[reportAssignmentType] "w:footer", ST_TwipsMeasure ) - gutter: Length | None = OptionalAttribute( # pyright: ignore + gutter: Length | None = OptionalAttribute( # pyright: ignore[reportAssignmentType] "w:gutter", ST_TwipsMeasure ) @@ -90,16 +86,14 @@ class CT_PageMar(BaseOxmlElement): class CT_PageSz(BaseOxmlElement): """```` element, defining page dimensions and orientation.""" - w: Length | None = OptionalAttribute( # pyright: ignore[reportGeneralTypeIssues] + w: Length | None = OptionalAttribute( # pyright: ignore[reportAssignmentType] "w:w", ST_TwipsMeasure ) - h: Length | None = OptionalAttribute( # pyright: ignore[reportGeneralTypeIssues] + h: Length | None = OptionalAttribute( # pyright: ignore[reportAssignmentType] "w:h", ST_TwipsMeasure ) - orient: WD_ORIENTATION = ( - OptionalAttribute( # pyright: ignore[reportGeneralTypeIssues] - "w:orient", WD_ORIENTATION, default=WD_ORIENTATION.PORTRAIT - ) + orient: WD_ORIENTATION = OptionalAttribute( # pyright: ignore[reportAssignmentType] + "w:orient", WD_ORIENTATION, default=WD_ORIENTATION.PORTRAIT ) @@ -139,16 +133,16 @@ class CT_SectPr(BaseOxmlElement): ) headerReference = ZeroOrMore("w:headerReference", successors=_tag_seq) footerReference = ZeroOrMore("w:footerReference", successors=_tag_seq) - type: CT_SectType | None = ZeroOrOne( # pyright: ignore[reportGeneralTypeIssues] + type: CT_SectType | None = ZeroOrOne( # pyright: ignore[reportAssignmentType] "w:type", successors=_tag_seq[3:] ) - pgSz: CT_PageSz | None = ZeroOrOne( # pyright: ignore[reportGeneralTypeIssues] + pgSz: CT_PageSz | None = ZeroOrOne( # pyright: ignore[reportAssignmentType] "w:pgSz", successors=_tag_seq[4:] ) - pgMar: CT_PageMar | None = ZeroOrOne( # pyright: ignore[reportGeneralTypeIssues] + pgMar: CT_PageMar | None = ZeroOrOne( # pyright: ignore[reportAssignmentType] "w:pgMar", successors=_tag_seq[5:] ) - titlePg: CT_OnOff | None = ZeroOrOne( # pyright: ignore[reportGeneralTypeIssues] + titlePg: CT_OnOff | None = ZeroOrOne( # pyright: ignore[reportAssignmentType] "w:titlePg", successors=_tag_seq[14:] ) del _tag_seq @@ -187,9 +181,7 @@ def bottom_margin(self) -> Length | None: @bottom_margin.setter def bottom_margin(self, value: int | Length | None): pgMar = self.get_or_add_pgMar() - pgMar.bottom = ( - value if value is None or isinstance(value, Length) else Length(value) - ) + pgMar.bottom = value if value is None or isinstance(value, Length) else Length(value) def clone(self) -> CT_SectPr: """Return an exact duplicate of this ```` element tree suitable for @@ -217,9 +209,7 @@ def footer(self) -> Length | None: @footer.setter def footer(self, value: int | Length | None): pgMar = self.get_or_add_pgMar() - pgMar.footer = ( - value if value is None or isinstance(value, Length) else Length(value) - ) + pgMar.footer = value if value is None or isinstance(value, Length) else Length(value) def get_footerReference(self, type_: WD_HEADER_FOOTER) -> CT_HdrFtrRef | None: """Return footerReference element of `type_` or None if not present.""" @@ -251,9 +241,7 @@ def gutter(self) -> Length | None: @gutter.setter def gutter(self, value: int | Length | None): pgMar = self.get_or_add_pgMar() - pgMar.gutter = ( - value if value is None or isinstance(value, Length) else Length(value) - ) + pgMar.gutter = value if value is None or isinstance(value, Length) else Length(value) @property def header(self) -> Length | None: @@ -270,9 +258,7 @@ def header(self) -> Length | None: @header.setter def header(self, value: int | Length | None): pgMar = self.get_or_add_pgMar() - pgMar.header = ( - value if value is None or isinstance(value, Length) else Length(value) - ) + pgMar.header = value if value is None or isinstance(value, Length) else Length(value) def iter_inner_content(self) -> Iterator[CT_P | CT_Tbl]: """Generate all `w:p` and `w:tbl` elements in this section. @@ -295,9 +281,7 @@ def left_margin(self) -> Length | None: @left_margin.setter def left_margin(self, value: int | Length | None): pgMar = self.get_or_add_pgMar() - pgMar.left = ( - value if value is None or isinstance(value, Length) else Length(value) - ) + pgMar.left = value if value is None or isinstance(value, Length) else Length(value) @property def orientation(self) -> WD_ORIENTATION: @@ -442,8 +426,8 @@ def top_margin(self, value: Length | None): class CT_SectType(BaseOxmlElement): """```` element, defining the section start type.""" - val: WD_SECTION_START | None = ( # pyright: ignore[reportGeneralTypeIssues] - OptionalAttribute("w:val", WD_SECTION_START) + val: WD_SECTION_START | None = OptionalAttribute( # pyright: ignore[reportAssignmentType] + "w:val", WD_SECTION_START ) diff --git a/src/docx/oxml/settings.py b/src/docx/oxml/settings.py index fd39fbd99..d5bb41a6d 100644 --- a/src/docx/oxml/settings.py +++ b/src/docx/oxml/settings.py @@ -1,11 +1,21 @@ """Custom element classes related to document settings.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable + from docx.oxml.xmlchemy import BaseOxmlElement, ZeroOrOne +if TYPE_CHECKING: + from docx.oxml.shared import CT_OnOff + class CT_Settings(BaseOxmlElement): """`w:settings` element, root element for the settings part.""" + get_or_add_evenAndOddHeaders: Callable[[], CT_OnOff] + _remove_evenAndOddHeaders: Callable[[], None] + _tag_seq = ( "w:writeProtection", "w:view", @@ -106,11 +116,13 @@ class CT_Settings(BaseOxmlElement): "w:decimalSymbol", "w:listSeparator", ) - evenAndOddHeaders = ZeroOrOne("w:evenAndOddHeaders", successors=_tag_seq[48:]) + evenAndOddHeaders: CT_OnOff | None = ZeroOrOne( # pyright: ignore[reportAssignmentType] + "w:evenAndOddHeaders", successors=_tag_seq[48:] + ) del _tag_seq @property - def evenAndOddHeaders_val(self): + def evenAndOddHeaders_val(self) -> bool: """Value of `w:evenAndOddHeaders/@w:val` or |None| if not present.""" evenAndOddHeaders = self.evenAndOddHeaders if evenAndOddHeaders is None: @@ -118,8 +130,9 @@ def evenAndOddHeaders_val(self): return evenAndOddHeaders.val @evenAndOddHeaders_val.setter - def evenAndOddHeaders_val(self, value): - if value in [None, False]: + def evenAndOddHeaders_val(self, value: bool | None): + if value is None or value is False: self._remove_evenAndOddHeaders() - else: - self.get_or_add_evenAndOddHeaders().val = value + return + + self.get_or_add_evenAndOddHeaders().val = value diff --git a/src/docx/oxml/shape.py b/src/docx/oxml/shape.py index 05c96697a..289d35579 100644 --- a/src/docx/oxml/shape.py +++ b/src/docx/oxml/shape.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast from docx.oxml.ns import nsdecls from docx.oxml.parser import parse_xml @@ -34,48 +34,58 @@ class CT_Blip(BaseOxmlElement): """```` element, specifies image source and adjustments such as alpha and tint.""" - embed = OptionalAttribute("r:embed", ST_RelationshipId) - link = OptionalAttribute("r:link", ST_RelationshipId) + embed: str | None = OptionalAttribute( # pyright: ignore[reportAssignmentType] + "r:embed", ST_RelationshipId + ) + link: str | None = OptionalAttribute( # pyright: ignore[reportAssignmentType] + "r:link", ST_RelationshipId + ) class CT_BlipFillProperties(BaseOxmlElement): """```` element, specifies picture properties.""" - blip = ZeroOrOne("a:blip", successors=("a:srcRect", "a:tile", "a:stretch")) + blip: CT_Blip = ZeroOrOne( # pyright: ignore[reportAssignmentType] + "a:blip", successors=("a:srcRect", "a:tile", "a:stretch") + ) class CT_GraphicalObject(BaseOxmlElement): """```` element, container for a DrawingML object.""" - graphicData = OneAndOnlyOne("a:graphicData") + graphicData: CT_GraphicalObjectData = OneAndOnlyOne( # pyright: ignore[reportAssignmentType] + "a:graphicData" + ) class CT_GraphicalObjectData(BaseOxmlElement): """```` element, container for the XML of a DrawingML object.""" - pic = ZeroOrOne("pic:pic") - uri = RequiredAttribute("uri", XsdToken) + pic: CT_Picture = ZeroOrOne("pic:pic") # pyright: ignore[reportAssignmentType] + uri: str = RequiredAttribute("uri", XsdToken) # pyright: ignore[reportAssignmentType] class CT_Inline(BaseOxmlElement): """`` element, container for an inline shape.""" - extent = OneAndOnlyOne("wp:extent") - docPr = OneAndOnlyOne("wp:docPr") - graphic = OneAndOnlyOne("a:graphic") + extent: CT_PositiveSize2D = OneAndOnlyOne("wp:extent") # pyright: ignore[reportAssignmentType] + docPr: CT_NonVisualDrawingProps = OneAndOnlyOne( # pyright: ignore[reportAssignmentType] + "wp:docPr" + ) + graphic: CT_GraphicalObject = OneAndOnlyOne( # pyright: ignore[reportAssignmentType] + "a:graphic" + ) @classmethod def new(cls, cx: Length, cy: Length, shape_id: int, pic: CT_Picture) -> CT_Inline: """Return a new ```` element populated with the values passed as parameters.""" - inline = parse_xml(cls._inline_xml()) + inline = cast(CT_Inline, parse_xml(cls._inline_xml())) inline.extent.cx = cx inline.extent.cy = cy inline.docPr.id = shape_id inline.docPr.name = "Picture %d" % shape_id - inline.graphic.graphicData.uri = ( - "http://schemas.openxmlformats.org/drawingml/2006/picture" - ) + inline.graphic.graphicData.uri = "http://schemas.openxmlformats.org/drawingml/2006/picture" inline.graphic.graphicData._insert_pic(pic) return inline @@ -126,9 +136,13 @@ class CT_NonVisualPictureProperties(BaseOxmlElement): class CT_Picture(BaseOxmlElement): """```` element, a DrawingML picture.""" - nvPicPr = OneAndOnlyOne("pic:nvPicPr") - blipFill = OneAndOnlyOne("pic:blipFill") - spPr = OneAndOnlyOne("pic:spPr") + nvPicPr: CT_PictureNonVisual = OneAndOnlyOne( # pyright: ignore[reportAssignmentType] + "pic:nvPicPr" + ) + blipFill: CT_BlipFillProperties = OneAndOnlyOne( # pyright: ignore[reportAssignmentType] + "pic:blipFill" + ) + spPr: CT_ShapeProperties = OneAndOnlyOne("pic:spPr") # pyright: ignore[reportAssignmentType] @classmethod def new(cls, pic_id, filename, rId, cx, cy): @@ -190,8 +204,12 @@ class CT_PositiveSize2D(BaseOxmlElement): Specifies the size of a DrawingML drawing. """ - cx = RequiredAttribute("cx", ST_PositiveCoordinate) - cy = RequiredAttribute("cy", ST_PositiveCoordinate) + cx: Length = RequiredAttribute( # pyright: ignore[reportAssignmentType] + "cx", ST_PositiveCoordinate + ) + cy: Length = RequiredAttribute( # pyright: ignore[reportAssignmentType] + "cy", ST_PositiveCoordinate + ) class CT_PresetGeometry2D(BaseOxmlElement): diff --git a/src/docx/oxml/shared.py b/src/docx/oxml/shared.py index 1774560ac..8c2ebc9a9 100644 --- a/src/docx/oxml/shared.py +++ b/src/docx/oxml/shared.py @@ -15,10 +15,10 @@ class CT_DecimalNumber(BaseOxmlElement): containing a text representation of a decimal number (e.g. 42) in its ``val`` attribute.""" - val = RequiredAttribute("w:val", ST_DecimalNumber) + val: int = RequiredAttribute("w:val", ST_DecimalNumber) # pyright: ignore[reportAssignmentType] @classmethod - def new(cls, nsptagname, val): + def new(cls, nsptagname: str, val: int): """Return a new ``CT_DecimalNumber`` element having tagname `nsptagname` and ``val`` attribute set to `val`.""" return OxmlElement(nsptagname, attrs={qn("w:val"): str(val)}) @@ -31,7 +31,7 @@ class CT_OnOff(BaseOxmlElement): "off". Defaults to `True`, so `` for example means "bold is turned on". """ - val: bool = OptionalAttribute( # pyright: ignore[reportGeneralTypeIssues] + val: bool = OptionalAttribute( # pyright: ignore[reportAssignmentType] "w:val", ST_OnOff, default=True ) @@ -42,9 +42,7 @@ class CT_String(BaseOxmlElement): In those cases, it containing a style name in its `val` attribute. """ - val: str = RequiredAttribute( # pyright: ignore[reportGeneralTypeIssues] - "w:val", ST_String - ) + val: str = RequiredAttribute("w:val", ST_String) # pyright: ignore[reportAssignmentType] @classmethod def new(cls, nsptagname: str, val: str): diff --git a/src/docx/oxml/simpletypes.py b/src/docx/oxml/simpletypes.py index debb5dc3c..dd10ab910 100644 --- a/src/docx/oxml/simpletypes.py +++ b/src/docx/oxml/simpletypes.py @@ -36,12 +36,10 @@ def convert_from_xml(cls, str_value: str) -> Any: return int(str_value) @classmethod - def convert_to_xml(cls, value: Any) -> str: - ... + def convert_to_xml(cls, value: Any) -> str: ... @classmethod - def validate(cls, value: Any) -> None: - ... + def validate(cls, value: Any) -> None: ... @classmethod def validate_int(cls, value: object): @@ -49,9 +47,7 @@ def validate_int(cls, value: object): raise TypeError("value must be , got %s" % type(value)) @classmethod - def validate_int_in_range( - cls, value: int, min_inclusive: int, max_inclusive: int - ) -> None: + def validate_int_in_range(cls, value: int, min_inclusive: int, max_inclusive: int) -> None: cls.validate_int(value) if value < min_inclusive or value > max_inclusive: raise ValueError( @@ -129,8 +125,7 @@ def convert_to_xml(cls, value: bool) -> str: def validate(cls, value: Any) -> None: if value not in (True, False): raise TypeError( - "only True or False (and possibly None) may be assigned, got" - " '%s'" % value + "only True or False (and possibly None) may be assigned, got" " '%s'" % value ) @@ -248,8 +243,7 @@ def validate(cls, value: Any) -> None: # must be an RGBColor object --- if not isinstance(value, RGBColor): raise ValueError( - "rgb color value must be RGBColor object, got %s %s" - % (type(value), value) + "rgb color value must be RGBColor object, got %s %s" % (type(value), value) ) @@ -316,7 +310,7 @@ class ST_SignedTwipsMeasure(XsdInt): def convert_from_xml(cls, str_value: str) -> Length: if "i" in str_value or "m" in str_value or "p" in str_value: return ST_UniversalMeasure.convert_from_xml(str_value) - return Twips(int(str_value)) + return Twips(int(round(float(str_value)))) @classmethod def convert_to_xml(cls, value: int | Length) -> str: diff --git a/src/docx/oxml/styles.py b/src/docx/oxml/styles.py index e0a3eaeaf..fb0e5d0dd 100644 --- a/src/docx/oxml/styles.py +++ b/src/docx/oxml/styles.py @@ -128,12 +128,10 @@ class CT_Style(BaseOxmlElement): rPr = ZeroOrOne("w:rPr", successors=_tag_seq[18:]) del _tag_seq - type: WD_STYLE_TYPE | None = ( - OptionalAttribute( # pyright: ignore[reportGeneralTypeIssues] - "w:type", WD_STYLE_TYPE - ) + type: WD_STYLE_TYPE | None = OptionalAttribute( # pyright: ignore[reportAssignmentType] + "w:type", WD_STYLE_TYPE ) - styleId: str | None = OptionalAttribute( # pyright: ignore[reportGeneralTypeIssues] + styleId: str | None = OptionalAttribute( # pyright: ignore[reportAssignmentType] "w:styleId", ST_String ) default = OptionalAttribute("w:default", ST_OnOff) diff --git a/src/docx/oxml/table.py b/src/docx/oxml/table.py index 48a6d8c2f..e38d58562 100644 --- a/src/docx/oxml/table.py +++ b/src/docx/oxml/table.py @@ -2,12 +2,13 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, List +from typing import TYPE_CHECKING, Callable, cast -from docx.enum.table import WD_CELL_VERTICAL_ALIGNMENT, WD_ROW_HEIGHT_RULE +from docx.enum.table import WD_CELL_VERTICAL_ALIGNMENT, WD_ROW_HEIGHT_RULE, WD_TABLE_DIRECTION from docx.exceptions import InvalidSpanError from docx.oxml.ns import nsdecls, qn from docx.oxml.parser import parse_xml +from docx.oxml.shared import CT_DecimalNumber from docx.oxml.simpletypes import ( ST_Merge, ST_TblLayoutType, @@ -15,6 +16,7 @@ ST_TwipsMeasure, XsdInt, ) +from docx.oxml.text.paragraph import CT_P from docx.oxml.xmlchemy import ( BaseOxmlElement, OneAndOnlyOne, @@ -24,57 +26,93 @@ ZeroOrMore, ZeroOrOne, ) -from docx.shared import Emu, Twips +from docx.shared import Emu, Length, Twips if TYPE_CHECKING: - from docx.oxml.text.paragraph import CT_P - from docx.shared import Length + from docx.enum.table import WD_TABLE_ALIGNMENT + from docx.enum.text import WD_ALIGN_PARAGRAPH + from docx.oxml.shared import CT_OnOff, CT_String + from docx.oxml.text.parfmt import CT_Jc class CT_Height(BaseOxmlElement): - """Used for ```` to specify a row height and row height rule.""" + """Used for `w:trHeight` to specify a row height and row height rule.""" - val = OptionalAttribute("w:val", ST_TwipsMeasure) - hRule = OptionalAttribute("w:hRule", WD_ROW_HEIGHT_RULE) + val: Length | None = OptionalAttribute( # pyright: ignore[reportAssignmentType] + "w:val", ST_TwipsMeasure + ) + hRule: WD_ROW_HEIGHT_RULE | None = OptionalAttribute( # pyright: ignore[reportAssignmentType] + "w:hRule", WD_ROW_HEIGHT_RULE + ) class CT_Row(BaseOxmlElement): """```` element.""" - tblPrEx = ZeroOrOne("w:tblPrEx") # custom inserter below - trPr = ZeroOrOne("w:trPr") # custom inserter below + add_tc: Callable[[], CT_Tc] + get_or_add_trPr: Callable[[], CT_TrPr] + _add_trPr: Callable[[], CT_TrPr] + + tc_lst: list[CT_Tc] + # -- custom inserter below -- + tblPrEx: CT_TblPrEx | None = ZeroOrOne("w:tblPrEx") # pyright: ignore[reportAssignmentType] + # -- custom inserter below -- + trPr: CT_TrPr | None = ZeroOrOne("w:trPr") # pyright: ignore[reportAssignmentType] tc = ZeroOrMore("w:tc") - def tc_at_grid_col(self, idx): - """The ```` element appearing at grid column `idx`. + @property + def grid_after(self) -> int: + """The number of unpopulated layout-grid cells at the end of this row.""" + trPr = self.trPr + if trPr is None: + return 0 + return trPr.grid_after + + @property + def grid_before(self) -> int: + """The number of unpopulated layout-grid cells at the start of this row.""" + trPr = self.trPr + if trPr is None: + return 0 + return trPr.grid_before + + def tc_at_grid_offset(self, grid_offset: int) -> CT_Tc: + """The `tc` element in this tr at exact `grid offset`. - Raises |ValueError| if no ``w:tc`` element begins at that grid column. + Raises ValueError when this `w:tr` contains no `w:tc` with exact starting `grid_offset`. """ - grid_col = 0 + # -- account for omitted cells at the start of the row -- + remaining_offset = grid_offset - self.grid_before + for tc in self.tc_lst: - if grid_col == idx: + # -- We've gone past grid_offset without finding a tc, no sense searching further. -- + if remaining_offset < 0: + break + # -- We've arrived at grid_offset, this is the `w:tc` we're looking for. -- + if remaining_offset == 0: return tc - grid_col += tc.grid_span - if grid_col > idx: - raise ValueError("no cell on grid column %d" % idx) - raise ValueError("index out of bounds") + # -- We're not there yet, skip forward the number of layout-grid cells this cell + # -- occupies. + remaining_offset -= tc.grid_span + + raise ValueError(f"no `tc` element at grid_offset={grid_offset}") @property - def tr_idx(self): - """The index of this ```` element within its parent ```` - element.""" - return self.getparent().tr_lst.index(self) + def tr_idx(self) -> int: + """Index of this `w:tr` element within its parent `w:tbl` element.""" + tbl = cast(CT_Tbl, self.getparent()) + return tbl.tr_lst.index(self) @property - def trHeight_hRule(self): - """Return the value of `w:trPr/w:trHeight@w:hRule`, or |None| if not present.""" + def trHeight_hRule(self) -> WD_ROW_HEIGHT_RULE | None: + """The value of `./w:trPr/w:trHeight/@w:hRule`, or |None| if not present.""" trPr = self.trPr if trPr is None: return None return trPr.trHeight_hRule @trHeight_hRule.setter - def trHeight_hRule(self, value): + def trHeight_hRule(self, value: WD_ROW_HEIGHT_RULE | None): trPr = self.get_or_add_trPr() trPr.trHeight_hRule = value @@ -87,14 +125,14 @@ def trHeight_val(self): return trPr.trHeight_val @trHeight_val.setter - def trHeight_val(self, value): + def trHeight_val(self, value: Length | None): trPr = self.get_or_add_trPr() trPr.trHeight_val = value - def _insert_tblPrEx(self, tblPrEx): + def _insert_tblPrEx(self, tblPrEx: CT_TblPrEx): self.insert(0, tblPrEx) - def _insert_trPr(self, trPr): + def _insert_trPr(self, trPr: CT_TrPr): tblPrEx = self.tblPrEx if tblPrEx is not None: tblPrEx.addnext(trPr) @@ -108,13 +146,16 @@ def _new_tc(self): class CT_Tbl(BaseOxmlElement): """```` element.""" - tblPr = OneAndOnlyOne("w:tblPr") - tblGrid = OneAndOnlyOne("w:tblGrid") + add_tr: Callable[[], CT_Row] + tr_lst: list[CT_Row] + + tblPr: CT_TblPr = OneAndOnlyOne("w:tblPr") # pyright: ignore[reportAssignmentType] + tblGrid: CT_TblGrid = OneAndOnlyOne("w:tblGrid") # pyright: ignore[reportAssignmentType] tr = ZeroOrMore("w:tr") @property - def bidiVisual_val(self): - """Value of `w:tblPr/w:bidiVisual/@w:val` or |None| if not present. + def bidiVisual_val(self) -> bool | None: + """Value of `./w:tblPr/w:bidiVisual/@w:val` or |None| if not present. Controls whether table cells are displayed right-to-left or left-to-right. """ @@ -124,12 +165,12 @@ def bidiVisual_val(self): return bidiVisual.val @bidiVisual_val.setter - def bidiVisual_val(self, value): + def bidiVisual_val(self, value: WD_TABLE_DIRECTION | None): tblPr = self.tblPr if value is None: - tblPr._remove_bidiVisual() + tblPr._remove_bidiVisual() # pyright: ignore[reportPrivateUsage] else: - tblPr.get_or_add_bidiVisual().val = value + tblPr.get_or_add_bidiVisual().val = bool(value) @property def col_count(self): @@ -153,111 +194,118 @@ def new_tbl(cls, rows: int, cols: int, width: Length) -> CT_Tbl: `width` is distributed evenly between the columns. """ - return parse_xml(cls._tbl_xml(rows, cols, width)) + return cast(CT_Tbl, parse_xml(cls._tbl_xml(rows, cols, width))) @property - def tblStyle_val(self): - """Value of `w:tblPr/w:tblStyle/@w:val` (a table style id) or |None| if not - present.""" + def tblStyle_val(self) -> str | None: + """`w:tblPr/w:tblStyle/@w:val` (a table style id) or |None| if not present.""" tblStyle = self.tblPr.tblStyle if tblStyle is None: return None return tblStyle.val @tblStyle_val.setter - def tblStyle_val(self, styleId): + def tblStyle_val(self, styleId: str | None) -> None: """Set the value of `w:tblPr/w:tblStyle/@w:val` (a table style id) to `styleId`. If `styleId` is None, remove the `w:tblStyle` element. """ tblPr = self.tblPr - tblPr._remove_tblStyle() + tblPr._remove_tblStyle() # pyright: ignore[reportPrivateUsage] if styleId is None: return - tblPr._add_tblStyle().val = styleId + tblPr._add_tblStyle().val = styleId # pyright: ignore[reportPrivateUsage] @classmethod def _tbl_xml(cls, rows: int, cols: int, width: Length) -> str: - col_width = Emu(width / cols) if cols > 0 else Emu(0) + col_width = Emu(width // cols) if cols > 0 else Emu(0) return ( - "\n" - " \n" - ' \n' - ' \n' - " \n" - "%s" # tblGrid - "%s" # trs - "\n" - ) % ( - nsdecls("w"), - cls._tblGrid_xml(cols, col_width), - cls._trs_xml(rows, cols, col_width), + f"\n" + f" \n" + f' \n' + f' \n' + f" \n" + f"{cls._tblGrid_xml(cols, col_width)}" + f"{cls._trs_xml(rows, cols, col_width)}" + f"\n" ) @classmethod - def _tblGrid_xml(cls, col_count, col_width): + def _tblGrid_xml(cls, col_count: int, col_width: Length) -> str: xml = " \n" - for i in range(col_count): + for _ in range(col_count): xml += ' \n' % col_width.twips xml += " \n" return xml @classmethod - def _trs_xml(cls, row_count, col_count, col_width): - xml = "" - for i in range(row_count): - xml += (" \n" "%s" " \n") % cls._tcs_xml( - col_count, col_width - ) - return xml + def _trs_xml(cls, row_count: int, col_count: int, col_width: Length) -> str: + return f" \n{cls._tcs_xml(col_count, col_width)} \n" * row_count @classmethod - def _tcs_xml(cls, col_count, col_width): - xml = "" - for i in range(col_count): - xml += ( - " \n" - " \n" - ' \n' - " \n" - " \n" - " \n" - ) % col_width.twips - return xml + def _tcs_xml(cls, col_count: int, col_width: Length) -> str: + return ( + f" \n" + f" \n" + f' \n' + f" \n" + f" \n" + f" \n" + ) * col_count class CT_TblGrid(BaseOxmlElement): - """```` element, child of ````, holds ```` elements - that define column count, width, etc.""" + """`w:tblGrid` element. + + Child of `w:tbl`, holds `w:gridCol> elements that define column count, width, etc. + """ + + add_gridCol: Callable[[], CT_TblGridCol] + gridCol_lst: list[CT_TblGridCol] gridCol = ZeroOrMore("w:gridCol", successors=("w:tblGridChange",)) class CT_TblGridCol(BaseOxmlElement): - """```` element, child of ````, defines a table column.""" + """`w:gridCol` element, child of `w:tblGrid`, defines a table column.""" - w = OptionalAttribute("w:w", ST_TwipsMeasure) + w: Length | None = OptionalAttribute( # pyright: ignore[reportAssignmentType] + "w:w", ST_TwipsMeasure + ) @property - def gridCol_idx(self): - """The index of this ```` element within its parent ```` - element.""" - return self.getparent().gridCol_lst.index(self) + def gridCol_idx(self) -> int: + """Index of this `w:gridCol` element within its parent `w:tblGrid` element.""" + tblGrid = cast(CT_TblGrid, self.getparent()) + return tblGrid.gridCol_lst.index(self) class CT_TblLayoutType(BaseOxmlElement): - """```` element, specifying whether column widths are fixed or can be - automatically adjusted based on content.""" + """`w:tblLayout` element. - type = OptionalAttribute("w:type", ST_TblLayoutType) + Specifies whether column widths are fixed or can be automatically adjusted based on + content. + """ + + type: str | None = OptionalAttribute( # pyright: ignore[reportAssignmentType] + "w:type", ST_TblLayoutType + ) class CT_TblPr(BaseOxmlElement): """```` element, child of ````, holds child elements that define table properties such as style and borders.""" + get_or_add_bidiVisual: Callable[[], CT_OnOff] + get_or_add_jc: Callable[[], CT_Jc] + get_or_add_tblLayout: Callable[[], CT_TblLayoutType] + _add_tblStyle: Callable[[], CT_String] + _remove_bidiVisual: Callable[[], None] + _remove_jc: Callable[[], None] + _remove_tblStyle: Callable[[], None] + _tag_seq = ( "w:tblStyle", "w:tblpPr", @@ -278,31 +326,35 @@ class CT_TblPr(BaseOxmlElement): "w:tblDescription", "w:tblPrChange", ) - tblStyle = ZeroOrOne("w:tblStyle", successors=_tag_seq[1:]) - bidiVisual = ZeroOrOne("w:bidiVisual", successors=_tag_seq[4:]) - jc = ZeroOrOne("w:jc", successors=_tag_seq[8:]) - tblLayout = ZeroOrOne("w:tblLayout", successors=_tag_seq[13:]) + tblStyle: CT_String | None = ZeroOrOne( # pyright: ignore[reportAssignmentType] + "w:tblStyle", successors=_tag_seq[1:] + ) + bidiVisual: CT_OnOff | None = ZeroOrOne( # pyright: ignore[reportAssignmentType] + "w:bidiVisual", successors=_tag_seq[4:] + ) + jc: CT_Jc | None = ZeroOrOne( # pyright: ignore[reportAssignmentType] + "w:jc", successors=_tag_seq[8:] + ) + tblLayout: CT_TblLayoutType | None = ZeroOrOne( # pyright: ignore[reportAssignmentType] + "w:tblLayout", successors=_tag_seq[13:] + ) del _tag_seq @property - def alignment(self): - """Member of :ref:`WdRowAlignment` enumeration or |None|, based on the contents - of the `w:val` attribute of `./w:jc`. - - |None| if no `w:jc` element is present. - """ + def alignment(self) -> WD_TABLE_ALIGNMENT | None: + """Horizontal alignment of table, |None| if `./w:jc` is not present.""" jc = self.jc if jc is None: return None - return jc.val + return cast("WD_TABLE_ALIGNMENT | None", jc.val) @alignment.setter - def alignment(self, value): + def alignment(self, value: WD_TABLE_ALIGNMENT | None): self._remove_jc() if value is None: return jc = self.get_or_add_jc() - jc.val = value + jc.val = cast("WD_ALIGN_PARAGRAPH", value) @property def autofit(self) -> bool: @@ -328,33 +380,40 @@ def style(self): return tblStyle.val @style.setter - def style(self, value): + def style(self, value: str | None): self._remove_tblStyle() if value is None: return - self._add_tblStyle(val=value) + self._add_tblStyle().val = value + + +class CT_TblPrEx(BaseOxmlElement): + """`w:tblPrEx` element, exceptions to table-properties. + + Applied at a lower level, like a `w:tr` to modify the appearance. Possibly used when + two tables are merged. For more see: + http://officeopenxml.com/WPtablePropertyExceptions.php + """ class CT_TblWidth(BaseOxmlElement): - """Used for ```` and ```` elements and many others, to specify a - table-related width.""" + """Used for `w:tblW` and `w:tcW` and others, specifies a table-related width.""" # the type for `w` attr is actually ST_MeasurementOrPercent, but using # XsdInt for now because only dxa (twips) values are being used. It's not # entirely clear what the semantics are for other values like -01.4mm - w = RequiredAttribute("w:w", XsdInt) + w: int = RequiredAttribute("w:w", XsdInt) # pyright: ignore[reportAssignmentType] type = RequiredAttribute("w:type", ST_TblWidth) @property - def width(self): - """Return the EMU length value represented by the combined ``w:w`` and - ``w:type`` attributes.""" + def width(self) -> Length | None: + """EMU length indicated by the combined `w:w` and `w:type` attrs.""" if self.type != "dxa": return None return Twips(self.w) @width.setter - def width(self, value): + def width(self, value: Length): self.type = "dxa" self.w = Emu(value).twips @@ -363,17 +422,19 @@ class CT_Tc(BaseOxmlElement): """`w:tc` table cell element.""" add_p: Callable[[], CT_P] - p_lst: List[CT_P] - tbl_lst: List[CT_Tbl] - + get_or_add_tcPr: Callable[[], CT_TcPr] + p_lst: list[CT_P] + tbl_lst: list[CT_Tbl] _insert_tbl: Callable[[CT_Tbl], CT_Tbl] + _new_p: Callable[[], CT_P] - tcPr = ZeroOrOne("w:tcPr") # bunches of successors, overriding insert + # -- tcPr has many successors, `._insert_tcPr()` is overridden below -- + tcPr: CT_TcPr | None = ZeroOrOne("w:tcPr") # pyright: ignore[reportAssignmentType] p = OneOrMore("w:p") tbl = OneOrMore("w:tbl") @property - def bottom(self): + def bottom(self) -> int: """The row index that marks the bottom extent of the vertical span of this cell. This is one greater than the index of the bottom-most row of the span, similar @@ -386,37 +447,44 @@ def bottom(self): return self._tr_idx + 1 def clear_content(self): - """Remove all content child elements, preserving the ```` element if - present. + """Remove all content elements, preserving `w:tcPr` element if present. - Note that this leaves the ```` element in an invalid state because it - doesn't contain at least one block-level element. It's up to the caller to add a - ````child element as the last content element. + Note that this leaves the `w:tc` element in an invalid state because it doesn't + contain at least one block-level element. It's up to the caller to add a + `w:p`child element as the last content element. """ - new_children = [] - tcPr = self.tcPr - if tcPr is not None: - new_children.append(tcPr) - self[:] = new_children + # -- remove all cell inner-content except a `w:tcPr` when present. -- + for e in self.xpath("./*[not(self::w:tcPr)]"): + self.remove(e) @property - def grid_span(self): + def grid_offset(self) -> int: + """Starting offset of `tc` in the layout-grid columns of its table. + + A cell in the leftmost grid-column has offset 0. + """ + grid_before = self._tr.grid_before + preceding_tc_grid_spans = sum( + tc.grid_span for tc in self.xpath("./preceding-sibling::w:tc") + ) + return grid_before + preceding_tc_grid_spans + + @property + def grid_span(self) -> int: """The integer number of columns this cell spans. Determined by ./w:tcPr/w:gridSpan/@val, it defaults to 1. """ tcPr = self.tcPr - if tcPr is None: - return 1 - return tcPr.grid_span + return 1 if tcPr is None else tcPr.grid_span @grid_span.setter - def grid_span(self, value): + def grid_span(self, value: int): tcPr = self.get_or_add_tcPr() tcPr.grid_span = value @property - def inner_content_elements(self) -> List[CT_P | CT_Tbl]: + def inner_content_elements(self) -> list[CT_P | CT_Tbl]: """Generate all `w:p` and `w:tbl` elements in this document-body. Elements appear in document order. Elements shaded by nesting in a `w:ins` or @@ -433,138 +501,141 @@ def iter_block_items(self): yield child @property - def left(self): + def left(self) -> int: """The grid column index at which this ```` element appears.""" - return self._grid_col + return self.grid_offset + + def merge(self, other_tc: CT_Tc) -> CT_Tc: + """Return top-left `w:tc` element of a new span. - def merge(self, other_tc): - """Return the top-left ```` element of a new span formed by merging the - rectangular region defined by using this tc element and `other_tc` as diagonal - corners.""" + Span is formed by merging the rectangular region defined by using this tc + element and `other_tc` as diagonal corners. + """ top, left, height, width = self._span_dimensions(other_tc) - top_tc = self._tbl.tr_lst[top].tc_at_grid_col(left) + top_tc = self._tbl.tr_lst[top].tc_at_grid_offset(left) top_tc._grow_to(width, height) return top_tc @classmethod - def new(cls): - """Return a new ```` element, containing an empty paragraph as the - required EG_BlockLevelElt.""" - return parse_xml("\n" " \n" "" % nsdecls("w")) + def new(cls) -> CT_Tc: + """A new `w:tc` element, containing an empty paragraph as the required EG_BlockLevelElt.""" + return cast(CT_Tc, parse_xml("\n" " \n" "" % nsdecls("w"))) @property - def right(self): + def right(self) -> int: """The grid column index that marks the right-side extent of the horizontal span of this cell. This is one greater than the index of the right-most column of the span, similar to how a slice of the cell's columns would be specified. """ - return self._grid_col + self.grid_span + return self.grid_offset + self.grid_span @property - def top(self): + def top(self) -> int: """The top-most row index in the vertical span of this cell.""" if self.vMerge is None or self.vMerge == ST_Merge.RESTART: return self._tr_idx return self._tc_above.top @property - def vMerge(self): - """The value of the ./w:tcPr/w:vMerge/@val attribute, or |None| if the w:vMerge - element is not present.""" + def vMerge(self) -> str | None: + """Value of ./w:tcPr/w:vMerge/@val, |None| if w:vMerge is not present.""" tcPr = self.tcPr if tcPr is None: return None return tcPr.vMerge_val @vMerge.setter - def vMerge(self, value): + def vMerge(self, value: str | None): tcPr = self.get_or_add_tcPr() tcPr.vMerge_val = value @property - def width(self): - """Return the EMU length value represented in the ``./w:tcPr/w:tcW`` child - element or |None| if not present.""" + def width(self) -> Length | None: + """EMU length represented in `./w:tcPr/w:tcW` or |None| if not present.""" tcPr = self.tcPr if tcPr is None: return None return tcPr.width @width.setter - def width(self, value): + def width(self, value: Length): tcPr = self.get_or_add_tcPr() tcPr.width = value - def _add_width_of(self, other_tc): + def _add_width_of(self, other_tc: CT_Tc): """Add the width of `other_tc` to this cell. Does nothing if either this tc or `other_tc` does not have a specified width. """ if self.width and other_tc.width: - self.width += other_tc.width + self.width = Length(self.width + other_tc.width) - @property - def _grid_col(self): - """The grid column at which this cell begins.""" - tr = self._tr - idx = tr.tc_lst.index(self) - preceding_tcs = tr.tc_lst[:idx] - return sum(tc.grid_span for tc in preceding_tcs) + def _grow_to(self, width: int, height: int, top_tc: CT_Tc | None = None): + """Grow this cell to `width` grid columns and `height` rows. - def _grow_to(self, width, height, top_tc=None): - """Grow this cell to `width` grid columns and `height` rows by expanding - horizontal spans and creating continuation cells to form vertical spans.""" + This is accomplished by expanding horizontal spans and creating continuation + cells to form vertical spans. + """ - def vMerge_val(top_tc): - if top_tc is not self: - return ST_Merge.CONTINUE - if height == 1: - return None - return ST_Merge.RESTART + def vMerge_val(top_tc: CT_Tc): + return ( + ST_Merge.CONTINUE + if top_tc is not self + else None if height == 1 else ST_Merge.RESTART + ) top_tc = self if top_tc is None else top_tc self._span_to_width(width, top_tc, vMerge_val(top_tc)) if height > 1: - self._tc_below._grow_to(width, height - 1, top_tc) + tc_below = self._tc_below + assert tc_below is not None + tc_below._grow_to(width, height - 1, top_tc) - def _insert_tcPr(self, tcPr): - """``tcPr`` has a bunch of successors, but it comes first if it appears, so just - overriding and using insert(0, ...) rather than spelling out successors.""" + def _insert_tcPr(self, tcPr: CT_TcPr) -> CT_TcPr: + """Override default `._insert_tcPr()`.""" + # -- `tcPr`` has a large number of successors, but always comes first if it appears, + # -- so just using insert(0, ...) rather than spelling out successors. self.insert(0, tcPr) return tcPr @property - def _is_empty(self): - """True if this cell contains only a single empty ```` element.""" + def _is_empty(self) -> bool: + """True if this cell contains only a single empty `w:p` element.""" block_items = list(self.iter_block_items()) if len(block_items) > 1: return False - p = block_items[0] # cell must include at least one element - if len(p.r_lst) == 0: + # -- cell must include at least one block item but can be a `w:tbl`, `w:sdt`, + # -- `w:customXml` or a `w:p` + only_item = block_items[0] + if isinstance(only_item, CT_P) and len(only_item.r_lst) == 0: return True return False - def _move_content_to(self, other_tc): - """Append the content of this cell to `other_tc`, leaving this cell with a - single empty ```` element.""" + def _move_content_to(self, other_tc: CT_Tc): + """Append the content of this cell to `other_tc`. + + Leaves this cell with a single empty ```` element. + """ if other_tc is self: return if self._is_empty: return other_tc._remove_trailing_empty_p() - # appending moves each element from self to other_tc + # -- appending moves each element from self to other_tc -- for block_element in self.iter_block_items(): other_tc.append(block_element) - # add back the required minimum single empty element + # -- add back the required minimum single empty element -- self.append(self._new_p()) - def _new_tbl(self): - return CT_Tbl.new() + def _new_tbl(self) -> None: + raise NotImplementedError( + "use CT_Tbl.new_tbl() to add a new table, specifying rows and columns" + ) @property - def _next_tc(self): + def _next_tc(self) -> CT_Tc | None: """The `w:tc` element immediately following this one in this row, or |None| if this is the last `w:tc` element in the row.""" following_tcs = self.xpath("./following-sibling::w:tc") @@ -572,32 +643,33 @@ def _next_tc(self): def _remove(self): """Remove this `w:tc` element from the XML tree.""" - self.getparent().remove(self) + parent_element = self.getparent() + assert parent_element is not None + parent_element.remove(self) def _remove_trailing_empty_p(self): - """Remove the last content element from this cell if it is an empty ```` - element.""" + """Remove last content element from this cell if it's an empty `w:p` element.""" block_items = list(self.iter_block_items()) last_content_elm = block_items[-1] - if last_content_elm.tag != qn("w:p"): + if not isinstance(last_content_elm, CT_P): return p = last_content_elm if len(p.r_lst) > 0: return self.remove(p) - def _span_dimensions(self, other_tc): + def _span_dimensions(self, other_tc: CT_Tc) -> tuple[int, int, int, int]: """Return a (top, left, height, width) 4-tuple specifying the extents of the merged cell formed by using this tc and `other_tc` as opposite corner extents.""" - def raise_on_inverted_L(a, b): + def raise_on_inverted_L(a: CT_Tc, b: CT_Tc): if a.top == b.top and a.bottom != b.bottom: raise InvalidSpanError("requested span not rectangular") if a.left == b.left and a.right != b.right: raise InvalidSpanError("requested span not rectangular") - def raise_on_tee_shaped(a, b): + def raise_on_tee_shaped(a: CT_Tc, b: CT_Tc): top_most, other = (a, b) if a.top < b.top else (b, a) if top_most.top < other.top and top_most.bottom > other.bottom: raise InvalidSpanError("requested span not rectangular") @@ -616,9 +688,10 @@ def raise_on_tee_shaped(a, b): return top, left, bottom - top, right - left - def _span_to_width(self, grid_width, top_tc, vMerge): - """Incorporate and then remove `w:tc` elements to the right of this one until - this cell spans `grid_width`. + def _span_to_width(self, grid_width: int, top_tc: CT_Tc, vMerge: str | None): + """Incorporate `w:tc` elements to the right until this cell spans `grid_width`. + + Incorporated `w:tc` elements are removed (replaced by gridSpan value). Raises |ValueError| if `grid_width` cannot be exactly achieved, such as when a merged cell would drive the span width greater than `grid_width` or if not @@ -632,7 +705,7 @@ def _span_to_width(self, grid_width, top_tc, vMerge): self._swallow_next_tc(grid_width, top_tc) self.vMerge = vMerge - def _swallow_next_tc(self, grid_width, top_tc): + def _swallow_next_tc(self, grid_width: int, top_tc: CT_Tc): """Extend the horizontal span of this `w:tc` element to incorporate the following `w:tc` element in the row and then delete that following `w:tc` element. @@ -643,7 +716,7 @@ def _swallow_next_tc(self, grid_width, top_tc): than `grid_width` or if there is no next `` element in the row. """ - def raise_on_invalid_swallow(next_tc): + def raise_on_invalid_swallow(next_tc: CT_Tc | None): if next_tc is None: raise InvalidSpanError("not enough grid columns") if self.grid_span + next_tc.grid_span > grid_width: @@ -651,48 +724,48 @@ def raise_on_invalid_swallow(next_tc): next_tc = self._next_tc raise_on_invalid_swallow(next_tc) + assert next_tc is not None next_tc._move_content_to(top_tc) self._add_width_of(next_tc) self.grid_span += next_tc.grid_span next_tc._remove() @property - def _tbl(self): + def _tbl(self) -> CT_Tbl: """The tbl element this tc element appears in.""" - return self.xpath("./ancestor::w:tbl[position()=1]")[0] + return cast(CT_Tbl, self.xpath("./ancestor::w:tbl[position()=1]")[0]) @property - def _tc_above(self): + def _tc_above(self) -> CT_Tc: """The `w:tc` element immediately above this one in its grid column.""" - return self._tr_above.tc_at_grid_col(self._grid_col) + return self._tr_above.tc_at_grid_offset(self.grid_offset) @property - def _tc_below(self): + def _tc_below(self) -> CT_Tc | None: """The tc element immediately below this one in its grid column.""" tr_below = self._tr_below if tr_below is None: return None - return tr_below.tc_at_grid_col(self._grid_col) + return tr_below.tc_at_grid_offset(self.grid_offset) @property - def _tr(self): + def _tr(self) -> CT_Row: """The tr element this tc element appears in.""" - return self.xpath("./ancestor::w:tr[position()=1]")[0] + return cast(CT_Row, self.xpath("./ancestor::w:tr[position()=1]")[0]) @property - def _tr_above(self): + def _tr_above(self) -> CT_Row: """The tr element prior in sequence to the tr this cell appears in. Raises |ValueError| if called on a cell in the top-most row. """ - tr_lst = self._tbl.tr_lst - tr_idx = tr_lst.index(self._tr) - if tr_idx == 0: - raise ValueError("no tr above topmost tr") - return tr_lst[tr_idx - 1] + tr_aboves = self.xpath("./ancestor::w:tr[position()=1]/preceding-sibling::w:tr[1]") + if not tr_aboves: + raise ValueError("no tr above topmost tr in w:tbl") + return tr_aboves[0] @property - def _tr_below(self): + def _tr_below(self) -> CT_Row | None: """The tr element next in sequence after the tr this cell appears in, or |None| if this cell appears in the last row.""" tr_lst = self._tbl.tr_lst @@ -703,7 +776,7 @@ def _tr_below(self): return None @property - def _tr_idx(self): + def _tr_idx(self) -> int: """The row index of the tr element this tc element appears in.""" return self._tbl.tr_lst.index(self._tr) @@ -711,6 +784,14 @@ def _tr_idx(self): class CT_TcPr(BaseOxmlElement): """```` element, defining table cell properties.""" + get_or_add_gridSpan: Callable[[], CT_DecimalNumber] + get_or_add_tcW: Callable[[], CT_TblWidth] + get_or_add_vAlign: Callable[[], CT_VerticalJc] + _add_vMerge: Callable[[], CT_VMerge] + _remove_gridSpan: Callable[[], None] + _remove_vAlign: Callable[[], None] + _remove_vMerge: Callable[[], None] + _tag_seq = ( "w:cnfStyle", "w:tcW", @@ -731,25 +812,31 @@ class CT_TcPr(BaseOxmlElement): "w:cellMerge", "w:tcPrChange", ) - tcW = ZeroOrOne("w:tcW", successors=_tag_seq[2:]) - gridSpan = ZeroOrOne("w:gridSpan", successors=_tag_seq[3:]) - vMerge = ZeroOrOne("w:vMerge", successors=_tag_seq[5:]) - vAlign = ZeroOrOne("w:vAlign", successors=_tag_seq[12:]) + tcW: CT_TblWidth | None = ZeroOrOne( # pyright: ignore[reportAssignmentType] + "w:tcW", successors=_tag_seq[2:] + ) + gridSpan: CT_DecimalNumber | None = ZeroOrOne( # pyright: ignore[reportAssignmentType] + "w:gridSpan", successors=_tag_seq[3:] + ) + vMerge: CT_VMerge | None = ZeroOrOne( # pyright: ignore[reportAssignmentType] + "w:vMerge", successors=_tag_seq[5:] + ) + vAlign: CT_VerticalJc | None = ZeroOrOne( # pyright: ignore[reportAssignmentType] + "w:vAlign", successors=_tag_seq[12:] + ) del _tag_seq @property - def grid_span(self): + def grid_span(self) -> int: """The integer number of columns this cell spans. Determined by ./w:gridSpan/@val, it defaults to 1. """ gridSpan = self.gridSpan - if gridSpan is None: - return 1 - return gridSpan.val + return 1 if gridSpan is None else gridSpan.val @grid_span.setter - def grid_span(self, value): + def grid_span(self, value: int): self._remove_gridSpan() if value > 1: self.get_or_add_gridSpan().val = value @@ -767,7 +854,7 @@ def vAlign_val(self): return vAlign.val @vAlign_val.setter - def vAlign_val(self, value): + def vAlign_val(self, value: WD_CELL_VERTICAL_ALIGNMENT | None): if value is None: self._remove_vAlign() return @@ -783,22 +870,21 @@ def vMerge_val(self): return vMerge.val @vMerge_val.setter - def vMerge_val(self, value): + def vMerge_val(self, value: str | None): self._remove_vMerge() if value is not None: self._add_vMerge().val = value @property - def width(self): - """Return the EMU length value represented in the ```` child element or - |None| if not present or its type is not 'dxa'.""" + def width(self) -> Length | None: + """EMU length in `./w:tcW` or |None| if not present or its type is not 'dxa'.""" tcW = self.tcW if tcW is None: return None return tcW.width @width.setter - def width(self, value): + def width(self, value: Length): tcW = self.get_or_add_tcW() tcW.width = value @@ -806,6 +892,8 @@ def width(self, value): class CT_TrPr(BaseOxmlElement): """```` element, defining table row properties.""" + get_or_add_trHeight: Callable[[], CT_Height] + _tag_seq = ( "w:cnfStyle", "w:divId", @@ -823,19 +911,37 @@ class CT_TrPr(BaseOxmlElement): "w:del", "w:trPrChange", ) - trHeight = ZeroOrOne("w:trHeight", successors=_tag_seq[8:]) + gridAfter: CT_DecimalNumber | None = ZeroOrOne( # pyright: ignore[reportAssignmentType] + "w:gridAfter", successors=_tag_seq[4:] + ) + gridBefore: CT_DecimalNumber | None = ZeroOrOne( # pyright: ignore[reportAssignmentType] + "w:gridBefore", successors=_tag_seq[3:] + ) + trHeight: CT_Height | None = ZeroOrOne( # pyright: ignore[reportAssignmentType] + "w:trHeight", successors=_tag_seq[8:] + ) del _tag_seq @property - def trHeight_hRule(self): + def grid_after(self) -> int: + """The number of unpopulated layout-grid cells at the end of this row.""" + gridAfter = self.gridAfter + return 0 if gridAfter is None else gridAfter.val + + @property + def grid_before(self) -> int: + """The number of unpopulated layout-grid cells at the start of this row.""" + gridBefore = self.gridBefore + return 0 if gridBefore is None else gridBefore.val + + @property + def trHeight_hRule(self) -> WD_ROW_HEIGHT_RULE | None: """Return the value of `w:trHeight@w:hRule`, or |None| if not present.""" trHeight = self.trHeight - if trHeight is None: - return None - return trHeight.hRule + return None if trHeight is None else trHeight.hRule @trHeight_hRule.setter - def trHeight_hRule(self, value): + def trHeight_hRule(self, value: WD_ROW_HEIGHT_RULE | None): if value is None and self.trHeight is None: return trHeight = self.get_or_add_trHeight() @@ -845,12 +951,10 @@ def trHeight_hRule(self, value): def trHeight_val(self): """Return the value of `w:trHeight@w:val`, or |None| if not present.""" trHeight = self.trHeight - if trHeight is None: - return None - return trHeight.val + return None if trHeight is None else trHeight.val @trHeight_val.setter - def trHeight_val(self, value): + def trHeight_val(self, value: Length | None): if value is None and self.trHeight is None: return trHeight = self.get_or_add_trHeight() @@ -860,10 +964,14 @@ def trHeight_val(self, value): class CT_VerticalJc(BaseOxmlElement): """`w:vAlign` element, specifying vertical alignment of cell.""" - val = RequiredAttribute("w:val", WD_CELL_VERTICAL_ALIGNMENT) + val: WD_CELL_VERTICAL_ALIGNMENT = RequiredAttribute( # pyright: ignore[reportAssignmentType] + "w:val", WD_CELL_VERTICAL_ALIGNMENT + ) class CT_VMerge(BaseOxmlElement): """```` element, specifying vertical merging behavior of a cell.""" - val = OptionalAttribute("w:val", ST_Merge, default=ST_Merge.CONTINUE) + val: str | None = OptionalAttribute( # pyright: ignore[reportAssignmentType] + "w:val", ST_Merge, default=ST_Merge.CONTINUE + ) diff --git a/src/docx/oxml/text/font.py b/src/docx/oxml/text/font.py index 0e183cf65..140086aab 100644 --- a/src/docx/oxml/text/font.py +++ b/src/docx/oxml/text/font.py @@ -39,10 +39,10 @@ class CT_Fonts(BaseOxmlElement): Specifies typeface name for the various language types. """ - ascii: str | None = OptionalAttribute( # pyright: ignore[reportGeneralTypeIssues] + ascii: str | None = OptionalAttribute( # pyright: ignore[reportAssignmentType] "w:ascii", ST_String ) - hAnsi: str | None = OptionalAttribute( # pyright: ignore[reportGeneralTypeIssues] + hAnsi: str | None = OptionalAttribute( # pyright: ignore[reportAssignmentType] "w:hAnsi", ST_String ) @@ -148,18 +148,14 @@ class CT_RPr(BaseOxmlElement): sz: CT_HpsMeasure | None = ZeroOrOne( # pyright: ignore[reportGeneralTypeIssues] "w:sz", successors=_tag_seq[24:] ) - highlight: CT_Highlight | None = ( - ZeroOrOne( # pyright: ignore[reportGeneralTypeIssues] - "w:highlight", successors=_tag_seq[26:] - ) + highlight: CT_Highlight | None = ZeroOrOne( # pyright: ignore[reportGeneralTypeIssues] + "w:highlight", successors=_tag_seq[26:] ) u: CT_Underline | None = ZeroOrOne( # pyright: ignore[reportGeneralTypeIssues] "w:u", successors=_tag_seq[27:] ) - vertAlign: CT_VerticalAlignRun | None = ( - ZeroOrOne( # pyright: ignore[reportGeneralTypeIssues] - "w:vertAlign", successors=_tag_seq[32:] - ) + vertAlign: CT_VerticalAlignRun | None = ZeroOrOne( # pyright: ignore[reportGeneralTypeIssues] + "w:vertAlign", successors=_tag_seq[32:] ) rtl = ZeroOrOne("w:rtl", successors=_tag_seq[33:]) cs = ZeroOrOne("w:cs", successors=_tag_seq[34:]) @@ -268,10 +264,7 @@ def subscript(self, value: bool | None) -> None: elif bool(value) is True: self.get_or_add_vertAlign().val = ST_VerticalAlignRun.SUBSCRIPT # -- assert bool(value) is False -- - elif ( - self.vertAlign is not None - and self.vertAlign.val == ST_VerticalAlignRun.SUBSCRIPT - ): + elif self.vertAlign is not None and self.vertAlign.val == ST_VerticalAlignRun.SUBSCRIPT: self._remove_vertAlign() @property @@ -295,10 +288,7 @@ def superscript(self, value: bool | None): elif bool(value) is True: self.get_or_add_vertAlign().val = ST_VerticalAlignRun.SUPERSCRIPT # -- assert bool(value) is False -- - elif ( - self.vertAlign is not None - and self.vertAlign.val == ST_VerticalAlignRun.SUPERSCRIPT - ): + elif self.vertAlign is not None and self.vertAlign.val == ST_VerticalAlignRun.SUPERSCRIPT: self._remove_vertAlign() @property @@ -353,10 +343,8 @@ def _set_bool_val(self, name: str, value: bool | None): class CT_Underline(BaseOxmlElement): """`` element, specifying the underlining style for a run.""" - val: WD_UNDERLINE | None = ( - OptionalAttribute( # pyright: ignore[reportGeneralTypeIssues] - "w:val", WD_UNDERLINE - ) + val: WD_UNDERLINE | None = OptionalAttribute( # pyright: ignore[reportAssignmentType] + "w:val", WD_UNDERLINE ) diff --git a/src/docx/oxml/text/hyperlink.py b/src/docx/oxml/text/hyperlink.py index 77d409f6a..38a33ff15 100644 --- a/src/docx/oxml/text/hyperlink.py +++ b/src/docx/oxml/text/hyperlink.py @@ -21,13 +21,13 @@ class CT_Hyperlink(BaseOxmlElement): r_lst: List[CT_R] - rId: str | None = OptionalAttribute( # pyright: ignore[reportGeneralTypeIssues] - "r:id", XsdString - ) - anchor: str | None = OptionalAttribute( # pyright: ignore[reportGeneralTypeIssues] + rId: str | None = OptionalAttribute("r:id", XsdString) # pyright: ignore[reportAssignmentType] + anchor: str | None = OptionalAttribute( # pyright: ignore[reportAssignmentType] "w:anchor", ST_String ) - history = OptionalAttribute("w:history", ST_OnOff, default=True) + history: bool = OptionalAttribute( # pyright: ignore[reportAssignmentType] + "w:history", ST_OnOff, default=True + ) r = ZeroOrMore("w:r") @@ -36,8 +36,8 @@ def lastRenderedPageBreaks(self) -> List[CT_LastRenderedPageBreak]: """All `w:lastRenderedPageBreak` descendants of this hyperlink.""" return self.xpath("./w:r/w:lastRenderedPageBreak") - @property # pyright: ignore[reportIncompatibleVariableOverride] - def text(self) -> str: + @property + def text(self) -> str: # pyright: ignore[reportIncompatibleMethodOverride] """The textual content of this hyperlink. `CT_Hyperlink` stores the hyperlink-text as one or more `w:r` children. diff --git a/src/docx/oxml/text/paragraph.py b/src/docx/oxml/text/paragraph.py index f771dd74f..63e96f312 100644 --- a/src/docx/oxml/text/paragraph.py +++ b/src/docx/oxml/text/paragraph.py @@ -26,7 +26,7 @@ class CT_P(BaseOxmlElement): hyperlink_lst: List[CT_Hyperlink] r_lst: List[CT_R] - pPr: CT_PPr | None = ZeroOrOne("w:pPr") # pyright: ignore[reportGeneralTypeIssues] + pPr: CT_PPr | None = ZeroOrOne("w:pPr") # pyright: ignore[reportAssignmentType] hyperlink = ZeroOrMore("w:hyperlink") r = ZeroOrMore("w:r") @@ -92,8 +92,8 @@ def style(self, style: str | None): pPr = self.get_or_add_pPr() pPr.style = style - @property # pyright: ignore[reportIncompatibleVariableOverride] - def text(self): + @property + def text(self): # pyright: ignore[reportIncompatibleMethodOverride] """The textual content of this paragraph. Inner-content child elements like `w:r` and `w:hyperlink` are translated to diff --git a/src/docx/oxml/text/parfmt.py b/src/docx/oxml/text/parfmt.py index 49ea01003..de5609636 100644 --- a/src/docx/oxml/text/parfmt.py +++ b/src/docx/oxml/text/parfmt.py @@ -28,21 +28,32 @@ class CT_Ind(BaseOxmlElement): """```` element, specifying paragraph indentation.""" - left = OptionalAttribute("w:left", ST_SignedTwipsMeasure) - right = OptionalAttribute("w:right", ST_SignedTwipsMeasure) - firstLine = OptionalAttribute("w:firstLine", ST_TwipsMeasure) - hanging = OptionalAttribute("w:hanging", ST_TwipsMeasure) + left: Length | None = OptionalAttribute( # pyright: ignore[reportAssignmentType] + "w:left", ST_SignedTwipsMeasure + ) + right: Length | None = OptionalAttribute( # pyright: ignore[reportAssignmentType] + "w:right", ST_SignedTwipsMeasure + ) + firstLine: Length | None = OptionalAttribute( # pyright: ignore[reportAssignmentType] + "w:firstLine", ST_TwipsMeasure + ) + hanging: Length | None = OptionalAttribute( # pyright: ignore[reportAssignmentType] + "w:hanging", ST_TwipsMeasure + ) class CT_Jc(BaseOxmlElement): """```` element, specifying paragraph justification.""" - val = RequiredAttribute("w:val", WD_ALIGN_PARAGRAPH) + val: WD_ALIGN_PARAGRAPH = RequiredAttribute( # pyright: ignore[reportAssignmentType] + "w:val", WD_ALIGN_PARAGRAPH + ) class CT_PPr(BaseOxmlElement): """```` element, containing the properties for a paragraph.""" + get_or_add_ind: Callable[[], CT_Ind] get_or_add_pStyle: Callable[[], CT_String] _insert_sectPr: Callable[[CT_SectPr], None] _remove_pStyle: Callable[[], None] @@ -86,7 +97,7 @@ class CT_PPr(BaseOxmlElement): "w:sectPr", "w:pPrChange", ) - pStyle: CT_String | None = ZeroOrOne( # pyright: ignore[reportGeneralTypeIssues] + pStyle: CT_String | None = ZeroOrOne( # pyright: ignore[reportAssignmentType] "w:pStyle", successors=_tag_seq[1:] ) keepNext = ZeroOrOne("w:keepNext", successors=_tag_seq[2:]) @@ -96,13 +107,15 @@ class CT_PPr(BaseOxmlElement): numPr = ZeroOrOne("w:numPr", successors=_tag_seq[7:]) tabs = ZeroOrOne("w:tabs", successors=_tag_seq[11:]) spacing = ZeroOrOne("w:spacing", successors=_tag_seq[22:]) - ind = ZeroOrOne("w:ind", successors=_tag_seq[23:]) + ind: CT_Ind | None = ZeroOrOne( # pyright: ignore[reportAssignmentType] + "w:ind", successors=_tag_seq[23:] + ) jc = ZeroOrOne("w:jc", successors=_tag_seq[27:]) sectPr = ZeroOrOne("w:sectPr", successors=_tag_seq[35:]) del _tag_seq @property - def first_line_indent(self): + def first_line_indent(self) -> Length | None: """A |Length| value calculated from the values of `w:ind/@w:firstLine` and `w:ind/@w:hanging`. @@ -120,7 +133,7 @@ def first_line_indent(self): return firstLine @first_line_indent.setter - def first_line_indent(self, value): + def first_line_indent(self, value: Length | None): if self.ind is None and value is None: return ind = self.get_or_add_ind() @@ -133,7 +146,7 @@ def first_line_indent(self, value): ind.firstLine = value @property - def ind_left(self): + def ind_left(self) -> Length | None: """The value of `w:ind/@w:left` or |None| if not present.""" ind = self.ind if ind is None: @@ -141,14 +154,14 @@ def ind_left(self): return ind.left @ind_left.setter - def ind_left(self, value): + def ind_left(self, value: Length | None): if value is None and self.ind is None: return ind = self.get_or_add_ind() ind.left = value @property - def ind_right(self): + def ind_right(self) -> Length | None: """The value of `w:ind/@w:right` or |None| if not present.""" ind = self.ind if ind is None: @@ -156,7 +169,7 @@ def ind_right(self): return ind.right @ind_right.setter - def ind_right(self, value): + def ind_right(self, value: Length | None): if value is None and self.ind is None: return ind = self.get_or_add_ind() @@ -338,9 +351,15 @@ class CT_TabStop(BaseOxmlElement): only needs a __str__ method. """ - val = RequiredAttribute("w:val", WD_TAB_ALIGNMENT) - leader = OptionalAttribute("w:leader", WD_TAB_LEADER, default=WD_TAB_LEADER.SPACES) - pos = RequiredAttribute("w:pos", ST_SignedTwipsMeasure) + val: WD_TAB_ALIGNMENT = RequiredAttribute( # pyright: ignore[reportAssignmentType] + "w:val", WD_TAB_ALIGNMENT + ) + leader: WD_TAB_LEADER | None = OptionalAttribute( # pyright: ignore[reportAssignmentType] + "w:leader", WD_TAB_LEADER, default=WD_TAB_LEADER.SPACES + ) + pos: Length = RequiredAttribute( # pyright: ignore[reportAssignmentType] + "w:pos", ST_SignedTwipsMeasure + ) def __str__(self) -> str: """Text equivalent of a `w:tab` element appearing in a run. diff --git a/src/docx/oxml/text/run.py b/src/docx/oxml/text/run.py index f17d33845..88efae83c 100644 --- a/src/docx/oxml/text/run.py +++ b/src/docx/oxml/text/run.py @@ -29,7 +29,7 @@ class CT_R(BaseOxmlElement): _add_drawing: Callable[[], CT_Drawing] _add_t: Callable[..., CT_Text] - rPr: CT_RPr | None = ZeroOrOne("w:rPr") # pyright: ignore[reportGeneralTypeIssues] + rPr: CT_RPr | None = ZeroOrOne("w:rPr") # pyright: ignore[reportAssignmentType] br = ZeroOrMore("w:br") cr = ZeroOrMore("w:cr") drawing = ZeroOrMore("w:drawing") @@ -120,12 +120,11 @@ def text(self) -> str: equivalent. """ return "".join( - str(e) - for e in self.xpath("w:br | w:cr | w:noBreakHyphen | w:ptab | w:t | w:tab") + str(e) for e in self.xpath("w:br | w:cr | w:noBreakHyphen | w:ptab | w:t | w:tab") ) - @text.setter # pyright: ignore[reportIncompatibleVariableOverride] - def text(self, text: str): + @text.setter + def text(self, text: str): # pyright: ignore[reportIncompatibleMethodOverride] self.clear_content() _RunContentAppender.append_to_run_from_text(self, text) @@ -141,12 +140,10 @@ def _insert_rPr(self, rPr: CT_RPr) -> CT_RPr: class CT_Br(BaseOxmlElement): """`` element, indicating a line, page, or column break in a run.""" - type: str | None = OptionalAttribute( # pyright: ignore[reportGeneralTypeIssues] + type: str | None = OptionalAttribute( # pyright: ignore[reportAssignmentType] "w:type", ST_BrType, default="textWrapping" ) - clear: str | None = OptionalAttribute( # pyright: ignore[reportGeneralTypeIssues] - "w:clear", ST_BrClear - ) + clear: str | None = OptionalAttribute("w:clear", ST_BrClear) # pyright: ignore def __str__(self) -> str: """Text equivalent of this element. Actual value depends on break type. diff --git a/src/docx/oxml/xmlchemy.py b/src/docx/oxml/xmlchemy.py index d075f88f1..077bcd583 100644 --- a/src/docx/oxml/xmlchemy.py +++ b/src/docx/oxml/xmlchemy.py @@ -126,16 +126,12 @@ class BaseAttribute: Provides common methods. """ - def __init__( - self, attr_name: str, simple_type: Type[BaseXmlEnum] | Type[BaseSimpleType] - ): + def __init__(self, attr_name: str, simple_type: Type[BaseXmlEnum] | Type[BaseSimpleType]): super(BaseAttribute, self).__init__() self._attr_name = attr_name self._simple_type = simple_type - def populate_class_members( - self, element_cls: MetaOxmlElement, prop_name: str - ) -> None: + def populate_class_members(self, element_cls: MetaOxmlElement, prop_name: str) -> None: """Add the appropriate methods to `element_cls`.""" self._element_cls = element_cls self._prop_name = prop_name @@ -159,14 +155,12 @@ def _clark_name(self): return self._attr_name @property - def _getter(self) -> Callable[[BaseOxmlElement], Any | None]: - ... + def _getter(self) -> Callable[[BaseOxmlElement], Any | None]: ... @property def _setter( self, - ) -> Callable[[BaseOxmlElement, Any | None], None]: - ... + ) -> Callable[[BaseOxmlElement, Any | None], None]: ... class OptionalAttribute(BaseAttribute): @@ -181,7 +175,7 @@ def __init__( self, attr_name: str, simple_type: Type[BaseXmlEnum] | Type[BaseSimpleType], - default: BaseXmlEnum | BaseSimpleType | None = None, + default: BaseXmlEnum | BaseSimpleType | str | bool | None = None, ): super(OptionalAttribute, self).__init__(attr_name, simple_type) self._default = default @@ -259,8 +253,7 @@ def get_attr_value(obj: BaseOxmlElement) -> Any | None: attr_str_value = obj.get(self._clark_name) if attr_str_value is None: raise InvalidXmlError( - "required '%s' attribute not present on element %s" - % (self._attr_name, obj.tag) + "required '%s' attribute not present on element %s" % (self._attr_name, obj.tag) ) return self._simple_type.from_xml(attr_str_value) @@ -292,9 +285,7 @@ def __init__(self, nsptagname: str, successors: Tuple[str, ...] = ()): self._nsptagname = nsptagname self._successors = successors - def populate_class_members( - self, element_cls: MetaOxmlElement, prop_name: str - ) -> None: + def populate_class_members(self, element_cls: MetaOxmlElement, prop_name: str) -> None: """Baseline behavior for adding the appropriate methods to `element_cls`.""" self._element_cls = element_cls self._prop_name = prop_name @@ -508,9 +499,7 @@ class OneAndOnlyOne(_BaseChildElement): def __init__(self, nsptagname: str): super(OneAndOnlyOne, self).__init__(nsptagname, ()) - def populate_class_members( - self, element_cls: MetaOxmlElement, prop_name: str - ) -> None: + def populate_class_members(self, element_cls: MetaOxmlElement, prop_name: str) -> None: """Add the appropriate methods to `element_cls`.""" super(OneAndOnlyOne, self).populate_class_members(element_cls, prop_name) self._add_getter() @@ -528,9 +517,7 @@ def get_child_element(obj: BaseOxmlElement): ) return child - get_child_element.__doc__ = ( - "Required ``<%s>`` child element." % self._nsptagname - ) + get_child_element.__doc__ = "Required ``<%s>`` child element." % self._nsptagname return get_child_element @@ -538,9 +525,7 @@ class OneOrMore(_BaseChildElement): """Defines a repeating child element for MetaOxmlElement that must appear at least once.""" - def populate_class_members( - self, element_cls: MetaOxmlElement, prop_name: str - ) -> None: + def populate_class_members(self, element_cls: MetaOxmlElement, prop_name: str) -> None: """Add the appropriate methods to `element_cls`.""" super(OneOrMore, self).populate_class_members(element_cls, prop_name) self._add_list_getter() @@ -554,9 +539,7 @@ def populate_class_members( class ZeroOrMore(_BaseChildElement): """Defines an optional repeating child element for MetaOxmlElement.""" - def populate_class_members( - self, element_cls: MetaOxmlElement, prop_name: str - ) -> None: + def populate_class_members(self, element_cls: MetaOxmlElement, prop_name: str) -> None: """Add the appropriate methods to `element_cls`.""" super(ZeroOrMore, self).populate_class_members(element_cls, prop_name) self._add_list_getter() @@ -570,9 +553,7 @@ def populate_class_members( class ZeroOrOne(_BaseChildElement): """Defines an optional child element for MetaOxmlElement.""" - def populate_class_members( - self, element_cls: MetaOxmlElement, prop_name: str - ) -> None: + def populate_class_members(self, element_cls: MetaOxmlElement, prop_name: str) -> None: """Add the appropriate methods to `element_cls`.""" super(ZeroOrOne, self).populate_class_members(element_cls, prop_name) self._add_getter() @@ -604,9 +585,7 @@ def _add_remover(self): def _remove_child(obj: BaseOxmlElement): obj.remove_all(self._nsptagname) - _remove_child.__doc__ = ( - "Remove all ``<%s>`` child elements." - ) % self._nsptagname + _remove_child.__doc__ = ("Remove all ``<%s>`` child elements.") % self._nsptagname self._add_to_class(self._remove_method_name, _remove_child) @lazyproperty @@ -622,16 +601,12 @@ def __init__(self, choices: Sequence[Choice], successors: Tuple[str, ...] = ()): self._choices = choices self._successors = successors - def populate_class_members( - self, element_cls: MetaOxmlElement, prop_name: str - ) -> None: + def populate_class_members(self, element_cls: MetaOxmlElement, prop_name: str) -> None: """Add the appropriate methods to `element_cls`.""" super(ZeroOrOneChoice, self).populate_class_members(element_cls, prop_name) self._add_choice_getter() for choice in self._choices: - choice.populate_class_members( - element_cls, self._prop_name, self._successors - ) + choice.populate_class_members(element_cls, self._prop_name, self._successors) self._add_group_remover() def _add_choice_getter(self): @@ -649,9 +624,7 @@ def _remove_choice_group(obj: BaseOxmlElement): for tagname in self._member_nsptagnames: obj.remove_all(tagname) - _remove_choice_group.__doc__ = ( - "Remove the current choice group child element if present." - ) + _remove_choice_group.__doc__ = "Remove the current choice group child element if present." self._add_to_class(self._remove_choice_group_method_name, _remove_choice_group) @property @@ -680,9 +653,7 @@ def _remove_choice_group_method_name(self): # -- lxml typing isn't quite right here, just ignore this error on _Element -- -class BaseOxmlElement( # pyright: ignore[reportGeneralTypeIssues] - etree.ElementBase, metaclass=MetaOxmlElement -): +class BaseOxmlElement(etree.ElementBase, metaclass=MetaOxmlElement): """Effective base class for all custom element classes. Adds standardized behavior to all classes in one place. @@ -726,9 +697,7 @@ def xml(self) -> str: """ return serialize_for_reading(self) - def xpath( # pyright: ignore[reportIncompatibleMethodOverride] - self, xpath_str: str - ) -> Any: + def xpath(self, xpath_str: str) -> Any: # pyright: ignore[reportIncompatibleMethodOverride] """Override of `lxml` _Element.xpath() method. Provides standard Open XML namespace mapping (`nsmap`) in centralized location. diff --git a/src/docx/package.py b/src/docx/package.py index 12a166bf3..7ea47e6e1 100644 --- a/src/docx/package.py +++ b/src/docx/package.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import IO +from typing import IO, cast from docx.image.image import Image from docx.opc.constants import RELATIONSHIP_TYPE as RT @@ -44,16 +44,16 @@ def _gather_image_parts(self): continue if rel.target_part in self.image_parts: continue - self.image_parts.append(rel.target_part) + self.image_parts.append(cast("ImagePart", rel.target_part)) class ImageParts: """Collection of |ImagePart| objects corresponding to images in the package.""" def __init__(self): - self._image_parts = [] + self._image_parts: list[ImagePart] = [] - def __contains__(self, item): + def __contains__(self, item: object): return self._image_parts.__contains__(item) def __iter__(self): @@ -62,7 +62,7 @@ def __iter__(self): def __len__(self): return self._image_parts.__len__() - def append(self, item): + def append(self, item: ImagePart): self._image_parts.append(item) def get_or_add_image_part(self, image_descriptor: str | IO[bytes]) -> ImagePart: @@ -77,15 +77,14 @@ def get_or_add_image_part(self, image_descriptor: str | IO[bytes]) -> ImagePart: return matching_image_part return self._add_image_part(image) - def _add_image_part(self, image): - """Return an |ImagePart| instance newly created from image and appended to the - collection.""" + def _add_image_part(self, image: Image): + """Return |ImagePart| instance newly created from `image` and appended to the collection.""" partname = self._next_image_partname(image.ext) image_part = ImagePart.from_image(image, partname) self.append(image_part) return image_part - def _get_by_sha1(self, sha1): + def _get_by_sha1(self, sha1: str) -> ImagePart | None: """Return the image part in this collection having a SHA1 hash matching `sha1`, or |None| if not found.""" for image_part in self._image_parts: @@ -93,7 +92,7 @@ def _get_by_sha1(self, sha1): return image_part return None - def _next_image_partname(self, ext): + def _next_image_partname(self, ext: str) -> PackURI: """The next available image partname, starting from ``/word/media/image1.{ext}`` where unused numbers are reused. @@ -101,7 +100,7 @@ def _next_image_partname(self, ext): not include the leading period. """ - def image_partname(n): + def image_partname(n: int) -> PackURI: return PackURI("/word/media/image%d.%s" % (n, ext)) used_numbers = [image_part.partname.idx for image_part in self] diff --git a/src/docx/parts/document.py b/src/docx/parts/document.py index a157764b9..416bb1a27 100644 --- a/src/docx/parts/document.py +++ b/src/docx/parts/document.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, cast +from typing import IO, TYPE_CHECKING, cast from docx.document import Document from docx.enum.style import WD_STYLE_TYPE @@ -16,6 +16,8 @@ from docx.shared import lazyproperty if TYPE_CHECKING: + from docx.opc.coreprops import CoreProperties + from docx.settings import Settings from docx.styles.style import BaseStyle @@ -41,7 +43,7 @@ def add_header_part(self): return header_part, rId @property - def core_properties(self): + def core_properties(self) -> CoreProperties: """A |CoreProperties| object providing read/write access to the core properties of this document.""" return self.package.core_properties @@ -100,13 +102,13 @@ def numbering_part(self): self.relate_to(numbering_part, RT.NUMBERING) return numbering_part - def save(self, path_or_stream): + def save(self, path_or_stream: str | IO[bytes]): """Save this document to `path_or_stream`, which can be either a path to a filesystem location (a string) or a file-like object.""" self.package.save(path_or_stream) @property - def settings(self): + def settings(self) -> Settings: """A |Settings| object providing access to the settings in the settings part of this document.""" return self._settings_part.settings @@ -118,14 +120,14 @@ def styles(self): return self._styles_part.styles @property - def _settings_part(self): + def _settings_part(self) -> SettingsPart: """A |SettingsPart| object providing access to the document-level settings for this document. Creates a default settings part if one is not present. """ try: - return self.part_related_by(RT.SETTINGS) + return cast(SettingsPart, self.part_related_by(RT.SETTINGS)) except KeyError: settings_part = SettingsPart.default(self.package) self.relate_to(settings_part, RT.SETTINGS) diff --git a/src/docx/parts/hdrftr.py b/src/docx/parts/hdrftr.py index 46821d780..35113801c 100644 --- a/src/docx/parts/hdrftr.py +++ b/src/docx/parts/hdrftr.py @@ -1,17 +1,23 @@ """Header and footer part objects.""" +from __future__ import annotations + import os +from typing import TYPE_CHECKING from docx.opc.constants import CONTENT_TYPE as CT from docx.oxml.parser import parse_xml from docx.parts.story import StoryPart +if TYPE_CHECKING: + from docx.package import Package + class FooterPart(StoryPart): """Definition of a section footer.""" @classmethod - def new(cls, package): + def new(cls, package: Package): """Return newly created footer part.""" partname = package.next_partname("/word/footer%d.xml") content_type = CT.WML_FOOTER @@ -21,9 +27,7 @@ def new(cls, package): @classmethod def _default_footer_xml(cls): """Return bytes containing XML for a default footer part.""" - path = os.path.join( - os.path.split(__file__)[0], "..", "templates", "default-footer.xml" - ) + path = os.path.join(os.path.split(__file__)[0], "..", "templates", "default-footer.xml") with open(path, "rb") as f: xml_bytes = f.read() return xml_bytes @@ -33,7 +37,7 @@ class HeaderPart(StoryPart): """Definition of a section header.""" @classmethod - def new(cls, package): + def new(cls, package: Package): """Return newly created header part.""" partname = package.next_partname("/word/header%d.xml") content_type = CT.WML_HEADER @@ -43,9 +47,7 @@ def new(cls, package): @classmethod def _default_header_xml(cls): """Return bytes containing XML for a default header part.""" - path = os.path.join( - os.path.split(__file__)[0], "..", "templates", "default-header.xml" - ) + path = os.path.join(os.path.split(__file__)[0], "..", "templates", "default-header.xml") with open(path, "rb") as f: xml_bytes = f.read() return xml_bytes diff --git a/src/docx/parts/image.py b/src/docx/parts/image.py index e4580df74..5aec07077 100644 --- a/src/docx/parts/image.py +++ b/src/docx/parts/image.py @@ -3,11 +3,16 @@ from __future__ import annotations import hashlib +from typing import TYPE_CHECKING from docx.image.image import Image from docx.opc.part import Part from docx.shared import Emu, Inches +if TYPE_CHECKING: + from docx.opc.package import OpcPackage + from docx.opc.packuri import PackURI + class ImagePart(Part): """An image part. @@ -16,7 +21,7 @@ class ImagePart(Part): """ def __init__( - self, partname: str, content_type: str, blob: bytes, image: Image | None = None + self, partname: PackURI, content_type: str, blob: bytes, image: Image | None = None ): super(ImagePart, self).__init__(partname, content_type, blob) self._image = image @@ -36,7 +41,7 @@ def default_cy(self): vertical dots per inch (dpi).""" px_height = self.image.px_height horz_dpi = self.image.horz_dpi - height_in_emu = 914400 * px_height / horz_dpi + height_in_emu = int(round(914400 * px_height / horz_dpi)) return Emu(height_in_emu) @property @@ -52,7 +57,7 @@ def filename(self): return "image.%s" % self.partname.ext @classmethod - def from_image(cls, image, partname): + def from_image(cls, image: Image, partname: PackURI): """Return an |ImagePart| instance newly created from `image` and assigned `partname`.""" return ImagePart(partname, image.content_type, image.blob, image) @@ -64,7 +69,7 @@ def image(self) -> Image: return self._image @classmethod - def load(cls, partname, content_type, blob, package): + def load(cls, partname: PackURI, content_type: str, blob: bytes, package: OpcPackage): """Called by ``docx.opc.package.PartFactory`` to load an image part from a package being opened by ``Document(...)`` call.""" return cls(partname, content_type, blob) @@ -72,4 +77,4 @@ def load(cls, partname, content_type, blob, package): @property def sha1(self): """SHA1 hash digest of the blob of this image part.""" - return hashlib.sha1(self._blob).hexdigest() + return hashlib.sha1(self.blob).hexdigest() diff --git a/src/docx/parts/settings.py b/src/docx/parts/settings.py index d83c9d5ca..116facca2 100644 --- a/src/docx/parts/settings.py +++ b/src/docx/parts/settings.py @@ -1,6 +1,9 @@ """|SettingsPart| and closely related objects.""" +from __future__ import annotations + import os +from typing import TYPE_CHECKING, cast from docx.opc.constants import CONTENT_TYPE as CT from docx.opc.packuri import PackURI @@ -8,31 +11,41 @@ from docx.oxml.parser import parse_xml from docx.settings import Settings +if TYPE_CHECKING: + from docx.oxml.settings import CT_Settings + from docx.package import Package + class SettingsPart(XmlPart): """Document-level settings part of a WordprocessingML (WML) package.""" + def __init__( + self, partname: PackURI, content_type: str, element: CT_Settings, package: Package + ): + super().__init__(partname, content_type, element, package) + self._settings = element + @classmethod - def default(cls, package): + def default(cls, package: Package): """Return a newly created settings part, containing a default `w:settings` element tree.""" partname = PackURI("/word/settings.xml") content_type = CT.WML_SETTINGS - element = parse_xml(cls._default_settings_xml()) + element = cast("CT_Settings", parse_xml(cls._default_settings_xml())) return cls(partname, content_type, element, package) @property - def settings(self): - """A |Settings| proxy object for the `w:settings` element in this part, - containing the document-level settings for this document.""" - return Settings(self.element) + def settings(self) -> Settings: + """A |Settings| proxy object for the `w:settings` element in this part. + + Contains the document-level settings for this document. + """ + return Settings(self._settings) @classmethod def _default_settings_xml(cls): """Return a bytestream containing XML for a default settings part.""" - path = os.path.join( - os.path.split(__file__)[0], "..", "templates", "default-settings.xml" - ) + path = os.path.join(os.path.split(__file__)[0], "..", "templates", "default-settings.xml") with open(path, "rb") as f: xml_bytes = f.read() return xml_bytes diff --git a/src/docx/parts/story.py b/src/docx/parts/story.py index b5c8ac882..7482c91a8 100644 --- a/src/docx/parts/story.py +++ b/src/docx/parts/story.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import IO, TYPE_CHECKING, Tuple +from typing import IO, TYPE_CHECKING, Tuple, cast from docx.opc.constants import RELATIONSHIP_TYPE as RT from docx.opc.part import XmlPart @@ -60,8 +60,8 @@ def get_style_id( def new_pic_inline( self, image_descriptor: str | IO[bytes], - width: Length | None = None, - height: Length | None = None, + width: int | Length | None = None, + height: int | Length | None = None, ) -> CT_Inline: """Return a newly-created `w:inline` element. @@ -92,4 +92,4 @@ def _document_part(self) -> DocumentPart: """|DocumentPart| object for this package.""" package = self.package assert package is not None - return package.main_document_part + return cast("DocumentPart", package.main_document_part) diff --git a/src/docx/py.typed b/src/docx/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/src/docx/section.py b/src/docx/section.py index f72b60867..982a14370 100644 --- a/src/docx/section.py +++ b/src/docx/section.py @@ -160,11 +160,7 @@ def iter_inner_content(self) -> Iterator[Paragraph | Table]: Items appear in document order. """ for element in self._sectPr.iter_inner_content(): - yield ( - Paragraph(element, self) # pyright: ignore[reportGeneralTypeIssues] - if isinstance(element, CT_P) - else Table(element, self) - ) + yield (Paragraph(element, self) if isinstance(element, CT_P) else Table(element, self)) @property def left_margin(self) -> Length | None: @@ -269,12 +265,10 @@ def __init__(self, document_elm: CT_Document, document_part: DocumentPart): self._document_part = document_part @overload - def __getitem__(self, key: int) -> Section: - ... + def __getitem__(self, key: int) -> Section: ... @overload - def __getitem__(self, key: slice) -> List[Section]: - ... + def __getitem__(self, key: slice) -> List[Section]: ... def __getitem__(self, key: int | slice) -> Section | List[Section]: if isinstance(key, slice): diff --git a/src/docx/settings.py b/src/docx/settings.py index 78f816e87..0a5aa2f36 100644 --- a/src/docx/settings.py +++ b/src/docx/settings.py @@ -1,7 +1,16 @@ """Settings object, providing access to document-level settings.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, cast + from docx.shared import ElementProxy +if TYPE_CHECKING: + import docx.types as t + from docx.oxml.settings import CT_Settings + from docx.oxml.xmlchemy import BaseOxmlElement + class Settings(ElementProxy): """Provides access to document-level settings for a document. @@ -9,14 +18,18 @@ class Settings(ElementProxy): Accessed using the :attr:`.Document.settings` property. """ + def __init__(self, element: BaseOxmlElement, parent: t.ProvidesXmlPart | None = None): + super().__init__(element, parent) + self._settings = cast("CT_Settings", element) + @property - def odd_and_even_pages_header_footer(self): + def odd_and_even_pages_header_footer(self) -> bool: """True if this document has distinct odd and even page headers and footers. Read/write. """ - return self._element.evenAndOddHeaders_val + return self._settings.evenAndOddHeaders_val @odd_and_even_pages_header_footer.setter - def odd_and_even_pages_header_footer(self, value): - self._element.evenAndOddHeaders_val = value + def odd_and_even_pages_header_footer(self, value: bool): + self._settings.evenAndOddHeaders_val = value diff --git a/src/docx/shape.py b/src/docx/shape.py index b91ecbf64..cd35deb35 100644 --- a/src/docx/shape.py +++ b/src/docx/shape.py @@ -3,26 +3,36 @@ A shape is a visual object that appears on the drawing layer of a document. """ +from __future__ import annotations + +from typing import TYPE_CHECKING + from docx.enum.shape import WD_INLINE_SHAPE from docx.oxml.ns import nsmap from docx.shared import Parented +if TYPE_CHECKING: + from docx.oxml.document import CT_Body + from docx.oxml.shape import CT_Inline + from docx.parts.story import StoryPart + from docx.shared import Length + class InlineShapes(Parented): - """Sequence of |InlineShape| instances, supporting len(), iteration, and indexed - access.""" + """Sequence of |InlineShape| instances, supporting len(), iteration, and indexed access.""" - def __init__(self, body_elm, parent): + def __init__(self, body_elm: CT_Body, parent: StoryPart): super(InlineShapes, self).__init__(parent) self._body = body_elm - def __getitem__(self, idx): + def __getitem__(self, idx: int): """Provide indexed access, e.g. 'inline_shapes[idx]'.""" try: inline = self._inline_lst[idx] except IndexError: msg = "inline shape index [%d] out of range" % idx raise IndexError(msg) + return InlineShape(inline) def __iter__(self): @@ -42,12 +52,12 @@ class InlineShape: """Proxy for an ```` element, representing the container for an inline graphical object.""" - def __init__(self, inline): + def __init__(self, inline: CT_Inline): super(InlineShape, self).__init__() self._inline = inline @property - def height(self): + def height(self) -> Length: """Read/write. The display height of this inline shape as an |Emu| instance. @@ -55,7 +65,7 @@ def height(self): return self._inline.extent.cy @height.setter - def height(self, cy): + def height(self, cy: Length): self._inline.extent.cy = cy self._inline.graphic.graphicData.pic.spPr.cy = cy @@ -88,6 +98,6 @@ def width(self): return self._inline.extent.cx @width.setter - def width(self, cx): + def width(self, cx: Length): self._inline.extent.cx = cx self._inline.graphic.graphicData.pic.spPr.cx = cx diff --git a/src/docx/shared.py b/src/docx/shared.py index 7b696202f..491d42741 100644 --- a/src/docx/shared.py +++ b/src/docx/shared.py @@ -16,7 +16,7 @@ ) if TYPE_CHECKING: - from docx import types as t + import docx.types as t from docx.opc.part import XmlPart from docx.oxml.xmlchemy import BaseOxmlElement from docx.parts.story import StoryPart @@ -284,9 +284,7 @@ class ElementProxy: common type of class in python-docx other than custom element (oxml) classes. """ - def __init__( - self, element: BaseOxmlElement, parent: t.ProvidesXmlPart | None = None - ): + def __init__(self, element: BaseOxmlElement, parent: t.ProvidesXmlPart | None = None): self._element = element self._parent = parent diff --git a/src/docx/table.py b/src/docx/table.py index 31372284c..545c46884 100644 --- a/src/docx/table.py +++ b/src/docx/table.py @@ -2,27 +2,37 @@ from __future__ import annotations -from typing import TYPE_CHECKING, List, Tuple, overload +from typing import TYPE_CHECKING, Iterator, cast, overload + +from typing_extensions import TypeAlias from docx.blkcntnr import BlockItemContainer from docx.enum.style import WD_STYLE_TYPE +from docx.enum.table import WD_CELL_VERTICAL_ALIGNMENT from docx.oxml.simpletypes import ST_Merge -from docx.shared import Inches, Parented, lazyproperty +from docx.oxml.table import CT_TblGridCol +from docx.shared import Inches, Parented, StoryChild, lazyproperty if TYPE_CHECKING: - from docx import types as t - from docx.enum.table import WD_TABLE_ALIGNMENT, WD_TABLE_DIRECTION - from docx.oxml.table import CT_Tbl, CT_TblPr + import docx.types as t + from docx.enum.table import WD_ROW_HEIGHT_RULE, WD_TABLE_ALIGNMENT, WD_TABLE_DIRECTION + from docx.oxml.table import CT_Row, CT_Tbl, CT_TblPr, CT_Tc from docx.shared import Length - from docx.styles.style import _TableStyle # pyright: ignore[reportPrivateUsage] + from docx.styles.style import ( + ParagraphStyle, + _TableStyle, # pyright: ignore[reportPrivateUsage] + ) + +TableParent: TypeAlias = "Table | _Columns | _Rows" -class Table(Parented): +class Table(StoryChild): """Proxy class for a WordprocessingML ```` element.""" - def __init__(self, tbl: CT_Tbl, parent: t.StoryChild): + def __init__(self, tbl: CT_Tbl, parent: t.ProvidesStoryPart): super(Table, self).__init__(parent) - self._element = self._tbl = tbl + self._element = tbl + self._tbl = tbl def add_column(self, width: Length): """Return a |_Column| object of `width`, newly added rightmost to the table.""" @@ -40,7 +50,8 @@ def add_row(self): tr = tbl.add_tr() for gridCol in tbl.tblGrid.gridCol_lst: tc = tr.add_tc() - tc.width = gridCol.w + if gridCol.w is not None: + tc.width = gridCol.w return _Row(tr, self) @property @@ -79,7 +90,7 @@ def cell(self, row_idx: int, col_idx: int) -> _Cell: cell_idx = col_idx + (row_idx * self._column_count) return self._cells[cell_idx] - def column_cells(self, column_idx: int) -> List[_Cell]: + def column_cells(self, column_idx: int) -> list[_Cell]: """Sequence of cells in the column at `column_idx` in this table.""" cells = self._cells idxs = range(column_idx, len(cells), self._column_count) @@ -90,8 +101,11 @@ def columns(self): """|_Columns| instance representing the sequence of columns in this table.""" return _Columns(self._tbl, self) - def row_cells(self, row_idx: int) -> List[_Cell]: - """Sequence of cells in the row at `row_idx` in this table.""" + def row_cells(self, row_idx: int) -> list[_Cell]: + """DEPRECATED: Use `table.rows[row_idx].cells` instead. + + Sequence of cells in the row at `row_idx` in this table. + """ column_count = self._column_count start = row_idx * column_count end = start + column_count @@ -116,10 +130,10 @@ def style(self) -> _TableStyle | None: `Light Shading - Accent 1` becomes `Light Shading Accent 1`. """ style_id = self._tbl.tblStyle_val - return self.part.get_style(style_id, WD_STYLE_TYPE.TABLE) + return cast("_TableStyle | None", self.part.get_style(style_id, WD_STYLE_TYPE.TABLE)) @style.setter - def style(self, style_or_name: _TableStyle | None): + def style(self, style_or_name: _TableStyle | str | None): style_id = self.part.get_style_id(style_or_name, WD_STYLE_TYPE.TABLE) self._tbl.tblStyle_val = style_id @@ -140,21 +154,21 @@ def table_direction(self) -> WD_TABLE_DIRECTION | None: For example: `WD_TABLE_DIRECTION.LTR`. |None| indicates the value is inherited from the style hierarchy. """ - return self._element.bidiVisual_val + return cast("WD_TABLE_DIRECTION | None", self._tbl.bidiVisual_val) @table_direction.setter def table_direction(self, value: WD_TABLE_DIRECTION | None): self._element.bidiVisual_val = value @property - def _cells(self) -> List[_Cell]: + def _cells(self) -> list[_Cell]: """A sequence of |_Cell| objects, one for each cell of the layout grid. If the table contains a span, one or more |_Cell| object references are repeated. """ col_count = self._column_count - cells = [] + cells: list[_Cell] = [] for tc in self._tbl.iter_tcs(): for grid_span_idx in range(tc.grid_span): if tc.vMerge == ST_Merge.CONTINUE: @@ -178,11 +192,12 @@ def _tblPr(self) -> CT_TblPr: class _Cell(BlockItemContainer): """Table cell.""" - def __init__(self, tc, parent): - super(_Cell, self).__init__(tc, parent) + def __init__(self, tc: CT_Tc, parent: TableParent): + super(_Cell, self).__init__(tc, cast("t.ProvidesStoryPart", parent)) + self._parent = parent self._tc = self._element = tc - def add_paragraph(self, text="", style=None): + def add_paragraph(self, text: str = "", style: str | ParagraphStyle | None = None): """Return a paragraph newly added to the end of the content in this cell. If present, `text` is added to the paragraph in a single run. If specified, the @@ -195,9 +210,12 @@ def add_paragraph(self, text="", style=None): """ return super(_Cell, self).add_paragraph(text, style) - def add_table(self, rows, cols): - """Return a table newly added to this cell after any existing cell content, - having `rows` rows and `cols` columns. + def add_table( # pyright: ignore[reportIncompatibleMethodOverride] + self, rows: int, cols: int + ) -> Table: + """Return a table newly added to this cell after any existing cell content. + + The new table will have `rows` rows and `cols` columns. An empty paragraph is added after the table because Word requires a paragraph element as the last element in every cell. @@ -207,7 +225,16 @@ def add_table(self, rows, cols): self.add_paragraph() return table - def merge(self, other_cell): + @property + def grid_span(self) -> int: + """Number of layout-grid cells this cell spans horizontally. + + A "normal" cell has a grid-span of 1. A horizontally merged cell has a grid-span of 2 or + more. + """ + return self._tc.grid_span + + def merge(self, other_cell: _Cell): """Return a merged cell created by spanning the rectangular region having this cell and `other_cell` as diagonal corners. @@ -244,7 +271,7 @@ def text(self) -> str: return "\n".join(p.text for p in self.paragraphs) @text.setter - def text(self, text): + def text(self, text: str): """Write-only. Set entire contents of cell to the string `text`. Any existing content or @@ -270,7 +297,7 @@ def vertical_alignment(self): return tcPr.vAlign_val @vertical_alignment.setter - def vertical_alignment(self, value): + def vertical_alignment(self, value: WD_CELL_VERTICAL_ALIGNMENT | None): tcPr = self._element.get_or_add_tcPr() tcPr.vAlign_val = value @@ -280,34 +307,35 @@ def width(self): return self._tc.width @width.setter - def width(self, value): + def width(self, value: Length): self._tc.width = value class _Column(Parented): """Table column.""" - def __init__(self, gridCol, parent): + def __init__(self, gridCol: CT_TblGridCol, parent: TableParent): super(_Column, self).__init__(parent) + self._parent = parent self._gridCol = gridCol @property - def cells(self): + def cells(self) -> tuple[_Cell, ...]: """Sequence of |_Cell| instances corresponding to cells in this column.""" return tuple(self.table.column_cells(self._index)) @property - def table(self): + def table(self) -> Table: """Reference to the |Table| object this column belongs to.""" return self._parent.table @property - def width(self): + def width(self) -> Length | None: """The width of this column in EMU, or |None| if no explicit width is set.""" return self._gridCol.w @width.setter - def width(self, value): + def width(self, value: Length | None): self._gridCol.w = value @property @@ -322,11 +350,12 @@ class _Columns(Parented): Supports ``len()``, iteration and indexed access. """ - def __init__(self, tbl, parent): + def __init__(self, tbl: CT_Tbl, parent: TableParent): super(_Columns, self).__init__(parent) + self._parent = parent self._tbl = tbl - def __getitem__(self, idx): + def __getitem__(self, idx: int): """Provide indexed access, e.g. 'columns[0]'.""" try: gridCol = self._gridCol_lst[idx] @@ -343,7 +372,7 @@ def __len__(self): return len(self._gridCol_lst) @property - def table(self): + def table(self) -> Table: """Reference to the |Table| object this column collection belongs to.""" return self._parent.table @@ -358,42 +387,119 @@ def _gridCol_lst(self): class _Row(Parented): """Table row.""" - def __init__(self, tr, parent): + def __init__(self, tr: CT_Row, parent: TableParent): super(_Row, self).__init__(parent) + self._parent = parent self._tr = self._element = tr @property - def cells(self) -> Tuple[_Cell]: - """Sequence of |_Cell| instances corresponding to cells in this row.""" - return tuple(self.table.row_cells(self._index)) + def cells(self) -> tuple[_Cell, ...]: + """Sequence of |_Cell| instances corresponding to cells in this row. + + Note that Word allows table rows to start later than the first column and end before the + last column. + + - Only cells actually present are included in the return value. + - This implies the length of this cell sequence may differ between rows of the same table. + - If you are reading the cells from each row to form a rectangular "matrix" data structure + of the table cell values, you will need to account for empty leading and/or trailing + layout-grid positions using `.grid_cols_before` and `.grid_cols_after`. + + """ + + def iter_tc_cells(tc: CT_Tc) -> Iterator[_Cell]: + """Generate a cell object for each layout-grid cell in `tc`. + + In particular, a `` element with a horizontal "span" with generate the same cell + multiple times, one for each grid-cell being spanned. This approximates a row in a + "uniform" table, where each row has a cell for each column in the table. + """ + # -- a cell comprising the second or later row of a vertical span is indicated by + # -- tc.vMerge="continue" (the default value of the `w:vMerge` attribute, when it is + # -- present in the XML). The `w:tc` element at the same grid-offset in the prior row + # -- is guaranteed to be the same width (gridSpan). So we can delegate content + # -- discovery to that prior-row `w:tc` element (recursively) until we arrive at the + # -- "root" cell -- for the vertical span. + if tc.vMerge == "continue": + yield from iter_tc_cells(tc._tc_above) # pyright: ignore[reportPrivateUsage] + return + + # -- Otherwise, vMerge is either "restart" or None, meaning this `tc` holds the actual + # -- content of the cell (whether it is vertically merged or not). + cell = _Cell(tc, self.table) + for _ in range(tc.grid_span): + yield cell + + def _iter_row_cells() -> Iterator[_Cell]: + """Generate `_Cell` instance for each populated layout-grid cell in this row.""" + for tc in self._tr.tc_lst: + yield from iter_tc_cells(tc) + + return tuple(_iter_row_cells()) @property - def height(self): + def grid_cols_after(self) -> int: + """Count of unpopulated grid-columns after the last cell in this row. + + Word allows a row to "end early", meaning that one or more cells are not present at the + end of that row. + + Note these are not simply "empty" cells. The renderer reads this value and "skips" this + many columns after drawing the last cell. + + Note this also implies that not all rows are guaranteed to have the same number of cells, + e.g. `_Row.cells` could have length `n` for one row and `n - m` for the next row in the same + table. Visually this appears as a column (at the beginning or end, not in the middle) with + one or more cells missing. + """ + return self._tr.grid_after + + @property + def grid_cols_before(self) -> int: + """Count of unpopulated grid-columns before the first cell in this row. + + Word allows a row to "start late", meaning that one or more cells are not present at the + beginning of that row. + + Note these are not simply "empty" cells. The renderer reads this value and skips forward to + the table layout-grid position of the first cell in this row; the renderer "skips" this many + columns before drawing the first cell. + + Note this also implies that not all rows are guaranteed to have the same number of cells, + e.g. `_Row.cells` could have length `n` for one row and `n - m` for the next row in the same + table. + """ + return self._tr.grid_before + + @property + def height(self) -> Length | None: """Return a |Length| object representing the height of this cell, or |None| if no explicit height is set.""" return self._tr.trHeight_val @height.setter - def height(self, value): + def height(self, value: Length | None): self._tr.trHeight_val = value @property - def height_rule(self): - """Return the height rule of this cell as a member of the :ref:`WdRowHeightRule` - enumeration, or |None| if no explicit height_rule is set.""" + def height_rule(self) -> WD_ROW_HEIGHT_RULE | None: + """Return the height rule of this cell as a member of the :ref:`WdRowHeightRule`. + + This value is |None| if no explicit height_rule is set. + """ return self._tr.trHeight_hRule @height_rule.setter - def height_rule(self, value): + def height_rule(self, value: WD_ROW_HEIGHT_RULE | None): self._tr.trHeight_hRule = value @property - def table(self): + def table(self) -> Table: """Reference to the |Table| object this row belongs to.""" return self._parent.table @property - def _index(self): + def _index(self) -> int: """Index of this row in its table, starting from zero.""" return self._tr.tr_idx @@ -404,19 +510,18 @@ class _Rows(Parented): Supports ``len()``, iteration, indexed access, and slicing. """ - def __init__(self, tbl, parent): + def __init__(self, tbl: CT_Tbl, parent: TableParent): super(_Rows, self).__init__(parent) + self._parent = parent self._tbl = tbl @overload - def __getitem__(self, idx: int) -> _Row: - ... + def __getitem__(self, idx: int) -> _Row: ... @overload - def __getitem__(self, idx: slice) -> List[_Row]: - ... + def __getitem__(self, idx: slice) -> list[_Row]: ... - def __getitem__(self, idx: int | slice) -> _Row | List[_Row]: + def __getitem__(self, idx: int | slice) -> _Row | list[_Row]: """Provide indexed access, (e.g. `rows[0]` or `rows[1:3]`)""" return list(self)[idx] @@ -427,6 +532,6 @@ def __len__(self): return len(self._tbl.tr_lst) @property - def table(self): + def table(self) -> Table: """Reference to the |Table| object this row collection belongs to.""" return self._parent.table diff --git a/src/docx/text/hyperlink.py b/src/docx/text/hyperlink.py index 705a97ee4..a23df1c74 100644 --- a/src/docx/text/hyperlink.py +++ b/src/docx/text/hyperlink.py @@ -7,13 +7,15 @@ from __future__ import annotations -from typing import List +from typing import TYPE_CHECKING -from docx import types as t -from docx.oxml.text.hyperlink import CT_Hyperlink from docx.shared import Parented from docx.text.run import Run +if TYPE_CHECKING: + import docx.types as t + from docx.oxml.text.hyperlink import CT_Hyperlink + class Hyperlink(Parented): """Proxy object wrapping a `` element. @@ -78,7 +80,7 @@ def fragment(self) -> str: return self._hyperlink.anchor or "" @property - def runs(self) -> List[Run]: + def runs(self) -> list[Run]: """List of |Run| instances in this hyperlink. Together these define the visible text of the hyperlink. The text of a hyperlink diff --git a/src/docx/text/pagebreak.py b/src/docx/text/pagebreak.py index a5e68b5aa..0977ccea9 100644 --- a/src/docx/text/pagebreak.py +++ b/src/docx/text/pagebreak.py @@ -4,11 +4,11 @@ from typing import TYPE_CHECKING -from docx import types as t from docx.oxml.text.pagebreak import CT_LastRenderedPageBreak from docx.shared import Parented if TYPE_CHECKING: + import docx.types as t from docx.text.paragraph import Paragraph diff --git a/src/docx/text/paragraph.py b/src/docx/text/paragraph.py index 0a5d67674..234ea66cb 100644 --- a/src/docx/text/paragraph.py +++ b/src/docx/text/paragraph.py @@ -4,9 +4,6 @@ from typing import TYPE_CHECKING, Iterator, List, cast -from typing_extensions import Self - -from docx import types as t from docx.enum.style import WD_STYLE_TYPE from docx.oxml.text.run import CT_R from docx.shared import StoryChild @@ -17,6 +14,7 @@ from docx.text.run import Run if TYPE_CHECKING: + import docx.types as t from docx.enum.text import WD_PARAGRAPH_ALIGNMENT from docx.oxml.text.paragraph import CT_P from docx.styles.style import CharacterStyle @@ -29,9 +27,7 @@ def __init__(self, p: CT_P, parent: t.ProvidesStoryPart): super(Paragraph, self).__init__(parent) self._p = self._element = p - def add_run( - self, text: str | None = None, style: str | CharacterStyle | None = None - ) -> Run: + def add_run(self, text: str | None = None, style: str | CharacterStyle | None = None) -> Run: """Append run containing `text` and having character-style `style`. `text` can contain tab (``\\t``) characters, which are converted to the @@ -82,7 +78,7 @@ def hyperlinks(self) -> List[Hyperlink]: def insert_paragraph_before( self, text: str | None = None, style: str | ParagraphStyle | None = None - ) -> Self: + ) -> Paragraph: """Return a newly created paragraph, inserted directly before this paragraph. If `text` is supplied, the new paragraph contains that text in a single run. If @@ -123,9 +119,7 @@ def rendered_page_breaks(self) -> List[RenderedPageBreak]: Most often an empty list, sometimes contains one page-break, but can contain more than one is rare or contrived cases. """ - return [ - RenderedPageBreak(lrpb, self) for lrpb in self._p.lastRenderedPageBreaks - ] + return [RenderedPageBreak(lrpb, self) for lrpb in self._p.lastRenderedPageBreaks] @property def runs(self) -> List[Run]: diff --git a/src/docx/text/run.py b/src/docx/text/run.py index 44c41c0fe..0e2f5bc17 100644 --- a/src/docx/text/run.py +++ b/src/docx/text/run.py @@ -4,7 +4,6 @@ from typing import IO, TYPE_CHECKING, Iterator, cast -from docx import types as t from docx.drawing import Drawing from docx.enum.style import WD_STYLE_TYPE from docx.enum.text import WD_BREAK @@ -17,6 +16,7 @@ from docx.text.pagebreak import RenderedPageBreak if TYPE_CHECKING: + import docx.types as t from docx.enum.text import WD_UNDERLINE from docx.oxml.text.run import CT_R, CT_Text from docx.shared import Length @@ -59,8 +59,8 @@ def add_break(self, break_type: WD_BREAK = WD_BREAK.LINE): def add_picture( self, image_path_or_stream: str | IO[bytes], - width: Length | None = None, - height: Length | None = None, + width: int | Length | None = None, + height: int | Length | None = None, ) -> InlineShape: """Return |InlineShape| containing image identified by `image_path_or_stream`. @@ -170,9 +170,7 @@ def iter_inner_content(self) -> Iterator[str | Drawing | RenderedPageBreak]: yield item elif isinstance(item, CT_LastRenderedPageBreak): yield RenderedPageBreak(item, self) - elif isinstance( # pyright: ignore[reportUnnecessaryIsInstance] - item, CT_Drawing - ): + elif isinstance(item, CT_Drawing): # pyright: ignore[reportUnnecessaryIsInstance] yield Drawing(item, self) @property @@ -185,9 +183,7 @@ def style(self) -> CharacterStyle: property to |None| removes any directly-applied character style. """ style_id = self._r.style - return cast( - CharacterStyle, self.part.get_style(style_id, WD_STYLE_TYPE.CHARACTER) - ) + return cast(CharacterStyle, self.part.get_style(style_id, WD_STYLE_TYPE.CHARACTER)) @style.setter def style(self, style_or_name: str | CharacterStyle | None): diff --git a/tests/opc/parts/test_coreprops.py b/tests/opc/parts/test_coreprops.py index 1db650353..b754d2d7e 100644 --- a/tests/opc/parts/test_coreprops.py +++ b/tests/opc/parts/test_coreprops.py @@ -1,47 +1,53 @@ """Unit test suite for the docx.opc.parts.coreprops module.""" -from datetime import datetime, timedelta +from __future__ import annotations + +import datetime as dt import pytest from docx.opc.coreprops import CoreProperties +from docx.opc.package import OpcPackage +from docx.opc.packuri import PackURI from docx.opc.parts.coreprops import CorePropertiesPart -from docx.oxml.coreprops import CT_CoreProperties -from ...unitutil.mock import class_mock, instance_mock +from ...unitutil.cxml import element +from ...unitutil.mock import FixtureRequest, Mock, class_mock, instance_mock class DescribeCorePropertiesPart: - def it_provides_access_to_its_core_props_object(self, coreprops_fixture): - core_properties_part, CoreProperties_ = coreprops_fixture + """Unit-test suite for `docx.opc.parts.coreprops.CorePropertiesPart` objects.""" + + def it_provides_access_to_its_core_props_object(self, CoreProperties_: Mock, package_: Mock): + core_properties_part = CorePropertiesPart( + PackURI("/part/name"), "content/type", element("cp:coreProperties"), package_ + ) + core_properties = core_properties_part.core_properties + CoreProperties_.assert_called_once_with(core_properties_part.element) assert isinstance(core_properties, CoreProperties) - def it_can_create_a_default_core_properties_part(self): - core_properties_part = CorePropertiesPart.default(None) + def it_can_create_a_default_core_properties_part(self, package_: Mock): + core_properties_part = CorePropertiesPart.default(package_) + assert isinstance(core_properties_part, CorePropertiesPart) + # -- core_properties = core_properties_part.core_properties assert core_properties.title == "Word Document" assert core_properties.last_modified_by == "python-docx" assert core_properties.revision == 1 - delta = datetime.utcnow() - core_properties.modified - max_expected_delta = timedelta(seconds=2) + assert core_properties.modified is not None + delta = dt.datetime.now(dt.timezone.utc) - core_properties.modified + max_expected_delta = dt.timedelta(seconds=2) assert delta < max_expected_delta # fixtures --------------------------------------------- @pytest.fixture - def coreprops_fixture(self, element_, CoreProperties_): - core_properties_part = CorePropertiesPart(None, None, element_, None) - return core_properties_part, CoreProperties_ - - # fixture components ----------------------------------- - - @pytest.fixture - def CoreProperties_(self, request): + def CoreProperties_(self, request: FixtureRequest): return class_mock(request, "docx.opc.parts.coreprops.CoreProperties") @pytest.fixture - def element_(self, request): - return instance_mock(request, CT_CoreProperties) + def package_(self, request: FixtureRequest): + return instance_mock(request, OpcPackage) diff --git a/tests/opc/test_coreprops.py b/tests/opc/test_coreprops.py index 2978ad5ae..5d9743397 100644 --- a/tests/opc/test_coreprops.py +++ b/tests/opc/test_coreprops.py @@ -1,160 +1,153 @@ +# pyright: reportPrivateUsage=false + """Unit test suite for the docx.opc.coreprops module.""" -from datetime import datetime +from __future__ import annotations + +import datetime as dt +from typing import TYPE_CHECKING, cast import pytest from docx.opc.coreprops import CoreProperties from docx.oxml.parser import parse_xml +if TYPE_CHECKING: + from docx.oxml.coreprops import CT_CoreProperties + class DescribeCoreProperties: - def it_knows_the_string_property_values(self, text_prop_get_fixture): - core_properties, prop_name, expected_value = text_prop_get_fixture + """Unit-test suite for `docx.opc.coreprops.CoreProperties` objects.""" + + @pytest.mark.parametrize( + ("prop_name", "expected_value"), + [ + ("author", "python-docx"), + ("category", ""), + ("comments", ""), + ("content_status", "DRAFT"), + ("identifier", "GXS 10.2.1ab"), + ("keywords", "foo bar baz"), + ("language", "US-EN"), + ("last_modified_by", "Steve Canny"), + ("subject", "Spam"), + ("title", "Word Document"), + ("version", "1.2.88"), + ], + ) + def it_knows_the_string_property_values( + self, prop_name: str, expected_value: str, core_properties: CoreProperties + ): actual_value = getattr(core_properties, prop_name) assert actual_value == expected_value - def it_can_change_the_string_property_values(self, text_prop_set_fixture): - core_properties, prop_name, value, expected_xml = text_prop_set_fixture - setattr(core_properties, prop_name, value) - assert core_properties._element.xml == expected_xml - - def it_knows_the_date_property_values(self, date_prop_get_fixture): - core_properties, prop_name, expected_datetime = date_prop_get_fixture - actual_datetime = getattr(core_properties, prop_name) - assert actual_datetime == expected_datetime + @pytest.mark.parametrize( + ("prop_name", "tagname", "value"), + [ + ("author", "dc:creator", "scanny"), + ("category", "cp:category", "silly stories"), + ("comments", "dc:description", "Bar foo to you"), + ("content_status", "cp:contentStatus", "FINAL"), + ("identifier", "dc:identifier", "GT 5.2.xab"), + ("keywords", "cp:keywords", "dog cat moo"), + ("language", "dc:language", "GB-EN"), + ("last_modified_by", "cp:lastModifiedBy", "Billy Bob"), + ("subject", "dc:subject", "Eggs"), + ("title", "dc:title", "Dissertation"), + ("version", "cp:version", "81.2.8"), + ], + ) + def it_can_change_the_string_property_values(self, prop_name: str, tagname: str, value: str): + coreProperties = self.coreProperties(tagname="", str_val="") + core_properties = CoreProperties(cast("CT_CoreProperties", parse_xml(coreProperties))) - def it_can_change_the_date_property_values(self, date_prop_set_fixture): - core_properties, prop_name, value, expected_xml = date_prop_set_fixture setattr(core_properties, prop_name, value) - assert core_properties._element.xml == expected_xml - - def it_knows_the_revision_number(self, revision_get_fixture): - core_properties, expected_revision = revision_get_fixture - assert core_properties.revision == expected_revision - - def it_can_change_the_revision_number(self, revision_set_fixture): - core_properties, revision, expected_xml = revision_set_fixture - core_properties.revision = revision - assert core_properties._element.xml == expected_xml - # fixtures ------------------------------------------------------- + assert core_properties._element.xml == self.coreProperties(tagname, value) - @pytest.fixture( - params=[ - ("created", datetime(2012, 11, 17, 16, 37, 40)), - ("last_printed", datetime(2014, 6, 4, 4, 28)), + @pytest.mark.parametrize( + ("prop_name", "expected_datetime"), + [ + ("created", dt.datetime(2012, 11, 17, 16, 37, 40, tzinfo=dt.timezone.utc)), + ("last_printed", dt.datetime(2014, 6, 4, 4, 28, tzinfo=dt.timezone.utc)), ("modified", None), - ] + ], ) - def date_prop_get_fixture(self, request, core_properties): - prop_name, expected_datetime = request.param - return core_properties, prop_name, expected_datetime + def it_knows_the_date_property_values( + self, prop_name: str, expected_datetime: dt.datetime, core_properties: CoreProperties + ): + actual_datetime = getattr(core_properties, prop_name) + assert actual_datetime == expected_datetime - @pytest.fixture( - params=[ + @pytest.mark.parametrize( + ("prop_name", "tagname", "value", "str_val", "attrs"), + [ ( "created", "dcterms:created", - datetime(2001, 2, 3, 4, 5), + dt.datetime(2001, 2, 3, 4, 5), "2001-02-03T04:05:00Z", ' xsi:type="dcterms:W3CDTF"', ), ( "last_printed", "cp:lastPrinted", - datetime(2014, 6, 4, 4), + dt.datetime(2014, 6, 4, 4), "2014-06-04T04:00:00Z", "", ), ( "modified", "dcterms:modified", - datetime(2005, 4, 3, 2, 1), + dt.datetime(2005, 4, 3, 2, 1), "2005-04-03T02:01:00Z", ' xsi:type="dcterms:W3CDTF"', ), - ] + ], ) - def date_prop_set_fixture(self, request): - prop_name, tagname, value, str_val, attrs = request.param - coreProperties = self.coreProperties(None, None) - core_properties = CoreProperties(parse_xml(coreProperties)) + def it_can_change_the_date_property_values( + self, prop_name: str, tagname: str, value: dt.datetime, str_val: str, attrs: str + ): + coreProperties = self.coreProperties(tagname="", str_val="") + core_properties = CoreProperties(cast("CT_CoreProperties", parse_xml(coreProperties))) expected_xml = self.coreProperties(tagname, str_val, attrs) - return core_properties, prop_name, value, expected_xml - @pytest.fixture( - params=[("42", 42), (None, 0), ("foobar", 0), ("-17", 0), ("32.7", 0)] - ) - def revision_get_fixture(self, request): - str_val, expected_revision = request.param - tagname = "" if str_val is None else "cp:revision" - coreProperties = self.coreProperties(tagname, str_val) - core_properties = CoreProperties(parse_xml(coreProperties)) - return core_properties, expected_revision - - @pytest.fixture( - params=[ - (42, "42"), - ] + setattr(core_properties, prop_name, value) + + assert core_properties._element.xml == expected_xml + + @pytest.mark.parametrize( + ("str_val", "expected_value"), + [("42", 42), (None, 0), ("foobar", 0), ("-17", 0), ("32.7", 0)], ) - def revision_set_fixture(self, request): - value, str_val = request.param - coreProperties = self.coreProperties(None, None) - core_properties = CoreProperties(parse_xml(coreProperties)) + def it_knows_the_revision_number(self, str_val: str | None, expected_value: int): + tagname, str_val = ("cp:revision", str_val) if str_val else ("", "") + coreProperties = self.coreProperties(tagname, str_val or "") + core_properties = CoreProperties(cast("CT_CoreProperties", parse_xml(coreProperties))) + + assert core_properties.revision == expected_value + + @pytest.mark.parametrize(("value", "str_val"), [(42, "42")]) + def it_can_change_the_revision_number(self, value: int, str_val: str): + coreProperties = self.coreProperties(tagname="", str_val="") + core_properties = CoreProperties(cast("CT_CoreProperties", parse_xml(coreProperties))) expected_xml = self.coreProperties("cp:revision", str_val) - return core_properties, value, expected_xml - @pytest.fixture( - params=[ - ("author", "python-docx"), - ("category", ""), - ("comments", ""), - ("content_status", "DRAFT"), - ("identifier", "GXS 10.2.1ab"), - ("keywords", "foo bar baz"), - ("language", "US-EN"), - ("last_modified_by", "Steve Canny"), - ("subject", "Spam"), - ("title", "Word Document"), - ("version", "1.2.88"), - ] - ) - def text_prop_get_fixture(self, request, core_properties): - prop_name, expected_value = request.param - return core_properties, prop_name, expected_value + core_properties.revision = value - @pytest.fixture( - params=[ - ("author", "dc:creator", "scanny"), - ("category", "cp:category", "silly stories"), - ("comments", "dc:description", "Bar foo to you"), - ("content_status", "cp:contentStatus", "FINAL"), - ("identifier", "dc:identifier", "GT 5.2.xab"), - ("keywords", "cp:keywords", "dog cat moo"), - ("language", "dc:language", "GB-EN"), - ("last_modified_by", "cp:lastModifiedBy", "Billy Bob"), - ("subject", "dc:subject", "Eggs"), - ("title", "dc:title", "Dissertation"), - ("version", "cp:version", "81.2.8"), - ] - ) - def text_prop_set_fixture(self, request): - prop_name, tagname, value = request.param - coreProperties = self.coreProperties(None, None) - core_properties = CoreProperties(parse_xml(coreProperties)) - expected_xml = self.coreProperties(tagname, value) - return core_properties, prop_name, value, expected_xml + assert core_properties._element.xml == expected_xml - # fixture components --------------------------------------------- + # fixtures ------------------------------------------------------- - def coreProperties(self, tagname, str_val, attrs=""): + def coreProperties(self, tagname: str, str_val: str, attrs: str = "") -> str: tmpl = ( - '%s\n' + "%s\n" ) if not tagname: child_element = "" @@ -166,27 +159,30 @@ def coreProperties(self, tagname, str_val, attrs=""): @pytest.fixture def core_properties(self): - element = parse_xml( - b"" - b'\n\n' - b" DRAFT\n" - b" python-docx\n" - b' 2012-11-17T11:07:' - b"40-05:30\n" - b" \n" - b" GXS 10.2.1ab\n" - b" US-EN\n" - b" 2014-06-04T04:28:00Z\n" - b" foo bar baz\n" - b" Steve Canny\n" - b" 4\n" - b" Spam\n" - b" Word Document\n" - b" 1.2.88\n" - b"\n" + element = cast( + "CT_CoreProperties", + parse_xml( + b"" + b'\n\n' + b" DRAFT\n" + b" python-docx\n" + b' 2012-11-17T11:07:' + b"40-05:30\n" + b" \n" + b" GXS 10.2.1ab\n" + b" US-EN\n" + b" 2014-06-04T04:28:00Z\n" + b" foo bar baz\n" + b" Steve Canny\n" + b" 4\n" + b" Spam\n" + b" Word Document\n" + b" 1.2.88\n" + b"\n" + ), ) return CoreProperties(element) diff --git a/tests/opc/test_package.py b/tests/opc/test_package.py index 7fdeaa422..d8fcef453 100644 --- a/tests/opc/test_package.py +++ b/tests/opc/test_package.py @@ -1,5 +1,9 @@ +# pyright: reportPrivateUsage=false + """Unit test suite for docx.opc.package module""" +from __future__ import annotations + import pytest from docx.opc.constants import RELATIONSHIP_TYPE as RT @@ -12,8 +16,8 @@ from docx.opc.rel import Relationships, _Relationship from ..unitutil.mock import ( + FixtureRequest, Mock, - PropertyMock, call, class_mock, instance_mock, @@ -25,6 +29,8 @@ class DescribeOpcPackage: + """Unit-test suite for `docx.opc.package.OpcPackage` objects.""" + def it_can_open_a_pkg_file(self, PackageReader_, PartFactory_, Unmarshaller_): # mockery ---------------------- pkg_file = Mock(name="pkg_file") @@ -42,19 +48,26 @@ def it_initializes_its_rels_collection_on_first_reference(self, Relationships_): Relationships_.assert_called_once_with(PACKAGE_URI.baseURI) assert rels == Relationships_.return_value - def it_can_add_a_relationship_to_a_part(self, pkg_with_rels_, rel_attrs_): - reltype, target, rId = rel_attrs_ - pkg = pkg_with_rels_ - # exercise --------------------- - pkg.load_rel(reltype, target, rId) - # verify ----------------------- - pkg._rels.add_relationship.assert_called_once_with(reltype, target, rId, False) + def it_can_add_a_relationship_to_a_part(self, rels_prop_: Mock, rels_: Mock, part_: Mock): + rels_prop_.return_value = rels_ + pkg = OpcPackage() + + pkg.load_rel("http://rel/type", part_, "rId99") - def it_can_establish_a_relationship_to_another_part(self, relate_to_part_fixture_): - pkg, part_, reltype, rId = relate_to_part_fixture_ - _rId = pkg.relate_to(part_, reltype) - pkg.rels.get_or_add.assert_called_once_with(reltype, part_) - assert _rId == rId + rels_.add_relationship.assert_called_once_with("http://rel/type", part_, "rId99", False) + + def it_can_establish_a_relationship_to_another_part( + self, rels_prop_: Mock, rels_: Mock, rel_: Mock, part_: Mock + ): + rel_.rId = "rId99" + rels_.get_or_add.return_value = rel_ + rels_prop_.return_value = rels_ + pkg = OpcPackage() + + rId = pkg.relate_to(part_, "http://rel/type") + + rels_.get_or_add.assert_called_once_with("http://rel/type", part_) + assert rId == "rId99" def it_can_provide_a_list_of_the_parts_it_contains(self): # mockery ---------------------- @@ -64,7 +77,7 @@ def it_can_provide_a_list_of_the_parts_it_contains(self): with patch.object(OpcPackage, "iter_parts", return_value=parts): assert pkg.parts == [parts[0], parts[1]] - def it_can_iterate_over_parts_by_walking_rels_graph(self): + def it_can_iterate_over_parts_by_walking_rels_graph(self, rels_prop_: Mock): # +----------+ +--------+ # | pkg_rels |-----> | part_1 | # +----------+ +--------+ @@ -77,7 +90,7 @@ def it_can_iterate_over_parts_by_walking_rels_graph(self): part1.rels = {1: Mock(name="rel1", is_external=False, target_part=part2)} part2.rels = {1: Mock(name="rel2", is_external=False, target_part=part1)} pkg = OpcPackage() - pkg._rels = { + rels_prop_.return_value = { 1: Mock(name="rel3", is_external=False, target_part=part1), 2: Mock(name="rel4", is_external=True), } @@ -106,21 +119,22 @@ def it_can_find_a_part_related_by_reltype(self, related_part_fixture_): pkg.rels.part_with_reltype.assert_called_once_with(reltype) assert related_part is related_part_ - def it_can_save_to_a_pkg_file(self, pkg_file_, PackageWriter_, parts, parts_): + def it_can_save_to_a_pkg_file( + self, pkg_file_: Mock, PackageWriter_: Mock, parts_prop_: Mock, parts_: list[Mock] + ): + parts_prop_.return_value = parts_ pkg = OpcPackage() pkg.save(pkg_file_) for part in parts_: part.before_marshal.assert_called_once_with() - PackageWriter_.write.assert_called_once_with(pkg_file_, pkg._rels, parts_) + PackageWriter_.write.assert_called_once_with(pkg_file_, pkg.rels, parts_) def it_provides_access_to_the_core_properties(self, core_props_fixture): opc_package, core_properties_ = core_props_fixture core_properties = opc_package.core_properties assert core_properties is core_properties_ - def it_provides_access_to_the_core_properties_part_to_help( - self, core_props_part_fixture - ): + def it_provides_access_to_the_core_properties_part_to_help(self, core_props_part_fixture): opc_package, core_properties_part_ = core_props_part_fixture core_properties_part = opc_package._core_properties_part assert core_properties_part is core_properties_part_ @@ -135,9 +149,7 @@ def it_creates_a_default_core_props_part_if_none_present( core_properties_part = opc_package._core_properties_part CorePropertiesPart_.default.assert_called_once_with(opc_package) - relate_to_.assert_called_once_with( - opc_package, core_properties_part_, RT.CORE_PROPERTIES - ) + relate_to_.assert_called_once_with(opc_package, core_properties_part_, RT.CORE_PROPERTIES) assert core_properties_part is core_properties_part_ # fixtures --------------------------------------------- @@ -161,134 +173,106 @@ def core_props_part_fixture(self, part_related_by_, core_properties_part_): def next_partname_fixture(self, request, iter_parts_): existing_partname_ns, next_partname_n = request.param parts_ = [ - instance_mock( - request, Part, name="part[%d]" % idx, partname="/foo/bar/baz%d.xml" % n - ) + instance_mock(request, Part, name="part[%d]" % idx, partname="/foo/bar/baz%d.xml" % n) for idx, n in enumerate(existing_partname_ns) ] expected_value = "/foo/bar/baz%d.xml" % next_partname_n return parts_, expected_value @pytest.fixture - def relate_to_part_fixture_(self, request, pkg, rels_, reltype): - rId = "rId99" - rel_ = instance_mock(request, _Relationship, name="rel_", rId=rId) - rels_.get_or_add.return_value = rel_ - pkg._rels = rels_ - part_ = instance_mock(request, Part, name="part_") - return pkg, part_, reltype, rId - - @pytest.fixture - def related_part_fixture_(self, request, rels_, reltype): + def related_part_fixture_(self, request: FixtureRequest, rels_prop_: Mock, rels_: Mock): related_part_ = instance_mock(request, Part, name="related_part_") rels_.part_with_reltype.return_value = related_part_ pkg = OpcPackage() - pkg._rels = rels_ - return pkg, reltype, related_part_ + rels_prop_.return_value = rels_ + return pkg, "http://rel/type", related_part_ # fixture components ----------------------------------- @pytest.fixture - def CorePropertiesPart_(self, request): + def CorePropertiesPart_(self, request: FixtureRequest): return class_mock(request, "docx.opc.package.CorePropertiesPart") @pytest.fixture - def core_properties_(self, request): + def core_properties_(self, request: FixtureRequest): return instance_mock(request, CoreProperties) @pytest.fixture - def core_properties_part_(self, request): + def core_properties_part_(self, request: FixtureRequest): return instance_mock(request, CorePropertiesPart) @pytest.fixture - def _core_properties_part_prop_(self, request): + def _core_properties_part_prop_(self, request: FixtureRequest): return property_mock(request, OpcPackage, "_core_properties_part") @pytest.fixture - def iter_parts_(self, request): + def iter_parts_(self, request: FixtureRequest): return method_mock(request, OpcPackage, "iter_parts") @pytest.fixture - def PackageReader_(self, request): + def PackageReader_(self, request: FixtureRequest): return class_mock(request, "docx.opc.package.PackageReader") @pytest.fixture - def PackURI_(self, request): + def PackURI_(self, request: FixtureRequest): return class_mock(request, "docx.opc.package.PackURI") @pytest.fixture - def packuri_(self, request): + def packuri_(self, request: FixtureRequest): return instance_mock(request, PackURI) @pytest.fixture - def PackageWriter_(self, request): + def PackageWriter_(self, request: FixtureRequest): return class_mock(request, "docx.opc.package.PackageWriter") @pytest.fixture - def PartFactory_(self, request): + def PartFactory_(self, request: FixtureRequest): return class_mock(request, "docx.opc.package.PartFactory") @pytest.fixture - def part_related_by_(self, request): - return method_mock(request, OpcPackage, "part_related_by") + def part_(self, request: FixtureRequest): + return instance_mock(request, Part) @pytest.fixture - def parts(self, parts_): - """ - Return a mock patching property OpcPackage.parts, reversing the - patch after each use. - """ - p = patch.object( - OpcPackage, "parts", new_callable=PropertyMock, return_value=parts_ - ) - yield p.start() - p.stop() + def part_related_by_(self, request: FixtureRequest): + return method_mock(request, OpcPackage, "part_related_by") @pytest.fixture - def parts_(self, request): + def parts_(self, request: FixtureRequest): part_ = instance_mock(request, Part, name="part_") part_2_ = instance_mock(request, Part, name="part_2_") return [part_, part_2_] @pytest.fixture - def pkg(self, request): - return OpcPackage() + def parts_prop_(self, request: FixtureRequest): + return property_mock(request, OpcPackage, "parts") @pytest.fixture - def pkg_file_(self, request): + def pkg_file_(self, request: FixtureRequest): return loose_mock(request) @pytest.fixture - def pkg_with_rels_(self, request, rels_): - pkg = OpcPackage() - pkg._rels = rels_ - return pkg - - @pytest.fixture - def Relationships_(self, request): + def Relationships_(self, request: FixtureRequest): return class_mock(request, "docx.opc.package.Relationships") @pytest.fixture - def rel_attrs_(self, request): - reltype = "http://rel/type" - target_ = instance_mock(request, Part, name="target_") - rId = "rId99" - return reltype, target_, rId + def rel_(self, request: FixtureRequest): + return instance_mock(request, _Relationship) @pytest.fixture - def relate_to_(self, request): + def relate_to_(self, request: FixtureRequest): return method_mock(request, OpcPackage, "relate_to") @pytest.fixture - def rels_(self, request): + def rels_(self, request: FixtureRequest): return instance_mock(request, Relationships) @pytest.fixture - def reltype(self, request): - return "http://rel/type" + def rels_prop_(self, request: FixtureRequest): + return property_mock(request, OpcPackage, "rels") @pytest.fixture - def Unmarshaller_(self, request): + def Unmarshaller_(self, request: FixtureRequest): return class_mock(request, "docx.opc.package.Unmarshaller") @@ -306,9 +290,7 @@ def it_can_unmarshal_from_a_pkg_reader( Unmarshaller.unmarshal(pkg_reader_, pkg_, part_factory_) _unmarshal_parts_.assert_called_once_with(pkg_reader_, pkg_, part_factory_) - _unmarshal_relationships_.assert_called_once_with( - pkg_reader_, pkg_, parts_dict_ - ) + _unmarshal_relationships_.assert_called_once_with(pkg_reader_, pkg_, parts_dict_) for part in parts_dict_.values(): part.after_unmarshal.assert_called_once_with() pkg_.after_unmarshal.assert_called_once_with() @@ -406,13 +388,13 @@ def it_can_unmarshal_relationships(self): # fixtures --------------------------------------------- @pytest.fixture - def blobs_(self, request): + def blobs_(self, request: FixtureRequest): blob_ = loose_mock(request, spec=str, name="blob_") blob_2_ = loose_mock(request, spec=str, name="blob_2_") return blob_, blob_2_ @pytest.fixture - def content_types_(self, request): + def content_types_(self, request: FixtureRequest): content_type_ = loose_mock(request, spec=str, name="content_type_") content_type_2_ = loose_mock(request, spec=str, name="content_type_2_") return content_type_, content_type_2_ @@ -424,13 +406,13 @@ def part_factory_(self, request, parts_): return part_factory_ @pytest.fixture - def partnames_(self, request): + def partnames_(self, request: FixtureRequest): partname_ = loose_mock(request, spec=str, name="partname_") partname_2_ = loose_mock(request, spec=str, name="partname_2_") return partname_, partname_2_ @pytest.fixture - def parts_(self, request): + def parts_(self, request: FixtureRequest): part_ = instance_mock(request, Part, name="part_") part_2_ = instance_mock(request, Part, name="part_2") return part_, part_2_ @@ -442,7 +424,7 @@ def parts_dict_(self, request, partnames_, parts_): return {partname_: part_, partname_2_: part_2_} @pytest.fixture - def pkg_(self, request): + def pkg_(self, request: FixtureRequest): return instance_mock(request, OpcPackage) @pytest.fixture @@ -460,17 +442,15 @@ def pkg_reader_(self, request, partnames_, content_types_, reltypes_, blobs_): return pkg_reader_ @pytest.fixture - def reltypes_(self, request): + def reltypes_(self, request: FixtureRequest): reltype_ = instance_mock(request, str, name="reltype_") reltype_2_ = instance_mock(request, str, name="reltype_2") return reltype_, reltype_2_ @pytest.fixture - def _unmarshal_parts_(self, request): + def _unmarshal_parts_(self, request: FixtureRequest): return method_mock(request, Unmarshaller, "_unmarshal_parts", autospec=False) @pytest.fixture - def _unmarshal_relationships_(self, request): - return method_mock( - request, Unmarshaller, "_unmarshal_relationships", autospec=False - ) + def _unmarshal_relationships_(self, request: FixtureRequest): + return method_mock(request, Unmarshaller, "_unmarshal_relationships", autospec=False) diff --git a/tests/opc/test_part.py b/tests/opc/test_part.py index 163912154..dbbcaf262 100644 --- a/tests/opc/test_part.py +++ b/tests/opc/test_part.py @@ -1,5 +1,9 @@ +# pyright: reportPrivateUsage=false + """Unit test suite for docx.opc.part module""" +from __future__ import annotations + import pytest from docx.opc.package import OpcPackage @@ -11,6 +15,7 @@ from ..unitutil.cxml import element from ..unitutil.mock import ( ANY, + FixtureRequest, Mock, class_mock, cls_attr_mock, @@ -18,249 +23,170 @@ initializer_mock, instance_mock, loose_mock, + property_mock, ) class DescribePart: - def it_can_be_constructed_by_PartFactory( - self, partname_, content_type_, blob_, package_, __init_ - ): - part = Part.load(partname_, content_type_, blob_, package_) + """Unit-test suite for `docx.opc.part.Part` objects.""" + + def it_can_be_constructed_by_PartFactory(self, package_: Mock, init__: Mock): + part = Part.load(PackURI("/part/name"), "content/type", b"1be2", package_) - __init_.assert_called_once_with(ANY, partname_, content_type_, blob_, package_) + init__.assert_called_once_with(ANY, "/part/name", "content/type", b"1be2", package_) assert isinstance(part, Part) - def it_knows_its_partname(self, partname_get_fixture): - part, expected_partname = partname_get_fixture - assert part.partname == expected_partname + def it_knows_its_partname(self): + part = Part(PackURI("/part/name"), "content/type") + assert part.partname == "/part/name" - def it_can_change_its_partname(self, partname_set_fixture): - part, new_partname = partname_set_fixture - part.partname = new_partname - assert part.partname == new_partname + def it_can_change_its_partname(self): + part = Part(PackURI("/old/part/name"), "content/type") + part.partname = PackURI("/new/part/name") + assert part.partname == "/new/part/name" - def it_knows_its_content_type(self, content_type_fixture): - part, expected_content_type = content_type_fixture - assert part.content_type == expected_content_type + def it_knows_its_content_type(self): + part = Part(PackURI("/part/name"), "content/type") + assert part.content_type == "content/type" - def it_knows_the_package_it_belongs_to(self, package_get_fixture): - part, expected_package = package_get_fixture - assert part.package == expected_package + def it_knows_the_package_it_belongs_to(self, package_: Mock): + part = Part(PackURI("/part/name"), "content/type", package=package_) + assert part.package is package_ - def it_can_be_notified_after_unmarshalling_is_complete(self, part): + def it_can_be_notified_after_unmarshalling_is_complete(self): + part = Part(PackURI("/part/name"), "content/type") part.after_unmarshal() - def it_can_be_notified_before_marshalling_is_started(self, part): + def it_can_be_notified_before_marshalling_is_started(self): + part = Part(PackURI("/part/name"), "content/type") part.before_marshal() - def it_uses_the_load_blob_as_its_blob(self, blob_fixture): - part, load_blob = blob_fixture - assert part.blob is load_blob + def it_uses_the_load_blob_as_its_blob(self): + blob = b"abcde" + part = Part(PackURI("/part/name"), "content/type", blob) + assert part.blob is blob # fixtures --------------------------------------------- @pytest.fixture - def blob_fixture(self, blob_): - part = Part(None, None, blob_, None) - return part, blob_ + def init__(self, request: FixtureRequest): + return initializer_mock(request, Part) @pytest.fixture - def content_type_fixture(self): - content_type = "content/type" - part = Part(None, content_type, None, None) - return part, content_type + def package_(self, request: FixtureRequest): + return instance_mock(request, OpcPackage) - @pytest.fixture - def package_get_fixture(self, package_): - part = Part(None, None, None, package_) - return part, package_ - @pytest.fixture - def part(self): - part = Part(None, None) - return part +class DescribePartRelationshipManagementInterface: + """Unit-test suite for `docx.opc.package.Part` relationship behaviors.""" - @pytest.fixture - def partname_get_fixture(self): - partname = PackURI("/part/name") - part = Part(partname, None, None, None) - return part, partname + def it_provides_access_to_its_relationships( + self, Relationships_: Mock, partname_: Mock, rels_: Mock + ): + Relationships_.return_value = rels_ + part = Part(partname_, "content_type") - @pytest.fixture - def partname_set_fixture(self): - old_partname = PackURI("/old/part/name") - new_partname = PackURI("/new/part/name") - part = Part(old_partname, None, None, None) - return part, new_partname + rels = part.rels - # fixture components --------------------------------------------- + Relationships_.assert_called_once_with(partname_.baseURI) + assert rels is rels_ - @pytest.fixture - def blob_(self, request): - return instance_mock(request, bytes) + def it_can_load_a_relationship(self, rels_prop_: Mock, rels_: Mock, other_part_: Mock): + rels_prop_.return_value = rels_ + part = Part("partname", "content_type") - @pytest.fixture - def content_type_(self, request): - return instance_mock(request, str) + part.load_rel("http://rel/type", other_part_, "rId42") - @pytest.fixture - def __init_(self, request): - return initializer_mock(request, Part) + rels_.add_relationship.assert_called_once_with( + "http://rel/type", other_part_, "rId42", False + ) - @pytest.fixture - def package_(self, request): - return instance_mock(request, OpcPackage) + def it_can_establish_a_relationship_to_another_part( + self, rels_prop_: Mock, rels_: Mock, rel_: Mock, other_part_: Mock + ): + rels_prop_.return_value = rels_ + rels_.get_or_add.return_value = rel_ + rel_.rId = "rId18" + part = Part("partname", "content_type") - @pytest.fixture - def partname_(self, request): - return instance_mock(request, PackURI) + rId = part.relate_to(other_part_, "http://rel/type") + rels_.get_or_add.assert_called_once_with("http://rel/type", other_part_) + assert rId == "rId18" -class DescribePartRelationshipManagementInterface: - def it_provides_access_to_its_relationships(self, rels_fixture): - part, Relationships_, partname_, rels_ = rels_fixture - rels = part.rels - Relationships_.assert_called_once_with(partname_.baseURI) - assert rels is rels_ + def it_can_establish_an_external_relationship(self, rels_prop_: Mock, rels_: Mock): + rels_prop_.return_value = rels_ + rels_.get_or_add_ext_rel.return_value = "rId27" + part = Part("partname", "content_type") - def it_can_load_a_relationship(self, load_rel_fixture): - part, rels_, reltype_, target_, rId_ = load_rel_fixture - part.load_rel(reltype_, target_, rId_) - rels_.add_relationship.assert_called_once_with(reltype_, target_, rId_, False) - - def it_can_establish_a_relationship_to_another_part(self, relate_to_part_fixture): - part, target_, reltype_, rId_ = relate_to_part_fixture - rId = part.relate_to(target_, reltype_) - part.rels.get_or_add.assert_called_once_with(reltype_, target_) - assert rId is rId_ - - def it_can_establish_an_external_relationship(self, relate_to_url_fixture): - part, url_, reltype_, rId_ = relate_to_url_fixture - rId = part.relate_to(url_, reltype_, is_external=True) - part.rels.get_or_add_ext_rel.assert_called_once_with(reltype_, url_) - assert rId is rId_ - - def it_can_drop_a_relationship(self, drop_rel_fixture): - part, rId, rel_should_be_gone = drop_rel_fixture - part.drop_rel(rId) - if rel_should_be_gone: - assert rId not in part.rels - else: - assert rId in part.rels - - def it_can_find_a_related_part_by_reltype(self, related_part_fixture): - part, reltype_, related_part_ = related_part_fixture - related_part = part.part_related_by(reltype_) - part.rels.part_with_reltype.assert_called_once_with(reltype_) - assert related_part is related_part_ - - def it_can_find_a_related_part_by_rId(self, related_parts_fixture): - part, related_parts_ = related_parts_fixture - assert part.related_parts is related_parts_ - - def it_can_find_the_uri_of_an_external_relationship(self, target_ref_fixture): - part, rId_, url_ = target_ref_fixture - url = part.target_ref(rId_) - assert url == url_ + rId = part.relate_to("https://hyper/link", "http://rel/type", is_external=True) - # fixtures --------------------------------------------- + rels_.get_or_add_ext_rel.assert_called_once_with("http://rel/type", "https://hyper/link") + assert rId == "rId27" - @pytest.fixture( - params=[ - ("w:p", True), - ("w:p/r:a{r:id=rId42}", True), - ("w:p/r:a{r:id=rId42}/r:b{r:id=rId42}", False), - ] - ) - def drop_rel_fixture(self, request, part): - part_cxml, rel_should_be_dropped = request.param - rId = "rId42" - part._element = element(part_cxml) - part._rels = {rId: None} - return part, rId, rel_should_be_dropped + def it_can_drop_a_relationship(self, rels_prop_: Mock): + rels_prop_.return_value = {"rId42": None} + part = Part(PackURI("/partname"), "content_type") - @pytest.fixture - def load_rel_fixture(self, part, rels_, reltype_, part_, rId_): - part._rels = rels_ - return part, rels_, reltype_, part_, rId_ + part.drop_rel("rId42") - @pytest.fixture - def relate_to_part_fixture(self, request, part, reltype_, part_, rels_, rId_): - part._rels = rels_ - target_ = part_ - return part, target_, reltype_, rId_ + assert "rId42" not in part.rels - @pytest.fixture - def relate_to_url_fixture(self, request, part, rels_, url_, reltype_, rId_): - part._rels = rels_ - return part, url_, reltype_, rId_ + def it_can_find_a_related_part_by_reltype( + self, rels_prop_: Mock, rels_: Mock, other_part_: Mock + ): + rels_prop_.return_value = rels_ + rels_.part_with_reltype.return_value = other_part_ + part = Part("partname", "content_type") - @pytest.fixture - def related_part_fixture(self, request, part, rels_, reltype_, part_): - part._rels = rels_ - return part, reltype_, part_ + related_part = part.part_related_by("http://rel/type") - @pytest.fixture - def related_parts_fixture(self, request, part, rels_, related_parts_): - part._rels = rels_ - return part, related_parts_ + rels_.part_with_reltype.assert_called_once_with("http://rel/type") + assert related_part is other_part_ - @pytest.fixture - def rels_fixture(self, Relationships_, partname_, rels_): - part = Part(partname_, None) - return part, Relationships_, partname_, rels_ + def it_can_find_a_related_part_by_rId(self, rels_prop_: Mock, rels_: Mock, other_part_: Mock): + rels_prop_.return_value = rels_ + rels_.related_parts = {"rId24": other_part_} + part = Part("partname", "content_type") - @pytest.fixture - def target_ref_fixture(self, request, part, rId_, rel_, url_): - part._rels = {rId_: rel_} - return part, rId_, url_ + assert part.related_parts["rId24"] is other_part_ - # fixture components --------------------------------------------- - - @pytest.fixture - def part(self): - return Part(None, None) + def it_can_find_the_uri_of_an_external_relationship( + self, rels_prop_: Mock, rel_: Mock, other_part_: Mock + ): + rels_prop_.return_value = {"rId7": rel_} + rel_.target_ref = "https://hyper/link" + part = Part("partname", "content_type") - @pytest.fixture - def part_(self, request): - return instance_mock(request, Part) + url = part.target_ref("rId7") - @pytest.fixture - def partname_(self, request): - return instance_mock(request, PackURI) + assert url == "https://hyper/link" - @pytest.fixture - def Relationships_(self, request, rels_): - return class_mock(request, "docx.opc.part.Relationships", return_value=rels_) + # fixtures --------------------------------------------- @pytest.fixture - def rel_(self, request, rId_, url_): - return instance_mock(request, _Relationship, rId=rId_, target_ref=url_) + def other_part_(self, request: FixtureRequest): + return instance_mock(request, Part) @pytest.fixture - def rels_(self, request, part_, rel_, rId_, related_parts_): - rels_ = instance_mock(request, Relationships) - rels_.part_with_reltype.return_value = part_ - rels_.get_or_add.return_value = rel_ - rels_.get_or_add_ext_rel.return_value = rId_ - rels_.related_parts = related_parts_ - return rels_ + def partname_(self, request: FixtureRequest): + return instance_mock(request, PackURI) @pytest.fixture - def related_parts_(self, request): - return instance_mock(request, dict) + def Relationships_(self, request: FixtureRequest): + return class_mock(request, "docx.opc.part.Relationships") @pytest.fixture - def reltype_(self, request): - return instance_mock(request, str) + def rel_(self, request: FixtureRequest): + return instance_mock(request, _Relationship) @pytest.fixture - def rId_(self, request): - return instance_mock(request, str) + def rels_(self, request: FixtureRequest): + return instance_mock(request, Relationships) @pytest.fixture - def url_(self, request): - return instance_mock(request, str) + def rels_prop_(self, request: FixtureRequest): + return property_mock(request, Part, "rels") class DescribePartFactory: @@ -278,9 +204,7 @@ def it_constructs_part_from_selector_if_defined(self, cls_selector_fixture): part = PartFactory(partname, content_type, reltype, blob, package) # verify ----------------------- cls_selector_fn_.assert_called_once_with(content_type, reltype) - CustomPartClass_.load.assert_called_once_with( - partname, content_type, blob, package - ) + CustomPartClass_.load.assert_called_once_with(partname, content_type, blob, package) assert part is part_of_custom_type_ def it_constructs_custom_part_type_for_registered_content_types( @@ -292,9 +216,7 @@ def it_constructs_custom_part_type_for_registered_content_types( PartFactory.part_type_for[content_type] = CustomPartClass_ part = PartFactory(partname, content_type, reltype, blob, package) # verify ----------------------- - CustomPartClass_.load.assert_called_once_with( - partname, content_type, blob, package - ) + CustomPartClass_.load.assert_called_once_with(partname, content_type, blob, package) assert part is part_of_custom_type_ def it_constructs_part_using_default_class_when_no_custom_registered( @@ -302,9 +224,7 @@ def it_constructs_part_using_default_class_when_no_custom_registered( ): partname, content_type, reltype, blob, package = part_args_2_ part = PartFactory(partname, content_type, reltype, blob, package) - DefaultPartClass_.load.assert_called_once_with( - partname, content_type, blob, package - ) + DefaultPartClass_.load.assert_called_once_with(partname, content_type, blob, package) assert part is part_of_default_type_ # fixtures --------------------------------------------- @@ -319,9 +239,7 @@ def blob_2_(self, request): @pytest.fixture def cls_method_fn_(self, request, cls_selector_fn_): - return function_mock( - request, "docx.opc.part.cls_method_fn", return_value=cls_selector_fn_ - ) + return function_mock(request, "docx.opc.part.cls_method_fn", return_value=cls_selector_fn_) @pytest.fixture def cls_selector_fixture( @@ -405,9 +323,7 @@ def part_args_(self, request, partname_, content_type_, reltype_, package_, blob return partname_, content_type_, reltype_, blob_, package_ @pytest.fixture - def part_args_2_( - self, request, partname_2_, content_type_2_, reltype_2_, package_2_, blob_2_ - ): + def part_args_2_(self, request, partname_2_, content_type_2_, reltype_2_, package_2_, blob_2_): return partname_2_, content_type_2_, reltype_2_, blob_2_, package_2_ @pytest.fixture @@ -426,9 +342,7 @@ def it_can_be_constructed_by_PartFactory( part = XmlPart.load(partname_, content_type_, blob_, package_) parse_xml_.assert_called_once_with(blob_) - __init_.assert_called_once_with( - ANY, partname_, content_type_, element_, package_ - ) + __init_.assert_called_once_with(ANY, partname_, content_type_, element_, package_) assert isinstance(part, XmlPart) def it_can_serialize_to_xml(self, blob_fixture): @@ -441,6 +355,24 @@ def it_knows_its_the_part_for_its_child_objects(self, part_fixture): xml_part = part_fixture assert xml_part.part is xml_part + @pytest.mark.parametrize( + ("part_cxml", "rel_should_be_dropped"), + [ + ("w:p", True), + ("w:p/r:a{r:id=rId42}", True), + ("w:p/r:a{r:id=rId42}/r:b{r:id=rId42}", False), + ], + ) + def it_only_drops_a_relationship_with_zero_reference_count( + self, part_cxml: str, rel_should_be_dropped: bool, rels_prop_: Mock, package_: Mock + ): + rels_prop_.return_value = {"rId42": None} + part = XmlPart(PackURI("/partname"), "content_type", element(part_cxml), package_) + + part.drop_rel("rId42") + + assert ("rId42" not in part.rels) is rel_should_be_dropped + # fixtures ------------------------------------------------------- @pytest.fixture @@ -482,6 +414,10 @@ def parse_xml_(self, request, element_): def partname_(self, request): return instance_mock(request, PackURI) + @pytest.fixture + def rels_prop_(self, request: FixtureRequest): + return property_mock(request, XmlPart, "rels") + @pytest.fixture def serialize_part_xml_(self, request): return function_mock(request, "docx.opc.part.serialize_part_xml") diff --git a/tests/opc/test_pkgwriter.py b/tests/opc/test_pkgwriter.py index 747300f82..aff8b22d9 100644 --- a/tests/opc/test_pkgwriter.py +++ b/tests/opc/test_pkgwriter.py @@ -1,5 +1,9 @@ +# pyright: reportPrivateUsage=false + """Test suite for opc.pkgwriter module.""" +from __future__ import annotations + import pytest from docx.opc.constants import CONTENT_TYPE as CT @@ -7,9 +11,10 @@ from docx.opc.part import Part from docx.opc.phys_pkg import _ZipPkgWriter from docx.opc.pkgwriter import PackageWriter, _ContentTypesItem +from docx.opc.rel import Relationships from ..unitutil.mock import ( - MagicMock, + FixtureRequest, Mock, call, class_mock, @@ -54,41 +59,48 @@ def it_can_write_a_pkg_rels_item(self): # verify ----------------------- phys_writer.write.assert_called_once_with("/_rels/.rels", pkg_rels.xml) - def it_can_write_a_list_of_parts(self): - # mockery ---------------------- - phys_writer = Mock(name="phys_writer") - rels = MagicMock(name="rels") - rels.__len__.return_value = 1 - part1 = Mock(name="part1", _rels=rels) - part2 = Mock(name="part2", _rels=[]) - # exercise --------------------- - PackageWriter._write_parts(phys_writer, [part1, part2]) - # verify ----------------------- + def it_can_write_a_list_of_parts( + self, phys_pkg_writer_: Mock, part_: Mock, part_2_: Mock, rels_: Mock + ): + rels_.__len__.return_value = 1 + part_.rels = rels_ + part_2_.rels = [] + + PackageWriter._write_parts(phys_pkg_writer_, [part_, part_2_]) + expected_calls = [ - call(part1.partname, part1.blob), - call(part1.partname.rels_uri, part1._rels.xml), - call(part2.partname, part2.blob), + call(part_.partname, part_.blob), + call(part_.partname.rels_uri, part_.rels.xml), + call(part_2_.partname, part_2_.blob), ] - assert phys_writer.write.mock_calls == expected_calls + assert phys_pkg_writer_.write.mock_calls == expected_calls # fixtures --------------------------------------------- @pytest.fixture - def blob_(self, request): + def blob_(self, request: FixtureRequest): return instance_mock(request, str) @pytest.fixture - def cti_(self, request, blob_): + def cti_(self, request: FixtureRequest, blob_): return instance_mock(request, _ContentTypesItem, blob=blob_) @pytest.fixture - def _ContentTypesItem_(self, request, cti_): + def _ContentTypesItem_(self, request: FixtureRequest, cti_): _ContentTypesItem_ = class_mock(request, "docx.opc.pkgwriter._ContentTypesItem") _ContentTypesItem_.from_parts.return_value = cti_ return _ContentTypesItem_ @pytest.fixture - def parts_(self, request): + def part_(self, request: FixtureRequest): + return instance_mock(request, Part) + + @pytest.fixture + def part_2_(self, request: FixtureRequest): + return instance_mock(request, Part) + + @pytest.fixture + def parts_(self, request: FixtureRequest): return instance_mock(request, list) @pytest.fixture @@ -98,9 +110,13 @@ def PhysPkgWriter_(self): p.stop() @pytest.fixture - def phys_pkg_writer_(self, request): + def phys_pkg_writer_(self, request: FixtureRequest): return instance_mock(request, _ZipPkgWriter) + @pytest.fixture + def rels_(self, request: FixtureRequest): + return instance_mock(request, Relationships) + @pytest.fixture def write_cti_fixture(self, _ContentTypesItem_, parts_, phys_pkg_writer_, blob_): return _ContentTypesItem_, parts_, phys_pkg_writer_, blob_ @@ -123,7 +139,7 @@ def _write_methods(self): patch3.stop() @pytest.fixture - def xml_for_(self, request): + def xml_for_(self, request: FixtureRequest): return method_mock(request, _ContentTypesItem, "xml_for") @@ -135,11 +151,9 @@ def it_can_compose_content_types_element(self, xml_for_fixture): # fixtures --------------------------------------------- - def _mock_part(self, request, name, partname_str, content_type): + def _mock_part(self, request: FixtureRequest, name, partname_str, content_type): partname = PackURI(partname_str) - return instance_mock( - request, Part, name=name, partname=partname, content_type=content_type - ) + return instance_mock(request, Part, name=name, partname=partname, content_type=content_type) @pytest.fixture( params=[ @@ -152,7 +166,7 @@ def _mock_part(self, request, name, partname_str, content_type): ("Override", "/zebra/foo.bar", "app/vnd.foobar"), ] ) - def xml_for_fixture(self, request): + def xml_for_fixture(self, request: FixtureRequest): elm_type, partname_str, content_type = request.param part_ = self._mock_part(request, "part_", partname_str, content_type) cti = _ContentTypesItem.from_parts([part_]) @@ -168,9 +182,7 @@ def xml_for_fixture(self, request): types_bldr.with_child( a_Default().with_Extension("rels").with_ContentType(CT.OPC_RELATIONSHIPS) ) - types_bldr.with_child( - a_Default().with_Extension("xml").with_ContentType(CT.XML) - ) + types_bldr.with_child(a_Default().with_Extension("xml").with_ContentType(CT.XML)) if elm_type == "Override": override_bldr = an_Override() diff --git a/tests/oxml/test_table.py b/tests/oxml/test_table.py index 395c812a6..46b2f4ed1 100644 --- a/tests/oxml/test_table.py +++ b/tests/oxml/test_table.py @@ -1,3 +1,5 @@ +# pyright: reportPrivateUsage=false + """Test suite for the docx.oxml.text module.""" from __future__ import annotations @@ -13,174 +15,71 @@ from ..unitutil.cxml import element, xml from ..unitutil.file import snippet_seq -from ..unitutil.mock import call, instance_mock, method_mock, property_mock +from ..unitutil.mock import FixtureRequest, Mock, call, instance_mock, method_mock, property_mock class DescribeCT_Row: - def it_can_add_a_trPr(self, add_trPr_fixture): - tr, expected_xml = add_trPr_fixture - tr._add_trPr() - assert tr.xml == expected_xml - def it_raises_on_tc_at_grid_col(self, tc_raise_fixture): - tr, idx = tc_raise_fixture - with pytest.raises(ValueError): # noqa: PT011 - tr.tc_at_grid_col(idx) - - # fixtures ------------------------------------------------------- - - @pytest.fixture( - params=[ + @pytest.mark.parametrize( + ("tr_cxml", "expected_cxml"), + [ ("w:tr", "w:tr/w:trPr"), ("w:tr/w:tblPrEx", "w:tr/(w:tblPrEx,w:trPr)"), ("w:tr/w:tc", "w:tr/(w:trPr,w:tc)"), ("w:tr/(w:sdt,w:del,w:tc)", "w:tr/(w:trPr,w:sdt,w:del,w:tc)"), - ] + ], ) - def add_trPr_fixture(self, request): - tr_cxml, expected_cxml = request.param - tr = element(tr_cxml) - expected_xml = xml(expected_cxml) - return tr, expected_xml + def it_can_add_a_trPr(self, tr_cxml: str, expected_cxml: str): + tr = cast(CT_Row, element(tr_cxml)) + tr._add_trPr() + assert tr.xml == xml(expected_cxml) - @pytest.fixture(params=[(0, 0, 3), (1, 0, 1)]) - def tc_raise_fixture(self, request): - snippet_idx, row_idx, col_idx = request.param - tbl = parse_xml(snippet_seq("tbl-cells")[snippet_idx]) - tr = tbl.tr_lst[row_idx] - return tr, col_idx + @pytest.mark.parametrize(("snippet_idx", "row_idx", "col_idx"), [(0, 0, 3), (1, 0, 1)]) + def it_raises_on_tc_at_grid_col(self, snippet_idx: int, row_idx: int, col_idx: int): + tr = cast(CT_Tbl, parse_xml(snippet_seq("tbl-cells")[snippet_idx])).tr_lst[row_idx] + with pytest.raises(ValueError, match=f"no `tc` element at grid_offset={col_idx}"): + tr.tc_at_grid_offset(col_idx) class DescribeCT_Tc: + """Unit-test suite for `docx.oxml.table.CT_Tc` objects.""" + + @pytest.mark.parametrize( + ("tr_cxml", "tc_idx", "expected_value"), + [ + ("w:tr/(w:tc/w:p,w:tc/w:p)", 0, 0), + ("w:tr/(w:tc/w:p,w:tc/w:p)", 1, 1), + ("w:tr/(w:trPr/w:gridBefore{w:val=2},w:tc/w:p,w:tc/w:p)", 0, 2), + ("w:tr/(w:trPr/w:gridBefore{w:val=2},w:tc/w:p,w:tc/w:p)", 1, 3), + ("w:tr/(w:trPr/w:gridBefore{w:val=4},w:tc/w:p,w:tc/w:p,w:tc/w:p,w:tc/w:p)", 2, 6), + ], + ) + def it_knows_its_grid_offset(self, tr_cxml: str, tc_idx: int, expected_value: int): + tr = cast(CT_Row, element(tr_cxml)) + tc = tr.tc_lst[tc_idx] + + assert tc.grid_offset == expected_value + def it_can_merge_to_another_tc( - self, tr_, _span_dimensions_, _tbl_, _grow_to_, top_tc_ + self, tr_: Mock, _span_dimensions_: Mock, _tbl_: Mock, _grow_to_: Mock, top_tc_: Mock ): top_tr_ = tr_ - tc, other_tc = element("w:tc"), element("w:tc") + tc, other_tc = cast(CT_Tc, element("w:tc")), cast(CT_Tc, element("w:tc")) top, left, height, width = 0, 1, 2, 3 _span_dimensions_.return_value = top, left, height, width _tbl_.return_value.tr_lst = [tr_] - tr_.tc_at_grid_col.return_value = top_tc_ + tr_.tc_at_grid_offset.return_value = top_tc_ merged_tc = tc.merge(other_tc) _span_dimensions_.assert_called_once_with(tc, other_tc) - top_tr_.tc_at_grid_col.assert_called_once_with(left) + top_tr_.tc_at_grid_offset.assert_called_once_with(left) top_tc_._grow_to.assert_called_once_with(width, height) assert merged_tc is top_tc_ - def it_knows_its_extents_to_help(self, extents_fixture): - tc, attr_name, expected_value = extents_fixture - extent = getattr(tc, attr_name) - assert extent == expected_value - - def it_calculates_the_dimensions_of_a_span_to_help(self, span_fixture): - tc, other_tc, expected_dimensions = span_fixture - dimensions = tc._span_dimensions(other_tc) - assert dimensions == expected_dimensions - - def it_raises_on_invalid_span(self, span_raise_fixture): - tc, other_tc = span_raise_fixture - with pytest.raises(InvalidSpanError): - tc._span_dimensions(other_tc) - - def it_can_grow_itself_to_help_merge(self, grow_to_fixture): - tc, width, height, top_tc, expected_calls = grow_to_fixture - tc._grow_to(width, height, top_tc) - assert tc._span_to_width.call_args_list == expected_calls - - def it_can_extend_its_horz_span_to_help_merge( - self, top_tc_, grid_span_, _move_content_to_, _swallow_next_tc_ - ): - grid_span_.side_effect = [1, 3, 4] - grid_width, vMerge = 4, "continue" - tc = element("w:tc") - - tc._span_to_width(grid_width, top_tc_, vMerge) - - _move_content_to_.assert_called_once_with(tc, top_tc_) - assert _swallow_next_tc_.call_args_list == [ - call(tc, grid_width, top_tc_), - call(tc, grid_width, top_tc_), - ] - assert tc.vMerge == vMerge - - def it_knows_its_inner_content_block_item_elements(self): - tc = cast(CT_Tc, element("w:tc/(w:p,w:tbl,w:p)")) - assert [type(e) for e in tc.inner_content_elements] == [CT_P, CT_Tbl, CT_P] - - def it_can_swallow_the_next_tc_help_merge(self, swallow_fixture): - tc, grid_width, top_tc, tr, expected_xml = swallow_fixture - tc._swallow_next_tc(grid_width, top_tc) - assert tr.xml == expected_xml - - def it_adds_cell_widths_on_swallow(self, add_width_fixture): - tc, grid_width, top_tc, tr, expected_xml = add_width_fixture - tc._swallow_next_tc(grid_width, top_tc) - assert tr.xml == expected_xml - - def it_raises_on_invalid_swallow(self, swallow_raise_fixture): - tc, grid_width, top_tc, tr = swallow_raise_fixture - with pytest.raises(InvalidSpanError): - tc._swallow_next_tc(grid_width, top_tc) - - def it_can_move_its_content_to_help_merge(self, move_fixture): - tc, tc_2, expected_tc_xml, expected_tc_2_xml = move_fixture - tc._move_content_to(tc_2) - assert tc.xml == expected_tc_xml - assert tc_2.xml == expected_tc_2_xml - - def it_raises_on_tr_above(self, tr_above_raise_fixture): - tc = tr_above_raise_fixture - with pytest.raises(ValueError, match="no tr above topmost tr"): - tc._tr_above - - # fixtures ------------------------------------------------------- - - @pytest.fixture( - params=[ - # both cells have a width - ( - "w:tr/(w:tc/(w:tcPr/w:tcW{w:w=1440,w:type=dxa},w:p)," - "w:tc/(w:tcPr/w:tcW{w:w=1440,w:type=dxa},w:p))", - 0, - 2, - "w:tr/(w:tc/(w:tcPr/(w:tcW{w:w=2880,w:type=dxa}," - "w:gridSpan{w:val=2}),w:p))", - ), - # neither have a width - ( - "w:tr/(w:tc/w:p,w:tc/w:p)", - 0, - 2, - "w:tr/(w:tc/(w:tcPr/w:gridSpan{w:val=2},w:p))", - ), - # only second one has a width - ( - "w:tr/(w:tc/w:p," "w:tc/(w:tcPr/w:tcW{w:w=1440,w:type=dxa},w:p))", - 0, - 2, - "w:tr/(w:tc/(w:tcPr/w:gridSpan{w:val=2},w:p))", - ), - # only first one has a width - ( - "w:tr/(w:tc/(w:tcPr/w:tcW{w:w=1440,w:type=dxa},w:p)," "w:tc/w:p)", - 0, - 2, - "w:tr/(w:tc/(w:tcPr/(w:tcW{w:w=1440,w:type=dxa}," - "w:gridSpan{w:val=2}),w:p))", - ), - ] - ) - def add_width_fixture(self, request): - tr_cxml, tc_idx, grid_width, expected_tr_cxml = request.param - tr = element(tr_cxml) - tc = top_tc = tr[tc_idx] - expected_tr_xml = xml(expected_tr_cxml) - return tc, grid_width, top_tc, tr, expected_tr_xml - - @pytest.fixture( - params=[ + @pytest.mark.parametrize( + ("snippet_idx", "row", "col", "attr_name", "expected_value"), + [ (0, 0, 0, "top", 0), (2, 0, 1, "top", 0), (2, 1, 1, "top", 0), @@ -195,63 +94,22 @@ def add_width_fixture(self, request): (4, 1, 1, "bottom", 3), (0, 0, 0, "right", 1), (1, 0, 0, "right", 2), - (0, 0, 0, "right", 1), (4, 2, 1, "right", 3), - ] + ], ) - def extents_fixture(self, request): - snippet_idx, row, col, attr_name, expected_value = request.param + def it_knows_its_extents_to_help( + self, snippet_idx: int, row: int, col: int, attr_name: str, expected_value: int + ): tbl = self._snippet_tbl(snippet_idx) tc = tbl.tr_lst[row].tc_lst[col] - return tc, attr_name, expected_value - @pytest.fixture( - params=[ - (0, 0, 0, 2, 1), - (0, 0, 1, 1, 2), - (0, 1, 1, 2, 2), - (1, 0, 0, 2, 2), - (2, 0, 0, 2, 2), - (2, 1, 2, 1, 2), - ] - ) - def grow_to_fixture(self, request, _span_to_width_): - snippet_idx, row, col, width, height = request.param - tbl = self._snippet_tbl(snippet_idx) - tc = tbl.tr_lst[row].tc_lst[col] - start = 0 if height == 1 else 1 - end = start + height - expected_calls = [ - call(width, tc, None), - call(width, tc, "restart"), - call(width, tc, "continue"), - call(width, tc, "continue"), - ][start:end] - return tc, width, height, None, expected_calls - - @pytest.fixture( - params=[ - ("w:tc/w:p", "w:tc/w:p", "w:tc/w:p", "w:tc/w:p"), - ("w:tc/w:p", "w:tc/w:p/w:r", "w:tc/w:p", "w:tc/w:p/w:r"), - ("w:tc/w:p/w:r", "w:tc/w:p", "w:tc/w:p", "w:tc/w:p/w:r"), - ("w:tc/(w:p/w:r,w:sdt)", "w:tc/w:p", "w:tc/w:p", "w:tc/(w:p/w:r,w:sdt)"), - ( - "w:tc/(w:p/w:r,w:sdt)", - "w:tc/(w:tbl,w:p)", - "w:tc/w:p", - "w:tc/(w:tbl,w:p/w:r,w:sdt)", - ), - ] - ) - def move_fixture(self, request): - tc_cxml, tc_2_cxml, expected_tc_cxml, expected_tc_2_cxml = request.param - tc, tc_2 = element(tc_cxml), element(tc_2_cxml) - expected_tc_xml = xml(expected_tc_cxml) - expected_tc_2_xml = xml(expected_tc_2_cxml) - return tc, tc_2, expected_tc_xml, expected_tc_2_xml - - @pytest.fixture( - params=[ + extent = getattr(tc, attr_name) + + assert extent == expected_value + + @pytest.mark.parametrize( + ("snippet_idx", "row", "col", "row_2", "col_2", "expected_value"), + [ (0, 0, 0, 0, 1, (0, 0, 1, 2)), (0, 0, 1, 2, 1, (0, 1, 3, 1)), (0, 2, 2, 1, 1, (1, 1, 2, 2)), @@ -262,17 +120,28 @@ def move_fixture(self, request): (2, 0, 1, 1, 0, (0, 0, 2, 2)), (2, 1, 2, 0, 1, (0, 1, 2, 2)), (4, 0, 1, 0, 0, (0, 0, 1, 3)), - ] + ], ) - def span_fixture(self, request): - snippet_idx, row, col, row_2, col_2, expected_value = request.param + def it_calculates_the_dimensions_of_a_span_to_help( + self, + snippet_idx: int, + row: int, + col: int, + row_2: int, + col_2: int, + expected_value: tuple[int, int, int, int], + ): tbl = self._snippet_tbl(snippet_idx) tc = tbl.tr_lst[row].tc_lst[col] - tc_2 = tbl.tr_lst[row_2].tc_lst[col_2] - return tc, tc_2, expected_value + other_tc = tbl.tr_lst[row_2].tc_lst[col_2] + + dimensions = tc._span_dimensions(other_tc) - @pytest.fixture( - params=[ + assert dimensions == expected_value + + @pytest.mark.parametrize( + ("snippet_idx", "row", "col", "row_2", "col_2"), + [ (1, 0, 0, 1, 0), # inverted-L horz (1, 1, 0, 0, 0), # same in opposite order (2, 0, 2, 0, 1), # inverted-L vert @@ -280,17 +149,72 @@ def span_fixture(self, request): (5, 1, 0, 2, 1), # same, opposite side (6, 1, 0, 0, 1), # tee-shape vert bar (6, 0, 1, 1, 2), # same, opposite side - ] + ], ) - def span_raise_fixture(self, request): - snippet_idx, row, col, row_2, col_2 = request.param + def it_raises_on_invalid_span( + self, snippet_idx: int, row: int, col: int, row_2: int, col_2: int + ): tbl = self._snippet_tbl(snippet_idx) tc = tbl.tr_lst[row].tc_lst[col] - tc_2 = tbl.tr_lst[row_2].tc_lst[col_2] - return tc, tc_2 + other_tc = tbl.tr_lst[row_2].tc_lst[col_2] + + with pytest.raises(InvalidSpanError): + tc._span_dimensions(other_tc) + + @pytest.mark.parametrize( + ("snippet_idx", "row", "col", "width", "height"), + [ + (0, 0, 0, 2, 1), + (0, 0, 1, 1, 2), + (0, 1, 1, 2, 2), + (1, 0, 0, 2, 2), + (2, 0, 0, 2, 2), + (2, 1, 2, 1, 2), + ], + ) + def it_can_grow_itself_to_help_merge( + self, snippet_idx: int, row: int, col: int, width: int, height: int, _span_to_width_: Mock + ): + tbl = self._snippet_tbl(snippet_idx) + tc = tbl.tr_lst[row].tc_lst[col] + start = 0 if height == 1 else 1 + end = start + height + + tc._grow_to(width, height, None) - @pytest.fixture( - params=[ + assert ( + _span_to_width_.call_args_list + == [ + call(width, tc, None), + call(width, tc, "restart"), + call(width, tc, "continue"), + call(width, tc, "continue"), + ][start:end] + ) + + def it_can_extend_its_horz_span_to_help_merge( + self, top_tc_: Mock, grid_span_: Mock, _move_content_to_: Mock, _swallow_next_tc_: Mock + ): + grid_span_.side_effect = [1, 3, 4] + grid_width, vMerge = 4, "continue" + tc = cast(CT_Tc, element("w:tc")) + + tc._span_to_width(grid_width, top_tc_, vMerge) + + _move_content_to_.assert_called_once_with(tc, top_tc_) + assert _swallow_next_tc_.call_args_list == [ + call(tc, grid_width, top_tc_), + call(tc, grid_width, top_tc_), + ] + assert tc.vMerge == vMerge + + def it_knows_its_inner_content_block_item_elements(self): + tc = cast(CT_Tc, element("w:tc/(w:p,w:tbl,w:p)")) + assert [type(e) for e in tc.inner_content_elements] == [CT_P, CT_Tbl, CT_P] + + @pytest.mark.parametrize( + ("tr_cxml", "tc_idx", "grid_width", "expected_cxml"), + [ ( "w:tr/(w:tc/w:p,w:tc/w:p)", 0, @@ -307,8 +231,7 @@ def span_raise_fixture(self, request): 'w:tr/(w:tc/w:p/w:r/w:t"a",w:tc/w:p/w:r/w:t"b")', 0, 2, - 'w:tr/(w:tc/(w:tcPr/w:gridSpan{w:val=2},w:p/w:r/w:t"a",' - 'w:p/w:r/w:t"b"))', + 'w:tr/(w:tc/(w:tcPr/w:gridSpan{w:val=2},w:p/w:r/w:t"a",' 'w:p/w:r/w:t"b"))', ), ( "w:tr/(w:tc/(w:tcPr/w:gridSpan{w:val=2},w:p),w:tc/w:p)", @@ -322,75 +245,145 @@ def span_raise_fixture(self, request): 3, "w:tr/(w:tc/(w:tcPr/w:gridSpan{w:val=3},w:p))", ), - ] + ], + ) + def it_can_swallow_the_next_tc_help_merge( + self, tr_cxml: str, tc_idx: int, grid_width: int, expected_cxml: str + ): + tr = cast(CT_Row, element(tr_cxml)) + tc = top_tc = tr.tc_lst[tc_idx] + + tc._swallow_next_tc(grid_width, top_tc) + + assert tr.xml == xml(expected_cxml) + + @pytest.mark.parametrize( + ("tr_cxml", "tc_idx", "grid_width", "expected_cxml"), + [ + # both cells have a width + ( + "w:tr/(w:tc/(w:tcPr/w:tcW{w:w=1440,w:type=dxa},w:p)," + "w:tc/(w:tcPr/w:tcW{w:w=1440,w:type=dxa},w:p))", + 0, + 2, + "w:tr/(w:tc/(w:tcPr/(w:tcW{w:w=2880,w:type=dxa}," "w:gridSpan{w:val=2}),w:p))", + ), + # neither have a width + ( + "w:tr/(w:tc/w:p,w:tc/w:p)", + 0, + 2, + "w:tr/(w:tc/(w:tcPr/w:gridSpan{w:val=2},w:p))", + ), + # only second one has a width + ( + "w:tr/(w:tc/w:p," "w:tc/(w:tcPr/w:tcW{w:w=1440,w:type=dxa},w:p))", + 0, + 2, + "w:tr/(w:tc/(w:tcPr/w:gridSpan{w:val=2},w:p))", + ), + # only first one has a width + ( + "w:tr/(w:tc/(w:tcPr/w:tcW{w:w=1440,w:type=dxa},w:p)," "w:tc/w:p)", + 0, + 2, + "w:tr/(w:tc/(w:tcPr/(w:tcW{w:w=1440,w:type=dxa}," "w:gridSpan{w:val=2}),w:p))", + ), + ], ) - def swallow_fixture(self, request): - tr_cxml, tc_idx, grid_width, expected_tr_cxml = request.param - tr = element(tr_cxml) - tc = top_tc = tr[tc_idx] - expected_tr_xml = xml(expected_tr_cxml) - return tc, grid_width, top_tc, tr, expected_tr_xml - - @pytest.fixture( - params=[ + def it_adds_cell_widths_on_swallow( + self, tr_cxml: str, tc_idx: int, grid_width: int, expected_cxml: str + ): + tr = cast(CT_Row, element(tr_cxml)) + tc = top_tc = tr.tc_lst[tc_idx] + tc._swallow_next_tc(grid_width, top_tc) + assert tr.xml == xml(expected_cxml) + + @pytest.mark.parametrize( + ("tr_cxml", "tc_idx", "grid_width"), + [ ("w:tr/w:tc/w:p", 0, 2), ("w:tr/(w:tc/w:p,w:tc/(w:tcPr/w:gridSpan{w:val=2},w:p))", 0, 2), - ] + ], ) - def swallow_raise_fixture(self, request): - tr_cxml, tc_idx, grid_width = request.param - tr = element(tr_cxml) - tc = top_tc = tr[tc_idx] - return tc, grid_width, top_tc, tr - - @pytest.fixture(params=[(0, 0, 0), (4, 0, 0)]) - def tr_above_raise_fixture(self, request): - snippet_idx, row_idx, col_idx = request.param - tbl = parse_xml(snippet_seq("tbl-cells")[snippet_idx]) + def it_raises_on_invalid_swallow(self, tr_cxml: str, tc_idx: int, grid_width: int): + tr = cast(CT_Row, element(tr_cxml)) + tc = top_tc = tr.tc_lst[tc_idx] + + with pytest.raises(InvalidSpanError): + tc._swallow_next_tc(grid_width, top_tc) + + @pytest.mark.parametrize( + ("tc_cxml", "tc_2_cxml", "expected_tc_cxml", "expected_tc_2_cxml"), + [ + ("w:tc/w:p", "w:tc/w:p", "w:tc/w:p", "w:tc/w:p"), + ("w:tc/w:p", "w:tc/w:p/w:r", "w:tc/w:p", "w:tc/w:p/w:r"), + ("w:tc/w:p/w:r", "w:tc/w:p", "w:tc/w:p", "w:tc/w:p/w:r"), + ("w:tc/(w:p/w:r,w:sdt)", "w:tc/w:p", "w:tc/w:p", "w:tc/(w:p/w:r,w:sdt)"), + ( + "w:tc/(w:p/w:r,w:sdt)", + "w:tc/(w:tbl,w:p)", + "w:tc/w:p", + "w:tc/(w:tbl,w:p/w:r,w:sdt)", + ), + ], + ) + def it_can_move_its_content_to_help_merge( + self, tc_cxml: str, tc_2_cxml: str, expected_tc_cxml: str, expected_tc_2_cxml: str + ): + tc, tc_2 = cast(CT_Tc, element(tc_cxml)), cast(CT_Tc, element(tc_2_cxml)) + + tc._move_content_to(tc_2) + + assert tc.xml == xml(expected_tc_cxml) + assert tc_2.xml == xml(expected_tc_2_cxml) + + @pytest.mark.parametrize(("snippet_idx", "row_idx", "col_idx"), [(0, 0, 0), (4, 0, 0)]) + def it_raises_on_tr_above(self, snippet_idx: int, row_idx: int, col_idx: int): + tbl = cast(CT_Tbl, parse_xml(snippet_seq("tbl-cells")[snippet_idx])) tc = tbl.tr_lst[row_idx].tc_lst[col_idx] - return tc - # fixture components --------------------------------------------- + with pytest.raises(ValueError, match="no tr above topmost tr"): + tc._tr_above + + # fixtures ------------------------------------------------------- @pytest.fixture - def grid_span_(self, request): + def grid_span_(self, request: FixtureRequest): return property_mock(request, CT_Tc, "grid_span") @pytest.fixture - def _grow_to_(self, request): + def _grow_to_(self, request: FixtureRequest): return method_mock(request, CT_Tc, "_grow_to") @pytest.fixture - def _move_content_to_(self, request): + def _move_content_to_(self, request: FixtureRequest): return method_mock(request, CT_Tc, "_move_content_to") @pytest.fixture - def _span_dimensions_(self, request): + def _span_dimensions_(self, request: FixtureRequest): return method_mock(request, CT_Tc, "_span_dimensions") @pytest.fixture - def _span_to_width_(self, request): + def _span_to_width_(self, request: FixtureRequest): return method_mock(request, CT_Tc, "_span_to_width", autospec=False) - def _snippet_tbl(self, idx): - """ - Return a element for snippet at `idx` in 'tbl-cells' snippet - file. - """ - return parse_xml(snippet_seq("tbl-cells")[idx]) + def _snippet_tbl(self, idx: int) -> CT_Tbl: + """A element for snippet at `idx` in 'tbl-cells' snippet file.""" + return cast(CT_Tbl, parse_xml(snippet_seq("tbl-cells")[idx])) @pytest.fixture - def _swallow_next_tc_(self, request): + def _swallow_next_tc_(self, request: FixtureRequest): return method_mock(request, CT_Tc, "_swallow_next_tc") @pytest.fixture - def _tbl_(self, request): + def _tbl_(self, request: FixtureRequest): return property_mock(request, CT_Tc, "_tbl") @pytest.fixture - def top_tc_(self, request): + def top_tc_(self, request: FixtureRequest): return instance_mock(request, CT_Tc) @pytest.fixture - def tr_(self, request): + def tr_(self, request: FixtureRequest): return instance_mock(request, CT_Row) diff --git a/tests/oxml/unitdata/table.py b/tests/oxml/unitdata/table.py deleted file mode 100644 index 4f760c1a8..000000000 --- a/tests/oxml/unitdata/table.py +++ /dev/null @@ -1,88 +0,0 @@ -"""Test data builders for text XML elements.""" - -from ...unitdata import BaseBuilder -from .shared import CT_StringBuilder - - -class CT_RowBuilder(BaseBuilder): - __tag__ = "w:tr" - __nspfxs__ = ("w",) - __attrs__ = ("w:w",) - - -class CT_TblBuilder(BaseBuilder): - __tag__ = "w:tbl" - __nspfxs__ = ("w",) - __attrs__ = () - - -class CT_TblGridBuilder(BaseBuilder): - __tag__ = "w:tblGrid" - __nspfxs__ = ("w",) - __attrs__ = ("w:w",) - - -class CT_TblGridColBuilder(BaseBuilder): - __tag__ = "w:gridCol" - __nspfxs__ = ("w",) - __attrs__ = ("w:w",) - - -class CT_TblPrBuilder(BaseBuilder): - __tag__ = "w:tblPr" - __nspfxs__ = ("w",) - __attrs__ = () - - -class CT_TblWidthBuilder(BaseBuilder): - __tag__ = "w:tblW" - __nspfxs__ = ("w",) - __attrs__ = ("w:w", "w:type") - - -class CT_TcBuilder(BaseBuilder): - __tag__ = "w:tc" - __nspfxs__ = ("w",) - __attrs__ = ("w:id",) - - -class CT_TcPrBuilder(BaseBuilder): - __tag__ = "w:tcPr" - __nspfxs__ = ("w",) - __attrs__ = () - - -def a_gridCol(): - return CT_TblGridColBuilder() - - -def a_tbl(): - return CT_TblBuilder() - - -def a_tblGrid(): - return CT_TblGridBuilder() - - -def a_tblPr(): - return CT_TblPrBuilder() - - -def a_tblStyle(): - return CT_StringBuilder("w:tblStyle") - - -def a_tblW(): - return CT_TblWidthBuilder() - - -def a_tc(): - return CT_TcBuilder() - - -def a_tcPr(): - return CT_TcPrBuilder() - - -def a_tr(): - return CT_RowBuilder() diff --git a/tests/test_table.py b/tests/test_table.py index 0ef273e3f..479d670c6 100644 --- a/tests/test_table.py +++ b/tests/test_table.py @@ -1,7 +1,14 @@ +# pyright: reportPrivateUsage=false + """Test suite for the docx.table module.""" +from __future__ import annotations + +from typing import cast + import pytest +from docx.document import Document from docx.enum.style import WD_STYLE_TYPE from docx.enum.table import ( WD_ALIGN_VERTICAL, @@ -10,37 +17,45 @@ WD_TABLE_DIRECTION, ) from docx.oxml.parser import parse_xml -from docx.oxml.table import CT_Tc +from docx.oxml.table import CT_Row, CT_Tbl, CT_TblGridCol, CT_Tc from docx.parts.document import DocumentPart -from docx.shared import Inches +from docx.shared import Emu, Inches, Length from docx.table import Table, _Cell, _Column, _Columns, _Row, _Rows from docx.text.paragraph import Paragraph -from .oxml.unitdata.table import a_gridCol, a_tbl, a_tblGrid, a_tc, a_tr -from .oxml.unitdata.text import a_p from .unitutil.cxml import element, xml from .unitutil.file import snippet_seq -from .unitutil.mock import instance_mock, property_mock +from .unitutil.mock import FixtureRequest, Mock, instance_mock, property_mock class DescribeTable: - def it_can_add_a_row(self, add_row_fixture): - table, expected_xml = add_row_fixture + """Unit-test suite for `docx.table._Rows` objects.""" + + def it_can_add_a_row(self, document_: Mock): + snippets = snippet_seq("add-row-col") + tbl = cast(CT_Tbl, parse_xml(snippets[0])) + table = Table(tbl, document_) + row = table.add_row() - assert table._tbl.xml == expected_xml + + assert table._tbl.xml == snippets[1] assert isinstance(row, _Row) assert row._tr is table._tbl.tr_lst[-1] assert row._parent is table - def it_can_add_a_column(self, add_column_fixture): - table, width, expected_xml = add_column_fixture - column = table.add_column(width) - assert table._tbl.xml == expected_xml + def it_can_add_a_column(self, document_: Mock): + snippets = snippet_seq("add-row-col") + tbl = cast(CT_Tbl, parse_xml(snippets[0])) + table = Table(tbl, document_) + + column = table.add_column(Inches(1.5)) + + assert table._tbl.xml == snippets[2] assert isinstance(column, _Column) assert column._gridCol is table._tbl.tblGrid.gridCol_lst[-1] assert column._parent is table - def it_provides_access_to_a_cell_by_row_and_col_indices(self, table): + def it_provides_access_to_a_cell_by_row_and_col_indices(self, table: Table): for row_idx in range(2): for col_idx in range(2): cell = table.cell(row_idx, col_idx) @@ -49,153 +64,95 @@ def it_provides_access_to_a_cell_by_row_and_col_indices(self, table): tc = tr.tc_lst[col_idx] assert tc is cell._tc - def it_provides_access_to_the_table_rows(self, table): + def it_provides_access_to_the_table_rows(self, table: Table): rows = table.rows assert isinstance(rows, _Rows) - def it_provides_access_to_the_table_columns(self, table): + def it_provides_access_to_the_table_columns(self, table: Table): columns = table.columns assert isinstance(columns, _Columns) - def it_provides_access_to_the_cells_in_a_column(self, col_cells_fixture): - table, column_idx, expected_cells = col_cells_fixture - column_cells = table.column_cells(column_idx) - assert column_cells == expected_cells - - def it_provides_access_to_the_cells_in_a_row(self, row_cells_fixture): - table, row_idx, expected_cells = row_cells_fixture - row_cells = table.row_cells(row_idx) - assert row_cells == expected_cells - - def it_knows_its_alignment_setting(self, alignment_get_fixture): - table, expected_value = alignment_get_fixture - assert table.alignment == expected_value - - def it_can_change_its_alignment_setting(self, alignment_set_fixture): - table, new_value, expected_xml = alignment_set_fixture - table.alignment = new_value - assert table._tbl.xml == expected_xml - - def it_knows_whether_it_should_autofit(self, autofit_get_fixture): - table, expected_value = autofit_get_fixture - assert table.autofit is expected_value - - def it_can_change_its_autofit_setting(self, autofit_set_fixture): - table, new_value, expected_xml = autofit_set_fixture - table.autofit = new_value - assert table._tbl.xml == expected_xml - - def it_knows_it_is_the_table_its_children_belong_to(self, table_fixture): - table = table_fixture - assert table.table is table - - def it_knows_its_direction(self, direction_get_fixture): - table, expected_value = direction_get_fixture - assert table.table_direction == expected_value - - def it_can_change_its_direction(self, direction_set_fixture): - table, new_value, expected_xml = direction_set_fixture - table.table_direction = new_value - assert table._element.xml == expected_xml - - def it_knows_its_table_style(self, style_get_fixture): - table, style_id_, style_ = style_get_fixture - style = table.style - table.part.get_style.assert_called_once_with(style_id_, WD_STYLE_TYPE.TABLE) - assert style is style_ - - def it_can_change_its_table_style(self, style_set_fixture): - table, value, expected_xml = style_set_fixture - table.style = value - table.part.get_style_id.assert_called_once_with(value, WD_STYLE_TYPE.TABLE) - assert table._tbl.xml == expected_xml + def it_provides_access_to_the_cells_in_a_column( + self, _cells_: Mock, _column_count_: Mock, document_: Mock + ): + table = Table(cast(CT_Tbl, element("w:tbl")), document_) + _cells_.return_value = [0, 1, 2, 3, 4, 5, 6, 7, 8] + _column_count_.return_value = 3 + column_idx = 1 - def it_provides_access_to_its_cells_to_help(self, cells_fixture): - table, cell_count, unique_count, matches = cells_fixture - cells = table._cells - assert len(cells) == cell_count - assert len(set(cells)) == unique_count - for matching_idxs in matches: - comparator_idx = matching_idxs[0] - for idx in matching_idxs[1:]: - assert cells[idx] is cells[comparator_idx] + column_cells = table.column_cells(column_idx) - def it_knows_its_column_count_to_help(self, column_count_fixture): - table, expected_value = column_count_fixture - column_count = table._column_count - assert column_count == expected_value + assert column_cells == [1, 4, 7] - # fixtures ------------------------------------------------------- + def it_provides_access_to_the_cells_in_a_row( + self, _cells_: Mock, _column_count_: Mock, document_: Mock + ): + table = Table(cast(CT_Tbl, element("w:tbl")), document_) + _cells_.return_value = [0, 1, 2, 3, 4, 5, 6, 7, 8] + _column_count_.return_value = 3 - @pytest.fixture - def add_column_fixture(self): - snippets = snippet_seq("add-row-col") - tbl = parse_xml(snippets[0]) - table = Table(tbl, None) - width = Inches(1.5) - expected_xml = snippets[2] - return table, width, expected_xml + row_cells = table.row_cells(1) - @pytest.fixture - def add_row_fixture(self): - snippets = snippet_seq("add-row-col") - tbl = parse_xml(snippets[0]) - table = Table(tbl, None) - expected_xml = snippets[1] - return table, expected_xml + assert row_cells == [3, 4, 5] - @pytest.fixture( - params=[ + @pytest.mark.parametrize( + ("tbl_cxml", "expected_value"), + [ ("w:tbl/w:tblPr", None), ("w:tbl/w:tblPr/w:jc{w:val=center}", WD_TABLE_ALIGNMENT.CENTER), ("w:tbl/w:tblPr/w:jc{w:val=right}", WD_TABLE_ALIGNMENT.RIGHT), ("w:tbl/w:tblPr/w:jc{w:val=left}", WD_TABLE_ALIGNMENT.LEFT), - ] + ], ) - def alignment_get_fixture(self, request): - tbl_cxml, expected_value = request.param - table = Table(element(tbl_cxml), None) - return table, expected_value + def it_knows_its_alignment_setting( + self, tbl_cxml: str, expected_value: WD_TABLE_ALIGNMENT | None, document_: Mock + ): + table = Table(cast(CT_Tbl, element(tbl_cxml)), document_) + assert table.alignment == expected_value - @pytest.fixture( - params=[ - ( - "w:tbl/w:tblPr", - WD_TABLE_ALIGNMENT.LEFT, - "w:tbl/w:tblPr/w:jc{w:val=left}", - ), + @pytest.mark.parametrize( + ("tbl_cxml", "new_value", "expected_cxml"), + [ + ("w:tbl/w:tblPr", WD_TABLE_ALIGNMENT.LEFT, "w:tbl/w:tblPr/w:jc{w:val=left}"), ( "w:tbl/w:tblPr/w:jc{w:val=left}", WD_TABLE_ALIGNMENT.RIGHT, "w:tbl/w:tblPr/w:jc{w:val=right}", ), ("w:tbl/w:tblPr/w:jc{w:val=right}", None, "w:tbl/w:tblPr"), - ] + ], ) - def alignment_set_fixture(self, request): - tbl_cxml, new_value, expected_tbl_cxml = request.param - table = Table(element(tbl_cxml), None) - expected_xml = xml(expected_tbl_cxml) - return table, new_value, expected_xml - - @pytest.fixture( - params=[ + def it_can_change_its_alignment_setting( + self, + tbl_cxml: str, + new_value: WD_TABLE_ALIGNMENT | None, + expected_cxml: str, + document_: Mock, + ): + table = Table(cast(CT_Tbl, element(tbl_cxml)), document_) + table.alignment = new_value + assert table._tbl.xml == xml(expected_cxml) + + @pytest.mark.parametrize( + ("tbl_cxml", "expected_value"), + [ ("w:tbl/w:tblPr", True), ("w:tbl/w:tblPr/w:tblLayout", True), ("w:tbl/w:tblPr/w:tblLayout{w:type=autofit}", True), ("w:tbl/w:tblPr/w:tblLayout{w:type=fixed}", False), - ] + ], ) - def autofit_get_fixture(self, request): - tbl_cxml, expected_autofit = request.param - table = Table(element(tbl_cxml), None) - return table, expected_autofit + def it_knows_whether_it_should_autofit( + self, tbl_cxml: str, expected_value: bool, document_: Mock + ): + table = Table(cast(CT_Tbl, element(tbl_cxml)), document_) + assert table.autofit is expected_value - @pytest.fixture( - params=[ + @pytest.mark.parametrize( + ("tbl_cxml", "new_value", "expected_cxml"), + [ ("w:tbl/w:tblPr", True, "w:tbl/w:tblPr/w:tblLayout{w:type=autofit}"), ("w:tbl/w:tblPr", False, "w:tbl/w:tblPr/w:tblLayout{w:type=fixed}"), - ("w:tbl/w:tblPr", None, "w:tbl/w:tblPr/w:tblLayout{w:type=fixed}"), ( "w:tbl/w:tblPr/w:tblLayout{w:type=fixed}", True, @@ -206,60 +163,36 @@ def autofit_get_fixture(self, request): False, "w:tbl/w:tblPr/w:tblLayout{w:type=fixed}", ), - ] + ], ) - def autofit_set_fixture(self, request): - tbl_cxml, new_value, expected_tbl_cxml = request.param - table = Table(element(tbl_cxml), None) - expected_xml = xml(expected_tbl_cxml) - return table, new_value, expected_xml - - @pytest.fixture( - params=[ - (0, 9, 9, ()), - (1, 9, 8, ((0, 1),)), - (2, 9, 8, ((1, 4),)), - (3, 9, 6, ((0, 1, 3, 4),)), - (4, 9, 4, ((0, 1), (3, 6), (4, 5, 7, 8))), - ] - ) - def cells_fixture(self, request): - snippet_idx, cell_count, unique_count, matches = request.param - tbl_xml = snippet_seq("tbl-cells")[snippet_idx] - table = Table(parse_xml(tbl_xml), None) - return table, cell_count, unique_count, matches - - @pytest.fixture - def col_cells_fixture(self, _cells_, _column_count_): - table = Table(None, None) - _cells_.return_value = [0, 1, 2, 3, 4, 5, 6, 7, 8] - _column_count_.return_value = 3 - column_idx = 1 - expected_cells = [1, 4, 7] - return table, column_idx, expected_cells + def it_can_change_its_autofit_setting( + self, tbl_cxml: str, new_value: bool, expected_cxml: str, document_: Mock + ): + table = Table(cast(CT_Tbl, element(tbl_cxml)), document_) + table.autofit = new_value + assert table._tbl.xml == xml(expected_cxml) - @pytest.fixture - def column_count_fixture(self): - tbl_cxml = "w:tbl/w:tblGrid/(w:gridCol,w:gridCol,w:gridCol)" - expected_value = 3 - table = Table(element(tbl_cxml), None) - return table, expected_value + def it_knows_it_is_the_table_its_children_belong_to(self, table: Table): + assert table.table is table - @pytest.fixture( - params=[ + @pytest.mark.parametrize( + ("tbl_cxml", "expected_value"), + [ ("w:tbl/w:tblPr", None), ("w:tbl/w:tblPr/w:bidiVisual", WD_TABLE_DIRECTION.RTL), ("w:tbl/w:tblPr/w:bidiVisual{w:val=0}", WD_TABLE_DIRECTION.LTR), ("w:tbl/w:tblPr/w:bidiVisual{w:val=on}", WD_TABLE_DIRECTION.RTL), - ] + ], ) - def direction_get_fixture(self, request): - tbl_cxml, expected_value = request.param - table = Table(element(tbl_cxml), None) - return table, expected_value - - @pytest.fixture( - params=[ + def it_knows_its_direction( + self, tbl_cxml: str, expected_value: WD_TABLE_DIRECTION | None, document_: Mock + ): + tbl = cast(CT_Tbl, element(tbl_cxml)) + assert Table(tbl, document_).table_direction == expected_value + + @pytest.mark.parametrize( + ("tbl_cxml", "new_value", "expected_cxml"), + [ ("w:tbl/w:tblPr", WD_TABLE_DIRECTION.RTL, "w:tbl/w:tblPr/w:bidiVisual"), ( "w:tbl/w:tblPr/w:bidiVisual", @@ -272,33 +205,28 @@ def direction_get_fixture(self, request): "w:tbl/w:tblPr/w:bidiVisual", ), ("w:tbl/w:tblPr/w:bidiVisual{w:val=1}", None, "w:tbl/w:tblPr"), - ] + ], ) - def direction_set_fixture(self, request): - tbl_cxml, new_value, expected_cxml = request.param - table = Table(element(tbl_cxml), None) - expected_xml = xml(expected_cxml) - return table, new_value, expected_xml + def it_can_change_its_direction( + self, tbl_cxml: str, new_value: WD_TABLE_DIRECTION, expected_cxml: str, document_: Mock + ): + table = Table(cast(CT_Tbl, element(tbl_cxml)), document_) + table.table_direction = new_value + assert table._element.xml == xml(expected_cxml) - @pytest.fixture - def row_cells_fixture(self, _cells_, _column_count_): - table = Table(None, None) - _cells_.return_value = [0, 1, 2, 3, 4, 5, 6, 7, 8] - _column_count_.return_value = 3 - row_idx = 1 - expected_cells = [3, 4, 5] - return table, row_idx, expected_cells + def it_knows_its_table_style(self, part_prop_: Mock, document_part_: Mock, document_: Mock): + part_prop_.return_value = document_part_ + style_ = document_part_.get_style.return_value + table = Table(cast(CT_Tbl, element("w:tbl/w:tblPr/w:tblStyle{w:val=BarBaz}")), document_) - @pytest.fixture - def style_get_fixture(self, part_prop_): - style_id = "Barbaz" - tbl_cxml = "w:tbl/w:tblPr/w:tblStyle{w:val=%s}" % style_id - table = Table(element(tbl_cxml), None) - style_ = part_prop_.return_value.get_style.return_value - return table, style_id, style_ - - @pytest.fixture( - params=[ + style = table.style + + document_part_.get_style.assert_called_once_with("BarBaz", WD_STYLE_TYPE.TABLE) + assert style is style_ + + @pytest.mark.parametrize( + ("tbl_cxml", "new_value", "style_id", "expected_cxml"), + [ ("w:tbl/w:tblPr", "Tbl A", "TblA", "w:tbl/w:tblPr/w:tblStyle{w:val=TblA}"), ( "w:tbl/w:tblPr/w:tblStyle{w:val=TblA}", @@ -307,155 +235,166 @@ def style_get_fixture(self, part_prop_): "w:tbl/w:tblPr/w:tblStyle{w:val=TblB}", ), ("w:tbl/w:tblPr/w:tblStyle{w:val=TblB}", None, None, "w:tbl/w:tblPr"), - ] + ], + ) + def it_can_change_its_table_style( + self, + tbl_cxml: str, + new_value: str | None, + style_id: str | None, + expected_cxml: str, + document_: Mock, + part_prop_: Mock, + document_part_: Mock, + ): + table = Table(cast(CT_Tbl, element(tbl_cxml)), document_) + part_prop_.return_value = document_part_ + document_part_.get_style_id.return_value = style_id + + table.style = new_value + + document_part_.get_style_id.assert_called_once_with(new_value, WD_STYLE_TYPE.TABLE) + assert table._tbl.xml == xml(expected_cxml) + + @pytest.mark.parametrize( + ("snippet_idx", "cell_count", "unique_count", "matches"), + [ + (0, 9, 9, ()), + (1, 9, 8, ((0, 1),)), + (2, 9, 8, ((1, 4),)), + (3, 9, 6, ((0, 1, 3, 4),)), + (4, 9, 4, ((0, 1), (3, 6), (4, 5, 7, 8))), + ], ) - def style_set_fixture(self, request, part_prop_): - tbl_cxml, value, style_id, expected_cxml = request.param - table = Table(element(tbl_cxml), None) - part_prop_.return_value.get_style_id.return_value = style_id - expected_xml = xml(expected_cxml) - return table, value, expected_xml + def it_provides_access_to_its_cells_to_help( + self, + snippet_idx: int, + cell_count: int, + unique_count: int, + matches: tuple[tuple[int, ...]], + document_: Mock, + ): + tbl_xml = snippet_seq("tbl-cells")[snippet_idx] + table = Table(cast(CT_Tbl, parse_xml(tbl_xml)), document_) - @pytest.fixture - def table_fixture(self): - table = Table(None, None) - return table + cells = table._cells - # fixture components --------------------------------------------- + assert len(cells) == cell_count + assert len(set(cells)) == unique_count + for matching_idxs in matches: + comparator_idx = matching_idxs[0] + for idx in matching_idxs[1:]: + assert cells[idx] is cells[comparator_idx] + + def it_knows_its_column_count_to_help(self, document_: Mock): + tbl_cxml = "w:tbl/w:tblGrid/(w:gridCol,w:gridCol,w:gridCol)" + expected_value = 3 + table = Table(cast(CT_Tbl, element(tbl_cxml)), document_) + + column_count = table._column_count + + assert column_count == expected_value + + # fixtures ------------------------------------------------------- @pytest.fixture - def _cells_(self, request): + def _cells_(self, request: FixtureRequest): return property_mock(request, Table, "_cells") @pytest.fixture - def _column_count_(self, request): + def _column_count_(self, request: FixtureRequest): return property_mock(request, Table, "_column_count") @pytest.fixture - def document_part_(self, request): + def document_(self, request: FixtureRequest): + return instance_mock(request, Document) + + @pytest.fixture + def document_part_(self, request: FixtureRequest): return instance_mock(request, DocumentPart) @pytest.fixture - def part_prop_(self, request, document_part_): - return property_mock(request, Table, "part", return_value=document_part_) + def part_prop_(self, request: FixtureRequest): + return property_mock(request, Table, "part") @pytest.fixture - def table(self): - tbl = _tbl_bldr(rows=2, cols=2).element - table = Table(tbl, None) - return table + def table(self, document_: Mock): + tbl_cxml = "w:tbl/(w:tblGrid/(w:gridCol,w:gridCol),w:tr/(w:tc,w:tc),w:tr/(w:tc,w:tc))" + return Table(cast(CT_Tbl, element(tbl_cxml)), document_) class Describe_Cell: - def it_knows_what_text_it_contains(self, text_get_fixture): - cell, expected_text = text_get_fixture + """Unit-test suite for `docx.table._Cell` objects.""" + + @pytest.mark.parametrize( + ("tc_cxml", "expected_value"), + [ + ("w:tc", 1), + ("w:tc/w:tcPr", 1), + ("w:tc/w:tcPr/w:gridSpan{w:val=1}", 1), + ("w:tc/w:tcPr/w:gridSpan{w:val=4}", 4), + ], + ) + def it_knows_its_grid_span(self, tc_cxml: str, expected_value: int, parent_: Mock): + cell = _Cell(cast(CT_Tc, element(tc_cxml)), parent_) + assert cell.grid_span == expected_value + + @pytest.mark.parametrize( + ("tc_cxml", "expected_text"), + [ + ("w:tc", ""), + ('w:tc/w:p/w:r/w:t"foobar"', "foobar"), + ('w:tc/(w:p/w:r/w:t"foo",w:p/w:r/w:t"bar")', "foo\nbar"), + ('w:tc/(w:tcPr,w:p/w:r/w:t"foobar")', "foobar"), + ('w:tc/w:p/w:r/(w:t"fo",w:tab,w:t"ob",w:br,w:t"ar",w:br)', "fo\tob\nar\n"), + ], + ) + def it_knows_what_text_it_contains(self, tc_cxml: str, expected_text: str, parent_: Mock): + cell = _Cell(cast(CT_Tc, element(tc_cxml)), parent_) text = cell.text assert text == expected_text - def it_can_replace_its_content_with_a_string_of_text(self, text_set_fixture): - cell, text, expected_xml = text_set_fixture - cell.text = text - assert cell._tc.xml == expected_xml - - def it_knows_its_vertical_alignment(self, alignment_get_fixture): - cell, expected_value = alignment_get_fixture - vertical_alignment = cell.vertical_alignment - assert vertical_alignment == expected_value - - def it_can_change_its_vertical_alignment(self, alignment_set_fixture): - cell, new_value, expected_xml = alignment_set_fixture - cell.vertical_alignment = new_value - assert cell._element.xml == expected_xml - - def it_knows_its_width_in_EMU(self, width_get_fixture): - cell, expected_width = width_get_fixture - assert cell.width == expected_width - - def it_can_change_its_width(self, width_set_fixture): - cell, value, expected_xml = width_set_fixture - cell.width = value - assert cell.width == value - assert cell._tc.xml == expected_xml - - def it_provides_access_to_the_paragraphs_it_contains(self, paragraphs_fixture): - cell = paragraphs_fixture - paragraphs = cell.paragraphs - assert len(paragraphs) == 2 - count = 0 - for idx, paragraph in enumerate(paragraphs): - assert isinstance(paragraph, Paragraph) - assert paragraph is paragraphs[idx] - count += 1 - assert count == 2 - - def it_provides_access_to_the_tables_it_contains(self, tables_fixture): - # test len(), iterable, and indexed access - cell, expected_count = tables_fixture - tables = cell.tables - assert len(tables) == expected_count - count = 0 - for idx, table in enumerate(tables): - assert isinstance(table, Table) - assert tables[idx] is table - count += 1 - assert count == expected_count - - def it_can_add_a_paragraph(self, add_paragraph_fixture): - cell, expected_xml = add_paragraph_fixture - p = cell.add_paragraph() - assert cell._tc.xml == expected_xml - assert isinstance(p, Paragraph) - - def it_can_add_a_table(self, add_table_fixture): - cell, expected_xml = add_table_fixture - table = cell.add_table(rows=2, cols=2) - assert cell._element.xml == expected_xml - assert isinstance(table, Table) - - def it_can_merge_itself_with_other_cells(self, merge_fixture): - cell, other_cell, merged_tc_ = merge_fixture - merged_cell = cell.merge(other_cell) - cell._tc.merge.assert_called_once_with(other_cell._tc) - assert isinstance(merged_cell, _Cell) - assert merged_cell._tc is merged_tc_ - assert merged_cell._parent is cell._parent - - # fixtures ------------------------------------------------------- - - @pytest.fixture( - params=[ - ("w:tc", "w:tc/w:p"), - ("w:tc/w:p", "w:tc/(w:p, w:p)"), - ("w:tc/w:tbl", "w:tc/(w:tbl, w:p)"), - ] + @pytest.mark.parametrize( + ("tc_cxml", "new_text", "expected_cxml"), + [ + ("w:tc/w:p", "foobar", 'w:tc/w:p/w:r/w:t"foobar"'), + ( + "w:tc/w:p", + "fo\tob\rar\n", + 'w:tc/w:p/w:r/(w:t"fo",w:tab,w:t"ob",w:br,w:t"ar",w:br)', + ), + ( + "w:tc/(w:tcPr, w:p, w:tbl, w:p)", + "foobar", + 'w:tc/(w:tcPr, w:p/w:r/w:t"foobar")', + ), + ], ) - def add_paragraph_fixture(self, request): - tc_cxml, after_tc_cxml = request.param - cell = _Cell(element(tc_cxml), None) - expected_xml = xml(after_tc_cxml) - return cell, expected_xml - - @pytest.fixture - def add_table_fixture(self, request): - cell = _Cell(element("w:tc/w:p"), None) - expected_xml = snippet_seq("new-tbl")[1] - return cell, expected_xml - - @pytest.fixture( - params=[ + def it_can_replace_its_content_with_a_string_of_text( + self, tc_cxml: str, new_text: str, expected_cxml: str, parent_: Mock + ): + cell = _Cell(cast(CT_Tc, element(tc_cxml)), parent_) + cell.text = new_text + assert cell._tc.xml == xml(expected_cxml) + + @pytest.mark.parametrize( + ("tc_cxml", "expected_value"), + [ ("w:tc", None), ("w:tc/w:tcPr", None), ("w:tc/w:tcPr/w:vAlign{w:val=bottom}", WD_ALIGN_VERTICAL.BOTTOM), ("w:tc/w:tcPr/w:vAlign{w:val=top}", WD_ALIGN_VERTICAL.TOP), - ] + ], ) - def alignment_get_fixture(self, request): - tc_cxml, expected_value = request.param - cell = _Cell(element(tc_cxml), None) - return cell, expected_value - - @pytest.fixture( - params=[ + def it_knows_its_vertical_alignment( + self, tc_cxml: str, expected_value: WD_ALIGN_VERTICAL | None, parent_: Mock + ): + cell = _Cell(cast(CT_Tc, element(tc_cxml)), parent_) + assert cell.vertical_alignment == expected_value + + @pytest.mark.parametrize( + ("tc_cxml", "new_value", "expected_cxml"), + [ ("w:tc", WD_ALIGN_VERTICAL.TOP, "w:tc/w:tcPr/w:vAlign{w:val=top}"), ( "w:tc/w:tcPr", @@ -470,330 +409,300 @@ def alignment_get_fixture(self, request): ("w:tc/w:tcPr/w:vAlign{w:val=center}", None, "w:tc/w:tcPr"), ("w:tc", None, "w:tc/w:tcPr"), ("w:tc/w:tcPr", None, "w:tc/w:tcPr"), - ] - ) - def alignment_set_fixture(self, request): - cxml, new_value, expected_cxml = request.param - cell = _Cell(element(cxml), None) - expected_xml = xml(expected_cxml) - return cell, new_value, expected_xml - - @pytest.fixture - def merge_fixture(self, tc_, tc_2_, parent_, merged_tc_): - cell, other_cell = _Cell(tc_, parent_), _Cell(tc_2_, parent_) - tc_.merge.return_value = merged_tc_ - return cell, other_cell, merged_tc_ - - @pytest.fixture - def paragraphs_fixture(self): - return _Cell(element("w:tc/(w:p, w:p)"), None) - - @pytest.fixture( - params=[ - ("w:tc", 0), - ("w:tc/w:tbl", 1), - ("w:tc/(w:tbl,w:tbl)", 2), - ("w:tc/(w:p,w:tbl)", 1), - ("w:tc/(w:tbl,w:tbl,w:p)", 2), - ] + ], ) - def tables_fixture(self, request): - cell_cxml, expected_count = request.param - cell = _Cell(element(cell_cxml), None) - return cell, expected_count - - @pytest.fixture( - params=[ - ("w:tc", ""), - ('w:tc/w:p/w:r/w:t"foobar"', "foobar"), - ('w:tc/(w:p/w:r/w:t"foo",w:p/w:r/w:t"bar")', "foo\nbar"), - ('w:tc/(w:tcPr,w:p/w:r/w:t"foobar")', "foobar"), - ('w:tc/w:p/w:r/(w:t"fo",w:tab,w:t"ob",w:br,w:t"ar",w:br)', "fo\tob\nar\n"), - ] - ) - def text_get_fixture(self, request): - tc_cxml, expected_text = request.param - cell = _Cell(element(tc_cxml), None) - return cell, expected_text + def it_can_change_its_vertical_alignment( + self, tc_cxml: str, new_value: WD_ALIGN_VERTICAL | None, expected_cxml: str, parent_: Mock + ): + cell = _Cell(cast(CT_Tc, element(tc_cxml)), parent_) + cell.vertical_alignment = new_value + assert cell._element.xml == xml(expected_cxml) - @pytest.fixture( - params=[ - ("w:tc/w:p", "foobar", 'w:tc/w:p/w:r/w:t"foobar"'), - ( - "w:tc/w:p", - "fo\tob\rar\n", - 'w:tc/w:p/w:r/(w:t"fo",w:tab,w:t"ob",w:br,w:t"ar",w:br)', - ), - ( - "w:tc/(w:tcPr, w:p, w:tbl, w:p)", - "foobar", - 'w:tc/(w:tcPr, w:p/w:r/w:t"foobar")', - ), - ] - ) - def text_set_fixture(self, request): - tc_cxml, new_text, expected_cxml = request.param - cell = _Cell(element(tc_cxml), None) - expected_xml = xml(expected_cxml) - return cell, new_text, expected_xml - - @pytest.fixture( - params=[ + @pytest.mark.parametrize( + ("tc_cxml", "expected_value"), + [ ("w:tc", None), ("w:tc/w:tcPr", None), ("w:tc/w:tcPr/w:tcW{w:w=25%,w:type=pct}", None), ("w:tc/w:tcPr/w:tcW{w:w=1440,w:type=dxa}", 914400), - ] + ], ) - def width_get_fixture(self, request): - tc_cxml, expected_width = request.param - cell = _Cell(element(tc_cxml), None) - return cell, expected_width + def it_knows_its_width_in_EMU(self, tc_cxml: str, expected_value: int | None, parent_: Mock): + cell = _Cell(cast(CT_Tc, element(tc_cxml)), parent_) + assert cell.width == expected_value - @pytest.fixture( - params=[ + @pytest.mark.parametrize( + ("tc_cxml", "new_value", "expected_cxml"), + [ ("w:tc", Inches(1), "w:tc/w:tcPr/w:tcW{w:w=1440,w:type=dxa}"), ( "w:tc/w:tcPr/w:tcW{w:w=25%,w:type=pct}", Inches(2), "w:tc/w:tcPr/w:tcW{w:w=2880,w:type=dxa}", ), - ] + ], + ) + def it_can_change_its_width( + self, tc_cxml: str, new_value: Length, expected_cxml: str, parent_: Mock + ): + cell = _Cell(cast(CT_Tc, element(tc_cxml)), parent_) + cell.width = new_value + assert cell.width == new_value + assert cell._tc.xml == xml(expected_cxml) + + def it_provides_access_to_the_paragraphs_it_contains(self, parent_: Mock): + cell = _Cell(cast(CT_Tc, element("w:tc/(w:p, w:p)")), parent_) + + paragraphs = cell.paragraphs + + # -- every w:p produces a Paragraph instance -- + assert len(paragraphs) == 2 + assert all(isinstance(p, Paragraph) for p in paragraphs) + # -- the return value is iterable and indexable -- + assert all(p is paragraphs[idx] for idx, p in enumerate(paragraphs)) + + @pytest.mark.parametrize( + ("tc_cxml", "expected_table_count"), + [ + ("w:tc", 0), + ("w:tc/w:tbl", 1), + ("w:tc/(w:tbl,w:tbl)", 2), + ("w:tc/(w:p,w:tbl)", 1), + ("w:tc/(w:tbl,w:tbl,w:p)", 2), + ], + ) + def it_provides_access_to_the_tables_it_contains( + self, tc_cxml: str, expected_table_count: int, parent_: Mock + ): + cell = _Cell(cast(CT_Tc, element(tc_cxml)), parent_) + + tables = cell.tables + + # --- test len(), iterable, and indexed access + assert len(tables) == expected_table_count + assert all(isinstance(t, Table) for t in tables) + assert all(t is tables[idx] for idx, t in enumerate(tables)) + + @pytest.mark.parametrize( + ("tc_cxml", "expected_cxml"), + [ + ("w:tc", "w:tc/w:p"), + ("w:tc/w:p", "w:tc/(w:p, w:p)"), + ("w:tc/w:tbl", "w:tc/(w:tbl, w:p)"), + ], ) - def width_set_fixture(self, request): - tc_cxml, new_value, expected_cxml = request.param - cell = _Cell(element(tc_cxml), None) - expected_xml = xml(expected_cxml) - return cell, new_value, expected_xml + def it_can_add_a_paragraph(self, tc_cxml: str, expected_cxml: str, parent_: Mock): + cell = _Cell(cast(CT_Tc, element(tc_cxml)), parent_) - # fixture components --------------------------------------------- + p = cell.add_paragraph() + + assert isinstance(p, Paragraph) + assert cell._tc.xml == xml(expected_cxml) + + def it_can_add_a_table(self, parent_: Mock): + cell = _Cell(cast(CT_Tc, element("w:tc/w:p")), parent_) + + table = cell.add_table(rows=2, cols=2) + + assert isinstance(table, Table) + assert cell._element.xml == snippet_seq("new-tbl")[1] + + def it_can_merge_itself_with_other_cells( + self, tc_: Mock, tc_2_: Mock, parent_: Mock, merged_tc_: Mock + ): + cell, other_cell = _Cell(tc_, parent_), _Cell(tc_2_, parent_) + tc_.merge.return_value = merged_tc_ + + merged_cell = cell.merge(other_cell) + + assert isinstance(merged_cell, _Cell) + tc_.merge.assert_called_once_with(other_cell._tc) + assert merged_cell._tc is merged_tc_ + assert merged_cell._parent is cell._parent + + # fixtures ------------------------------------------------------- @pytest.fixture - def merged_tc_(self, request): + def merged_tc_(self, request: FixtureRequest): return instance_mock(request, CT_Tc) @pytest.fixture - def parent_(self, request): + def parent_(self, request: FixtureRequest): return instance_mock(request, Table) @pytest.fixture - def tc_(self, request): + def tc_(self, request: FixtureRequest): return instance_mock(request, CT_Tc) @pytest.fixture - def tc_2_(self, request): + def tc_2_(self, request: FixtureRequest): return instance_mock(request, CT_Tc) class Describe_Column: - def it_provides_access_to_its_cells(self, cells_fixture): - column, column_idx, expected_cells = cells_fixture - cells = column.cells - column.table.column_cells.assert_called_once_with(column_idx) - assert cells == expected_cells - - def it_provides_access_to_the_table_it_belongs_to(self, table_fixture): - column, table_ = table_fixture - assert column.table is table_ - - def it_knows_its_width_in_EMU(self, width_get_fixture): - column, expected_width = width_get_fixture - assert column.width == expected_width + """Unit-test suite for `docx.table._Cell` objects.""" - def it_can_change_its_width(self, width_set_fixture): - column, value, expected_xml = width_set_fixture - column.width = value - assert column.width == value - assert column._gridCol.xml == expected_xml + def it_provides_access_to_its_cells(self, _index_prop_: Mock, table_prop_: Mock, table_: Mock): + table_prop_.return_value = table_ + _index_prop_.return_value = 4 + column = _Column(cast(CT_TblGridCol, element("w:gridCol{w:w=500}")), table_) + table_.column_cells.return_value = [3, 2, 1] - def it_knows_its_index_in_table_to_help(self, index_fixture): - column, expected_idx = index_fixture - assert column._index == expected_idx + cells = column.cells - # fixtures ------------------------------------------------------- + table_.column_cells.assert_called_once_with(4) + assert cells == (3, 2, 1) - @pytest.fixture - def cells_fixture(self, _index_, table_prop_, table_): - column = _Column(None, None) - _index_.return_value = column_idx = 4 - expected_cells = (3, 2, 1) - table_.column_cells.return_value = list(expected_cells) - return column, column_idx, expected_cells - - @pytest.fixture - def index_fixture(self): - tbl = element("w:tbl/w:tblGrid/(w:gridCol,w:gridCol,w:gridCol)") - gridCol, expected_idx = tbl.tblGrid[1], 1 - column = _Column(gridCol, None) - return column, expected_idx + def it_provides_access_to_the_table_it_belongs_to(self, table_: Mock): + table_.table = table_ + column = _Column(cast(CT_TblGridCol, element("w:gridCol{w:w=500}")), table_) - @pytest.fixture - def table_fixture(self, parent_, table_): - column = _Column(None, parent_) - parent_.table = table_ - return column, table_ + assert column.table is table_ - @pytest.fixture( - params=[ + @pytest.mark.parametrize( + ("gridCol_cxml", "expected_width"), + [ ("w:gridCol{w:w=4242}", 2693670), ("w:gridCol{w:w=1440}", 914400), ("w:gridCol{w:w=2.54cm}", 914400), ("w:gridCol{w:w=54mm}", 1944000), ("w:gridCol{w:w=12.5pt}", 158750), ("w:gridCol", None), - ] + ], ) - def width_get_fixture(self, request): - gridCol_cxml, expected_width = request.param - column = _Column(element(gridCol_cxml), None) - return column, expected_width - - @pytest.fixture( - params=[ - ("w:gridCol", 914400, "w:gridCol{w:w=1440}"), - ("w:gridCol{w:w=4242}", 457200, "w:gridCol{w:w=720}"), + def it_knows_its_width_in_EMU( + self, gridCol_cxml: str, expected_width: int | None, table_: Mock + ): + column = _Column(cast(CT_TblGridCol, element(gridCol_cxml)), table_) + assert column.width == expected_width + + @pytest.mark.parametrize( + ("gridCol_cxml", "new_value", "expected_cxml"), + [ + ("w:gridCol", Emu(914400), "w:gridCol{w:w=1440}"), + ("w:gridCol{w:w=4242}", Inches(0.5), "w:gridCol{w:w=720}"), ("w:gridCol{w:w=4242}", None, "w:gridCol"), ("w:gridCol", None, "w:gridCol"), - ] + ], ) - def width_set_fixture(self, request): - gridCol_cxml, new_value, expected_cxml = request.param - column = _Column(element(gridCol_cxml), None) - expected_xml = xml(expected_cxml) - return column, new_value, expected_xml + def it_can_change_its_width( + self, gridCol_cxml: str, new_value: Length | None, expected_cxml: str, table_: Mock + ): + column = _Column(cast(CT_TblGridCol, element(gridCol_cxml)), table_) + + column.width = new_value + + assert column.width == new_value + assert column._gridCol.xml == xml(expected_cxml) - # fixture components --------------------------------------------- + def it_knows_its_index_in_table_to_help(self, table_: Mock): + tbl = cast(CT_Tbl, element("w:tbl/w:tblGrid/(w:gridCol,w:gridCol,w:gridCol)")) + gridCol = tbl.tblGrid.gridCol_lst[1] + column = _Column(gridCol, table_) + assert column._index == 1 + + # fixtures ------------------------------------------------------- @pytest.fixture - def _index_(self, request): + def _index_prop_(self, request: FixtureRequest): return property_mock(request, _Column, "_index") @pytest.fixture - def parent_(self, request): + def parent_(self, request: FixtureRequest): return instance_mock(request, Table) @pytest.fixture - def table_(self, request): + def table_(self, request: FixtureRequest): return instance_mock(request, Table) @pytest.fixture - def table_prop_(self, request, table_): - return property_mock(request, _Column, "table", return_value=table_) + def table_prop_(self, request: FixtureRequest): + return property_mock(request, _Column, "table") class Describe_Columns: - def it_knows_how_many_columns_it_contains(self, columns_fixture): - columns, column_count = columns_fixture - assert len(columns) == column_count - - def it_can_interate_over_its__Column_instances(self, columns_fixture): - columns, column_count = columns_fixture - actual_count = 0 - for column in columns: - assert isinstance(column, _Column) - actual_count += 1 - assert actual_count == column_count - - def it_provides_indexed_access_to_columns(self, columns_fixture): - columns, column_count = columns_fixture - for idx in range(-column_count, column_count): - column = columns[idx] - assert isinstance(column, _Column) - - def it_raises_on_indexed_access_out_of_range(self, columns_fixture): - columns, column_count = columns_fixture - too_low = -1 - column_count - too_high = column_count - with pytest.raises(IndexError): - columns[too_low] - with pytest.raises(IndexError): - columns[too_high] + """Unit-test suite for `docx.table._Columns` objects.""" - def it_provides_access_to_the_table_it_belongs_to(self, table_fixture): - columns, table_ = table_fixture - assert columns.table is table_ + def it_has_sequence_behaviors(self, table_: Mock): + columns = _Columns(cast(CT_Tbl, element("w:tbl/w:tblGrid/(w:gridCol,w:gridCol)")), table_) - # fixtures ------------------------------------------------------- + # -- it supports len() -- + assert len(columns) == 2 + # -- it is iterable -- + assert len(tuple(c for c in columns)) == 2 + assert all(type(c) is _Column for c in columns) + # -- it is indexable -- + assert all(type(columns[i]) is _Column for i in range(2)) - @pytest.fixture - def columns_fixture(self): - column_count = 2 - tbl = _tbl_bldr(rows=2, cols=column_count).element - columns = _Columns(tbl, None) - return columns, column_count + def it_raises_on_indexed_access_out_of_range(self, table_: Mock): + columns = _Columns(cast(CT_Tbl, element("w:tbl/w:tblGrid/(w:gridCol,w:gridCol)")), table_) - @pytest.fixture - def table_fixture(self, table_): - columns = _Columns(None, table_) + with pytest.raises(IndexError): + columns[2] + with pytest.raises(IndexError): + columns[-3] + + def it_provides_access_to_the_table_it_belongs_to(self, table_: Mock): + columns = _Columns(cast(CT_Tbl, element("w:tbl")), table_) table_.table = table_ - return columns, table_ - # fixture components --------------------------------------------- + assert columns.table is table_ + + # fixtures ------------------------------------------------------- @pytest.fixture - def table_(self, request): + def table_(self, request: FixtureRequest): return instance_mock(request, Table) class Describe_Row: - def it_knows_its_height(self, height_get_fixture): - row, expected_height = height_get_fixture - assert row.height == expected_height - - def it_can_change_its_height(self, height_set_fixture): - row, value, expected_xml = height_set_fixture - row.height = value - assert row._tr.xml == expected_xml - - def it_knows_its_height_rule(self, height_rule_get_fixture): - row, expected_rule = height_rule_get_fixture - assert row.height_rule == expected_rule - - def it_can_change_its_height_rule(self, height_rule_set_fixture): - row, rule, expected_xml = height_rule_set_fixture - row.height_rule = rule - assert row._tr.xml == expected_xml - - def it_provides_access_to_its_cells(self, cells_fixture): - row, row_idx, expected_cells = cells_fixture - cells = row.cells - row.table.row_cells.assert_called_once_with(row_idx) - assert cells == expected_cells - - def it_provides_access_to_the_table_it_belongs_to(self, table_fixture): - row, table_ = table_fixture - assert row.table is table_ - - def it_knows_its_index_in_table_to_help(self, idx_fixture): - row, expected_idx = idx_fixture - assert row._index == expected_idx - - # fixtures ------------------------------------------------------- - - @pytest.fixture - def cells_fixture(self, _index_, table_prop_, table_): - row = _Row(None, None) - _index_.return_value = row_idx = 6 - expected_cells = (1, 2, 3) - table_.row_cells.return_value = list(expected_cells) - return row, row_idx, expected_cells - - @pytest.fixture( - params=[ + """Unit-test suite for `docx.table._Row` objects.""" + + @pytest.mark.parametrize( + ("tr_cxml", "expected_value"), + [ + ("w:tr", 0), + ("w:tr/w:trPr", 0), + ("w:tr/w:trPr/w:gridAfter{w:val=0}", 0), + ("w:tr/w:trPr/w:gridAfter{w:val=4}", 4), + ], + ) + def it_knows_its_grid_cols_after(self, tr_cxml: str, expected_value: int | None, parent_: Mock): + row = _Row(cast(CT_Row, element(tr_cxml)), parent_) + assert row.grid_cols_after == expected_value + + @pytest.mark.parametrize( + ("tr_cxml", "expected_value"), + [ + ("w:tr", 0), + ("w:tr/w:trPr", 0), + ("w:tr/w:trPr/w:gridBefore{w:val=0}", 0), + ("w:tr/w:trPr/w:gridBefore{w:val=3}", 3), + ], + ) + def it_knows_its_grid_cols_before( + self, tr_cxml: str, expected_value: int | None, parent_: Mock + ): + row = _Row(cast(CT_Row, element(tr_cxml)), parent_) + assert row.grid_cols_before == expected_value + + @pytest.mark.parametrize( + ("tr_cxml", "expected_value"), + [ ("w:tr", None), ("w:tr/w:trPr", None), ("w:tr/w:trPr/w:trHeight", None), ("w:tr/w:trPr/w:trHeight{w:val=0}", 0), ("w:tr/w:trPr/w:trHeight{w:val=1440}", 914400), - ] + ], ) - def height_get_fixture(self, request): - tr_cxml, expected_height = request.param - row = _Row(element(tr_cxml), None) - return row, expected_height + def it_knows_its_height(self, tr_cxml: str, expected_value: int | None, parent_: Mock): + row = _Row(cast(CT_Row, element(tr_cxml)), parent_) + assert row.height == expected_value - @pytest.fixture( - params=[ + @pytest.mark.parametrize( + ("tr_cxml", "new_value", "expected_cxml"), + [ ("w:tr", Inches(1), "w:tr/w:trPr/w:trHeight{w:val=1440}"), ("w:tr/w:trPr", Inches(1), "w:tr/w:trPr/w:trHeight{w:val=1440}"), ("w:tr/w:trPr/w:trHeight", Inches(1), "w:tr/w:trPr/w:trHeight{w:val=1440}"), @@ -806,16 +715,18 @@ def height_get_fixture(self, request): ("w:tr", None, "w:tr/w:trPr"), ("w:tr/w:trPr", None, "w:tr/w:trPr"), ("w:tr/w:trPr/w:trHeight", None, "w:tr/w:trPr/w:trHeight"), - ] + ], ) - def height_set_fixture(self, request): - tr_cxml, new_value, expected_cxml = request.param - row = _Row(element(tr_cxml), None) - expected_xml = xml(expected_cxml) - return row, new_value, expected_xml - - @pytest.fixture( - params=[ + def it_can_change_its_height( + self, tr_cxml: str, new_value: Length | None, expected_cxml: str, parent_: Mock + ): + row = _Row(cast(CT_Row, element(tr_cxml)), parent_) + row.height = new_value + assert row._tr.xml == xml(expected_cxml) + + @pytest.mark.parametrize( + ("tr_cxml", "expected_value"), + [ ("w:tr", None), ("w:tr/w:trPr", None), ("w:tr/w:trPr/w:trHeight{w:val=0, w:hRule=auto}", WD_ROW_HEIGHT.AUTO), @@ -827,15 +738,17 @@ def height_set_fixture(self, request): "w:tr/w:trPr/w:trHeight{w:val=2880, w:hRule=exact}", WD_ROW_HEIGHT.EXACTLY, ), - ] + ], ) - def height_rule_get_fixture(self, request): - tr_cxml, expected_rule = request.param - row = _Row(element(tr_cxml), None) - return row, expected_rule - - @pytest.fixture( - params=[ + def it_knows_its_height_rule( + self, tr_cxml: str, expected_value: WD_ROW_HEIGHT | None, parent_: Mock + ): + row = _Row(cast(CT_Row, element(tr_cxml)), parent_) + assert row.height_rule == expected_value + + @pytest.mark.parametrize( + ("tr_cxml", "new_value", "expected_cxml"), + [ ("w:tr", WD_ROW_HEIGHT.AUTO, "w:tr/w:trPr/w:trHeight{w:hRule=auto}"), ( "w:tr/w:trPr", @@ -860,143 +773,148 @@ def height_rule_get_fixture(self, request): ("w:tr", None, "w:tr/w:trPr"), ("w:tr/w:trPr", None, "w:tr/w:trPr"), ("w:tr/w:trPr/w:trHeight", None, "w:tr/w:trPr/w:trHeight"), - ] + ], + ) + def it_can_change_its_height_rule( + self, tr_cxml: str, new_value: WD_ROW_HEIGHT | None, expected_cxml: str, parent_: Mock + ): + row = _Row(cast(CT_Row, element(tr_cxml)), parent_) + row.height_rule = new_value + assert row._tr.xml == xml(expected_cxml) + + @pytest.mark.parametrize( + ("tbl_cxml", "row_idx", "expected_len"), + [ + # -- cell corresponds to single layout-grid cell -- + ("w:tbl/w:tr/w:tc/w:p", 0, 1), + # -- cell has a horizontal span -- + ("w:tbl/w:tr/w:tc/(w:tcPr/w:gridSpan{w:val=2},w:p)", 0, 2), + # -- cell is in latter row of vertical span -- + ( + "w:tbl/(w:tr/w:tc/(w:tcPr/w:vMerge{w:val=restart},w:p)," + "w:tr/w:tc/(w:tcPr/w:vMerge,w:p))", + 1, + 1, + ), + # -- cell both has horizontal span and is latter row of vertical span -- + ( + "w:tbl/(w:tr/w:tc/(w:tcPr/(w:gridSpan{w:val=2},w:vMerge{w:val=restart}),w:p)," + "w:tr/w:tc/(w:tcPr/(w:gridSpan{w:val=2},w:vMerge),w:p))", + 1, + 2, + ), + ], ) - def height_rule_set_fixture(self, request): - tr_cxml, new_rule, expected_cxml = request.param - row = _Row(element(tr_cxml), None) - expected_xml = xml(expected_cxml) - return row, new_rule, expected_xml + def it_provides_access_to_its_cells( + self, tbl_cxml: str, row_idx: int, expected_len: int, parent_: Mock + ): + tbl = cast(CT_Tbl, element(tbl_cxml)) + tr = tbl.tr_lst[row_idx] + table = Table(tbl, parent_) + row = _Row(tr, table) - @pytest.fixture - def idx_fixture(self): - tbl = element("w:tbl/(w:tr,w:tr,w:tr)") - tr, expected_idx = tbl[1], 1 - row = _Row(tr, None) - return row, expected_idx + cells = row.cells - @pytest.fixture - def table_fixture(self, parent_, table_): - row = _Row(None, parent_) + assert len(cells) == expected_len + assert all(type(c) is _Cell for c in cells) + + def it_provides_access_to_the_table_it_belongs_to(self, parent_: Mock, table_: Mock): parent_.table = table_ - return row, table_ + row = _Row(cast(CT_Row, element("w:tr")), parent_) + assert row.table is table_ - # fixture components --------------------------------------------- + def it_knows_its_index_in_table_to_help(self, parent_: Mock): + tbl = element("w:tbl/(w:tr,w:tr,w:tr)") + row = _Row(cast(CT_Row, tbl[1]), parent_) + assert row._index == 1 + + # fixtures ------------------------------------------------------- @pytest.fixture - def _index_(self, request): + def _index_prop_(self, request: FixtureRequest): return property_mock(request, _Row, "_index") @pytest.fixture - def parent_(self, request): + def parent_(self, request: FixtureRequest): return instance_mock(request, Table) @pytest.fixture - def table_(self, request): + def table_(self, request: FixtureRequest): return instance_mock(request, Table) @pytest.fixture - def table_prop_(self, request, table_): - return property_mock(request, _Row, "table", return_value=table_) + def table_prop_(self, request: FixtureRequest, table_: Mock): + return property_mock(request, _Row, "table") class Describe_Rows: - def it_knows_how_many_rows_it_contains(self, rows_fixture): - rows, row_count = rows_fixture - assert len(rows) == row_count - - def it_can_iterate_over_its__Row_instances(self, rows_fixture): - rows, row_count = rows_fixture - actual_count = 0 - for row in rows: - assert isinstance(row, _Row) - actual_count += 1 - assert actual_count == row_count + """Unit-test suite for `docx.table._Rows` objects.""" + + @pytest.mark.parametrize( + ("tbl_cxml", "expected_len"), + [ + ("w:tbl", 0), + ("w:tbl/w:tr", 1), + ("w:tbl/(w:tr,w:tr)", 2), + ("w:tbl/(w:tr,w:tr,w:tr)", 3), + ], + ) + def it_has_sequence_behaviors(self, tbl_cxml: str, expected_len: int, parent_: Mock): + tbl = cast(CT_Tbl, element(tbl_cxml)) + table = Table(tbl, parent_) + rows = _Rows(tbl, table) + + # -- it supports len() -- + assert len(rows) == expected_len + # -- it is iterable -- + assert len(tuple(r for r in rows)) == expected_len + assert all(type(r) is _Row for r in rows) + # -- it is indexable -- + assert all(type(rows[i]) is _Row for i in range(expected_len)) + + @pytest.mark.parametrize( + ("tbl_cxml", "out_of_range_idx"), + [ + ("w:tbl", 0), + ("w:tbl", 1), + ("w:tbl", -1), + ("w:tbl/w:tr", 1), + ("w:tbl/w:tr", -2), + ("w:tbl/(w:tr,w:tr,w:tr)", 3), + ("w:tbl/(w:tr,w:tr,w:tr)", -4), + ], + ) + def it_raises_on_indexed_access_out_of_range( + self, tbl_cxml: str, out_of_range_idx: int, parent_: Mock + ): + rows = _Rows(cast(CT_Tbl, element(tbl_cxml)), parent_) - def it_provides_indexed_access_to_rows(self, rows_fixture): - rows, row_count = rows_fixture - for idx in range(-row_count, row_count): - row = rows[idx] - assert isinstance(row, _Row) + with pytest.raises(IndexError, match="list index out of range"): + rows[out_of_range_idx] + + @pytest.mark.parametrize(("start", "end", "expected_len"), [(1, 3, 2), (0, -1, 2)]) + def it_provides_sliced_access_to_rows( + self, start: int, end: int, expected_len: int, parent_: Mock + ): + tbl = cast(CT_Tbl, element("w:tbl/(w:tr,w:tr,w:tr)")) + rows = _Rows(tbl, parent_) - def it_provides_sliced_access_to_rows(self, slice_fixture): - rows, start, end, expected_count = slice_fixture slice_of_rows = rows[start:end] - assert len(slice_of_rows) == expected_count - tr_lst = rows._tbl.tr_lst + + assert len(slice_of_rows) == expected_len for idx, row in enumerate(slice_of_rows): - assert tr_lst.index(row._tr) == start + idx + assert tbl.tr_lst.index(row._tr) == start + idx assert isinstance(row, _Row) - def it_raises_on_indexed_access_out_of_range(self, rows_fixture): - rows, row_count = rows_fixture - too_low = -1 - row_count - too_high = row_count + def it_provides_access_to_the_table_it_belongs_to(self, parent_: Mock): + tbl = cast(CT_Tbl, element("w:tbl")) + table = Table(tbl, parent_) + rows = _Rows(tbl, table) - with pytest.raises(IndexError, match="list index out of range"): - rows[too_low] - with pytest.raises(IndexError, match="list index out of range"): - rows[too_high] - - def it_provides_access_to_the_table_it_belongs_to(self, table_fixture): - rows, table_ = table_fixture - assert rows.table is table_ + assert rows.table is table # fixtures ------------------------------------------------------- @pytest.fixture - def rows_fixture(self): - row_count = 2 - tbl = _tbl_bldr(rows=row_count, cols=2).element - rows = _Rows(tbl, None) - return rows, row_count - - @pytest.fixture( - params=[ - (3, 1, 3, 2), - (3, 0, -1, 2), - ] - ) - def slice_fixture(self, request): - row_count, start, end, expected_count = request.param - tbl = _tbl_bldr(rows=row_count, cols=2).element - rows = _Rows(tbl, None) - return rows, start, end, expected_count - - @pytest.fixture - def table_fixture(self, table_): - rows = _Rows(None, table_) - table_.table = table_ - return rows, table_ - - # fixture components --------------------------------------------- - - @pytest.fixture - def table_(self, request): - return instance_mock(request, Table) - - -# fixtures ----------------------------------------------------------- - - -def _tbl_bldr(rows, cols): - tblGrid_bldr = a_tblGrid() - for i in range(cols): - tblGrid_bldr.with_child(a_gridCol()) - tbl_bldr = a_tbl().with_nsdecls().with_child(tblGrid_bldr) - for i in range(rows): - tr_bldr = _tr_bldr(cols) - tbl_bldr.with_child(tr_bldr) - return tbl_bldr - - -def _tc_bldr(): - return a_tc().with_child(a_p()) - - -def _tr_bldr(cols): - tr_bldr = a_tr() - for i in range(cols): - tc_bldr = _tc_bldr() - tr_bldr.with_child(tc_bldr) - return tr_bldr + def parent_(self, request: FixtureRequest): + return instance_mock(request, Document) diff --git a/tests/unitutil/cxml.py b/tests/unitutil/cxml.py index c7b7d172c..e76cabd74 100644 --- a/tests/unitutil/cxml.py +++ b/tests/unitutil/cxml.py @@ -89,7 +89,7 @@ def from_token(cls, token): Return an ``Element`` object constructed from a parser element token. """ tagname = token.tagname - attrs = [(name, value) for name, value in token.attr_list] + attrs = [tuple(a) for a in token.attr_list] text = token.text return cls(tagname, attrs, text) @@ -263,9 +263,7 @@ def grammar(): child_node_list << (open_paren + delimitedList(node) + close_paren | node) root_node = ( - element("element") - + Group(Optional(slash + child_node_list))("child_node_list") - + stringEnd + element("element") + Group(Optional(slash + child_node_list))("child_node_list") + stringEnd ).setParseAction(connect_root_node_children) return root_node diff --git a/tox.ini b/tox.ini index 1c4e3aea7..37acaa5fa 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py37, py38, py39, py310, py311 +envlist = py38, py39, py310, py311, py312 [testenv] deps = -rrequirements-test.txt