Skip to content

TunedThreasholdClassiffierCV not understanding func(y_pred, y_true, ...) as a valid scoring #31894

@adrinjalali

Description

@adrinjalali

This code

from sklearn.model_selection import TunedThresholdClassifierCV
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import make_classification
import sklearn
import numpy as np

sklearn.set_config(enable_metadata_routing=True)

def my_metric(y_true, y_pred, sample_weight=None):
    assert sample_weight is not None
    return np.mean(y_pred)

X, y = make_classification(random_state=0)
sample_weight = np.random.rand(len(y))

est = TunedThresholdClassifierCV(LogisticRegression(), cv=2, scoring=my_metric)
est.fit(X, y, sample_weight=sample_weight)

gives this:

Traceback (most recent call last):
  File "/tmp/2.py", line 17, in <module>
    est.fit(X, y, sample_weight=sample_weight)
  File "/path/to/scikit-learn/sklearn/base.py", line 1366, in wrapper
    return fit_method(estimator, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/scikit-learn/sklearn/model_selection/_classification_threshold.py", line 129, in fit
    self._fit(X, y, **params)
  File "/path/to/scikit-learn/sklearn/model_selection/_classification_threshold.py", line 742, in _fit
    routed_params = process_routing(self, "fit", **params)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/scikit-learn/sklearn/utils/_metadata_requests.py", line 1636, in process_routing
    request_routing = get_routing_for_object(_obj)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/scikit-learn/sklearn/utils/_metadata_requests.py", line 1197, in get_routing_for_object
    return deepcopy(obj.get_metadata_routing())
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/scikit-learn/sklearn/model_selection/_classification_threshold.py", line 871, in get_metadata_routing
    scorer=self._get_curve_scorer(),
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/scikit-learn/sklearn/model_selection/_classification_threshold.py", line 880, in _get_curve_scorer
    curve_scorer = _CurveScorer.from_scorer(
                   ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/scikit-learn/sklearn/metrics/_scorer.py", line 1108, in from_scorer
    score_func=scorer._score_func,
               ^^^^^^^^^^^^^^^^^^
AttributeError: 'function' object has no attribute '_score_func'

cc @glemaitre

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions