-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
ENH Adds typing support to LogisticRegression #17799
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
075d7e3
6b35ee1
e158440
4a925e1
6cb1e79
2930743
87f6b07
4c47776
7578b63
d8c85d0
2e67bc4
cf9ff0c
2535e9f
4af558a
db3dd08
f337c4c
d861af4
292c6ba
5484e42
751a578
fd94e96
ddb421a
5999a87
0cff327
7b6ad2c
e4ea4d8
cadb711
5330c53
e790ad8
74c58e1
8e0ee08
667518e
6ead295
c2b9a37
b6a5200
f2e73b4
e723ccc
c491c98
923ac16
8d2aa0e
2378769
67c4574
f05e47c
9f258ad
301eee1
e2334ec
7bbcec7
459c0bf
39d3fea
b43dd9c
63c0591
09cad30
64079ea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -9,9 +9,16 @@ | |||||||
# Lars Buitinck | ||||||||
# Simon Wu <s8wu@uwaterloo.ca> | ||||||||
# Arthur Mensch <arthur.mensch@m4x.org | ||||||||
from __future__ import annotations | ||||||||
import typing | ||||||||
|
||||||||
if typing.TYPE_CHECKING: | ||||||||
from typing_extensions import Literal | ||||||||
|
||||||||
import numbers | ||||||||
import warnings | ||||||||
from typing import Union | ||||||||
from typing import Optional | ||||||||
Comment on lines
+20
to
+21
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
|
||||||||
import numpy as np | ||||||||
from scipy import optimize | ||||||||
|
@@ -1028,22 +1035,22 @@ class LogisticRegression(LinearClassifierMixin, SparseCoefMixin, BaseEstimator): | |||||||
|
||||||||
def __init__( | ||||||||
self, | ||||||||
penalty="l2", | ||||||||
penalty: Literal["l1", "l2", "elasticnet", "none"] = "l2", | ||||||||
*, | ||||||||
dual=False, | ||||||||
tol=1e-4, | ||||||||
C=1.0, | ||||||||
fit_intercept=True, | ||||||||
intercept_scaling=1, | ||||||||
class_weight=None, | ||||||||
random_state=None, | ||||||||
solver="lbfgs", | ||||||||
max_iter=100, | ||||||||
multi_class="auto", | ||||||||
verbose=0, | ||||||||
warm_start=False, | ||||||||
n_jobs=None, | ||||||||
l1_ratio=None, | ||||||||
dual: bool = False, | ||||||||
tol: float = 1e-4, | ||||||||
C: float = 1.0, | ||||||||
fit_intercept: bool = True, | ||||||||
intercept_scaling: float = 1, | ||||||||
class_weight: Union[dict, Literal["balanced"]] = None, | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As annotations are not evaluated with Also, don't use the In addition, the allowed value
Suggested change
|
||||||||
random_state: Union[int, np.random.RandomState, None] = None, | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would probably be useful to define a
Suggested change
|
||||||||
solver: Literal["newton-cg", "lbfgs", "liblinear", "sag", "saga"] = "lbfgs", | ||||||||
max_iter: int = 100, | ||||||||
multi_class: Literal["auto", "ovr", "multinomial"] = "auto", | ||||||||
verbose: int = 0, | ||||||||
warm_start: bool = False, | ||||||||
n_jobs: Optional[int] = None, | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
l1_ratio: Optional[float] = None, | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
): | ||||||||
|
||||||||
self.penalty = penalty | ||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason why these are all on a separate line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a personal style choice so diffs are easier to look at. I also like:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True as well.
Unrelated: Does sklearn uses black?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are slowly working our way on deciding this: #11336
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this comment is relevant here: #11336 (comment)
If we go with adding typing information, then black is reasonable for its style choices.