Skip to content

Commit 02d20c1

Browse files
OmarManzoorOmar Salmanadrinjalaliglemaitre
authored
ENH Add routing to LogisticRegressionCV (scikit-learn#26525)
Co-authored-by: Omar Salman <omar.salman@arbisoft> Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com> Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent b004f3a commit 02d20c1

File tree

11 files changed

+390
-60
lines changed

11 files changed

+390
-60
lines changed

doc/whats_new/v1.4.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,16 @@ Changelog
9494
- |Fix| :func:`feature_selection.mutual_info_regression` now correctly computes the
9595
result when `X` is of integer dtype. :pr:`26748` by :user:`Yao Xiao <Charlie-XIAO>`.
9696

97+
:mod:`sklearn.linear_model`
98+
...........................
99+
100+
- |Enhancement| :class:`linear_model.LogisticRegressionCV` now supports
101+
metadata routing. :meth:`linear_model.LogisticRegressionCV.fit` now
102+
accepts ``**params`` which are passed to the underlying splitter and
103+
scorer. :meth:`linear_model.LogisticRegressionCV.score` now accepts
104+
``**score_params`` which are passed to the underlying scorer.
105+
:pr:`26525` by :user:`Omar Salman <OmarManzoor>`.
106+
97107
:mod:`sklearn.pipeline`
98108
.......................
99109

sklearn/calibration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
from .utils.multiclass import check_classification_targets
5454
from .utils.parallel import Parallel, delayed
5555
from .utils.validation import (
56-
_check_fit_params,
56+
_check_method_params,
5757
_check_pos_label_consistency,
5858
_check_sample_weight,
5959
_num_samples,
@@ -612,7 +612,7 @@ def _fit_classifier_calibrator_pair(
612612
-------
613613
calibrated_classifier : _CalibratedClassifier instance
614614
"""
615-
fit_params_train = _check_fit_params(X, fit_params, train)
615+
fit_params_train = _check_method_params(X, params=fit_params, indices=train)
616616
X_train, y_train = _safe_indexing(X, train), _safe_indexing(y, train)
617617
X_test, y_test = _safe_indexing(X, test), _safe_indexing(y, test)
618618

sklearn/linear_model/_logistic.py

Lines changed: 135 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,28 @@
2727
from ..preprocessing import LabelBinarizer, LabelEncoder
2828
from ..svm._base import _fit_liblinear
2929
from ..utils import (
30+
Bunch,
3031
check_array,
3132
check_consistent_length,
3233
check_random_state,
3334
compute_class_weight,
3435
)
3536
from ..utils._param_validation import Interval, StrOptions
3637
from ..utils.extmath import row_norms, softmax
38+
from ..utils.metadata_routing import (
39+
MetadataRouter,
40+
MethodMapping,
41+
_routing_enabled,
42+
process_routing,
43+
)
3744
from ..utils.multiclass import check_classification_targets
3845
from ..utils.optimize import _check_optimize_result, _newton_cg
3946
from ..utils.parallel import Parallel, delayed
40-
from ..utils.validation import _check_sample_weight, check_is_fitted
47+
from ..utils.validation import (
48+
_check_method_params,
49+
_check_sample_weight,
50+
check_is_fitted,
51+
)
4152
from ._base import BaseEstimator, LinearClassifierMixin, SparseCoefMixin
4253
from ._glm.glm import NewtonCholeskySolver
4354
from ._linear_loss import LinearModelLoss
@@ -576,23 +587,25 @@ def _log_reg_scoring_path(
576587
y,
577588
train,
578589
test,
579-
pos_class=None,
580-
Cs=10,
581-
scoring=None,
582-
fit_intercept=False,
583-
max_iter=100,
584-
tol=1e-4,
585-
class_weight=None,
586-
verbose=0,
587-
solver="lbfgs",
588-
penalty="l2",
589-
dual=False,
590-
intercept_scaling=1.0,
591-
multi_class="auto",
592-
random_state=None,
593-
max_squared_sum=None,
594-
sample_weight=None,
595-
l1_ratio=None,
590+
*,
591+
pos_class,
592+
Cs,
593+
scoring,
594+
fit_intercept,
595+
max_iter,
596+
tol,
597+
class_weight,
598+
verbose,
599+
solver,
600+
penalty,
601+
dual,
602+
intercept_scaling,
603+
multi_class,
604+
random_state,
605+
max_squared_sum,
606+
sample_weight,
607+
l1_ratio,
608+
score_params,
596609
):
597610
"""Computes scores across logistic_regression_path
598611
@@ -704,6 +717,9 @@ def _log_reg_scoring_path(
704717
to using ``penalty='l1'``. For ``0 < l1_ratio <1``, the penalty is a
705718
combination of L1 and L2.
706719
720+
score_params : dict
721+
Parameters to pass to the `score` method of the underlying scorer.
722+
707723
Returns
708724
-------
709725
coefs : ndarray of shape (n_cs, n_features) or (n_cs, n_features + 1)
@@ -784,7 +800,9 @@ def _log_reg_scoring_path(
784800
if scoring is None:
785801
scores.append(log_reg.score(X_test, y_test))
786802
else:
787-
scores.append(scoring(log_reg, X_test, y_test))
803+
score_params = score_params or {}
804+
score_params = _check_method_params(X=X, params=score_params, indices=test)
805+
scores.append(scoring(log_reg, X_test, y_test, **score_params))
788806

789807
return coefs, Cs, np.array(scores), n_iter
790808

@@ -1747,7 +1765,7 @@ def __init__(
17471765
self.l1_ratios = l1_ratios
17481766

17491767
@_fit_context(prefer_skip_nested_validation=True)
1750-
def fit(self, X, y, sample_weight=None):
1768+
def fit(self, X, y, sample_weight=None, **params):
17511769
"""Fit the model according to the given training data.
17521770
17531771
Parameters
@@ -1763,11 +1781,22 @@ def fit(self, X, y, sample_weight=None):
17631781
Array of weights that are assigned to individual samples.
17641782
If not provided, then each sample is given unit weight.
17651783
1784+
**params : dict
1785+
Parameters to pass to the underlying splitter and scorer.
1786+
1787+
.. versionadded:: 1.4
1788+
17661789
Returns
17671790
-------
17681791
self : object
17691792
Fitted LogisticRegressionCV estimator.
17701793
"""
1794+
if params and not _routing_enabled():
1795+
raise ValueError(
1796+
"params is only supported if enable_metadata_routing=True."
1797+
" See the User Guide for more information."
1798+
)
1799+
17711800
solver = _check_solver(self.solver, self.penalty, self.dual)
17721801

17731802
if self.penalty == "elasticnet":
@@ -1829,9 +1858,23 @@ def fit(self, X, y, sample_weight=None):
18291858
else:
18301859
max_squared_sum = None
18311860

1861+
if _routing_enabled():
1862+
routed_params = process_routing(
1863+
obj=self,
1864+
method="fit",
1865+
sample_weight=sample_weight,
1866+
other_params=params,
1867+
)
1868+
else:
1869+
routed_params = Bunch()
1870+
routed_params.splitter = Bunch(split={})
1871+
routed_params.scorer = Bunch(score=params)
1872+
if sample_weight is not None:
1873+
routed_params.scorer.score["sample_weight"] = sample_weight
1874+
18321875
# init cross-validation generator
18331876
cv = check_cv(self.cv, y, classifier=True)
1834-
folds = list(cv.split(X, y))
1877+
folds = list(cv.split(X, y, **routed_params.splitter.split))
18351878

18361879
# Use the label encoded classes
18371880
n_classes = len(encoded_labels)
@@ -1898,6 +1941,7 @@ def fit(self, X, y, sample_weight=None):
18981941
max_squared_sum=max_squared_sum,
18991942
sample_weight=sample_weight,
19001943
l1_ratio=l1_ratio,
1944+
score_params=routed_params.scorer.score,
19011945
)
19021946
for label in iter_encoded_labels
19031947
for train, test in folds
@@ -2078,7 +2122,7 @@ def fit(self, X, y, sample_weight=None):
20782122

20792123
return self
20802124

2081-
def score(self, X, y, sample_weight=None):
2125+
def score(self, X, y, sample_weight=None, **score_params):
20822126
"""Score using the `scoring` option on the given test data and labels.
20832127
20842128
Parameters
@@ -2092,15 +2136,74 @@ def score(self, X, y, sample_weight=None):
20922136
sample_weight : array-like of shape (n_samples,), default=None
20932137
Sample weights.
20942138
2139+
**score_params : dict
2140+
Parameters to pass to the `score` method of the underlying scorer.
2141+
2142+
.. versionadded:: 1.4
2143+
20952144
Returns
20962145
-------
20972146
score : float
20982147
Score of self.predict(X) w.r.t. y.
20992148
"""
2100-
scoring = self.scoring or "accuracy"
2101-
scoring = get_scorer(scoring)
2149+
if score_params and not _routing_enabled():
2150+
raise ValueError(
2151+
"score_params is only supported if enable_metadata_routing=True."
2152+
" See the User Guide for more information."
2153+
" https://scikit-learn.org/stable/metadata_routing.html"
2154+
)
2155+
2156+
scoring = self._get_scorer()
2157+
if _routing_enabled():
2158+
routed_params = process_routing(
2159+
obj=self,
2160+
method="score",
2161+
sample_weight=sample_weight,
2162+
other_params=score_params,
2163+
)
2164+
else:
2165+
routed_params = Bunch()
2166+
routed_params.scorer = Bunch(score={})
2167+
if sample_weight is not None:
2168+
routed_params.scorer.score["sample_weight"] = sample_weight
2169+
2170+
return scoring(
2171+
self,
2172+
X,
2173+
y,
2174+
**routed_params.scorer.score,
2175+
)
21022176

2103-
return scoring(self, X, y, sample_weight=sample_weight)
2177+
def get_metadata_routing(self):
2178+
"""Get metadata routing of this object.
2179+
2180+
Please check :ref:`User Guide <metadata_routing>` on how the routing
2181+
mechanism works.
2182+
2183+
.. versionadded:: 1.4
2184+
2185+
Returns
2186+
-------
2187+
routing : MetadataRouter
2188+
A :class:`~utils.metadata_routing.MetadataRouter` encapsulating
2189+
routing information.
2190+
"""
2191+
2192+
router = (
2193+
MetadataRouter(owner=self.__class__.__name__)
2194+
.add_self_request(self)
2195+
.add(
2196+
splitter=self.cv,
2197+
method_mapping=MethodMapping().add(callee="split", caller="fit"),
2198+
)
2199+
.add(
2200+
scorer=self._get_scorer(),
2201+
method_mapping=MethodMapping()
2202+
.add(callee="score", caller="score")
2203+
.add(callee="score", caller="fit"),
2204+
)
2205+
)
2206+
return router
21042207

21052208
def _more_tags(self):
21062209
return {
@@ -2110,3 +2213,10 @@ def _more_tags(self):
21102213
),
21112214
}
21122215
}
2216+
2217+
def _get_scorer(self):
2218+
"""Get the scorer based on the scoring method specified.
2219+
The default scoring method is `accuracy`.
2220+
"""
2221+
scoring = self.scoring or "accuracy"
2222+
return get_scorer(scoring)

0 commit comments

Comments
 (0)