From 4f97207eb860bd9d619006866711f291d0bfa66b Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Wed, 5 Mar 2025 10:30:39 +1100 Subject: [PATCH] flexibe types --- .../test_docstring_parameters_consistency.py | 19 +++++- sklearn/utils/_testing.py | 61 +++++++++++++++++-- 2 files changed, 75 insertions(+), 5 deletions(-) diff --git a/sklearn/tests/test_docstring_parameters_consistency.py b/sklearn/tests/test_docstring_parameters_consistency.py index 73c7ca2655374..989980b579145 100644 --- a/sklearn/tests/test_docstring_parameters_consistency.py +++ b/sklearn/tests/test_docstring_parameters_consistency.py @@ -4,7 +4,13 @@ import pytest from sklearn import metrics -from sklearn.ensemble import StackingClassifier, StackingRegressor +from sklearn.ensemble import ( + BaggingClassifier, + BaggingRegressor, + IsolationForest, + StackingClassifier, + StackingRegressor, +) from sklearn.utils._testing import assert_docstring_consistency, skip_if_no_numpydoc CLASS_DOCSTRING_CONSISTENCY_CASES = [ @@ -18,6 +24,17 @@ "exclude_returns": None, "descr_regex_pattern": None, }, + { + "objects": [BaggingClassifier, BaggingRegressor, IsolationForest], + "include_params": ["max_samples"], + "exclude_params": None, + "include_attrs": False, + "exclude_attrs": None, + "include_returns": False, + "exclude_returns": None, + "descr_regex_pattern": r"The number of samples to draw from X to train each.*", + "flexible_types": {"max_samples": 2}, + }, ] FUNCTION_DOCSTRING_CONSISTENCY_CASES = [ diff --git a/sklearn/utils/_testing.py b/sklearn/utils/_testing.py index 5028818d0697f..b0db9c5f4fa88 100644 --- a/sklearn/utils/_testing.py +++ b/sklearn/utils/_testing.py @@ -688,7 +688,12 @@ def _get_diff_msg(docstrings_grouped): def _check_consistency_items( - items_docs, type_or_desc, section, n_objects, descr_regex_pattern="" + items_docs, + type_or_desc, + section, + n_objects, + descr_regex_pattern="", + flexible_types={}, ): """Helper to check docstring consistency of all `items_docs`. @@ -700,8 +705,9 @@ def _check_consistency_items( # If item not found in all objects, skip if sum([len(objs) for objs in docstrings_grouped.values()]) < n_objects: skipped.append(item_name) + continue # If regex provided, match to all descriptions - elif type_or_desc == "description" and descr_regex_pattern: + if type_or_desc == "description" and descr_regex_pattern: not_matched = [] for docstring, group in docstrings_grouped.items(): if not re.search(descr_regex_pattern, docstring): @@ -712,8 +718,42 @@ def _check_consistency_items( f" does not match 'descr_regex_pattern': {descr_regex_pattern} " ) raise AssertionError(msg) + continue + # If `flexible_types` provided, ignore default, and check only specified + # number of defaults + if type_or_desc == "type specification" and item_name in flexible_types: + n_types = flexible_types[item_name] + docstrings_regrouped = {} + type_sets = [] + for types, group in docstrings_grouped.items(): + # Split types, remove default + types_split = [ + t for t in types.split(", ") if not t.startswith("default") + ] + if n_types == 0: + new_types = ", ".join(types_split) + docstrings_regrouped.setdefault(new_types, []).extend(group) + else: + # Remove "or " from last type + if types_split[-1].startswith("or "): + types_split[-1] = types_split[-1][3:] + # Create list of type sets + type_sets.append(set(types_split)) + # If only removing default, update `docstrings_grouped` and check as usual + if docstrings_regrouped: + docstrings_grouped = docstrings_regrouped + # Check min number of same types, raising if required + else: + type_union = set().union(*types) + if n_common := len(type_union) < n_types: + raise AssertionError( + f"Expected {section[:-1]} '{item_name}' to have at least " + f"{n_types} in common, got {n_common}: ({type_union})" + ) + continue + # Otherwise, if more than one key, docstrings not consistent between objects - elif len(docstrings_grouped.keys()) > 1: + if len(docstrings_grouped.keys()) > 1: msg_diff = _get_diff_msg(docstrings_grouped) obj_groups = " and ".join( str(group) for group in docstrings_grouped.values() @@ -740,6 +780,7 @@ def assert_docstring_consistency( include_returns=False, exclude_returns=None, descr_regex_pattern=None, + flexible_types={}, ): r"""Check consistency between docstring parameters/attributes/returns of objects. @@ -788,6 +829,12 @@ def assert_docstring_consistency( parameters/attributes/returns. If None, will revert to default behavior of comparing descriptions between objects. + flexible_types : dict, default={} + Dictionary parameter/attribute/return name to integer indicating the number + of types that should be consistent between the objects. The default value + will also be ignored for that parameter/attribute/return. + Set value to 0 to ignore default value only. + Examples -------- >>> from sklearn.metrics import (accuracy_score, classification_report, @@ -851,7 +898,13 @@ def _create_args(include, exclude, arg_name, section_name): type_items[item_name][type_def].append(obj_name) desc_items[item_name][desc].append(obj_name) - _check_consistency_items(type_items, "type specification", section, n_objects) + _check_consistency_items( + type_items, + "type specification", + section, + n_objects, + flexible_types=flexible_types, + ) _check_consistency_items( desc_items, "description",