Skip to content

[WIP] replaced n_iter by max_iter and added deprecation #7761

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 25 additions & 16 deletions sklearn/linear_model/bayes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
from math import log
import numpy as np
from scipy import linalg
import warnings

from .base import LinearModel
from ..base import RegressorMixin
from ..utils.extmath import fast_logdet, pinvh
from ..utils import check_X_y
from ..utils import deprecated


###############################################################################
Expand All @@ -29,7 +31,7 @@ class BayesianRidge(LinearModel, RegressorMixin):

Parameters
----------
n_iter : int, optional
max_iter : int, optional
Maximum number of iterations. Default is 300.

tol : float, optional
Expand Down Expand Up @@ -102,7 +104,7 @@ class BayesianRidge(LinearModel, RegressorMixin):
... # doctest: +NORMALIZE_WHITESPACE
BayesianRidge(alpha_1=1e-06, alpha_2=1e-06, compute_score=False,
copy_X=True, fit_intercept=True, lambda_1=1e-06, lambda_2=1e-06,
n_iter=300, normalize=False, tol=0.001, verbose=False)
max_iter=300, normalize=False, tol=0.001, verbose=False)
>>> clf.predict([[1, 1]])
array([ 1.])

Expand All @@ -111,11 +113,14 @@ class BayesianRidge(LinearModel, RegressorMixin):
See examples/linear_model/plot_bayesian_ridge.py for an example.
"""

def __init__(self, n_iter=300, tol=1.e-3, alpha_1=1.e-6, alpha_2=1.e-6,
lambda_1=1.e-6, lambda_2=1.e-6, compute_score=False,
fit_intercept=True, normalize=False, copy_X=True,
verbose=False):
self.n_iter = n_iter
@deprecated("Attribute n_iter was deprecated. Use 'max_iter' instead")
def __init__(self, n_iter=None, max_iter=300, tol=1.e-3, alpha_1=1.e-6,
alpha_2=1.e-6, lambda_1=1.e-6, lambda_2=1.e-6,
compute_score=False, fit_intercept=True, normalize=False,
copy_X=True, verbose=False):
if n_iter is not None:
warnings.warn("'n_iter' was deprecated. Use 'max_iter' instead.")
self.max_iter = max_iter
self.tol = tol
self.alpha_1 = alpha_1
self.alpha_2 = alpha_2
Expand Down Expand Up @@ -164,7 +169,7 @@ def fit(self, X, y):
eigen_vals_ = S ** 2

# Convergence loop of the bayesian ridge regression
for iter_ in range(self.n_iter):
for iter_ in range(self.max_iter):

# Compute mu and sigma
# sigma_ = lambda_ / alpha_ * np.eye(n_features) + np.dot(X.T, X)
Expand Down Expand Up @@ -238,7 +243,7 @@ class ARDRegression(LinearModel, RegressorMixin):

Parameters
----------
n_iter : int, optional
max_iter : int, optional
Maximum number of iterations. Default is 300

tol : float, optional
Expand Down Expand Up @@ -315,7 +320,7 @@ class ARDRegression(LinearModel, RegressorMixin):
... # doctest: +NORMALIZE_WHITESPACE
ARDRegression(alpha_1=1e-06, alpha_2=1e-06, compute_score=False,
copy_X=True, fit_intercept=True, lambda_1=1e-06, lambda_2=1e-06,
n_iter=300, normalize=False, threshold_lambda=10000.0, tol=0.001,
max_iter=300, normalize=False, threshold_lambda=10000.0, tol=0.001,
verbose=False)
>>> clf.predict([[1, 1]])
array([ 1.])
Expand All @@ -325,11 +330,15 @@ class ARDRegression(LinearModel, RegressorMixin):
See examples/linear_model/plot_ard.py for an example.
"""

def __init__(self, n_iter=300, tol=1.e-3, alpha_1=1.e-6, alpha_2=1.e-6,
lambda_1=1.e-6, lambda_2=1.e-6, compute_score=False,
threshold_lambda=1.e+4, fit_intercept=True, normalize=False,
copy_X=True, verbose=False):
self.n_iter = n_iter
@deprecated("Attribute n_iter was deprecated. Use 'max_iter' instead")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need to remove this line (same in other classes)

def __init__(self, n_iter=None, max_iter=300, tol=1.e-3, alpha_1=1.e-6,
alpha_2=1.e-6, lambda_1=1.e-6, lambda_2=1.e-6,
compute_score=False, threshold_lambda=1.e+4,
fit_intercept=True, normalize=False, copy_X=True,
verbose=False):
if n_iter is not None:
warnings.warn("'n_iter' was deprecated. Use 'max_iter' instead.")
Copy link
Member

@TomDLT TomDLT Oct 28, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please raise a DeprecationWarning
You also need to use this value if provided, so that the result of `ARDRegression(n_iter=10).fit(...)``is not modify by this PR:

self.n_iter = n_iter
if n_iter is not None:
    warnings.warn(...)
    self.max_iter = n_iter
else:
    self.max_iter = max_iter

(same in other classes)

self.max_iter = max_iter
self.tol = tol
self.fit_intercept = fit_intercept
self.normalize = normalize
Expand Down Expand Up @@ -385,7 +394,7 @@ def fit(self, X, y):
coef_old_ = None

# Iterative procedure of ARDRegression
for iter_ in range(self.n_iter):
for iter_ in range(self.max_iter):
# Compute mu and sigma (using Woodbury matrix identity)
sigma_ = pinvh(np.eye(n_samples) / alpha_ +
np.dot(X[:, keep_lambda] *
Expand Down
29 changes: 17 additions & 12 deletions sklearn/manifold/t_sne.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
import scipy.sparse as sp
from scipy.spatial.distance import pdist
from scipy.spatial.distance import squareform
import warnings
from ..neighbors import BallTree
from ..base import BaseEstimator
from ..utils import check_array
from ..utils import check_random_state
from ..utils import deprecated
from ..utils.extmath import _ravel
from ..decomposition import PCA
from ..metrics.pairwise import pairwise_distances
Expand Down Expand Up @@ -295,7 +297,7 @@ def _kl_divergence_bh(params, P, neighbors, degrees_of_freedom, n_samples,
return error, grad


def _gradient_descent(objective, p0, it, n_iter, objective_error=None,
def _gradient_descent(objective, p0, it, max_iter, objective_error=None,
n_iter_check=1, n_iter_without_progress=50,
momentum=0.5, learning_rate=1000.0, min_gain=0.01,
min_grad_norm=1e-7, min_error_diff=1e-7, verbose=0,
Expand All @@ -317,7 +319,7 @@ def _gradient_descent(objective, p0, it, n_iter, objective_error=None,
Current number of iterations (this function will be called more than
once during the optimization).

n_iter : int
max_iter : int
Maximum number of gradient descent iterations.

n_iter_check : int
Expand Down Expand Up @@ -383,7 +385,7 @@ def _gradient_descent(objective, p0, it, n_iter, objective_error=None,
best_error = np.finfo(np.float).max
best_iter = 0

for i in range(it, n_iter):
for i in range(it, max_iter):
new_error, grad = objective(p, *args, **kwargs)
grad_norm = linalg.norm(grad)

Expand Down Expand Up @@ -541,7 +543,7 @@ class TSNE(BaseEstimator):
might be too high. If the cost function gets stuck in a bad local
minimum increasing the learning rate helps sometimes.

n_iter : int, optional (default: 1000)
max_iter : int, optional (default: 1000)
Maximum number of iterations for the optimization. Should be at
least 200.

Expand Down Expand Up @@ -644,21 +646,24 @@ class TSNE(BaseEstimator):
http://lvdmaaten.github.io/publications/papers/JMLR_2014.pdf
"""

@deprecated("Attribute n_iter was deprecated. Use 'max_iter' instead")
def __init__(self, n_components=2, perplexity=30.0,
early_exaggeration=4.0, learning_rate=1000.0, n_iter=1000,
n_iter_without_progress=30, min_grad_norm=1e-7,
early_exaggeration=4.0, learning_rate=1000.0, n_iter=None,
max_iter=1000, n_iter_without_progress=30, min_grad_norm=1e-7,
metric="euclidean", init="random", verbose=0,
random_state=None, method='barnes_hut', angle=0.5):
if not ((isinstance(init, string_types) and
init in ["pca", "random"]) or
isinstance(init, np.ndarray)):
msg = "'init' must be 'pca', 'random', or a numpy array"
raise ValueError(msg)
if n_iter is not None:
warnings.warn("'n_iter' was deprecated. Use 'max_iter' instead.")
self.n_components = n_components
self.perplexity = perplexity
self.early_exaggeration = early_exaggeration
self.learning_rate = learning_rate
self.n_iter = n_iter
self.max_iter = max_iter
self.n_iter_without_progress = n_iter_without_progress
self.min_grad_norm = min_grad_norm
self.metric = metric
Expand Down Expand Up @@ -711,8 +716,8 @@ def _fit(self, X, skip_num_points=0):
raise ValueError("early_exaggeration must be at least 1, but is "
"%f" % self.early_exaggeration)

if self.n_iter < 200:
raise ValueError("n_iter should be at least 200")
if self.max_iter < 200:
raise ValueError("max_iter should be at least 200")

if self.metric == "precomputed":
if isinstance(self.init, string_types) and self.init == 'pca':
Expand Down Expand Up @@ -806,7 +811,7 @@ def _tsne(self, P, degrees_of_freedom, n_samples, random_state,
self.n_components)
params = X_embedded.ravel()

opt_args = {"n_iter": 50, "momentum": 0.5, "it": 0,
opt_args = {"max_iter": 50, "momentum": 0.5, "it": 0,
"learning_rate": self.learning_rate,
"n_iter_without_progress": self.n_iter_without_progress,
"verbose": self.verbose, "n_iter_check": 25,
Expand Down Expand Up @@ -840,7 +845,7 @@ def _tsne(self, P, degrees_of_freedom, n_samples, random_state,

params, kl_divergence, it = _gradient_descent(obj_func, params,
**opt_args)
opt_args['n_iter'] = 100
opt_args['max_iter'] = 100
opt_args['momentum'] = 0.8
opt_args['it'] = it + 1
params, kl_divergence, it = _gradient_descent(obj_func, params,
Expand All @@ -853,7 +858,7 @@ def _tsne(self, P, degrees_of_freedom, n_samples, random_state,

# Final optimization
P /= self.early_exaggeration
opt_args['n_iter'] = self.n_iter
opt_args['max_iter'] = self.max_iter
opt_args['it'] = it + 1
params, error, it = _gradient_descent(obj_func, params, **opt_args)

Expand Down
18 changes: 9 additions & 9 deletions sklearn/manifold/tests/test_t_sne.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def flat_function(_):
sys.stdout = StringIO()
try:
_, error, it = _gradient_descent(
ObjectiveSmallGradient(), np.zeros(1), 0, n_iter=100,
ObjectiveSmallGradient(), np.zeros(1), 0, max_iter=100,
n_iter_without_progress=100, momentum=0.0, learning_rate=0.0,
min_gain=0.0, min_grad_norm=1e-5, min_error_diff=0.0, verbose=2)
finally:
Expand All @@ -63,7 +63,7 @@ def flat_function(_):
sys.stdout = StringIO()
try:
_, error, it = _gradient_descent(
ObjectiveSmallGradient(), np.zeros(1), 0, n_iter=100,
ObjectiveSmallGradient(), np.zeros(1), 0, max_iter=100,
n_iter_without_progress=100, momentum=0.0, learning_rate=0.0,
min_gain=0.0, min_grad_norm=0.0, min_error_diff=0.2, verbose=2)
finally:
Expand All @@ -79,7 +79,7 @@ def flat_function(_):
sys.stdout = StringIO()
try:
_, error, it = _gradient_descent(
flat_function, np.zeros(1), 0, n_iter=100,
flat_function, np.zeros(1), 0, max_iter=100,
n_iter_without_progress=10, momentum=0.0, learning_rate=0.0,
min_gain=0.0, min_grad_norm=0.0, min_error_diff=-1.0, verbose=2)
finally:
Expand All @@ -95,7 +95,7 @@ def flat_function(_):
sys.stdout = StringIO()
try:
_, error, it = _gradient_descent(
ObjectiveSmallGradient(), np.zeros(1), 0, n_iter=11,
ObjectiveSmallGradient(), np.zeros(1), 0, max_iter=11,
n_iter_without_progress=100, momentum=0.0, learning_rate=0.0,
min_gain=0.0, min_grad_norm=0.0, min_error_diff=0.0, verbose=2)
finally:
Expand Down Expand Up @@ -254,9 +254,9 @@ def test_optimization_minimizes_kl_divergence():
random_state = check_random_state(0)
X, _ = make_blobs(n_features=3, random_state=random_state)
kl_divergences = []
for n_iter in [200, 250, 300]:
for max_iter in [200, 250, 300]:
tsne = TSNE(n_components=2, perplexity=10, learning_rate=100.0,
n_iter=n_iter, random_state=0)
max_iter=max_iter, random_state=0)
tsne.fit_transform(X)
kl_divergences.append(tsne.kl_divergence_)
assert_less_equal(kl_divergences[1], kl_divergences[0])
Expand Down Expand Up @@ -297,8 +297,8 @@ def test_early_exaggeration_too_small():

def test_too_few_iterations():
# Number of gradient descent iterations must be at least 200.
tsne = TSNE(n_iter=199)
assert_raises_regexp(ValueError, "n_iter .*", tsne.fit_transform,
tsne = TSNE(max_iter=199)
assert_raises_regexp(ValueError, "max_iter .*", tsne.fit_transform,
np.array([[0.0]]))


Expand Down Expand Up @@ -469,7 +469,7 @@ def test_no_sparse_on_barnes_hut():
X = random_state.randn(100, 2)
X[(np.random.randint(0, 100, 50), np.random.randint(0, 2, 50))] = 0.0
X_csr = sp.csr_matrix(X)
tsne = TSNE(n_iter=199, method='barnes_hut')
tsne = TSNE(max_iter=199, method='barnes_hut')
assert_raises_regexp(TypeError, "A sparse matrix was.*",
tsne.fit_transform, X_csr)

Expand Down
16 changes: 11 additions & 5 deletions sklearn/neural_network/rbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@

import numpy as np
import scipy.sparse as sp
import warnings

from ..base import BaseEstimator
from ..base import TransformerMixin
from ..externals.six.moves import xrange
from ..utils import check_array
from ..utils import check_random_state
from ..utils import deprecated
from ..utils import gen_even_slices
from ..utils import issparse
from ..utils.extmath import safe_sparse_dot
Expand Down Expand Up @@ -51,7 +53,7 @@ class BernoulliRBM(BaseEstimator, TransformerMixin):
batch_size : int, optional
Number of examples per minibatch.

n_iter : int, optional
max_iter : int, optional
Number of iterations/sweeps over the training dataset to perform
during training.

Expand Down Expand Up @@ -83,7 +85,7 @@ class BernoulliRBM(BaseEstimator, TransformerMixin):
>>> X = np.array([[0, 0, 0], [0, 1, 1], [1, 0, 1], [1, 1, 1]])
>>> model = BernoulliRBM(n_components=2)
>>> model.fit(X)
BernoulliRBM(batch_size=10, learning_rate=0.1, n_components=2, n_iter=10,
BernoulliRBM(batch_size=10, learning_rate=0.1, n_components=2, max_iter=10,
random_state=None, verbose=0)

References
Expand All @@ -97,12 +99,16 @@ class BernoulliRBM(BaseEstimator, TransformerMixin):
Approximations to the Likelihood Gradient. International Conference
on Machine Learning (ICML) 2008
"""

@deprecated("Attribute n_iter was deprecated. Use 'max_iter' instead")
def __init__(self, n_components=256, learning_rate=0.1, batch_size=10,
n_iter=10, verbose=0, random_state=None):
n_iter=None, max_iter=10, verbose=0, random_state=None):
if n_iter is not None:
warnings.warn("'n_iter' was deprecated. Use 'max_iter' instead.")
self.n_components = n_components
self.learning_rate = learning_rate
self.batch_size = batch_size
self.n_iter = n_iter
self.max_iter = max_iter
self.verbose = verbose
self.random_state = random_state

Expand Down Expand Up @@ -350,7 +356,7 @@ def fit(self, X, y=None):
n_batches, n_samples))
verbose = self.verbose
begin = time.time()
for iteration in xrange(1, self.n_iter + 1):
for iteration in xrange(1, self.max_iter + 1):
for batch_slice in batch_slices:
self._fit(X[batch_slice], rng)

Expand Down
Loading