Skip to content

MNT Add flexibe_types to assert_docstring_consistency #30943

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion sklearn/tests/test_docstring_parameters_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
import pytest

from sklearn import metrics
from sklearn.ensemble import StackingClassifier, StackingRegressor
from sklearn.ensemble import (
BaggingClassifier,
BaggingRegressor,
IsolationForest,
StackingClassifier,
StackingRegressor,
)
from sklearn.utils._testing import assert_docstring_consistency, skip_if_no_numpydoc

CLASS_DOCSTRING_CONSISTENCY_CASES = [
Expand All @@ -18,6 +24,17 @@
"exclude_returns": None,
"descr_regex_pattern": None,
},
{
"objects": [BaggingClassifier, BaggingRegressor, IsolationForest],
"include_params": ["max_samples"],
"exclude_params": None,
"include_attrs": False,
"exclude_attrs": None,
"include_returns": False,
"exclude_returns": None,
"descr_regex_pattern": r"The number of samples to draw from X to train each.*",
"flexible_types": {"max_samples": 2},
},
]

FUNCTION_DOCSTRING_CONSISTENCY_CASES = [
Expand Down
61 changes: 57 additions & 4 deletions sklearn/utils/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,12 @@


def _check_consistency_items(
items_docs, type_or_desc, section, n_objects, descr_regex_pattern=""
items_docs,
type_or_desc,
section,
n_objects,
descr_regex_pattern="",
flexible_types={},
):
"""Helper to check docstring consistency of all `items_docs`.

Expand All @@ -700,8 +705,9 @@
# If item not found in all objects, skip
if sum([len(objs) for objs in docstrings_grouped.values()]) < n_objects:
skipped.append(item_name)
continue
# If regex provided, match to all descriptions
elif type_or_desc == "description" and descr_regex_pattern:
if type_or_desc == "description" and descr_regex_pattern:
not_matched = []
for docstring, group in docstrings_grouped.items():
if not re.search(descr_regex_pattern, docstring):
Expand All @@ -712,8 +718,42 @@
f" does not match 'descr_regex_pattern': {descr_regex_pattern} "
)
raise AssertionError(msg)
continue
# If `flexible_types` provided, ignore default, and check only specified
# number of defaults
if type_or_desc == "type specification" and item_name in flexible_types:
n_types = flexible_types[item_name]
docstrings_regrouped = {}
type_sets = []
for types, group in docstrings_grouped.items():
# Split types, remove default
types_split = [
t for t in types.split(", ") if not t.startswith("default")
]
if n_types == 0:
new_types = ", ".join(types_split)
docstrings_regrouped.setdefault(new_types, []).extend(group)

Check warning on line 735 in sklearn/utils/_testing.py

View check run for this annotation

Codecov / codecov/patch

sklearn/utils/_testing.py#L734-L735

Added lines #L734 - L735 were not covered by tests
else:
# Remove "or " from last type
if types_split[-1].startswith("or "):
types_split[-1] = types_split[-1][3:]

Check warning on line 739 in sklearn/utils/_testing.py

View check run for this annotation

Codecov / codecov/patch

sklearn/utils/_testing.py#L739

Added line #L739 was not covered by tests
# Create list of type sets
type_sets.append(set(types_split))
# If only removing default, update `docstrings_grouped` and check as usual
if docstrings_regrouped:
docstrings_grouped = docstrings_regrouped

Check warning on line 744 in sklearn/utils/_testing.py

View check run for this annotation

Codecov / codecov/patch

sklearn/utils/_testing.py#L744

Added line #L744 was not covered by tests
# Check min number of same types, raising if required
else:
type_union = set().union(*types)
if n_common := len(type_union) < n_types:
raise AssertionError(

Check warning on line 749 in sklearn/utils/_testing.py

View check run for this annotation

Codecov / codecov/patch

sklearn/utils/_testing.py#L749

Added line #L749 was not covered by tests
f"Expected {section[:-1]} '{item_name}' to have at least "
f"{n_types} in common, got {n_common}: ({type_union})"
)
continue

# Otherwise, if more than one key, docstrings not consistent between objects
elif len(docstrings_grouped.keys()) > 1:
if len(docstrings_grouped.keys()) > 1:
msg_diff = _get_diff_msg(docstrings_grouped)
obj_groups = " and ".join(
str(group) for group in docstrings_grouped.values()
Expand All @@ -740,6 +780,7 @@
include_returns=False,
exclude_returns=None,
descr_regex_pattern=None,
flexible_types={},
):
r"""Check consistency between docstring parameters/attributes/returns of objects.

Expand Down Expand Up @@ -788,6 +829,12 @@
parameters/attributes/returns. If None, will revert to default behavior
of comparing descriptions between objects.

flexible_types : dict, default={}
Dictionary parameter/attribute/return name to integer indicating the number
of types that should be consistent between the objects. The default value
will also be ignored for that parameter/attribute/return.
Set value to 0 to ignore default value only.

Examples
--------
>>> from sklearn.metrics import (accuracy_score, classification_report,
Expand Down Expand Up @@ -851,7 +898,13 @@
type_items[item_name][type_def].append(obj_name)
desc_items[item_name][desc].append(obj_name)

_check_consistency_items(type_items, "type specification", section, n_objects)
_check_consistency_items(
type_items,
"type specification",
section,
n_objects,
flexible_types=flexible_types,
)
_check_consistency_items(
desc_items,
"description",
Expand Down