Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unexpected behavior for subclassing Pipeline #30748

Open
schroedk opened this issue Feb 2, 2025 · 2 comments
Open

Unexpected behavior for subclassing Pipeline #30748

schroedk opened this issue Feb 2, 2025 · 2 comments

Comments

@schroedk
Copy link

schroedk commented Feb 2, 2025

Describe the issue linked to the documentation

Hey, I don't know if I should call this a bug, but for me at least it was unexpected behavior. I tried to subclass from Pipeline
to implement a customization, so having a simplified configuration, which is used to build a sequence of transformations.

It generates an AttributeError, due to not having an instance attribute with the same name as a positional argument (same is true for a kwarg) of the subclasses's init. Find a minimal example below.

Is this expected behavior? It does not harm to set the instance attributes with the same name, but it is surprising it is demanded and is very implicit. Also, it does not pop up, when you instantiate the object, but only when you try to call a method on it.

In case it is absolutely necessary, it may need some documentation.

In addition, I tried to globally skip parameter validation and it did not help in this situation, which might be a real bug?

Thanks for your help, and your good work:)

A simple example:

import sklearn
sklearn.set_config(
    skip_parameter_validation=True,  # disable validation
)
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder
from sklearn.base import BaseEstimator, TransformerMixin
import pandas as pd


class TakeColumn(BaseEstimator, TransformerMixin):
    def __init__(self, column: str):
        self.column = column

    def __str__(self):
        return self.__class__.__name__ + f"[{self.column}]"

    def fit(self, X, y=None):
        return self

    def transform(self, X: pd.DataFrame) -> pd.DataFrame:
        return X[[self.column]]


class CategoricalFeature(Pipeline):
    def __init__(self, column: str, encode=True):

        take_column = TakeColumn(column)
        steps = [(str(take_column), take_column)]

        if encode:
            encoder = OneHotEncoder()
            steps.append((str(encoder), encoder))

        # setting instance attributes having the same name, removes the exception
        #self.column = column
        #self.encode = encode

        super().__init__(steps)


df = pd.DataFrame([["a"], ["b"], ["c"]], columns=["column"])

column_feature = CategoricalFeature("column")
some_other_feature = CategoricalFeature("other_column", encode=False)

# this fails, if instance attributes are not set with the same name as the
# corresponding parameter of __init__
result = column_feature.fit_transform(df)

Output:

Traceback (most recent call last):
  File "<input>", line 1, in <module>
  File "/Users/kristof/Library/Application Support/JetBrains/IntelliJIdea2024.3/plugins/python-ce/helpers/pydev/_pydev_bundle/pydev_umd.py", line 197, in runfile
    pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/kristof/Library/Application Support/JetBrains/IntelliJIdea2024.3/plugins/python-ce/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/Users/kristof/Projects/pipeline-issue/default_console.py", line 45, in <module>
    result = column_feature.fit_transform(df)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/kristof/Projects/pipeline-issue/.venv/lib/python3.12/site-packages/sklearn/base.py", line 1382, in wrapper
    estimator._validate_params()
  File "/Users/kristof/Projects/pipeline-issue/.venv/lib/python3.12/site-packages/sklearn/base.py", line 438, in _validate_params
    self.get_params(deep=False),
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/kristof/Projects/pipeline-issue/.venv/lib/python3.12/site-packages/sklearn/pipeline.py", line 299, in get_params
    return self._get_params("steps", deep=deep)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/kristof/Projects/pipeline-issue/.venv/lib/python3.12/site-packages/sklearn/utils/metaestimators.py", line 30, in _get_params
    out = super().get_params(deep=deep)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/kristof/Projects/pipeline-issue/.venv/lib/python3.12/site-packages/sklearn/base.py", line 248, in get_params
    value = getattr(self, key)
            ^^^^^^^^^^^^^^^^^^
AttributeError: 'CategoricalFeature' object has no attribute 'column'
System:
    python: 3.12.7 | packaged by conda-forge | (main, Oct  4 2024, 15:57:01) [Clang 17.0.6 ]
executable: /Users/kristof/Projects/pipeline-issue/.venv/bin/python
   machine: macOS-15.3-arm64-arm-64bit

Python dependencies:
      sklearn: 1.6.1
          pip: None
   setuptools: None
        numpy: 2.2.2
        scipy: 1.15.1
       Cython: None
       pandas: 2.2.3
   matplotlib: None
       joblib: 1.4.2
threadpoolctl: 3.5.0

Built with OpenMP: True

threadpoolctl info:
       user_api: openmp
   internal_api: openmp
    num_threads: 12
         prefix: libomp
       filepath: /Users/kristof/Projects/pipeline-issue/.venv/lib/python3.12/site-packages/sklearn/.dylibs/libomp.dylib
        version: None

Suggest a potential alternative/fix

No response

@schroedk schroedk added Documentation Needs Triage Issue requires triage labels Feb 2, 2025
@betatim
Copy link
Member

betatim commented Feb 3, 2025

You are running into what is documented in the "instantiation" section of the guide to develop your own estimator https://scikit-learn.org/dev/developers/develop.html#instantiation

TL;DR: in an estimator's __init__ you should not perform any work, just store all the (keyword) arguments as instance attributes. All other work, including validation, should happen in fit.

@betatim betatim removed the Needs Triage Issue requires triage label Feb 3, 2025
@schroedk
Copy link
Author

schroedk commented Feb 4, 2025

Thanks a lot for the quick answer and the hint to the documentation. So, do you would say the best advice is to not inherit from Pipeline at all?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants