27
27
from ..preprocessing import LabelBinarizer , LabelEncoder
28
28
from ..svm ._base import _fit_liblinear
29
29
from ..utils import (
30
+ Bunch ,
30
31
check_array ,
31
32
check_consistent_length ,
32
33
check_random_state ,
33
34
compute_class_weight ,
34
35
)
35
36
from ..utils ._param_validation import Interval , StrOptions
36
37
from ..utils .extmath import row_norms , softmax
38
+ from ..utils .metadata_routing import (
39
+ MetadataRouter ,
40
+ MethodMapping ,
41
+ _routing_enabled ,
42
+ process_routing ,
43
+ )
37
44
from ..utils .multiclass import check_classification_targets
38
45
from ..utils .optimize import _check_optimize_result , _newton_cg
39
46
from ..utils .parallel import Parallel , delayed
40
- from ..utils .validation import _check_sample_weight , check_is_fitted
47
+ from ..utils .validation import (
48
+ _check_method_params ,
49
+ _check_sample_weight ,
50
+ check_is_fitted ,
51
+ )
41
52
from ._base import BaseEstimator , LinearClassifierMixin , SparseCoefMixin
42
53
from ._glm .glm import NewtonCholeskySolver
43
54
from ._linear_loss import LinearModelLoss
@@ -576,23 +587,25 @@ def _log_reg_scoring_path(
576
587
y ,
577
588
train ,
578
589
test ,
579
- pos_class = None ,
580
- Cs = 10 ,
581
- scoring = None ,
582
- fit_intercept = False ,
583
- max_iter = 100 ,
584
- tol = 1e-4 ,
585
- class_weight = None ,
586
- verbose = 0 ,
587
- solver = "lbfgs" ,
588
- penalty = "l2" ,
589
- dual = False ,
590
- intercept_scaling = 1.0 ,
591
- multi_class = "auto" ,
592
- random_state = None ,
593
- max_squared_sum = None ,
594
- sample_weight = None ,
595
- l1_ratio = None ,
590
+ * ,
591
+ pos_class ,
592
+ Cs ,
593
+ scoring ,
594
+ fit_intercept ,
595
+ max_iter ,
596
+ tol ,
597
+ class_weight ,
598
+ verbose ,
599
+ solver ,
600
+ penalty ,
601
+ dual ,
602
+ intercept_scaling ,
603
+ multi_class ,
604
+ random_state ,
605
+ max_squared_sum ,
606
+ sample_weight ,
607
+ l1_ratio ,
608
+ score_params ,
596
609
):
597
610
"""Computes scores across logistic_regression_path
598
611
@@ -704,6 +717,9 @@ def _log_reg_scoring_path(
704
717
to using ``penalty='l1'``. For ``0 < l1_ratio <1``, the penalty is a
705
718
combination of L1 and L2.
706
719
720
+ score_params : dict
721
+ Parameters to pass to the `score` method of the underlying scorer.
722
+
707
723
Returns
708
724
-------
709
725
coefs : ndarray of shape (n_cs, n_features) or (n_cs, n_features + 1)
@@ -784,7 +800,9 @@ def _log_reg_scoring_path(
784
800
if scoring is None :
785
801
scores .append (log_reg .score (X_test , y_test ))
786
802
else :
787
- scores .append (scoring (log_reg , X_test , y_test ))
803
+ score_params = score_params or {}
804
+ score_params = _check_method_params (X = X , params = score_params , indices = test )
805
+ scores .append (scoring (log_reg , X_test , y_test , ** score_params ))
788
806
789
807
return coefs , Cs , np .array (scores ), n_iter
790
808
@@ -1747,7 +1765,7 @@ def __init__(
1747
1765
self .l1_ratios = l1_ratios
1748
1766
1749
1767
@_fit_context (prefer_skip_nested_validation = True )
1750
- def fit (self , X , y , sample_weight = None ):
1768
+ def fit (self , X , y , sample_weight = None , ** params ):
1751
1769
"""Fit the model according to the given training data.
1752
1770
1753
1771
Parameters
@@ -1763,11 +1781,22 @@ def fit(self, X, y, sample_weight=None):
1763
1781
Array of weights that are assigned to individual samples.
1764
1782
If not provided, then each sample is given unit weight.
1765
1783
1784
+ **params : dict
1785
+ Parameters to pass to the underlying splitter and scorer.
1786
+
1787
+ .. versionadded:: 1.4
1788
+
1766
1789
Returns
1767
1790
-------
1768
1791
self : object
1769
1792
Fitted LogisticRegressionCV estimator.
1770
1793
"""
1794
+ if params and not _routing_enabled ():
1795
+ raise ValueError (
1796
+ "params is only supported if enable_metadata_routing=True."
1797
+ " See the User Guide for more information."
1798
+ )
1799
+
1771
1800
solver = _check_solver (self .solver , self .penalty , self .dual )
1772
1801
1773
1802
if self .penalty == "elasticnet" :
@@ -1829,9 +1858,23 @@ def fit(self, X, y, sample_weight=None):
1829
1858
else :
1830
1859
max_squared_sum = None
1831
1860
1861
+ if _routing_enabled ():
1862
+ routed_params = process_routing (
1863
+ obj = self ,
1864
+ method = "fit" ,
1865
+ sample_weight = sample_weight ,
1866
+ other_params = params ,
1867
+ )
1868
+ else :
1869
+ routed_params = Bunch ()
1870
+ routed_params .splitter = Bunch (split = {})
1871
+ routed_params .scorer = Bunch (score = params )
1872
+ if sample_weight is not None :
1873
+ routed_params .scorer .score ["sample_weight" ] = sample_weight
1874
+
1832
1875
# init cross-validation generator
1833
1876
cv = check_cv (self .cv , y , classifier = True )
1834
- folds = list (cv .split (X , y ))
1877
+ folds = list (cv .split (X , y , ** routed_params . splitter . split ))
1835
1878
1836
1879
# Use the label encoded classes
1837
1880
n_classes = len (encoded_labels )
@@ -1898,6 +1941,7 @@ def fit(self, X, y, sample_weight=None):
1898
1941
max_squared_sum = max_squared_sum ,
1899
1942
sample_weight = sample_weight ,
1900
1943
l1_ratio = l1_ratio ,
1944
+ score_params = routed_params .scorer .score ,
1901
1945
)
1902
1946
for label in iter_encoded_labels
1903
1947
for train , test in folds
@@ -2078,7 +2122,7 @@ def fit(self, X, y, sample_weight=None):
2078
2122
2079
2123
return self
2080
2124
2081
- def score (self , X , y , sample_weight = None ):
2125
+ def score (self , X , y , sample_weight = None , ** score_params ):
2082
2126
"""Score using the `scoring` option on the given test data and labels.
2083
2127
2084
2128
Parameters
@@ -2092,15 +2136,74 @@ def score(self, X, y, sample_weight=None):
2092
2136
sample_weight : array-like of shape (n_samples,), default=None
2093
2137
Sample weights.
2094
2138
2139
+ **score_params : dict
2140
+ Parameters to pass to the `score` method of the underlying scorer.
2141
+
2142
+ .. versionadded:: 1.4
2143
+
2095
2144
Returns
2096
2145
-------
2097
2146
score : float
2098
2147
Score of self.predict(X) w.r.t. y.
2099
2148
"""
2100
- scoring = self .scoring or "accuracy"
2101
- scoring = get_scorer (scoring )
2149
+ if score_params and not _routing_enabled ():
2150
+ raise ValueError (
2151
+ "score_params is only supported if enable_metadata_routing=True."
2152
+ " See the User Guide for more information."
2153
+ " https://scikit-learn.org/stable/metadata_routing.html"
2154
+ )
2155
+
2156
+ scoring = self ._get_scorer ()
2157
+ if _routing_enabled ():
2158
+ routed_params = process_routing (
2159
+ obj = self ,
2160
+ method = "score" ,
2161
+ sample_weight = sample_weight ,
2162
+ other_params = score_params ,
2163
+ )
2164
+ else :
2165
+ routed_params = Bunch ()
2166
+ routed_params .scorer = Bunch (score = {})
2167
+ if sample_weight is not None :
2168
+ routed_params .scorer .score ["sample_weight" ] = sample_weight
2169
+
2170
+ return scoring (
2171
+ self ,
2172
+ X ,
2173
+ y ,
2174
+ ** routed_params .scorer .score ,
2175
+ )
2102
2176
2103
- return scoring (self , X , y , sample_weight = sample_weight )
2177
+ def get_metadata_routing (self ):
2178
+ """Get metadata routing of this object.
2179
+
2180
+ Please check :ref:`User Guide <metadata_routing>` on how the routing
2181
+ mechanism works.
2182
+
2183
+ .. versionadded:: 1.4
2184
+
2185
+ Returns
2186
+ -------
2187
+ routing : MetadataRouter
2188
+ A :class:`~utils.metadata_routing.MetadataRouter` encapsulating
2189
+ routing information.
2190
+ """
2191
+
2192
+ router = (
2193
+ MetadataRouter (owner = self .__class__ .__name__ )
2194
+ .add_self_request (self )
2195
+ .add (
2196
+ splitter = self .cv ,
2197
+ method_mapping = MethodMapping ().add (callee = "split" , caller = "fit" ),
2198
+ )
2199
+ .add (
2200
+ scorer = self ._get_scorer (),
2201
+ method_mapping = MethodMapping ()
2202
+ .add (callee = "score" , caller = "score" )
2203
+ .add (callee = "score" , caller = "fit" ),
2204
+ )
2205
+ )
2206
+ return router
2104
2207
2105
2208
def _more_tags (self ):
2106
2209
return {
@@ -2110,3 +2213,10 @@ def _more_tags(self):
2110
2213
),
2111
2214
}
2112
2215
}
2216
+
2217
+ def _get_scorer (self ):
2218
+ """Get the scorer based on the scoring method specified.
2219
+ The default scoring method is `accuracy`.
2220
+ """
2221
+ scoring = self .scoring or "accuracy"
2222
+ return get_scorer (scoring )
0 commit comments