-
-
Notifications
You must be signed in to change notification settings - Fork 26.2k
Closed
Labels
Description
Describe the bug
In PR #26952 @thomasjpfan was hinting that transformers in a pipeline are allowed to be stateless. They are, but only if no other step of the pipeline implements a fit()
.
However, if one of the steps implements a fit
, the previous steps are expected to be stateful transformers, too. This is because def _validate_steps()
(which checks the methods of all the steps in the pipeline) only checks the steps if .fit()
is called.
Should _validate_steps()
be modified to be less strict about stateless transformers? And should it then also be run on transform()
, and all the other possible methods?
Steps/Code to Reproduce
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LinearRegression
import numpy as np
class DoubleIt:
def transform(self, X, y=None):
return 2*X
X = np.array([[1, 2, 3], [4, 5, 6]])
p = Pipeline([
('double1', DoubleIt()),
('double2', DoubleIt()),
('linreg', LinearRegression()) # same results with ('linreg', None), but not with "passthrough"
])
p.fit(X)
Expected Results
No error is raised.
Actual Results
Traceback (most recent call last):
File "/home/stefanie/Python/scikit-learn_dev/scikit-learn/::::::::::::::::::::::::::::.py", line 126, in <module>
p.fit(X)
File "/home/stefanie/Python/scikit-learn_dev/scikit-learn/sklearn/base.py", line 1215, in wrapper
return fit_method(estimator, *args, **kwargs)
File "/home/stefanie/Python/scikit-learn_dev/scikit-learn/sklearn/pipeline.py", line 456, in fit
Xt = self._fit(X, y, routed_params)
File "/home/stefanie/Python/scikit-learn_dev/scikit-learn/sklearn/pipeline.py", line 372, in _fit
self._validate_steps()
File "/home/stefanie/Python/scikit-learn_dev/scikit-learn/sklearn/pipeline.py", line 242, in _validate_steps
raise TypeError(
TypeError: All intermediate steps should be transformers and implement fit and transform or be the string 'passthrough' '<__main__.DoubleIt object at 0x7fd4b028f310>' (type <class '__main__.DoubleIt'>) doesn't
Versions
System:
python: 3.10.6 (main, Oct 10 2022, 12:43:33) [GCC 9.4.0]
executable: /home/stefanie/.pyenv/versions/3.10.6/envs/scikit-learn_dev/bin/python
machine: Linux-5.15.0-78-generic-x86_64-with-glibc2.31
Python dependencies:
sklearn: 1.4.dev0
pip: 23.2.1
setuptools: 63.2.0
numpy: 1.24.3
scipy: 1.10.1
Cython: 0.29.34
pandas: 2.0.1
matplotlib: 3.7.1
joblib: 1.2.0
threadpoolctl: 3.1.0
Built with OpenMP: True
threadpoolctl info:
user_api: blas
internal_api: openblas
prefix: libopenblas
filepath: /home/stefanie/.pyenv/versions/3.10.6/envs/scikit-learn_dev/lib/python3.10/site-packages/numpy.libs/libopenblas64_p-r0-15028c96.3.21.so
version: 0.3.21
threading_layer: pthreads
architecture: SkylakeX
num_threads: 8
user_api: blas
internal_api: openblas
prefix: libopenblas
filepath: /home/stefanie/.pyenv/versions/3.10.6/envs/scikit-learn_dev/lib/python3.10/site-packages/scipy.libs/libopenblasp-r0-41284840.3.18.so
version: 0.3.18
threading_layer: pthreads
architecture: SkylakeX
num_threads: 8
user_api: openmp
internal_api: openmp
prefix: libgomp
filepath: /usr/lib/x86_64-linux-gnu/libgomp.so.1.0.0
version: None
num_threads: 8
Metadata
Metadata
Assignees
Labels
Type
Projects
Status
Done