Skip to content

ENH Display fitted attributes in HTML representation #31442

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 38 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
aef3dd4
wip
DeaMariaLeon May 21, 2025
01692a5
found fitted attr and arrays shapes-values
DeaMariaLeon May 22, 2025
43c6c8a
wip
DeaMariaLeon May 22, 2025
af43abe
only works on single
DeaMariaLeon May 26, 2025
308a2b5
clean up
DeaMariaLeon May 26, 2025
549d062
merged
DeaMariaLeon May 26, 2025
c41ccb1
Use dict for attributes in base
DeaMariaLeon May 28, 2025
3cde64b
showing column transformer elements attributes - when parallel
DeaMariaLeon May 30, 2025
c10c2e7
wip
DeaMariaLeon Jun 2, 2025
ba67d5f
fixing columntransformer again
DeaMariaLeon Jun 2, 2025
ddb30c6
remove Fitted attributes table title when not fitted
DeaMariaLeon Jun 2, 2025
a864592
work on css
DeaMariaLeon Jun 2, 2025
43459f3
columnTransformer works with remainder - not with drop
DeaMariaLeon Jun 3, 2025
07a9363
wip
DeaMariaLeon Jun 3, 2025
ba7585a
fixing pipeline
DeaMariaLeon Jun 3, 2025
eb8a89b
tests
DeaMariaLeon Jun 4, 2025
a4b6e64
body-table title
DeaMariaLeon Jun 4, 2025
0672f9d
skip inspect when Pipeline has no steps
DeaMariaLeon Jun 4, 2025
e03daeb
fixing hover
DeaMariaLeon Jun 4, 2025
2e9207d
[doc build] trigger doc
DeaMariaLeon Jun 4, 2025
72a85ea
[doc build] remove transformers_ from fitted attr
DeaMariaLeon Jun 4, 2025
4ed0d08
modified test and transformers_ back
DeaMariaLeon Jun 5, 2025
872f461
adding tests
DeaMariaLeon Jun 5, 2025
cf90e3c
Merge remote-tracking branch 'upstream/main' into fitted-att
DeaMariaLeon Jun 5, 2025
5541d51
[doc build] more tests
DeaMariaLeon Jun 5, 2025
62d4f89
[doc build] trigger doc
DeaMariaLeon Jun 5, 2025
9844ec7
[doc build] format attr row with class default
DeaMariaLeon Jun 6, 2025
20c5920
_column_transformer improved
DeaMariaLeon Jun 10, 2025
77f733e
moved AttrDict to params
DeaMariaLeon Jun 10, 2025
ee4e079
fix value__name__ for arrays
DeaMariaLeon Jun 10, 2025
2496632
Modified attr when None or Sequence
DeaMariaLeon Jun 11, 2025
234e383
changed test_get_fitted_attr_html
DeaMariaLeon Jun 11, 2025
3e765ae
[doc build] trigger doc
DeaMariaLeon Jun 11, 2025
a661db1
[doc build] triggering doc again - maybe
DeaMariaLeon Jun 11, 2025
7ce4499
Merge remote-tracking branch 'upstream/main' into fitted-att
DeaMariaLeon Jun 13, 2025
963f889
[doc build] forgot to trigger doc again
DeaMariaLeon Jun 13, 2025
1c609e3
[doc build] after merge conflict
DeaMariaLeon Jun 19, 2025
e4f3400
removed unused param
DeaMariaLeon Jul 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import re
import warnings
from collections import defaultdict
from collections.abc import Sequence

import numpy as np

Expand All @@ -21,7 +22,7 @@
from .utils._param_validation import validate_parameter_constraints
from .utils._repr_html.base import ReprHTMLMixin, _HTMLDocumentationLinkMixin
from .utils._repr_html.estimator import estimator_html_repr
from .utils._repr_html.params import ParamsDict
from .utils._repr_html.params import AttrsDict, ParamsDict
from .utils._set_output import _SetOutputMixin
from .utils._tags import (
ClassifierTags,
Expand All @@ -35,6 +36,7 @@
from .utils.validation import (
_check_feature_names_in,
_generate_get_feature_names_out,
_is_arraylike_not_scalar,
_is_fitted,
check_array,
check_is_fitted,
Expand Down Expand Up @@ -199,6 +201,49 @@

_html_repr = estimator_html_repr

def _get_fitted_attr_html(self):
"""Get fitted attributes of the estimator."""
# fetch the constructor or the original constructor before
# deprecation wrapping if any
init = getattr(self.__init__, "deprecated_original", self)
if init is object.__init__:
# No explicit constructor to introspect
return []

Check warning on line 211 in sklearn/base.py

View check run for this annotation

Codecov / codecov/patch

sklearn/base.py#L211

Added line #L211 was not covered by tests

# It raises when inspecting an empty Pipeline. So we need
# to check that a Pipeline is not empty.
if hasattr(init, "steps") and not len(init.steps):
return AttrsDict("")

attributes = inspect.getmembers(init)

fitted_attributes = {
name: value
for name, value in attributes
if not name.startswith("_") and name.endswith("_")
}

cleaned_fitted_attr = {
name: "None"
if value is None
else f"{type(value).__name__} of lenght {len(value)}"
for name, value in fitted_attributes.items()
if value is None or isinstance(value, Sequence)
}

arrays_attr = {
name: f"{type(value).__name__} of shape {value.shape}, dtype={value.dtype}"
for name, value in fitted_attributes.items()
if _is_arraylike_not_scalar(value) and hasattr(value, "shape")
}

fitted_attributes = {
key: type(value).__name__ for key, value in fitted_attributes.items()
}
out = fitted_attributes | cleaned_fitted_attr | arrays_attr

return AttrsDict(out)

@classmethod
def _get_param_names(cls):
"""Get parameter names for the estimator"""
Expand Down
39 changes: 23 additions & 16 deletions sklearn/compose/_column_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,25 +1220,32 @@
return np.hstack(Xs)

def _sk_visual_block_(self):
if isinstance(self.remainder, str) and self.remainder == "drop":
transformers = self.transformers
elif hasattr(self, "_remainder"):
remainder_columns = self._remainder[2]
if (
hasattr(self, "feature_names_in_")
and remainder_columns
and not all(isinstance(col, str) for col in remainder_columns)
):
remainder_columns = self.feature_names_in_[remainder_columns].tolist()
transformers = chain(
self.transformers, [("remainder", self.remainder, remainder_columns)]
transformers = getattr(self, "transformers_", self.transformers)
filtered_transformers = [tr for tr in transformers if "remainder" not in tr]

if not (isinstance(self.remainder, str) and self.remainder == "drop"):
# We can find the columns of remainder only when it's fitted
# because only when it's fitted it has a remainder
if hasattr(self, "_remainder"):
remainder_columns = self._remainder[2]
if (
hasattr(self, "feature_names_in_")
and remainder_columns
and not all(isinstance(col, str) for col in remainder_columns)
):
remainder_columns = self.feature_names_in_[

Check warning on line 1236 in sklearn/compose/_column_transformer.py

View check run for this annotation

Codecov / codecov/patch

sklearn/compose/_column_transformer.py#L1236

Added line #L1236 was not covered by tests
remainder_columns
].tolist()
else:
remainder_columns = ""
filtered_transformers = chain(
filtered_transformers,
[("remainder", self.remainder, remainder_columns)],
)
else:
transformers = chain(self.transformers, [("remainder", self.remainder, "")])
names, filtered_transformers, name_details = zip(*filtered_transformers)

names, transformers, name_details = zip(*transformers)
return _VisualBlock(
"parallel", transformers, names=names, name_details=name_details
"parallel", filtered_transformers, names=names, name_details=name_details
)

def __getitem__(self, key):
Expand Down
6 changes: 4 additions & 2 deletions sklearn/compose/tests/test_column_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1565,7 +1565,8 @@ def test_sk_visual_block_remainder_fitted_pandas(remainder):
visual_block = ct._sk_visual_block_()
assert visual_block.names == ("ohe", "remainder")
assert visual_block.name_details == (["col1", "col2"], ["col3", "col4"])
assert visual_block.estimators == (ohe, remainder)
assert isinstance(visual_block.estimators[0], OneHotEncoder)
assert visual_block.estimators[1] == remainder


@pytest.mark.parametrize("remainder", ["passthrough", StandardScaler()])
Expand All @@ -1580,7 +1581,8 @@ def test_sk_visual_block_remainder_fitted_numpy(remainder):
visual_block = ct._sk_visual_block_()
assert visual_block.names == ("scale", "remainder")
assert visual_block.name_details == ([0, 2], [1])
assert visual_block.estimators == (scaler, remainder)
assert isinstance(visual_block.estimators[0], StandardScaler)
assert visual_block.estimators[1] == remainder


@pytest.mark.parametrize("explicit_colname", ["first", "second", 0, 1])
Expand Down
17 changes: 17 additions & 0 deletions sklearn/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,6 +1003,23 @@ def test_get_params_html():
assert est._get_params_html().non_default == ("empty",)


def test_get_fitted_attr_html():
"""Check the behaviour of the `_get_fitted_attr_html` method."""
est = MyEstimator()
est.n_features_in_ = 2
est._not_a_fitted_attr = "x"
est._not_a_fitted_attr_either_ = "y"

assert est._get_fitted_attr_html() == {"n_features_in_": "int"}

X = np.array([[-1, -1], [-2, -1], [-3, -2]])
pca = PCA().fit(X)

fitted_attr_html = pca._get_fitted_attr_html()
assert len(fitted_attr_html) == 9
assert fitted_attr_html["components_"] == "ndarray of shape (2, 2), dtype=float64"


def make_estimator_with_param(default_value):
class DynamicEstimator(BaseEstimator):
def __init__(self, param=default_value):
Expand Down
2 changes: 1 addition & 1 deletion sklearn/utils/_repr_html/estimator.js
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ function copyToClipboard(text, element) {
return false;
}

document.querySelectorAll('.fa-regular.fa-copy').forEach(function(element) {
document.querySelectorAll('.copy-paste-icon').forEach(function(element) {
const toggleableContent = element.closest('.sk-toggleable__content');
const paramPrefix = toggleableContent ? toggleableContent.dataset.paramPrefix : '';
const paramName = element.parentElement.nextElementSibling.textContent.trim();
Expand Down
24 changes: 23 additions & 1 deletion sklearn/utils/_repr_html/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def _sk_visual_block_(self):
def _write_label_html(
out,
params,
attrs,
name,
name_details,
name_caption=None,
Expand All @@ -130,6 +131,9 @@ def _write_label_html(
If estimator has `get_params` method, this is the HTML representation
of the estimator's parameters and their values. When the estimator
does not have `get_params`, it is an empty string.
attrs: str
If estimator is fitted, this is the HTML representation of its
the fitted attributes.
name : str
The label for the estimator. It corresponds either to the estimator class name
for a simple estimator or in the case of a `Pipeline` and `ColumnTransformer`,
Expand Down Expand Up @@ -210,7 +214,7 @@ def _write_label_html(
)

if params:
fmt_str = "".join([fmt_str, f"{params}</div>"])
fmt_str = "".join([fmt_str, f"{params}{attrs}</div>"])
elif name_details and ("Pipeline" not in name):
fmt_str = "".join([fmt_str, f"<pre>{name_details}</pre></div>"])

Expand Down Expand Up @@ -306,6 +310,7 @@ def _write_estimator_html(
The prefix to prepend to parameter names for nested estimators.
For example, in a pipeline this might be "pipeline__stepname__".
"""

if first_call:
est_block = _get_visual_block(estimator)
else:
Expand All @@ -327,12 +332,21 @@ def _write_estimator_html(
estimator, "_get_params_html"
):
params = estimator._get_params_html(deep=False)._repr_html_inner()

else:
params = ""
if (
hasattr(estimator, "_get_fitted_attr_html")
and is_fitted_css_class == "fitted"
):
attrs = estimator._get_fitted_attr_html()._repr_html_inner()
else:
attrs = ""

_write_label_html(
out,
params,
attrs,
estimator_label,
estimator_label_details,
doc_link=doc_link,
Expand Down Expand Up @@ -386,10 +400,18 @@ def _write_estimator_html(
params = estimator._get_params_html()._repr_html_inner()
else:
params = ""
if (
hasattr(estimator, "_get_fitted_attr_html")
and is_fitted_css_class == "fitted"
):
attrs = estimator._get_fitted_attr_html()._repr_html_inner()
else:
attrs = ""

_write_label_html(
out,
params,
attrs,
est_block.names,
est_block.name_details,
est_block.name_caption,
Expand Down
8 changes: 4 additions & 4 deletions sklearn/utils/_repr_html/params.css
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,20 @@
padding-bottom: 0.3rem;
}

.estimator-table .parameters-table {
.estimator-table .body-table {
margin-left: auto !important;
margin-right: auto !important;
}

.estimator-table .parameters-table tr:nth-child(odd) {
.estimator-table .body-table tr:nth-child(odd) {
background-color: #fff;
}

.estimator-table .parameters-table tr:nth-child(even) {
.estimator-table .body-table tr:nth-child(even) {
background-color: #f6f6f6;
}

.estimator-table .parameters-table tr:hover {
.estimator-table .body-table tr:hover {
background-color: #e0e0e0;
}

Expand Down
53 changes: 52 additions & 1 deletion sklearn/utils/_repr_html/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _params_html_repr(params):
<div class="estimator-table">
<details>
<summary>Parameters</summary>
<table class="parameters-table">
<table class="body-table">
<tbody>
{rows}
</tbody>
Expand Down Expand Up @@ -81,3 +81,54 @@ class ParamsDict(ReprHTMLMixin, UserDict):
def __init__(self, params=None, non_default=tuple()):
super().__init__(params or {})
self.non_default = non_default


def _fitted_attr_html_repr(fitted_attributes):
"""Generate HTML representation of estimator fitted attributes.

Creates an HTML table with fitted attribute names and values
wrapped in a collapsible details element. When attributes are arrays,
shape is shown.
"""

HTML_TEMPLATE = """
<div class="estimator-table">
<details>
<summary>Fitted attributes</summary>
<table class="body-table">
<tbody>
{rows}
</tbody>
</table>
</details>
</div>
"""
ROW_TEMPLATE = """
<tr class="default">
<td>{name}&nbsp;</td>
<td>{value}</td>
</tr>
"""

rows = [
ROW_TEMPLATE.format(name=name, value=value)
for name, value in fitted_attributes.items()
]

return HTML_TEMPLATE.format(rows="\n".join(rows))


class AttrsDict(ReprHTMLMixin, dict):
"""Dictionary-like class to store and provide an HTML representation.

It builds an HTML structure to be used with Jupyter notebooks or similar
environments.

Parameters
----------
fitted_attributes : dict, default=None
Dictionary of fitted attributes and their values. When this is
an array, it includes its size.
"""

_html_repr = _fitted_attr_html_repr
45 changes: 45 additions & 0 deletions sklearn/utils/_repr_html/tests/test_attributes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pytest

from sklearn import config_context
from sklearn.utils._repr_html.params import AttrsDict, _fitted_attr_html_repr


def test_fitted_attrs_dict_content():
fitted_attrs = AttrsDict({"a": int, "b": bool})
assert fitted_attrs["a"] == int
assert fitted_attrs["b"] == bool


def test_fitted_attrs_dict_repr_html_():
fitted_attrs = AttrsDict({"a": int, "b": bool})
out = fitted_attrs._repr_html_()
assert "<summary>Fitted attributes</summary>" in out
assert "<td><class 'int'></td>" in out

with config_context(display="text"):
msg = "_repr_html_ is only defined when"
with pytest.raises(AttributeError, match=msg):
fitted_attrs._repr_html_()


def test_fitted_attrs_dict_repr_mimebundle():
fitted_attrs = AttrsDict({"a": int, "b": float})
out = fitted_attrs._repr_mimebundle_()

assert "text/plain" in out
assert "text/html" in out
assert "<summary>Fitted attributes</summary>" in out["text/html"]
assert out["text/plain"] == "{'a': <class 'int'>, 'b': <class 'float'>}"

with config_context(display="text"):
out = fitted_attrs._repr_mimebundle_()
assert "text/plain" in out
assert "text/html" not in out


def test_fitted_attr_html_repr():
out = _fitted_attr_html_repr({"a": int, "b": float})
assert "<summary>Fitted attributes</summary>" in out
assert '<table class="body-table">' in out
assert '<tr class="default">' in out
assert "<class 'float'></td>" in out
Loading