diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 3d9924638b69b..2489eaf55bac7 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -1569,6 +1569,7 @@ Plotting utils.deprecated utils.estimator_checks.check_estimator utils.estimator_checks.parametrize_with_checks + utils.estimator_html_repr utils.extmath.safe_sparse_dot utils.extmath.randomized_range_finder utils.extmath.randomized_svd diff --git a/doc/modules/compose.rst b/doc/modules/compose.rst index cd29b14b1f081..e7dac0dadc630 100644 --- a/doc/modules/compose.rst +++ b/doc/modules/compose.rst @@ -528,6 +528,31 @@ above example would be:: ('countvectorizer', CountVectorizer(), 'title')]) +.. _visualizing_composite_estimators: + +Visualizing Composite Estimators +================================ + +Estimators can be displayed with a HTML representation when shown in a +jupyter notebook. This can be useful to diagnose or visualize a Pipeline with +many estimators. This visualization is activated by setting the +`display` option in :func:`sklearn.set_config`:: + + >>> from sklearn import set_config + >>> set_config(display='diagram') # doctest: +SKIP + >>> # diplays HTML representation in a jupyter context + >>> column_trans # doctest: +SKIP + +An example of the HTML output can be seen in the +**HTML representation of Pipeline** section of +:ref:`sphx_glr_auto_examples_compose_plot_column_transformer_mixed_types.py`. +As an alternative, the HTML can be written to a file using +:func:`~sklearn.utils.estimator_html_repr`:: + + >>> from sklearn.utils import estimator_html_repr + >>> with open('my_estimator.html', 'w') as f: # doctest: +SKIP + ... f.write(estimator_html_repr(clf)) + .. topic:: Examples: * :ref:`sphx_glr_auto_examples_compose_plot_column_transformer.py` diff --git a/doc/whats_new/v0.23.rst b/doc/whats_new/v0.23.rst index 0e149ed03a9fa..1ac63ca473faf 100644 --- a/doc/whats_new/v0.23.rst +++ b/doc/whats_new/v0.23.rst @@ -567,6 +567,9 @@ Changelog :mod:`sklearn.utils` .................... +- |Feature| Adds :func:`utils.estimator_html_repr` for returning a + HTML representation of an estimator. :pr:`14180` by `Thomas Fan`_. + - |Enhancement| improve error message in :func:`utils.validation.column_or_1d`. :pr:`15926` by :user:`Loïc Estève `. @@ -605,6 +608,11 @@ Changelog Miscellaneous ............. +- |MajorFeature| Adds a HTML representation of estimators to be shown in + a jupyter notebook or lab. This visualization is acitivated by setting the + `display` option in :func:`sklearn.set_config`. :pr:`14180` by + `Thomas Fan`_. + - |Enhancement| ``scikit-learn`` now works with ``mypy`` without errors. :pr:`16726` by `Roman Yurchak`_. diff --git a/examples/compose/plot_column_transformer_mixed_types.py b/examples/compose/plot_column_transformer_mixed_types.py index 1c79c4bb1d607..24fc4d69e35d0 100644 --- a/examples/compose/plot_column_transformer_mixed_types.py +++ b/examples/compose/plot_column_transformer_mixed_types.py @@ -87,6 +87,15 @@ clf.fit(X_train, y_train) print("model score: %.3f" % clf.score(X_test, y_test)) +############################################################################## +# HTML representation of ``Pipeline`` +############################################################################### +# When the ``Pipeline`` is printed out in a jupyter notebook an HTML +# representation of the estimator is displayed as follows: +from sklearn import set_config +set_config(display='diagram') +clf + ############################################################################### # Use ``ColumnTransformer`` by selecting column by data types ############################################################################### diff --git a/sklearn/_config.py b/sklearn/_config.py index 44eaae1d59012..f183203e13228 100644 --- a/sklearn/_config.py +++ b/sklearn/_config.py @@ -7,6 +7,7 @@ 'assume_finite': bool(os.environ.get('SKLEARN_ASSUME_FINITE', False)), 'working_memory': int(os.environ.get('SKLEARN_WORKING_MEMORY', 1024)), 'print_changed_only': True, + 'display': 'text', } @@ -27,7 +28,7 @@ def get_config(): def set_config(assume_finite=None, working_memory=None, - print_changed_only=None): + print_changed_only=None, display=None): """Set global scikit-learn configuration .. versionadded:: 0.19 @@ -59,6 +60,13 @@ def set_config(assume_finite=None, working_memory=None, .. versionadded:: 0.21 + display : {'text', 'diagram'}, optional + If 'diagram', estimators will be displayed as text in a jupyter lab + of notebook context. If 'text', estimators will be displayed as + text. Default is 'text'. + + .. versionadded:: 0.23 + See Also -------- config_context: Context manager for global scikit-learn configuration @@ -70,6 +78,8 @@ def set_config(assume_finite=None, working_memory=None, _global_config['working_memory'] = working_memory if print_changed_only is not None: _global_config['print_changed_only'] = print_changed_only + if display is not None: + _global_config['display'] = display @contextmanager @@ -100,6 +110,13 @@ def config_context(**new_config): .. versionchanged:: 0.23 Default changed from False to True. + display : {'text', 'diagram'}, optional + If 'diagram', estimators will be displayed as text in a jupyter lab + of notebook context. If 'text', estimators will be displayed as + text. Default is 'text'. + + .. versionadded:: 0.23 + Notes ----- All settings, not just those presently modified, will be returned to diff --git a/sklearn/base.py b/sklearn/base.py index bf5ee370aa8f1..666574b491594 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -17,9 +17,11 @@ import numpy as np from . import __version__ +from ._config import get_config from .utils import _IS_32BIT from .utils.validation import check_X_y from .utils.validation import check_array +from .utils._estimator_html_repr import estimator_html_repr from .utils.validation import _deprecate_positional_args _DEFAULT_TAGS = { @@ -435,6 +437,17 @@ def _validate_data(self, X, y=None, reset=True, return out + def _repr_html_(self): + """HTML representation of estimator""" + return estimator_html_repr(self) + + def _repr_mimebundle_(self, **kwargs): + """Mime bundle used by jupyter kernels to display estimator""" + output = {"text/plain": repr(self)} + if get_config()["display"] == 'diagram': + output["text/html"] = estimator_html_repr(self) + return output + class ClassifierMixin: """Mixin class for all classifiers in scikit-learn.""" diff --git a/sklearn/compose/_column_transformer.py b/sklearn/compose/_column_transformer.py index 2ef8876b0c4e7..f148633021a97 100644 --- a/sklearn/compose/_column_transformer.py +++ b/sklearn/compose/_column_transformer.py @@ -15,6 +15,7 @@ from joblib import Parallel, delayed from ..base import clone, TransformerMixin +from ..utils._estimator_html_repr import _VisualBlock from ..pipeline import _fit_transform_one, _transform_one, _name_estimators from ..preprocessing import FunctionTransformer from ..utils import Bunch @@ -637,6 +638,11 @@ def _hstack(self, Xs): Xs = [f.toarray() if sparse.issparse(f) else f for f in Xs] return np.hstack(Xs) + def _sk_visual_block_(self): + names, transformers, name_details = zip(*self.transformers) + return _VisualBlock('parallel', transformers, + names=names, name_details=name_details) + def _check_X(X): """Use check_array only on lists and other non-array-likes / sparse""" diff --git a/sklearn/ensemble/_stacking.py b/sklearn/ensemble/_stacking.py index a75e9236f1612..73aa55c0575a7 100644 --- a/sklearn/ensemble/_stacking.py +++ b/sklearn/ensemble/_stacking.py @@ -13,6 +13,7 @@ from ..base import clone from ..base import ClassifierMixin, RegressorMixin, TransformerMixin from ..base import is_classifier, is_regressor +from ..utils._estimator_html_repr import _VisualBlock from ._base import _fit_single_estimator from ._base import _BaseHeterogeneousEnsemble @@ -233,6 +234,14 @@ def predict(self, X, **predict_params): self.transform(X), **predict_params ) + def _sk_visual_block_(self, final_estimator): + names, estimators = zip(*self.estimators) + parallel = _VisualBlock('parallel', estimators, names=names, + dash_wrapped=False) + serial = _VisualBlock('serial', (parallel, final_estimator), + dash_wrapped=False) + return _VisualBlock('serial', [serial]) + class StackingClassifier(ClassifierMixin, _BaseStacking): """Stack of estimators with a final classifier. @@ -496,6 +505,15 @@ def transform(self, X): """ return self._transform(X) + def _sk_visual_block_(self): + # If final_estimator's default changes then this should be + # updated. + if self.final_estimator is None: + final_estimator = LogisticRegression() + else: + final_estimator = self.final_estimator + return super()._sk_visual_block_(final_estimator) + class StackingRegressor(RegressorMixin, _BaseStacking): """Stack of estimators with a final regressor. @@ -665,3 +683,12 @@ def transform(self, X): Prediction outputs for each estimator. """ return self._transform(X) + + def _sk_visual_block_(self): + # If final_estimator's default changes then this should be + # updated. + if self.final_estimator is None: + final_estimator = RidgeCV() + else: + final_estimator = self.final_estimator + return super()._sk_visual_block_(final_estimator) diff --git a/sklearn/ensemble/_voting.py b/sklearn/ensemble/_voting.py index 0ac42407f5998..6a2b5736d8b4e 100644 --- a/sklearn/ensemble/_voting.py +++ b/sklearn/ensemble/_voting.py @@ -32,6 +32,7 @@ from ..utils.validation import column_or_1d from ..utils.validation import _deprecate_positional_args from ..exceptions import NotFittedError +from ..utils._estimator_html_repr import _VisualBlock class _BaseVoting(TransformerMixin, _BaseHeterogeneousEnsemble): @@ -104,6 +105,10 @@ def n_features_in_(self): return self.estimators_[0].n_features_in_ + def _sk_visual_block_(self): + names, estimators = zip(*self.estimators) + return _VisualBlock('parallel', estimators, names=names) + class VotingClassifier(ClassifierMixin, _BaseVoting): """Soft Voting/Majority Rule classifier for unfitted estimators. diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 8e2a539786557..6f02cb565e15c 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -18,6 +18,7 @@ from joblib import Parallel, delayed from .base import clone, TransformerMixin +from .utils._estimator_html_repr import _VisualBlock from .utils.metaestimators import if_delegate_has_method from .utils import Bunch, _print_elapsed_time from .utils.validation import check_memory @@ -623,6 +624,21 @@ def n_features_in_(self): # delegate to first step (which will call _check_is_fitted) return self.steps[0][1].n_features_in_ + def _sk_visual_block_(self): + _, estimators = zip(*self.steps) + + def _get_name(name, est): + if est is None or est == 'passthrough': + return f'{name}: passthrough' + # Is an estimator + return f'{name}: {est.__class__.__name__}' + names = [_get_name(name, est) for name, est in self.steps] + name_details = [str(est) for est in estimators] + return _VisualBlock('serial', estimators, + names=names, + name_details=name_details, + dash_wrapped=False) + def _name_estimators(estimators): """Generate names for estimators.""" @@ -1004,6 +1020,10 @@ def n_features_in_(self): # X is passed to all transformers so we just delegate to the first one return self.transformer_list[0][1].n_features_in_ + def _sk_visual_block_(self): + names, transformers = zip(*self.transformer_list) + return _VisualBlock('parallel', transformers, names=names) + def make_union(*transformers, **kwargs): """ diff --git a/sklearn/tests/test_base.py b/sklearn/tests/test_base.py index 52f2e60b4af70..e20fa440d1933 100644 --- a/sklearn/tests/test_base.py +++ b/sklearn/tests/test_base.py @@ -23,6 +23,7 @@ from sklearn.base import TransformerMixin from sklearn.utils._mocking import MockDataFrame +from sklearn import config_context import pickle @@ -511,3 +512,16 @@ def fit(self, X, y=None): params = est.get_params() assert params['param'] is None + + +def test_repr_mimebundle_(): + # Checks the display configuration flag controls the json output + tree = DecisionTreeClassifier() + output = tree._repr_mimebundle_() + assert "text/plain" in output + assert "text/html" not in output + + with config_context(display='diagram'): + output = tree._repr_mimebundle_() + assert "text/plain" in output + assert "text/html" in output diff --git a/sklearn/tests/test_config.py b/sklearn/tests/test_config.py index ae13c61838694..eec349861258c 100644 --- a/sklearn/tests/test_config.py +++ b/sklearn/tests/test_config.py @@ -4,7 +4,8 @@ def test_config_context(): assert get_config() == {'assume_finite': False, 'working_memory': 1024, - 'print_changed_only': True} + 'print_changed_only': True, + 'display': 'text'} # Not using as a context manager affects nothing config_context(assume_finite=True) @@ -12,7 +13,8 @@ def test_config_context(): with config_context(assume_finite=True): assert get_config() == {'assume_finite': True, 'working_memory': 1024, - 'print_changed_only': True} + 'print_changed_only': True, + 'display': 'text'} assert get_config()['assume_finite'] is False with config_context(assume_finite=True): @@ -37,7 +39,8 @@ def test_config_context(): assert get_config()['assume_finite'] is True assert get_config() == {'assume_finite': False, 'working_memory': 1024, - 'print_changed_only': True} + 'print_changed_only': True, + 'display': 'text'} # No positional arguments assert_raises(TypeError, config_context, True) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index afde7614070fd..f814ea11c12c1 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -25,6 +25,7 @@ from ..exceptions import DataConversionWarning from .deprecation import deprecated from .fixes import np_version +from ._estimator_html_repr import estimator_html_repr from .validation import (as_float_array, assert_all_finite, check_random_state, column_or_1d, check_array, @@ -52,7 +53,7 @@ "check_symmetric", "indices_to_mask", "deprecated", "parallel_backend", "register_parallel_backend", "resample", "shuffle", "check_matplotlib_support", "all_estimators", - "DataConversionWarning" + "DataConversionWarning", "estimator_html_repr" ] IS_PYPY = platform.python_implementation() == 'PyPy' diff --git a/sklearn/utils/_estimator_html_repr.py b/sklearn/utils/_estimator_html_repr.py new file mode 100644 index 0000000000000..9b2e45790fd2b --- /dev/null +++ b/sklearn/utils/_estimator_html_repr.py @@ -0,0 +1,311 @@ +from contextlib import closing +from contextlib import suppress +from io import StringIO +import uuid +import html + +from sklearn import config_context + + +class _VisualBlock: + """HTML Representation of Estimator + + Parameters + ---------- + kind : {'serial', 'parallel', 'single'} + kind of HTML block + + estimators : list of estimators or `_VisualBlock`s or a single estimator + If kind != 'single', then `estimators` is a list of + estimators. + If kind == 'single', then `estimators` is a single estimator. + + names : list of str + If kind != 'single', then `names` corresponds to estimators. + If kind == 'single', then `names` is a single string corresponding to + the single estimator. + + name_details : list of str, str, or None, default=None + If kind != 'single', then `name_details` corresponds to `names`. + If kind == 'single', then `name_details` is a single string + corresponding to the single estimator. + + dash_wrapped : bool, default=True + If true, wrapped HTML element will be wrapped with a dashed border. + Only active when kind != 'single'. + """ + def __init__(self, kind, estimators, *, names=None, name_details=None, + dash_wrapped=True): + self.kind = kind + self.estimators = estimators + self.dash_wrapped = dash_wrapped + + if self.kind in ('parallel', 'serial'): + if names is None: + names = (None, ) * len(estimators) + if name_details is None: + name_details = (None, ) * len(estimators) + + self.names = names + self.name_details = name_details + + def _sk_visual_block_(self): + return self + + +def _write_label_html(out, name, name_details, + outer_class="sk-label-container", + inner_class="sk-label", + checked=False): + """Write labeled html with or without a dropdown with named details""" + out.write(f'
' + f'
') + name = html.escape(name) + + if name_details is not None: + checked_str = 'checked' if checked else '' + est_id = uuid.uuid4() + out.write(f'' + f'' + f'
{name_details}'
+                  f'
') + else: + out.write(f'') + out.write('
') # outer_class inner_class + + +def _get_visual_block(estimator): + """Generate information about how to display an estimator. + """ + with suppress(AttributeError): + return estimator._sk_visual_block_() + + if isinstance(estimator, str): + return _VisualBlock('single', estimator, + names=estimator, name_details=estimator) + elif estimator is None: + return _VisualBlock('single', estimator, + names='None', name_details='None') + + # check if estimator looks like a meta estimator wraps estimators + if hasattr(estimator, 'get_params'): + estimators = [] + for key, value in estimator.get_params().items(): + # Only look at the estimators in the first layer + if '__' not in key and hasattr(value, 'get_params'): + estimators.append(value) + if len(estimators): + return _VisualBlock('parallel', estimators, names=None) + + return _VisualBlock('single', estimator, + names=estimator.__class__.__name__, + name_details=str(estimator)) + + +def _write_estimator_html(out, estimator, estimator_label, + estimator_label_details, first_call=False): + """Write estimator to html in serial, parallel, or by itself (single). + """ + if first_call: + est_block = _get_visual_block(estimator) + else: + with config_context(print_changed_only=True): + est_block = _get_visual_block(estimator) + + if est_block.kind in ('serial', 'parallel'): + dashed_wrapped = first_call or est_block.dash_wrapped + dash_cls = " sk-dashed-wrapped" if dashed_wrapped else "" + out.write(f'
') + + if estimator_label: + _write_label_html(out, estimator_label, estimator_label_details) + + kind = est_block.kind + out.write(f'
') + est_infos = zip(est_block.estimators, est_block.names, + est_block.name_details) + + for est, name, name_details in est_infos: + if kind == 'serial': + _write_estimator_html(out, est, name, name_details) + else: # parallel + out.write('
') + # wrap element in a serial visualblock + serial_block = _VisualBlock('serial', [est], + dash_wrapped=False) + _write_estimator_html(out, serial_block, name, name_details) + out.write('
') # sk-parallel-item + + out.write('
') + elif est_block.kind == 'single': + _write_label_html(out, est_block.names, est_block.name_details, + outer_class="sk-item", inner_class="sk-estimator", + checked=first_call) + + +_STYLE = """ +div.sk-top-container { + color: black; + background-color: white; +} +div.sk-toggleable { + background-color: white; +} +label.sk-toggleable__label { + cursor: pointer; + display: block; + width: 100%; + margin-bottom: 0; + padding: 0.2em 0.3em; + box-sizing: border-box; + text-align: center; +} +div.sk-toggleable__content { + max-height: 0; + max-width: 0; + overflow: hidden; + text-align: left; + background-color: #f0f8ff; +} +div.sk-toggleable__content pre { + margin: 0.2em; + color: black; + border-radius: 0.25em; + background-color: #f0f8ff; +} +input.sk-toggleable__control:checked~div.sk-toggleable__content { + max-height: 200px; + max-width: 100%; + overflow: auto; +} +div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label { + background-color: #d4ebff; +} +div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label { + background-color: #d4ebff; +} +input.sk-hidden--visually { + border: 0; + clip: rect(1px 1px 1px 1px); + clip: rect(1px, 1px, 1px, 1px); + height: 1px; + margin: -1px; + overflow: hidden; + padding: 0; + position: absolute; + width: 1px; +} +div.sk-estimator { + font-family: monospace; + background-color: #f0f8ff; + margin: 0.25em 0.25em; + border: 1px dotted black; + border-radius: 0.25em; + box-sizing: border-box; +} +div.sk-estimator:hover { + background-color: #d4ebff; +} +div.sk-parallel-item::after { + content: ""; + width: 100%; + border-bottom: 1px solid gray; + flex-grow: 1; +} +div.sk-label:hover label.sk-toggleable__label { + background-color: #d4ebff; +} +div.sk-serial::before { + content: ""; + position: absolute; + border-left: 1px solid gray; + box-sizing: border-box; + top: 2em; + bottom: 0; + left: 50%; +} +div.sk-serial { + display: flex; + flex-direction: column; + align-items: center; + background-color: white; +} +div.sk-item { + z-index: 1; +} +div.sk-parallel { + display: flex; + align-items: stretch; + justify-content: center; + background-color: white; +} +div.sk-parallel-item { + display: flex; + flex-direction: column; + position: relative; + background-color: white; +} +div.sk-parallel-item:first-child::after { + align-self: flex-end; + width: 50%; +} +div.sk-parallel-item:last-child::after { + align-self: flex-start; + width: 50%; +} +div.sk-parallel-item:only-child::after { + width: 0; +} +div.sk-dashed-wrapped { + border: 1px dashed gray; + margin: 0.2em; + box-sizing: border-box; + padding-bottom: 0.1em; + background-color: white; + position: relative; +} +div.sk-label label { + font-family: monospace; + font-weight: bold; + background-color: white; + display: inline-block; + line-height: 1.2em; +} +div.sk-label-container { + position: relative; + z-index: 2; + text-align: center; +} +div.sk-container { + display: inline-block; + position: relative; +} +""".replace(' ', '').replace('\n', '') # noqa + + +def estimator_html_repr(estimator): + """Build a HTML representation of an estimator. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + estimator : estimator object + The estimator to visualize. + + Returns + ------- + html: str + HTML representation of estimator. + """ + with closing(StringIO()) as out: + out.write(f'' + f'
') + _write_estimator_html(out, estimator, estimator.__class__.__name__, + str(estimator), first_call=True) + out.write('
') + + html_output = out.getvalue() + return html_output diff --git a/sklearn/utils/tests/test_estimator_html_repr.py b/sklearn/utils/tests/test_estimator_html_repr.py new file mode 100644 index 0000000000000..47d33051bd9a7 --- /dev/null +++ b/sklearn/utils/tests/test_estimator_html_repr.py @@ -0,0 +1,267 @@ +from contextlib import closing +from io import StringIO + +import pytest + +from sklearn import config_context +from sklearn.linear_model import LogisticRegression +from sklearn.neural_network import MLPClassifier +from sklearn.impute import SimpleImputer +from sklearn.decomposition import PCA +from sklearn.decomposition import TruncatedSVD +from sklearn.pipeline import Pipeline +from sklearn.pipeline import FeatureUnion +from sklearn.compose import ColumnTransformer +from sklearn.ensemble import VotingClassifier +from sklearn.feature_selection import SelectPercentile +from sklearn.cluster import Birch +from sklearn.cluster import AgglomerativeClustering +from sklearn.preprocessing import OneHotEncoder +from sklearn.svm import LinearSVC +from sklearn.svm import LinearSVR +from sklearn.tree import DecisionTreeClassifier +from sklearn.multiclass import OneVsOneClassifier +from sklearn.ensemble import StackingClassifier +from sklearn.ensemble import StackingRegressor +from sklearn.gaussian_process import GaussianProcessRegressor +from sklearn.gaussian_process.kernels import RationalQuadratic +from sklearn.utils._estimator_html_repr import _write_label_html +from sklearn.utils._estimator_html_repr import _get_visual_block +from sklearn.utils._estimator_html_repr import estimator_html_repr + + +@pytest.mark.parametrize("checked", [True, False]) +def test_write_label_html(checked): + # Test checking logic and labeling + name = "LogisticRegression" + tool_tip = "hello-world" + + with closing(StringIO()) as out: + _write_label_html(out, name, tool_tip, checked=checked) + html_label = out.getvalue() + assert 'LogisticRegression' in html_label + assert html_label.startswith('
') + assert '
hello-world
' in html_label + if checked: + assert 'checked>' in html_label + + +@pytest.mark.parametrize('est', ['passthrough', 'drop', None]) +def test_get_visual_block_single_str_none(est): + # Test estimators that are represnted by strings + est_html_info = _get_visual_block(est) + assert est_html_info.kind == 'single' + assert est_html_info.estimators == est + assert est_html_info.names == str(est) + assert est_html_info.name_details == str(est) + + +def test_get_visual_block_single_estimator(): + est = LogisticRegression(C=10.0) + est_html_info = _get_visual_block(est) + assert est_html_info.kind == 'single' + assert est_html_info.estimators == est + assert est_html_info.names == est.__class__.__name__ + assert est_html_info.name_details == str(est) + + +def test_get_visual_block_pipeline(): + pipe = Pipeline([ + ('imputer', SimpleImputer()), + ('do_nothing', 'passthrough'), + ('do_nothing_more', None), + ('classifier', LogisticRegression()) + ]) + est_html_info = _get_visual_block(pipe) + assert est_html_info.kind == 'serial' + assert est_html_info.estimators == tuple(step[1] for step in pipe.steps) + assert est_html_info.names == ['imputer: SimpleImputer', + 'do_nothing: passthrough', + 'do_nothing_more: passthrough', + 'classifier: LogisticRegression'] + assert est_html_info.name_details == [str(est) for _, est in pipe.steps] + + +def test_get_visual_block_feature_union(): + f_union = FeatureUnion([ + ('pca', PCA()), ('svd', TruncatedSVD()) + ]) + est_html_info = _get_visual_block(f_union) + assert est_html_info.kind == 'parallel' + assert est_html_info.names == ('pca', 'svd') + assert est_html_info.estimators == tuple( + trans[1] for trans in f_union.transformer_list) + assert est_html_info.name_details == (None, None) + + +def test_get_visual_block_voting(): + clf = VotingClassifier([ + ('log_reg', LogisticRegression()), + ('mlp', MLPClassifier()) + ]) + est_html_info = _get_visual_block(clf) + assert est_html_info.kind == 'parallel' + assert est_html_info.estimators == tuple(trans[1] + for trans in clf.estimators) + assert est_html_info.names == ('log_reg', 'mlp') + assert est_html_info.name_details == (None, None) + + +def test_get_visual_block_column_transformer(): + ct = ColumnTransformer([ + ('pca', PCA(), ['num1', 'num2']), + ('svd', TruncatedSVD, [0, 3]) + ]) + est_html_info = _get_visual_block(ct) + assert est_html_info.kind == 'parallel' + assert est_html_info.estimators == tuple( + trans[1] for trans in ct.transformers) + assert est_html_info.names == ('pca', 'svd') + assert est_html_info.name_details == (['num1', 'num2'], [0, 3]) + + +def test_estimator_html_repr_pipeline(): + num_trans = Pipeline(steps=[ + ('pass', 'passthrough'), + ('imputer', SimpleImputer(strategy='median')) + ]) + + cat_trans = Pipeline(steps=[ + ('imputer', SimpleImputer(strategy='constant', + missing_values='empty')), + ('one-hot', OneHotEncoder(drop='first')) + ]) + + preprocess = ColumnTransformer([ + ('num', num_trans, ['a', 'b', 'c', 'd', 'e']), + ('cat', cat_trans, [0, 1, 2, 3]) + ]) + + feat_u = FeatureUnion([ + ('pca', PCA(n_components=1)), + ('tsvd', Pipeline([('first', TruncatedSVD(n_components=3)), + ('select', SelectPercentile())])) + ]) + + clf = VotingClassifier([ + ('lr', LogisticRegression(solver='lbfgs', random_state=1)), + ('mlp', MLPClassifier(alpha=0.001)) + ]) + + pipe = Pipeline([ + ('preprocessor', preprocess), ('feat_u', feat_u), ('classifier', clf) + ]) + html_output = estimator_html_repr(pipe) + + # top level estimators show estimator with changes + assert str(pipe) in html_output + for _, est in pipe.steps: + assert (f"
" + f"
{str(est)}") in html_output
+
+    # low level estimators do not show changes
+    with config_context(print_changed_only=True):
+        assert str(num_trans['pass']) in html_output
+        assert 'passthrough' in html_output
+        assert str(num_trans['imputer']) in html_output
+
+        for _, _, cols in preprocess.transformers:
+            assert f"
{cols}
" in html_output + + # feature union + for name, _ in feat_u.transformer_list: + assert f"" in html_output + + pca = feat_u.transformer_list[0][1] + assert f"
{str(pca)}
" in html_output + + tsvd = feat_u.transformer_list[1][1] + first = tsvd['first'] + select = tsvd['select'] + assert f"
{str(first)}
" in html_output + assert f"
{str(select)}
" in html_output + + # voting classifer + for name, est in clf.estimators: + assert f"" in html_output + assert f"
{str(est)}
" in html_output + + +@pytest.mark.parametrize("final_estimator", [None, LinearSVC()]) +def test_stacking_classsifer(final_estimator): + estimators = [('mlp', MLPClassifier(alpha=0.001)), + ('tree', DecisionTreeClassifier())] + clf = StackingClassifier( + estimators=estimators, final_estimator=final_estimator) + + html_output = estimator_html_repr(clf) + + assert str(clf) in html_output + # If final_estimator's default changes from LogisticRegression + # this should be updated + if final_estimator is None: + assert "LogisticRegression(" in html_output + else: + assert final_estimator.__class__.__name__ in html_output + + +@pytest.mark.parametrize("final_estimator", [None, LinearSVR()]) +def test_stacking_regressor(final_estimator): + reg = StackingRegressor( + estimators=[('svr', LinearSVR())], final_estimator=final_estimator) + html_output = estimator_html_repr(reg) + + assert str(reg.estimators[0][0]) in html_output + assert "LinearSVR" in html_output + if final_estimator is None: + assert "RidgeCV" in html_output + else: + assert final_estimator.__class__.__name__ in html_output + + +def test_birch_duck_typing_meta(): + # Test duck typing meta estimators with Birch + birch = Birch(n_clusters=AgglomerativeClustering(n_clusters=3)) + html_output = estimator_html_repr(birch) + + # inner estimators do not show changes + with config_context(print_changed_only=True): + assert f"
{str(birch.n_clusters)}" in html_output
+        assert "AgglomerativeClustering" in html_output
+
+    # outer estimator contains all changes
+    assert f"
{str(birch)}" in html_output
+
+
+def test_ovo_classifier_duck_typing_meta():
+    # Test duck typing metaestimators with OVO
+    ovo = OneVsOneClassifier(LinearSVC(penalty='l1'))
+    html_output = estimator_html_repr(ovo)
+
+    # inner estimators do not show changes
+    with config_context(print_changed_only=True):
+        assert f"
{str(ovo.estimator)}" in html_output
+        assert "LinearSVC" in html_output
+
+    # outter estimator
+    assert f"
{str(ovo)}" in html_output
+
+
+def test_duck_typing_nested_estimator():
+    # Test duck typing metaestimators with GP
+    kernel = RationalQuadratic(length_scale=1.0, alpha=0.1)
+    gp = GaussianProcessRegressor(kernel=kernel)
+    html_output = estimator_html_repr(gp)
+
+    assert f"
{str(kernel)}" in html_output
+    assert f"
{str(gp)}" in html_output
+
+
+@pytest.mark.parametrize('print_changed_only', [True, False])
+def test_one_estimator_print_change_only(print_changed_only):
+    pca = PCA(n_components=10)
+
+    with config_context(print_changed_only=print_changed_only):
+        pca_repr = str(pca)
+        html_output = estimator_html_repr(pca)
+        assert pca_repr in html_output