-
-
Notifications
You must be signed in to change notification settings - Fork 26.2k
Closed as not planned
Closed as not planned
Copy link
Labels
Description
This code
from sklearn.model_selection import TunedThresholdClassifierCV
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import make_classification
import sklearn
import numpy as np
sklearn.set_config(enable_metadata_routing=True)
def my_metric(y_true, y_pred, sample_weight=None):
assert sample_weight is not None
return np.mean(y_pred)
X, y = make_classification(random_state=0)
sample_weight = np.random.rand(len(y))
est = TunedThresholdClassifierCV(LogisticRegression(), cv=2, scoring=my_metric)
est.fit(X, y, sample_weight=sample_weight)
gives this:
Traceback (most recent call last):
File "/tmp/2.py", line 17, in <module>
est.fit(X, y, sample_weight=sample_weight)
File "/path/to/scikit-learn/sklearn/base.py", line 1366, in wrapper
return fit_method(estimator, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/path/to/scikit-learn/sklearn/model_selection/_classification_threshold.py", line 129, in fit
self._fit(X, y, **params)
File "/path/to/scikit-learn/sklearn/model_selection/_classification_threshold.py", line 742, in _fit
routed_params = process_routing(self, "fit", **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/path/to/scikit-learn/sklearn/utils/_metadata_requests.py", line 1636, in process_routing
request_routing = get_routing_for_object(_obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/path/to/scikit-learn/sklearn/utils/_metadata_requests.py", line 1197, in get_routing_for_object
return deepcopy(obj.get_metadata_routing())
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/path/to/scikit-learn/sklearn/model_selection/_classification_threshold.py", line 871, in get_metadata_routing
scorer=self._get_curve_scorer(),
^^^^^^^^^^^^^^^^^^^^^^^^
File "/path/to/scikit-learn/sklearn/model_selection/_classification_threshold.py", line 880, in _get_curve_scorer
curve_scorer = _CurveScorer.from_scorer(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/path/to/scikit-learn/sklearn/metrics/_scorer.py", line 1108, in from_scorer
score_func=scorer._score_func,
^^^^^^^^^^^^^^^^^^
AttributeError: 'function' object has no attribute '_score_func'
cc @glemaitre