Skip to content

Metadata routing breaks MultioutputClassifier with estimator that doesn't support sample_weight in fit. #28239

@tvdboom

Description

@tvdboom

Describe the bug

When combining MultioutputClassifier with an estimator that doesn't have sample_weight as metadata in the fit method, such as LinearDiscriminantAnalysis, it fails to fit.

Steps/Code to Reproduce

import sklearn
from sklearn.multioutput import MultiOutputClassifier
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.datasets import make_multilabel_classification

sklearn.set_config(enable_metadata_routing=True)

X, y = make_multilabel_classification(n_samples=100, n_features=2, n_classes=2)

MultiOutputClassifier(LinearDiscriminantAnalysis()).fit(X, y)

Expected Results

No error thrown.

Actual Results

Traceback (most recent call last):
  File "C:\Users\Mavs\Documents\Python\ATOM\venv311\Lib\site-packages\IPython\core\interactiveshell.py", line 3553, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-2-135dae0c8613>", line 10, in <module>
    MultiOutputClassifier(LinearDiscriminantAnalysis()).fit(X, y)
  File "C:\Users\Mavs\Documents\Python\ATOM\venv311\Lib\site-packages\sklearn\multioutput.py", line 535, in fit
    super().fit(X, Y, sample_weight=sample_weight, **fit_params)
  File "C:\Users\Mavs\Documents\Python\ATOM\venv311\Lib\site-packages\sklearn\base.py", line 1351, in wrapper
    return fit_method(estimator, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Mavs\Documents\Python\ATOM\venv311\Lib\site-packages\sklearn\multioutput.py", line 251, in fit
    routed_params = process_routing(
                    ^^^^^^^^^^^^^^^^
  File "C:\Users\Mavs\Documents\Python\ATOM\venv311\Lib\site-packages\sklearn\utils\_metadata_requests.py", line 1556, in process_routing
    request_routing.validate_metadata(params=kwargs, method=_method)
  File "C:\Users\Mavs\Documents\Python\ATOM\venv311\Lib\site-packages\sklearn\utils\_metadata_requests.py", line 1060, in validate_metadata
    raise TypeError(
TypeError: MultiOutputClassifier.fit got unexpected argument(s) {'sample_weight'}, which are not requested metadata in any object.

Versions

System:
    python: 3.11.2 (tags/v3.11.2:878ead1, Feb  7 2023, 16:38:35) [MSC v.1934 64 bit (AMD64)]
executable: C:\Users\Mavs\Documents\Python\ATOM\venv311\Scripts\python.exe
   machine: Windows-10-10.0.19045-SP0
Python dependencies:
      sklearn: 1.4.0
          pip: 23.3.2
   setuptools: 68.2.2
        numpy: 1.26.3
        scipy: 1.11.4
       Cython: 3.0.5
       pandas: 2.1.4
   matplotlib: 3.8.2
       joblib: 1.3.2
threadpoolctl: 3.2.0
Built with OpenMP: True
threadpoolctl info:
       user_api: openmp
   internal_api: openmp
    num_threads: 16
         prefix: vcomp
       filepath: C:\Users\Mavs\Documents\Python\ATOM\venv311\Lib\site-packages\sklearn\.libs\vcomp140.dll
        version: None
       user_api: blas
   internal_api: openblas
    num_threads: 16
         prefix: libopenblas
       filepath: C:\Users\Mavs\Documents\Python\ATOM\venv311\Lib\site-packages\numpy.libs\libopenblas64__v0.3.23-293-gc2f4bdbb-gcc_10_3_0-2bde3a66a51006b2b53eb373ff767a3f.dll
        version: 0.3.23.dev
threading_layer: pthreads
   architecture: Zen
       user_api: blas
   internal_api: openblas
    num_threads: 16
         prefix: libopenblas
       filepath: C:\Users\Mavs\Documents\Python\ATOM\venv311\Lib\site-packages\scipy.libs\libopenblas_v0.3.20-571-g3dec11c6-gcc_10_3_0-c2315440d6b6cef5037bad648efc8c59.dll
        version: 0.3.21.dev
threading_layer: pthreads
   architecture: Zen

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions