-
-
Notifications
You must be signed in to change notification settings - Fork 26.2k
Closed
Description
Describe the bug
When combining MultioutputClassifier
with an estimator that doesn't have sample_weight as metadata in the fit
method, such as LinearDiscriminantAnalysis
, it fails to fit.
Steps/Code to Reproduce
import sklearn
from sklearn.multioutput import MultiOutputClassifier
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.datasets import make_multilabel_classification
sklearn.set_config(enable_metadata_routing=True)
X, y = make_multilabel_classification(n_samples=100, n_features=2, n_classes=2)
MultiOutputClassifier(LinearDiscriminantAnalysis()).fit(X, y)
Expected Results
No error thrown.
Actual Results
Traceback (most recent call last):
File "C:\Users\Mavs\Documents\Python\ATOM\venv311\Lib\site-packages\IPython\core\interactiveshell.py", line 3553, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-2-135dae0c8613>", line 10, in <module>
MultiOutputClassifier(LinearDiscriminantAnalysis()).fit(X, y)
File "C:\Users\Mavs\Documents\Python\ATOM\venv311\Lib\site-packages\sklearn\multioutput.py", line 535, in fit
super().fit(X, Y, sample_weight=sample_weight, **fit_params)
File "C:\Users\Mavs\Documents\Python\ATOM\venv311\Lib\site-packages\sklearn\base.py", line 1351, in wrapper
return fit_method(estimator, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\Mavs\Documents\Python\ATOM\venv311\Lib\site-packages\sklearn\multioutput.py", line 251, in fit
routed_params = process_routing(
^^^^^^^^^^^^^^^^
File "C:\Users\Mavs\Documents\Python\ATOM\venv311\Lib\site-packages\sklearn\utils\_metadata_requests.py", line 1556, in process_routing
request_routing.validate_metadata(params=kwargs, method=_method)
File "C:\Users\Mavs\Documents\Python\ATOM\venv311\Lib\site-packages\sklearn\utils\_metadata_requests.py", line 1060, in validate_metadata
raise TypeError(
TypeError: MultiOutputClassifier.fit got unexpected argument(s) {'sample_weight'}, which are not requested metadata in any object.
Versions
System:
python: 3.11.2 (tags/v3.11.2:878ead1, Feb 7 2023, 16:38:35) [MSC v.1934 64 bit (AMD64)]
executable: C:\Users\Mavs\Documents\Python\ATOM\venv311\Scripts\python.exe
machine: Windows-10-10.0.19045-SP0
Python dependencies:
sklearn: 1.4.0
pip: 23.3.2
setuptools: 68.2.2
numpy: 1.26.3
scipy: 1.11.4
Cython: 3.0.5
pandas: 2.1.4
matplotlib: 3.8.2
joblib: 1.3.2
threadpoolctl: 3.2.0
Built with OpenMP: True
threadpoolctl info:
user_api: openmp
internal_api: openmp
num_threads: 16
prefix: vcomp
filepath: C:\Users\Mavs\Documents\Python\ATOM\venv311\Lib\site-packages\sklearn\.libs\vcomp140.dll
version: None
user_api: blas
internal_api: openblas
num_threads: 16
prefix: libopenblas
filepath: C:\Users\Mavs\Documents\Python\ATOM\venv311\Lib\site-packages\numpy.libs\libopenblas64__v0.3.23-293-gc2f4bdbb-gcc_10_3_0-2bde3a66a51006b2b53eb373ff767a3f.dll
version: 0.3.23.dev
threading_layer: pthreads
architecture: Zen
user_api: blas
internal_api: openblas
num_threads: 16
prefix: libopenblas
filepath: C:\Users\Mavs\Documents\Python\ATOM\venv311\Lib\site-packages\scipy.libs\libopenblas_v0.3.20-571-g3dec11c6-gcc_10_3_0-c2315440d6b6cef5037bad648efc8c59.dll
version: 0.3.21.dev
threading_layer: pthreads
architecture: Zen