diff --git a/maint_tools/test_docstrings.py b/maint_tools/test_docstrings.py index 1a8c8b3297110..2861303c796b4 100644 --- a/maint_tools/test_docstrings.py +++ b/maint_tools/test_docstrings.py @@ -34,15 +34,22 @@ def get_all_methods(): estimators = all_estimators() - for name, estimator in estimators: + for name, Estimator in estimators: if name.startswith("_"): # skip private classes continue - methods = [el for el in dir(estimator) if not el.startswith("_")] + methods = [] + for name in dir(Estimator): + if name.startswith("_"): + continue + method_obj = getattr(Estimator, name) + if (hasattr(method_obj, '__call__') + or isinstance(method_obj, property)): + methods.append(name) methods.append(None) for method in sorted(methods, key=lambda x: str(x)): - yield estimator, method + yield Estimator, method def filter_errors(errors, method): @@ -102,7 +109,16 @@ def repr_errors(res, estimator=None, method: Optional[str] = None) -> str: raise NotImplementedError if estimator is not None: - obj_signature = signature(getattr(estimator, method)) + obj = getattr(estimator, method) + try: + obj_signature = signature(obj) + except TypeError: + # In particular we can't parse the signature of properties + obj_signature = ( + "\nParsing of the method signature failed, " + "possibly because this is a property." + ) + obj_name = estimator.__name__ + "." + method else: obj_signature = "" @@ -110,7 +126,7 @@ def repr_errors(res, estimator=None, method: Optional[str] = None) -> str: msg = "\n\n" + "\n\n".join( [ - res["file"], + str(res["file"]), obj_name + str(obj_signature), res["docstring"], "# Errors", @@ -123,10 +139,10 @@ def repr_errors(res, estimator=None, method: Optional[str] = None) -> str: return msg -@pytest.mark.parametrize("estimator, method", get_all_methods()) -def test_docstring(estimator, method, request): - base_import_path = estimator.__module__ - import_path = [base_import_path, estimator.__name__] +@pytest.mark.parametrize("Estimator, method", get_all_methods()) +def test_docstring(Estimator, method, request): + base_import_path = Estimator.__module__ + import_path = [base_import_path, Estimator.__name__] if method is not None: import_path.append(method) @@ -144,7 +160,7 @@ def test_docstring(estimator, method, request): res["errors"] = list(filter_errors(res["errors"], method)) if res["errors"]: - msg = repr_errors(res, estimator, method) + msg = repr_errors(res, Estimator, method) raise ValueError(msg)