Skip to content

ENH Replaced RandomState with Generator compatible calls #22271

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sklearn/cluster/_affinity_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def affinity_propagation(
# Remove degeneracies
S += (
np.finfo(S.dtype).eps * S + np.finfo(S.dtype).tiny * 100
) * random_state.randn(n_samples, n_samples)
) * random_state.standard_normal(size=(n_samples, n_samples))

# Execute parallel affinity propagation updates
e = np.zeros((n_samples, convergence_iter))
Expand Down
2 changes: 1 addition & 1 deletion sklearn/cluster/_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def _kmeans_plusplus(X, n_clusters, x_squared_norms, random_state, n_local_trial
for c in range(1, n_clusters):
# Choose center candidates by sampling with probability proportional
# to the squared distance to the closest existing center
rand_vals = random_state.random_sample(n_local_trials) * current_pot
rand_vals = random_state.uniform(size=n_local_trials) * current_pot
candidate_ids = np.searchsorted(stable_cumsum(closest_dist_sq), rand_vals)
# XXX: numerical imprecision can result in a candidate_id out of range
np.clip(candidate_ids, None, closest_dist_sq.size - 1, out=candidate_ids)
Expand Down
30 changes: 18 additions & 12 deletions sklearn/datasets/_samples_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def make_classification(
centroids *= generator.rand(1, n_informative)

# Initially draw informative features from the standard normal
X[:, :n_informative] = generator.randn(n_samples, n_informative)
X[:, :n_informative] = generator.standard_normal(size=(n_samples, n_informative))

# Create each cluster; a variant of make_blobs
stop = 0
Expand Down Expand Up @@ -259,7 +259,7 @@ def make_classification(

# Fill useless features
if n_useless > 0:
X[:, -n_useless:] = generator.randn(n_samples, n_useless)
X[:, -n_useless:] = generator.standard_normal(size=(n_samples, n_useless))

# Randomly replace labels
if flip_y >= 0.0:
Expand Down Expand Up @@ -595,7 +595,7 @@ def make_regression(

if effective_rank is None:
# Randomly generate a well conditioned input set
X = generator.randn(n_samples, n_features)
X = generator.standard_normal(size=(n_samples, n_features))

else:
# Randomly generate a low rank, fat tail input set
Expand Down Expand Up @@ -1022,7 +1022,7 @@ def make_friedman1(n_samples=100, n_features=10, *, noise=0.0, random_state=None
+ 20 * (X[:, 2] - 0.5) ** 2
+ 10 * X[:, 3]
+ 5 * X[:, 4]
+ noise * generator.randn(n_samples)
+ noise * generator.standard_normal(size=(n_samples))
)

return X, y
Expand Down Expand Up @@ -1088,7 +1088,7 @@ def make_friedman2(n_samples=100, *, noise=0.0, random_state=None):

y = (
X[:, 0] ** 2 + (X[:, 1] * X[:, 2] - 1 / (X[:, 1] * X[:, 3])) ** 2
) ** 0.5 + noise * generator.randn(n_samples)
) ** 0.5 + noise * generator.standard_normal(size=(n_samples))

return X, y

Expand Down Expand Up @@ -1153,7 +1153,7 @@ def make_friedman3(n_samples=100, *, noise=0.0, random_state=None):

y = np.arctan(
(X[:, 1] * X[:, 2] - 1 / (X[:, 1] * X[:, 3])) / X[:, 0]
) + noise * generator.randn(n_samples)
) + noise * generator.standard_normal(size=(n_samples))

return X, y

Expand Down Expand Up @@ -1218,9 +1218,15 @@ def make_low_rank_matrix(
n = min(n_samples, n_features)

# Random (ortho normal) vectors
u, _ = linalg.qr(generator.randn(n_samples, n), mode="economic", check_finite=False)
u, _ = linalg.qr(
generator.standard_normal(size=(n_samples, n)),
mode="economic",
check_finite=False,
)
v, _ = linalg.qr(
generator.randn(n_features, n), mode="economic", check_finite=False
generator.standard_normal(size=(n_features, n)),
mode="economic",
check_finite=False,
)

# Index of the singular values
Expand Down Expand Up @@ -1280,7 +1286,7 @@ def make_sparse_coded_signal(
generator = check_random_state(random_state)

# generate dictionary
D = generator.randn(n_features, n_components)
D = generator.standard_normal(size=(n_features, n_components))
D /= np.sqrt(np.sum((D ** 2), axis=0))

# generate code
Expand All @@ -1289,7 +1295,7 @@ def make_sparse_coded_signal(
idx = np.arange(n_components)
generator.shuffle(idx)
idx = idx[:n_nonzero_coefs]
X[idx, i] = generator.randn(n_nonzero_coefs)
X[idx, i] = generator.standard_normal(size=n_nonzero_coefs)

# encode signal
Y = np.dot(D, X)
Expand Down Expand Up @@ -1519,7 +1525,7 @@ def make_swiss_roll(n_samples=100, *, noise=0.0, random_state=None, hole=False):
z = t * np.sin(t)

X = np.vstack((x, y, z))
X += noise * generator.randn(3, n_samples)
X += noise * generator.standard_normal(size=(3, n_samples))
X = X.T
t = np.squeeze(t)

Expand Down Expand Up @@ -1561,7 +1567,7 @@ def make_s_curve(n_samples=100, *, noise=0.0, random_state=None):
z = np.sign(t) * (np.cos(t) - 1)

X = np.concatenate((x, y, z))
X += noise * generator.randn(3, n_samples)
X += noise * generator.standard_normal(size=(3, n_samples))
X = X.T
t = np.squeeze(t)

Expand Down
12 changes: 8 additions & 4 deletions sklearn/decomposition/_nmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,12 @@ def _initialize_nmf(X, n_components, init=None, eps=1e-6, random_state=None):
if init == "random":
avg = np.sqrt(X.mean() / n_components)
rng = check_random_state(random_state)
H = avg * rng.randn(n_components, n_features).astype(X.dtype, copy=False)
W = avg * rng.randn(n_samples, n_components).astype(X.dtype, copy=False)
H = avg * rng.standard_normal(size=(n_components, n_features)).astype(
X.dtype, copy=False
)
W = avg * rng.standard_normal(size=(n_samples, n_components)).astype(
X.dtype, copy=False
)
np.abs(H, out=H)
np.abs(W, out=W)
return W, H
Expand Down Expand Up @@ -369,8 +373,8 @@ def _initialize_nmf(X, n_components, init=None, eps=1e-6, random_state=None):
elif init == "nndsvdar":
rng = check_random_state(random_state)
avg = X.mean()
W[W == 0] = abs(avg * rng.randn(len(W[W == 0])) / 100)
H[H == 0] = abs(avg * rng.randn(len(H[H == 0])) / 100)
W[W == 0] = abs(avg * rng.standard_normal(size=len(W[W == 0])) / 100)
H[H == 0] = abs(avg * rng.standard_normal(size=len(H[H == 0])) / 100)
else:
raise ValueError(
"Invalid init parameter: got %r instead of one of %r"
Expand Down
10 changes: 8 additions & 2 deletions sklearn/feature_selection/_mutual_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,12 +290,18 @@ def _estimate_mi(
X = X.astype(float, **_astype_copy_false(X))
means = np.maximum(1, np.mean(np.abs(X[:, continuous_mask]), axis=0))
X[:, continuous_mask] += (
1e-10 * means * rng.randn(n_samples, np.sum(continuous_mask))
1e-10
* means
* rng.standard_normal(size=(n_samples, np.sum(continuous_mask)))
)

if not discrete_target:
y = scale(y, with_mean=False)
y += 1e-10 * np.maximum(1, np.mean(np.abs(y))) * rng.randn(n_samples)
y += (
1e-10
* np.maximum(1, np.mean(np.abs(y)))
* rng.standard_normal(size=n_samples)
)

mi = [
_compute_mi(x, y, discrete_feature, discrete_target, n_neighbors)
Expand Down
6 changes: 4 additions & 2 deletions sklearn/manifold/_spectral_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def spectral_embedding(

M = ml.aspreconditioner()
# Create initial approximation X to eigenvectors
X = random_state.randn(laplacian.shape[0], n_components + 1)
X = random_state.standard_normal(size=(laplacian.shape[0], n_components + 1))
X[:, 0] = dd.ravel()
X = X.astype(laplacian.dtype)
_, diffusion_map = lobpcg(laplacian, X, M=M, tol=1.0e-5, largest=False)
Expand Down Expand Up @@ -367,7 +367,9 @@ def spectral_embedding(
# We increase the number of eigenvectors requested, as lobpcg
# doesn't behave well in low dimension and create initial
# approximation X to eigenvectors
X = random_state.randn(laplacian.shape[0], n_components + 1)
X = random_state.standard_normal(
size=(laplacian.shape[0], n_components + 1)
)
X[:, 0] = dd.ravel()
X = X.astype(laplacian.dtype)
_, diffusion_map = lobpcg(
Expand Down
6 changes: 3 additions & 3 deletions sklearn/manifold/_t_sne.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,9 +977,9 @@ def _fit(self, X, skip_num_points=0):
elif self._init == "random":
# The embedding is initialized with iid samples from Gaussians with
# standard deviation 1e-4.
X_embedded = 1e-4 * random_state.randn(n_samples, self.n_components).astype(
np.float32
)
X_embedded = 1e-4 * random_state.standard_normal(
size=(n_samples, self.n_components)
).astype(np.float32)
else:
raise ValueError("'init' must be 'pca', 'random', or a numpy array")

Expand Down
4 changes: 3 additions & 1 deletion sklearn/mixture/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,9 @@ def sample(self, n_samples=1):
else:
X = np.vstack(
[
mean + rng.randn(sample, n_features) * np.sqrt(covariance)
mean
+ rng.standard_normal(size=(sample, n_features))
* np.sqrt(covariance)
for (mean, covariance, sample) in zip(
self.means_, self.covariances_, n_samples_comp
)
Expand Down
2 changes: 1 addition & 1 deletion sklearn/multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,7 +1055,7 @@ def fit(self, X, y):

# FIXME: there are more elaborate methods than generating the codebook
# randomly.
self.code_book_ = random_state.random_sample((n_classes, code_size_))
self.code_book_ = random_state.uniform(size=(n_classes, code_size_))
self.code_book_[self.code_book_ > 0.5] = 1

if hasattr(self.estimator, "decision_function"):
Expand Down
4 changes: 3 additions & 1 deletion sklearn/neighbors/_nca.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,9 @@ def _initialize(self, X, y, init):
if init == "identity":
transformation = np.eye(n_components, X.shape[1])
elif init == "random":
transformation = self.random_state_.randn(n_components, X.shape[1])
transformation = self.random_state_.standard_normal(
size=(n_components, X.shape[1])
)
elif init in {"pca", "lda"}:
init_time = time.time()
if init == "pca":
Expand Down
4 changes: 2 additions & 2 deletions sklearn/neural_network/_rbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def _sample_hiddens(self, v, rng):
Values of the hidden layer.
"""
p = self._mean_hiddens(v)
return rng.random_sample(size=p.shape) < p
return rng.uniform(size=p.shape) < p

def _sample_visibles(self, h, rng):
"""Sample from the distribution P(v|h).
Expand All @@ -217,7 +217,7 @@ def _sample_visibles(self, h, rng):
p = np.dot(h, self.components_)
p += self.intercept_visible_
expit(p, out=p)
return rng.random_sample(size=p.shape) < p
return rng.uniform(size=p.shape) < p

def _free_energy(self, v):
"""Computes the free energy F(v) = - log sum_h exp(-E(v,h)).
Expand Down
2 changes: 1 addition & 1 deletion sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def check_supervised_y_no_nan(name, estimator_orig):
# Checks that the Estimator targets are not NaN.
estimator = clone(estimator_orig)
rng = np.random.RandomState(888)
X = rng.randn(10, 5)
X = rng.standard_normal(size=(10, 5))

for value in [np.nan, np.inf]:
y = np.full(10, value)
Expand Down