Skip to content

Error when returning embedded transformers in Jupyter notebook #24545

@gbrookshire

Description

@gbrookshire

Describe the bug

When creating a custom transformer object that includes a transformer type as an instance, a TypeError is thrown if the object is returned at the end of a Jupyter cell. This does not cause an error in the terminal, but raises an error during the conversion to an HTML object for Jupyter notebooks. Weirdly, the object is created and returned, but an error is thrown when Jupyter attempts to display it.

Steps/Code to Reproduce

from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.feature_selection import VarianceThreshold

class EmbeddedTransformer(BaseEstimator, TransformerMixin):
    def __init__(self, transformer):
        self.transformer = transformer
        
    def fit(self, X=None, y=None):
        return self
            
    def transform(self, X=None):
        return X

EmbeddedTransformer(VarianceThreshold())  # No error
EmbeddedTransformer('hello')  # No error
t = EmbeddedTransformer(VarianceThreshold)  # No error
EmbeddedTransformer(VarianceThreshold)  # ERROR

Expected Results

No error is thrown, and the HTML representation of the transformer is shown in the Jupyter cell.

Actual Results

Here's the full error traceback:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
File ~/anaconda3/envs/prodenv/lib/python3.8/site-packages/IPython/core/formatters.py:973, in MimeBundleFormatter.__call__(self, obj, include, exclude)
    970     method = get_real_method(obj, self.print_method)
    972     if method is not None:
--> 973         return method(include=include, exclude=exclude)
    974     return None
    975 else:

File ~/anaconda3/envs/prodenv/lib/python3.8/site-packages/sklearn/base.py:631, in BaseEstimator._repr_mimebundle_(self, **kwargs)
    629 output = {"text/plain": repr(self)}
    630 if get_config()["display"] == "diagram":
--> 631     output["text/html"] = estimator_html_repr(self)
    632 return output

File ~/anaconda3/envs/prodenv/lib/python3.8/site-packages/sklearn/utils/_estimator_html_repr.py:410, in estimator_html_repr(estimator)
    396 fallback_msg = (
    397     "In a Jupyter environment, please rerun this cell to show the HTML"
    398     " representation or trust the notebook. <br />On GitHub, the"
    399     " HTML representation is unable to render, please try loading this page"
    400     " with nbviewer.org."
    401 )
    402 out.write(
    403     f"<style>{style_with_id}</style>"
    404     f'<div id="{container_id}" class="sk-top-container">'
   (...)
    408     '<div class="sk-container" hidden>'
    409 )
--> 410 _write_estimator_html(
    411     out,
    412     estimator,
    413     estimator.__class__.__name__,
    414     estimator_str,
    415     first_call=True,
    416 )
    417 out.write("</div></div>")
    419 html_output = out.getvalue()

File ~/anaconda3/envs/prodenv/lib/python3.8/site-packages/sklearn/utils/_estimator_html_repr.py:168, in _write_estimator_html(out, estimator, estimator_label, estimator_label_details, first_call)
    166         # wrap element in a serial visualblock
    167         serial_block = _VisualBlock("serial", [est], dash_wrapped=False)
--> 168         _write_estimator_html(out, serial_block, name, name_details)
    169         out.write("</div>")  # sk-parallel-item
    171 out.write("</div></div>")

File ~/anaconda3/envs/prodenv/lib/python3.8/site-packages/sklearn/utils/_estimator_html_repr.py:163, in _write_estimator_html(out, estimator, estimator_label, estimator_label_details, first_call)
    161 for est, name, name_details in est_infos:
    162     if kind == "serial":
--> 163         _write_estimator_html(out, est, name, name_details)
    164     else:  # parallel
    165         out.write('<div class="sk-parallel-item">')

File ~/anaconda3/envs/prodenv/lib/python3.8/site-packages/sklearn/utils/_estimator_html_repr.py:147, in _write_estimator_html(out, estimator, estimator_label, estimator_label_details, first_call)
    145 else:
    146     with config_context(print_changed_only=True):
--> 147         est_block = _get_visual_block(estimator)
    149 if est_block.kind in ("serial", "parallel"):
    150     dashed_wrapped = first_call or est_block.dash_wrapped

File ~/anaconda3/envs/prodenv/lib/python3.8/site-packages/sklearn/utils/_estimator_html_repr.py:120, in _get_visual_block(estimator)
    116 # check if estimator looks like a meta estimator wraps estimators
    117 if hasattr(estimator, "get_params"):
    118     estimators = [
    119         (key, est)
--> 120         for key, est in estimator.get_params(deep=False).items()
    121         if hasattr(est, "get_params") and hasattr(est, "fit")
    122     ]
    123     if estimators:
    124         return _VisualBlock(
    125             "parallel",
    126             [est for _, est in estimators],
    127             names=[f"{key}: {est.__class__.__name__}" for key, est in estimators],
    128             name_details=[str(est) for _, est in estimators],
    129         )

TypeError: get_params() missing 1 required positional argument: 'self'

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
File ~/anaconda3/envs/prodenv/lib/python3.8/site-packages/IPython/core/formatters.py:343, in BaseFormatter.__call__(self, obj)
    341     method = get_real_method(obj, self.print_method)
    342     if method is not None:
--> 343         return method()
    344     return None
    345 else:

File ~/anaconda3/envs/prodenv/lib/python3.8/site-packages/sklearn/base.py:625, in BaseEstimator._repr_html_inner(self)
    620 def _repr_html_inner(self):
    621     """This function is returned by the @property `_repr_html_` to make
    622     `hasattr(estimator, "_repr_html_") return `True` or `False` depending
    623     on `get_config()["display"]`.
    624     """
--> 625     return estimator_html_repr(self)

File ~/anaconda3/envs/prodenv/lib/python3.8/site-packages/sklearn/utils/_estimator_html_repr.py:410, in estimator_html_repr(estimator)
    396 fallback_msg = (
    397     "In a Jupyter environment, please rerun this cell to show the HTML"
    398     " representation or trust the notebook. <br />On GitHub, the"
    399     " HTML representation is unable to render, please try loading this page"
    400     " with nbviewer.org."
    401 )
    402 out.write(
    403     f"<style>{style_with_id}</style>"
    404     f'<div id="{container_id}" class="sk-top-container">'
   (...)
    408     '<div class="sk-container" hidden>'
    409 )
--> 410 _write_estimator_html(
    411     out,
    412     estimator,
    413     estimator.__class__.__name__,
    414     estimator_str,
    415     first_call=True,
    416 )
    417 out.write("</div></div>")
    419 html_output = out.getvalue()

File ~/anaconda3/envs/prodenv/lib/python3.8/site-packages/sklearn/utils/_estimator_html_repr.py:168, in _write_estimator_html(out, estimator, estimator_label, estimator_label_details, first_call)
    166         # wrap element in a serial visualblock
    167         serial_block = _VisualBlock("serial", [est], dash_wrapped=False)
--> 168         _write_estimator_html(out, serial_block, name, name_details)
    169         out.write("</div>")  # sk-parallel-item
    171 out.write("</div></div>")

File ~/anaconda3/envs/prodenv/lib/python3.8/site-packages/sklearn/utils/_estimator_html_repr.py:163, in _write_estimator_html(out, estimator, estimator_label, estimator_label_details, first_call)
    161 for est, name, name_details in est_infos:
    162     if kind == "serial":
--> 163         _write_estimator_html(out, est, name, name_details)
    164     else:  # parallel
    165         out.write('<div class="sk-parallel-item">')

File ~/anaconda3/envs/prodenv/lib/python3.8/site-packages/sklearn/utils/_estimator_html_repr.py:147, in _write_estimator_html(out, estimator, estimator_label, estimator_label_details, first_call)
    145 else:
    146     with config_context(print_changed_only=True):
--> 147         est_block = _get_visual_block(estimator)
    149 if est_block.kind in ("serial", "parallel"):
    150     dashed_wrapped = first_call or est_block.dash_wrapped

File ~/anaconda3/envs/prodenv/lib/python3.8/site-packages/sklearn/utils/_estimator_html_repr.py:120, in _get_visual_block(estimator)
    116 # check if estimator looks like a meta estimator wraps estimators
    117 if hasattr(estimator, "get_params"):
    118     estimators = [
    119         (key, est)
--> 120         for key, est in estimator.get_params(deep=False).items()
    121         if hasattr(est, "get_params") and hasattr(est, "fit")
    122     ]
    123     if estimators:
    124         return _VisualBlock(
    125             "parallel",
    126             [est for _, est in estimators],
    127             names=[f"{key}: {est.__class__.__name__}" for key, est in estimators],
    128             name_details=[str(est) for _, est in estimators],
    129         )

TypeError: get_params() missing 1 required positional argument: 'self'

After that error message, we get a representation of the object:
EmbeddedTransformer(transformer=<class 'sklearn.feature_selection._variance_threshold.VarianceThreshold'>)

Versions

System:
    python: 3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:04:18)  [GCC 10.3.0]
executable: /home/ec2-user/anaconda3/envs/prodenv/bin/python
   machine: Linux-4.14.287-148.504.amzn1.x86_64-x86_64-with-glibc2.10

Python dependencies:
      sklearn: 1.1.0
          pip: 22.1.2
   setuptools: 63.2.0
        numpy: 1.22.0
        scipy: 1.8.0
       Cython: 0.29.30
       pandas: 1.4.2
   matplotlib: 3.5.2
       joblib: 1.0.1
threadpoolctl: 3.1.0

Built with OpenMP: True

threadpoolctl info:
       user_api: blas
   internal_api: openblas
         prefix: libopenblas
       filepath: /home/ec2-user/SageMaker/kernels/prodenv/lib/libopenblasp-r0.3.21.so
        version: 0.3.21
threading_layer: pthreads
   architecture: SkylakeX
    num_threads: 8

       user_api: openmp
   internal_api: openmp
         prefix: libgomp
       filepath: /home/ec2-user/SageMaker/kernels/prodenv/lib/python3.8/site-packages/scikit_learn.libs/libgomp-a34b3233.so.1.0.0
        version: None
    num_threads: 8

       user_api: blas
   internal_api: openblas
         prefix: libopenblas
       filepath: /home/ec2-user/SageMaker/kernels/prodenv/lib/python3.8/site-packages/scipy.libs/libopenblasp-r0-8b9e111f.3.17.so
        version: 0.3.17
threading_layer: pthreads
   architecture: SkylakeX
    num_threads: 8

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions