Skip to content

Commit 285883c

Browse files
FIX make sure _PassthroughScorer works with meta-estimators (scikit-learn#31898)
Co-authored-by: Stefanie Senger <91849487+StefanieSenger@users.noreply.github.com>
1 parent b5c5130 commit 285883c

File tree

4 files changed

+62
-64
lines changed

4 files changed

+62
-64
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
- Fixed an issue where passing `sample_weight` to a :class:`Pipeline` inside a
2+
:class:`GridSearchCV` would raise an error with metadata routing enabled.
3+
By `Adrin Jalali`_.

sklearn/linear_model/_ridge.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
_rescale_data,
3131
)
3232
from sklearn.linear_model._sag import sag_solver
33-
from sklearn.metrics import check_scoring, get_scorer_names
33+
from sklearn.metrics import check_scoring, get_scorer, get_scorer_names
3434
from sklearn.model_selection import GridSearchCV
3535
from sklearn.preprocessing import LabelBinarizer
3636
from sklearn.utils import (
@@ -1359,6 +1359,12 @@ def __sklearn_tags__(self):
13591359
tags.classifier_tags.multi_label = True
13601360
return tags
13611361

1362+
def _get_scorer_instance(self):
1363+
"""Return a scorer which corresponds to what's defined in ClassiferMixin
1364+
parent class. This is used for routing `sample_weight`.
1365+
"""
1366+
return get_scorer("accuracy")
1367+
13621368

13631369
class RidgeClassifier(_RidgeClassifierMixin, _BaseRidge):
13641370
"""Classifier using Ridge regression.
@@ -2499,7 +2505,7 @@ def get_metadata_routing(self):
24992505
MetadataRouter(owner=self.__class__.__name__)
25002506
.add_self_request(self)
25012507
.add(
2502-
scorer=self.scoring,
2508+
scorer=self._get_scorer(),
25032509
method_mapping=MethodMapping().add(caller="fit", callee="score"),
25042510
)
25052511
.add(
@@ -2510,14 +2516,20 @@ def get_metadata_routing(self):
25102516
return router
25112517

25122518
def _get_scorer(self):
2513-
scorer = check_scoring(estimator=self, scoring=self.scoring, allow_none=True)
2519+
"""Make sure the scorer is weighted if necessary.
2520+
2521+
This uses `self._get_scorer_instance()` implemented in child objects to get the
2522+
raw scorer instance of the estimator, which will be ignored if `self.scoring` is
2523+
not None.
2524+
"""
25142525
if _routing_enabled() and self.scoring is None:
25152526
# This estimator passes an array of 1s as sample_weight even if
25162527
# sample_weight is not provided by the user. Therefore we need to
25172528
# always request it. But we don't set it if it's passed explicitly
25182529
# by the user.
2519-
scorer.set_score_request(sample_weight=True)
2520-
return scorer
2530+
return self._get_scorer_instance().set_score_request(sample_weight=True)
2531+
2532+
return check_scoring(estimator=self, scoring=self.scoring, allow_none=True)
25212533

25222534
def __sklearn_tags__(self):
25232535
tags = super().__sklearn_tags__()
@@ -2707,6 +2719,12 @@ def fit(self, X, y, sample_weight=None, **params):
27072719
super().fit(X, y, sample_weight=sample_weight, **params)
27082720
return self
27092721

2722+
def _get_scorer_instance(self):
2723+
"""Return a scorer which corresponds to what's defined in RegressorMixin
2724+
parent class. This is used for routing `sample_weight`.
2725+
"""
2726+
return get_scorer("r2")
2727+
27102728

27112729
class RidgeClassifierCV(_RidgeClassifierMixin, _BaseRidgeCV):
27122730
"""Ridge classifier with built-in cross-validation.

sklearn/metrics/_scorer.py

Lines changed: 1 addition & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -489,17 +489,6 @@ class _PassthroughScorer(_MetadataRequester):
489489
def __init__(self, estimator):
490490
self._estimator = estimator
491491

492-
requests = MetadataRequest(owner=self.__class__.__name__)
493-
try:
494-
requests.score = copy.deepcopy(estimator._metadata_request.score)
495-
except AttributeError:
496-
try:
497-
requests.score = copy.deepcopy(estimator._get_default_requests().score)
498-
except AttributeError:
499-
pass
500-
501-
self._metadata_request = requests
502-
503492
def __call__(self, estimator, *args, **kwargs):
504493
"""Method that wraps estimator.score"""
505494
return estimator.score(*args, **kwargs)
@@ -525,32 +514,7 @@ def get_metadata_routing(self):
525514
A :class:`~utils.metadata_routing.MetadataRouter` encapsulating
526515
routing information.
527516
"""
528-
return get_routing_for_object(self._metadata_request)
529-
530-
def set_score_request(self, **kwargs):
531-
"""Set requested parameters by the scorer.
532-
533-
Please see :ref:`User Guide <metadata_routing>` on how the routing
534-
mechanism works.
535-
536-
.. versionadded:: 1.5
537-
538-
Parameters
539-
----------
540-
kwargs : dict
541-
Arguments should be of the form ``param_name=alias``, and `alias`
542-
can be one of ``{True, False, None, str}``.
543-
"""
544-
if not _routing_enabled():
545-
raise RuntimeError(
546-
"This method is only available when metadata routing is enabled."
547-
" You can enable it using"
548-
" sklearn.set_config(enable_metadata_routing=True)."
549-
)
550-
551-
for param, alias in kwargs.items():
552-
self._metadata_request.score.add_request(param=param, alias=alias)
553-
return self
517+
return get_routing_for_object(self._estimator)
554518

555519

556520
def _check_multimetric_scoring(estimator, scoring):

sklearn/metrics/tests/test_score_objects.py

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
from sklearn.model_selection import GridSearchCV, cross_val_score, train_test_split
5353
from sklearn.multiclass import OneVsRestClassifier
5454
from sklearn.neighbors import KNeighborsClassifier
55-
from sklearn.pipeline import make_pipeline
55+
from sklearn.pipeline import Pipeline, make_pipeline
5656
from sklearn.svm import LinearSVC
5757
from sklearn.tests.metadata_routing_common import (
5858
assert_request_is_empty,
@@ -1301,37 +1301,27 @@ def test_metadata_kwarg_conflict():
13011301

13021302
@config_context(enable_metadata_routing=True)
13031303
def test_PassthroughScorer_set_score_request():
1304-
"""Test that _PassthroughScorer.set_score_request adds the correct metadata request
1305-
on itself and doesn't change its estimator's routing."""
1304+
"""Test that _PassthroughScorer.set_score_request raises when routing enabled."""
13061305
est = LogisticRegression().set_score_request(sample_weight="estimator_weights")
13071306
# make a `_PassthroughScorer` with `check_scoring`:
13081307
scorer = check_scoring(est, None)
1309-
assert (
1310-
scorer.get_metadata_routing().score.requests["sample_weight"]
1311-
== "estimator_weights"
1312-
)
1313-
1314-
scorer.set_score_request(sample_weight="scorer_weights")
1315-
assert (
1316-
scorer.get_metadata_routing().score.requests["sample_weight"]
1317-
== "scorer_weights"
1318-
)
1319-
1320-
# making sure changing the passthrough object doesn't affect the estimator.
1321-
assert (
1322-
est.get_metadata_routing().score.requests["sample_weight"]
1323-
== "estimator_weights"
1324-
)
1308+
with pytest.raises(
1309+
AttributeError,
1310+
match="'_PassthroughScorer' object has no attribute 'set_score_request'",
1311+
):
1312+
scorer.set_score_request(sample_weight=True)
13251313

13261314

13271315
def test_PassthroughScorer_set_score_request_raises_without_routing_enabled():
13281316
"""Test that _PassthroughScorer.set_score_request raises if metadata routing is
13291317
disabled."""
13301318
scorer = check_scoring(LogisticRegression(), None)
1331-
msg = "This method is only available when metadata routing is enabled."
13321319

1333-
with pytest.raises(RuntimeError, match=msg):
1334-
scorer.set_score_request(sample_weight="my_weights")
1320+
with pytest.raises(
1321+
AttributeError,
1322+
match="'_PassthroughScorer' object has no attribute 'set_score_request'",
1323+
):
1324+
scorer.set_score_request(sample_weight=True)
13351325

13361326

13371327
@config_context(enable_metadata_routing=True)
@@ -1673,3 +1663,26 @@ def test_make_scorer_reponse_method_default_warning():
16731663
with warnings.catch_warnings():
16741664
warnings.simplefilter("error", FutureWarning)
16751665
make_scorer(accuracy_score)
1666+
1667+
1668+
@config_context(enable_metadata_routing=True)
1669+
def test_Pipeline_in_PassthroughScorer():
1670+
"""Non-regression test for
1671+
https://github.com/scikit-learn/scikit-learn/issues/30937
1672+
1673+
Make sure pipeline inside a gridsearchcv works with sample_weight passed!
1674+
"""
1675+
X, y = make_classification(10, 4)
1676+
sample_weight = np.ones_like(y)
1677+
pipe = Pipeline(
1678+
[
1679+
(
1680+
"logistic",
1681+
LogisticRegression()
1682+
.set_fit_request(sample_weight=True)
1683+
.set_score_request(sample_weight=True),
1684+
)
1685+
]
1686+
)
1687+
search = GridSearchCV(pipe, {"logistic__C": [0.1, 1]}, n_jobs=1, cv=3)
1688+
search.fit(X, y, sample_weight=sample_weight)

0 commit comments

Comments
 (0)