Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats_new/v1.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,9 @@ Changelog
- |FIX| :func:`utils.multiclass.type_of_target` now properly handles sparse matrices.
:pr:`14862` by :user:`Léonard Binet <leonardbinet>`.

- |Fix| HTML representation no longer errors when an estimator class is a value in
`get_params`. :pr:`24512` by `Thomas Fan`_.

- |Feature| A new module exposes development tools to discover estimators (i.e.
:func:`utils.discovery.all_estimators`), displays (i.e.
:func:`utils.discovery.all_displays`) and functions (i.e.
Expand Down
5 changes: 3 additions & 2 deletions sklearn/utils/_estimator_html_repr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from contextlib import closing
from io import StringIO
from inspect import isclass
from string import Template
import html

Expand Down Expand Up @@ -121,11 +122,11 @@ def _get_visual_block(estimator):
return _VisualBlock("single", estimator, names="None", name_details="None")

# check if estimator looks like a meta estimator wraps estimators
if hasattr(estimator, "get_params"):
if hasattr(estimator, "get_params") and not isclass(estimator):
estimators = [
(key, est)
for key, est in estimator.get_params(deep=False).items()
if hasattr(est, "get_params") and hasattr(est, "fit")
if hasattr(est, "get_params") and hasattr(est, "fit") and not isclass(est)
]
if estimators:
return _VisualBlock(
Expand Down
11 changes: 11 additions & 0 deletions sklearn/utils/tests/test_estimator_html_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,3 +320,14 @@ def test_invalid_parameters_in_stacking():

html_output = estimator_html_repr(stacker)
assert html.escape(str(stacker)) in html_output


def test_estimator_get_params_return_cls():
"""Check HTML repr works where a value in get_params is a class."""

class MyEstimator:
def get_params(self, deep=False):
return {"inner_cls": LogisticRegression}

est = MyEstimator()
assert "MyEstimator" in estimator_html_repr(est)