Skip to content

Commit 24844a0

Browse files
authored
FIX make scorer.repr work with a partial score_func (scikit-learn#31891)
1 parent a9a7b7d commit 24844a0

File tree

3 files changed

+22
-1
lines changed

3 files changed

+22
-1
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
- `repr` on a scorer which has been created with a `partial` `score_func` now correctly
2+
works and uses the `repr` of the given `partial` object.
3+
By `Adrin Jalali`_.

sklearn/metrics/_scorer.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,14 @@ def _cached_call(cache, estimator, response_method, *args, **kwargs):
102102
return result
103103

104104

105+
def _get_func_repr_or_name(func):
106+
"""Returns the name of the function or repr of a partial."""
107+
if isinstance(func, partial):
108+
return repr(func)
109+
110+
return func.__name__
111+
112+
105113
class _MultimetricScorer:
106114
"""Callable for multimetric scoring used to avoid repeated calls
107115
to `predict_proba`, `predict`, and `decision_function`.
@@ -262,7 +270,7 @@ def __repr__(self):
262270
kwargs_string = "".join([f", {k}={v}" for k, v in self._kwargs.items()])
263271

264272
return (
265-
f"make_scorer({self._score_func.__name__}{sign_string}"
273+
f"make_scorer({_get_func_repr_or_name(self._score_func)}{sign_string}"
266274
f"{response_method_string}{kwargs_string})"
267275
)
268276

sklearn/metrics/tests/test_score_objects.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numbers
22
import pickle
3+
import re
34
import warnings
45
from copy import deepcopy
56
from functools import partial
@@ -218,6 +219,15 @@ def test_all_scorers_repr():
218219
repr(get_scorer(name))
219220

220221

222+
def test_repr_partial():
223+
metric = partial(precision_score, pos_label=1)
224+
scorer = make_scorer(metric)
225+
pattern = (
226+
"functools\\.partial\\(<function\\ precision_score\\ at\\ .*>,\\ pos_label=1\\)"
227+
)
228+
assert re.search(pattern, repr(scorer))
229+
230+
221231
def check_scoring_validator_for_single_metric_usecases(scoring_validator):
222232
# Test all branches of single metric usecases
223233
estimator = EstimatorWithFitAndScore()

0 commit comments

Comments
 (0)