Skip to content

MAINT validate parameters in LogisticRegression and LogisiticRegressionCV #23565

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
d77f962
Added _parameter_constraints for Logistic Regression Class
Jitensid Jun 7, 2022
8a2c0fd
Removed code for simple param valiadation of Logistic Regression
Jitensid Jun 7, 2022
1b0db86
Removed unncessary import
Jitensid Jun 8, 2022
2c825f7
Merge branch 'main' of github.com:scikit-learn/scikit-learn into logi…
Jitensid Jun 8, 2022
f5655c3
Final Commit
Jitensid Jun 8, 2022
e6de713
Removed tests checking error msg by simple param validation
Jitensid Jun 8, 2022
1a80c96
Merge branch 'main' of github.com:scikit-learn/scikit-learn into logi…
Jitensid Jun 8, 2022
7d2aba0
Changed value from 1e4 to 10000 of max_iter in ensemble test of logis…
Jitensid Jun 9, 2022
a48b142
Added None in intercept_scaling parameter
Jitensid Jun 9, 2022
fdd64a4
Merge branch 'main' of github.com:scikit-learn/scikit-learn into logi…
Jitensid Jun 10, 2022
93888f0
Changed default value from None to 1.0 for intercept_scaling
Jitensid Jun 11, 2022
7444477
Removed unnecessary test for intercept_scaling parameter
Jitensid Jun 11, 2022
ffef5ca
changed param valdiations and used logistic regression params in logi…
Jitensid Jun 11, 2022
571226b
Merge branch 'main' into logistic_regression_validate_params
Jitensid Jun 14, 2022
b233d9c
Merge branch 'main' of github.com:scikit-learn/scikit-learn into logi…
Jitensid Jun 14, 2022
9022ed4
Merge branch 'main' into logistic_regression_validate_params
Jitensid Jun 14, 2022
32a089a
Merge branch 'main' into logistic_regression_validate_params
Jitensid Jun 15, 2022
69f48de
Merge branch 'main' into logistic_regression_validate_params
Jitensid Jun 15, 2022
031786d
Merge branch 'main' into logistic_regression_validate_params
Jitensid Jun 16, 2022
355d09e
Merge branch 'main' into logistic_regression_validate_params
Jitensid Jun 21, 2022
1f39bff
Merge branch 'main' into logistic_regression_validate_params
Jitensid Jun 24, 2022
ae753cc
Merge branch 'main' of github.com:scikit-learn/scikit-learn into logi…
Jitensid Jun 24, 2022
5004f4d
Modified bool constraint to boolean for validation_params
Jitensid Jun 24, 2022
a03c415
Merge branch 'main' of github.com:scikit-learn/scikit-learn into logi…
Jitensid Jun 24, 2022
c103101
Merge branch 'logistic_regression_validate_params' of https://github.…
Jitensid Jun 24, 2022
9f7a7c7
fixes
jeremiedbb Jun 24, 2022
1e5425f
lint
jeremiedbb Jun 24, 2022
444c34c
nitpick
glemaitre Jun 24, 2022
70293c7
Using verbose constraint for the verbose param for the logistic regre…
Jitensid Jun 24, 2022
fa52fc2
Using verbose constraint for the verbose param for the logistic regre…
Jitensid Jun 24, 2022
71e6c97
Removed unnecessary test of test_multinomial_validation
Jitensid Jun 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sklearn/ensemble/tests/test_stacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def test_stacking_classifier_stratify_default():
# check that we stratify the classes for the default CV
clf = StackingClassifier(
estimators=[
("lr", LogisticRegression(max_iter=1e4)),
("lr", LogisticRegression(max_iter=10_000)),
("svm", LinearSVC(max_iter=1e4)),
]
)
Expand Down
110 changes: 53 additions & 57 deletions sklearn/linear_model/_logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,16 @@
# Arthur Mensch <arthur.mensch@m4x.org

import numbers
from numbers import Integral, Real
import warnings

import numpy as np
from scipy import optimize
from joblib import Parallel, effective_n_jobs
from collections.abc import Iterable

from sklearn.model_selection import BaseCrossValidator
from sklearn.metrics import get_scorer_names

from ._base import LinearClassifierMixin, SparseCoefMixin, BaseEstimator
from ._linear_loss import LinearModelLoss
Expand All @@ -31,6 +36,7 @@
from ..utils.validation import check_is_fitted, _check_sample_weight
from ..utils.multiclass import check_classification_targets
from ..utils.fixes import delayed
from ..utils._param_validation import StrOptions, Interval
from ..model_selection import check_cv
from ..metrics import get_scorer

Expand All @@ -43,19 +49,6 @@


def _check_solver(solver, penalty, dual):
all_solvers = ["liblinear", "newton-cg", "lbfgs", "sag", "saga"]
if solver not in all_solvers:
raise ValueError(
"Logistic Regression supports only solvers in %s, got %s."
% (all_solvers, solver)
)

all_penalties = ["l1", "l2", "elasticnet", "none"]
if penalty not in all_penalties:
raise ValueError(
"Logistic Regression supports only penalties in %s, got %s."
% (all_penalties, penalty)
)

if solver not in ["liblinear", "saga"] and penalty not in ("l2", "none"):
raise ValueError(
Expand Down Expand Up @@ -88,11 +81,6 @@ def _check_multi_class(multi_class, solver, n_classes):
multi_class = "multinomial"
else:
multi_class = "ovr"
if multi_class not in ("multinomial", "ovr"):
raise ValueError(
"multi_class should be 'multinomial', 'ovr' or 'auto'. Got %s."
% multi_class
)
if multi_class == "multinomial" and solver == "liblinear":
raise ValueError("Solver %s does not support a multinomial backend." % solver)
return multi_class
Expand Down Expand Up @@ -1023,6 +1011,24 @@ class LogisticRegression(LinearClassifierMixin, SparseCoefMixin, BaseEstimator):
0.97...
"""

_parameter_constraints = {
"penalty": [StrOptions({"l1", "l2", "elasticnet", "none"})],
"dual": ["boolean"],
"tol": [Interval(Real, 0, None, closed="left")],
"C": [Interval(Real, 0, None, closed="right")],
"fit_intercept": ["boolean"],
"intercept_scaling": [Interval(Real, 0, None, closed="neither")],
"class_weight": [dict, StrOptions({"balanced"}), None],
"random_state": ["random_state"],
"solver": [StrOptions({"newton-cg", "lbfgs", "liblinear", "sag", "saga"})],
"max_iter": [Interval(Integral, 0, None, closed="left")],
"multi_class": [StrOptions({"auto", "ovr", "multinomial"})],
"verbose": ["verbose"],
"warm_start": ["boolean"],
"n_jobs": [None, Integral],
"l1_ratio": [Interval(Real, 0, 1, closed="both"), None],
}

def __init__(
self,
penalty="l2",
Expand Down Expand Up @@ -1088,26 +1094,18 @@ def fit(self, X, y, sample_weight=None):
-----
The SAGA solver supports both float64 and float32 bit arrays.
"""

self._validate_params()

solver = _check_solver(self.solver, self.penalty, self.dual)

if not isinstance(self.C, numbers.Number) or self.C < 0:
raise ValueError("Penalty term must be positive; got (C=%r)" % self.C)
if self.penalty == "elasticnet":
if (
not isinstance(self.l1_ratio, numbers.Number)
or self.l1_ratio < 0
or self.l1_ratio > 1
):
raise ValueError(
"l1_ratio must be between 0 and 1; got (l1_ratio=%r)"
% self.l1_ratio
)
elif self.l1_ratio is not None:
if self.penalty != "elasticnet" and self.l1_ratio is not None:
warnings.warn(
"l1_ratio parameter is only used when penalty is "
"'elasticnet'. Got "
"(penalty={})".format(self.penalty)
)

if self.penalty == "none":
if self.C != 1.0: # default values
warnings.warn(
Expand All @@ -1119,16 +1117,6 @@ def fit(self, X, y, sample_weight=None):
else:
C_ = self.C
penalty = self.penalty
if not isinstance(self.max_iter, numbers.Number) or self.max_iter < 0:
raise ValueError(
"Maximum number of iteration must be positive; got (max_iter=%r)"
% self.max_iter
)
if not isinstance(self.tol, numbers.Number) or self.tol < 0:
raise ValueError(
"Tolerance for stopping criteria must be positive; got (tol=%r)"
% self.tol
)

if solver == "lbfgs":
_dtype = np.float64
Expand Down Expand Up @@ -1609,6 +1597,27 @@ class LogisticRegressionCV(LogisticRegression, LinearClassifierMixin, BaseEstima
0.98...
"""

_parameter_constraints = {**LogisticRegression._parameter_constraints}

for param in ["C", "warm_start", "l1_ratio"]:
_parameter_constraints.pop(param)

_parameter_constraints.update(
{
"Cs": [Interval(Integral, 1, None, closed="left"), "array-like"],
"cv": [
Interval(Integral, 2, None, closed="left"),
Iterable,
BaseCrossValidator,
None,
],
"scoring": [StrOptions(set(get_scorer_names())), callable, None],
"l1_ratios": ["array-like", None],
"refit": ["boolean"],
"penalty": [StrOptions({"l1", "l2", "elasticnet"})],
}
)

def __init__(
self,
*,
Expand Down Expand Up @@ -1669,18 +1678,11 @@ def fit(self, X, y, sample_weight=None):
self : object
Fitted LogisticRegressionCV estimator.
"""

self._validate_params()

solver = _check_solver(self.solver, self.penalty, self.dual)

if not isinstance(self.max_iter, numbers.Number) or self.max_iter < 0:
raise ValueError(
"Maximum number of iteration must be positive; got (max_iter=%r)"
% self.max_iter
)
if not isinstance(self.tol, numbers.Number) or self.tol < 0:
raise ValueError(
"Tolerance for stopping criteria must be positive; got (tol=%r)"
% self.tol
)
if self.penalty == "elasticnet":
if (
self.l1_ratios is None
Expand Down Expand Up @@ -1709,12 +1711,6 @@ def fit(self, X, y, sample_weight=None):

l1_ratios_ = [None]

if self.penalty == "none":
raise ValueError(
"penalty='none' is not useful and not supported by "
"LogisticRegressionCV."
)

X, y = self._validate_data(
X,
y,
Expand Down
132 changes: 11 additions & 121 deletions sklearn/linear_model/tests/test_logistic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import itertools
import os
import re
import numpy as np
from numpy.testing import assert_allclose, assert_almost_equal
from numpy.testing import assert_array_almost_equal, assert_array_equal
Expand Down Expand Up @@ -70,38 +69,6 @@ def test_predict_2_classes():
check_predictions(LogisticRegression(fit_intercept=False, random_state=0), X_sp, Y1)


def test_error():
# Test for appropriate exception on errors
msg = "Penalty term must be positive"

with pytest.raises(ValueError, match=msg):
LogisticRegression(C=-1).fit(X, Y1)

with pytest.raises(ValueError, match=msg):
LogisticRegression(C="test").fit(X, Y1)

msg = "is not a valid scoring value"
with pytest.raises(ValueError, match=msg):
LogisticRegressionCV(scoring="bad-scorer", cv=2).fit(X, Y1)

for LR in [LogisticRegression, LogisticRegressionCV]:
msg = "Tolerance for stopping criteria must be positive"

with pytest.raises(ValueError, match=msg):
LR(tol=-1).fit(X, Y1)

with pytest.raises(ValueError, match=msg):
LR(tol="test").fit(X, Y1)

msg = "Maximum number of iteration must be positive"

with pytest.raises(ValueError, match=msg):
LR(max_iter=-1).fit(X, Y1)

with pytest.raises(ValueError, match=msg):
LR(max_iter="test").fit(X, Y1)


def test_logistic_cv_mock_scorer():
class MockScorer:
def __init__(self):
Expand Down Expand Up @@ -193,31 +160,10 @@ def test_predict_iris():
assert np.mean(pred == target) > 0.95


@pytest.mark.parametrize("solver", ["lbfgs", "newton-cg", "sag", "saga"])
def test_multinomial_validation(solver):
lr = LogisticRegression(C=-1, solver=solver, multi_class="multinomial")

with pytest.raises(ValueError):
lr.fit([[0, 1], [1, 0]], [0, 1])


@pytest.mark.parametrize("LR", [LogisticRegression, LogisticRegressionCV])
def test_check_solver_option(LR):
X, y = iris.data, iris.target

msg = (
r"Logistic Regression supports only solvers in \['liblinear', "
r"'newton-cg', 'lbfgs', 'sag', 'saga'\], got wrong_name."
)
lr = LR(solver="wrong_name", multi_class="ovr")
with pytest.raises(ValueError, match=msg):
lr.fit(X, y)

msg = "multi_class should be 'multinomial', 'ovr' or 'auto'. Got wrong_name"
lr = LR(solver="newton-cg", multi_class="wrong_name")
with pytest.raises(ValueError, match=msg):
lr.fit(X, y)

# only 'liblinear' solver
msg = "Solver liblinear does not support a multinomial backend."
lr = LR(solver="liblinear", multi_class="multinomial")
Expand Down Expand Up @@ -248,10 +194,12 @@ def test_check_solver_option(LR):
lr.fit(X, y)

# liblinear does not support penalty='none'
msg = "penalty='none' is not supported for the liblinear solver"
lr = LR(penalty="none", solver="liblinear")
with pytest.raises(ValueError, match=msg):
lr.fit(X, y)
# (LogisticRegressionCV does not supports penalty='none' at all)
if LR is LogisticRegression:
msg = "penalty='none' is not supported for the liblinear solver"
lr = LR(penalty="none", solver="liblinear")
with pytest.raises(ValueError, match=msg):
lr.fit(X, y)


@pytest.mark.parametrize("solver", ["lbfgs", "newton-cg", "sag", "saga"])
Expand Down Expand Up @@ -1010,23 +958,6 @@ def test_saga_sparse():
clf.fit(sparse.csr_matrix(X), y)


def test_logreg_intercept_scaling():
# Test that the right error message is thrown when intercept_scaling <= 0

for i in [-1, 0]:
clf = LogisticRegression(
intercept_scaling=i, solver="liblinear", multi_class="ovr"
)
msg = (
"Intercept scaling is %r but needs to be greater than 0."
" To disable fitting an intercept,"
" set fit_intercept=False."
% clf.intercept_scaling
)
with pytest.raises(ValueError, match=msg):
clf.fit(X, Y1)


def test_logreg_intercept_scaling_zero():
# Test that intercept_scaling is ignored when fit_intercept is False

Expand Down Expand Up @@ -1706,49 +1637,13 @@ def test_LogisticRegressionCV_elasticnet_attribute_shapes():
assert lrcv.n_iter_.shape == (n_classes, n_folds, Cs.size, l1_ratios.size)


@pytest.mark.parametrize("l1_ratio", (-1, 2, None, "something_wrong"))
def test_l1_ratio_param(l1_ratio):

msg = r"l1_ratio must be between 0 and 1; got \(l1_ratio=%r\)" % l1_ratio
with pytest.raises(ValueError, match=msg):
LogisticRegression(penalty="elasticnet", solver="saga", l1_ratio=l1_ratio).fit(
X, Y1
)

if l1_ratio is not None:
msg = (
r"l1_ratio parameter is only used when penalty is"
r" 'elasticnet'\. Got \(penalty=l1\)"
)
with pytest.warns(UserWarning, match=msg):
LogisticRegression(penalty="l1", solver="saga", l1_ratio=l1_ratio).fit(
X, Y1
)


@pytest.mark.parametrize("l1_ratios", ([], [0.5, 2], None, "something_wrong"))
def test_l1_ratios_param(l1_ratios):

def test_l1_ratio_non_elasticnet():
msg = (
"l1_ratios must be a list of numbers between 0 and 1; got (l1_ratios=%r)"
% l1_ratios
r"l1_ratio parameter is only used when penalty is"
r" 'elasticnet'\. Got \(penalty=l1\)"
)

with pytest.raises(ValueError, match=re.escape(msg)):
LogisticRegressionCV(
penalty="elasticnet", solver="saga", l1_ratios=l1_ratios, cv=2
).fit(X, Y1)

if l1_ratios is not None:
msg = (
r"l1_ratios parameter is only used when penalty"
r" is 'elasticnet'. Got \(penalty=l1\)"
)
function = LogisticRegressionCV(
penalty="l1", solver="saga", l1_ratios=l1_ratios, cv=2
).fit
with pytest.warns(UserWarning, match=msg):
function(X, Y1)
with pytest.warns(UserWarning, match=msg):
LogisticRegression(penalty="l1", solver="saga", l1_ratio=0.5).fit(X, Y1)


@pytest.mark.parametrize("C", np.logspace(-3, 2, 4))
Expand Down Expand Up @@ -1896,11 +1791,6 @@ def test_penalty_none(solver):
pred_l2_C_inf = lr_l2_C_inf.fit(X, y).predict(X)
assert_array_equal(pred_none, pred_l2_C_inf)

lr = LogisticRegressionCV(penalty="none")
err_msg = "penalty='none' is not useful and not supported by LogisticRegressionCV"
with pytest.raises(ValueError, match=err_msg):
lr.fit(X, y)


@pytest.mark.parametrize(
"params",
Expand Down
2 changes: 1 addition & 1 deletion sklearn/svm/tests/test_bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_l1_min_c_l2_loss():
l1_min_c(dense_X, Y1, loss="l2")


def check_l1_min_c(X, y, loss, fit_intercept=True, intercept_scaling=None):
def check_l1_min_c(X, y, loss, fit_intercept=True, intercept_scaling=1.0):
min_c = l1_min_c(
X,
y,
Expand Down
Loading