Skip to content

API make a few estimators' init args kw-only #16474

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions sklearn/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .isotonic import IsotonicRegression
from .svm import LinearSVC
from .model_selection import check_cv
from .utils.validation import _deprecate_positional_args


class CalibratedClassifierCV(BaseEstimator, ClassifierMixin,
Expand Down Expand Up @@ -98,7 +99,8 @@ class CalibratedClassifierCV(BaseEstimator, ClassifierMixin,
.. [4] Predicting Good Probabilities with Supervised Learning,
A. Niculescu-Mizil & R. Caruana, ICML 2005
"""
def __init__(self, base_estimator=None, method='sigmoid', cv=None):
@_deprecate_positional_args
def __init__(self, base_estimator=None, *, method='sigmoid', cv=None):
self.base_estimator = base_estimator
self.method = method
self.cv = cv
Expand Down Expand Up @@ -275,7 +277,8 @@ class _CalibratedClassifier:
.. [4] Predicting Good Probabilities with Supervised Learning,
A. Niculescu-Mizil & R. Caruana, ICML 2005
"""
def __init__(self, base_estimator, method='sigmoid', classes=None):
@_deprecate_positional_args
def __init__(self, base_estimator, *, method='sigmoid', classes=None):
self.base_estimator = base_estimator
self.method = method
self.classes = classes
Expand Down
10 changes: 5 additions & 5 deletions sklearn/discriminant_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import warnings
import numpy as np
from .exceptions import ChangedBehaviorWarning
from scipy import linalg
from scipy.special import expit

Expand All @@ -24,6 +23,7 @@
from .utils.multiclass import check_classification_targets
from .utils.extmath import softmax
from .preprocessing import StandardScaler
from .utils.validation import _deprecate_positional_args


__all__ = ['LinearDiscriminantAnalysis', 'QuadraticDiscriminantAnalysis']
Expand Down Expand Up @@ -246,8 +246,8 @@ class LinearDiscriminantAnalysis(BaseEstimator, LinearClassifierMixin,
>>> print(clf.predict([[-0.8, -1]]))
[1]
"""

def __init__(self, solver='svd', shrinkage=None, priors=None,
@_deprecate_positional_args
def __init__(self, *, solver='svd', shrinkage=None, priors=None,
n_components=None, store_covariance=False, tol=1e-4):
self.solver = solver
self.shrinkage = shrinkage
Expand Down Expand Up @@ -618,8 +618,8 @@ class QuadraticDiscriminantAnalysis(ClassifierMixin, BaseEstimator):
sklearn.discriminant_analysis.LinearDiscriminantAnalysis: Linear
Discriminant Analysis
"""

def __init__(self, priors=None, reg_param=0., store_covariance=False,
@_deprecate_positional_args
def __init__(self, *, priors=None, reg_param=0., store_covariance=False,
tol=1.0e-4):
self.priors = np.asarray(priors) if priors is not None else None
self.reg_param = reg_param
Expand Down
10 changes: 5 additions & 5 deletions sklearn/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .utils.stats import _weighted_percentile
from .utils.multiclass import class_distribution
from .utils import deprecated

from .utils.validation import _deprecate_positional_args

class DummyClassifier(MultiOutputMixin, ClassifierMixin, BaseEstimator):
"""
Expand Down Expand Up @@ -98,8 +98,8 @@ class DummyClassifier(MultiOutputMixin, ClassifierMixin, BaseEstimator):
>>> dummy_clf.score(X, y)
0.75
"""

def __init__(self, strategy="warn", random_state=None,
@_deprecate_positional_args
def __init__(self, *, strategy="warn", random_state=None,
constant=None):
self.strategy = strategy
self.random_state = random_state
Expand Down Expand Up @@ -453,8 +453,8 @@ class DummyRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
>>> dummy_regr.score(X, y)
0.0
"""

def __init__(self, strategy="mean", constant=None, quantile=None):
@_deprecate_positional_args
def __init__(self, *, strategy="mean", constant=None, quantile=None):
self.strategy = strategy
self.constant = constant
self.quantile = quantile
Expand Down
10 changes: 6 additions & 4 deletions sklearn/isotonic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
import numpy as np
from scipy import interpolate
from scipy.stats import spearmanr
import warnings
import math

from .base import BaseEstimator, TransformerMixin, RegressorMixin
from .utils import check_array, check_consistent_length
from .utils.validation import _check_sample_weight
from .utils.validation import _check_sample_weight, _deprecate_positional_args
from ._isotonic import _inplace_contiguous_isotonic_regression, _make_unique
import warnings
import math


__all__ = ['check_increasing', 'isotonic_regression',
Expand Down Expand Up @@ -198,7 +199,8 @@ class IsotonicRegression(RegressorMixin, TransformerMixin, BaseEstimator):
>>> iso_reg.predict([.1, .2])
array([1.8628..., 3.7256...])
"""
def __init__(self, y_min=None, y_max=None, increasing=True,
@_deprecate_positional_args
def __init__(self, *, y_min=None, y_max=None, increasing=True,
out_of_bounds='nan'):
self.y_min = y_min
self.y_max = y_max
Expand Down
18 changes: 9 additions & 9 deletions sklearn/kernel_approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .utils.extmath import safe_sparse_dot
from .utils.validation import check_is_fitted
from .metrics.pairwise import pairwise_kernels, KERNEL_PARAMS
from .utils.validation import check_non_negative
from .utils.validation import check_non_negative, _deprecate_positional_args


class RBFSampler(TransformerMixin, BaseEstimator):
Expand Down Expand Up @@ -81,8 +81,8 @@ class RBFSampler(TransformerMixin, BaseEstimator):
Benjamin Recht.
(https://people.eecs.berkeley.edu/~brecht/papers/08.rah.rec.nips.pdf)
"""

def __init__(self, gamma=1., n_components=100, random_state=None):
@_deprecate_positional_args
def __init__(self, *, gamma=1., n_components=100, random_state=None):
self.gamma = gamma
self.n_components = n_components
self.random_state = random_state
Expand Down Expand Up @@ -187,8 +187,8 @@ class SkewedChi2Sampler(TransformerMixin, BaseEstimator):

sklearn.metrics.pairwise.chi2_kernel : The exact chi squared kernel.
"""

def __init__(self, skewedness=1., n_components=100, random_state=None):
@_deprecate_positional_args
def __init__(self, *, skewedness=1., n_components=100, random_state=None):
self.skewedness = skewedness
self.n_components = n_components
self.random_state = random_state
Expand Down Expand Up @@ -318,8 +318,8 @@ class AdditiveChi2Sampler(TransformerMixin, BaseEstimator):
A. Vedaldi and A. Zisserman, Pattern Analysis and Machine Intelligence,
2011
"""

def __init__(self, sample_steps=2, sample_interval=None):
@_deprecate_positional_args
def __init__(self, *, sample_steps=2, sample_interval=None):
self.sample_steps = sample_steps
self.sample_interval = sample_interval

Expand Down Expand Up @@ -534,8 +534,8 @@ class Nystroem(TransformerMixin, BaseEstimator):

sklearn.metrics.pairwise.kernel_metrics : List of built-in kernels.
"""

def __init__(self, kernel="rbf", gamma=None, coef0=None, degree=None,
@_deprecate_positional_args
def __init__(self, kernel="rbf", *, gamma=None, coef0=None, degree=None,
kernel_params=None, n_components=100, random_state=None):
self.kernel = kernel
self.gamma = gamma
Expand Down
2 changes: 1 addition & 1 deletion sklearn/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1290,7 +1290,7 @@ def inner_f(*args, **kwargs):
args_msg = ['{}={}'.format(name, arg)
for name, arg in zip(kwonly_args[:extra_args],
args[-extra_args:])]
warnings.warn("Pass {} as keyword args. From version 0.24 "
warnings.warn("Pass {} as keyword args. From version 0.25 "
"passing these as positional arguments will "
"result in an error".format(", ".join(args_msg)),
FutureWarning)
Expand Down