Skip to content

[WIP] Make random_state accept np.random.Generator #23962

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

BenjaminBossan
Copy link
Contributor

@BenjaminBossan BenjaminBossan commented Jul 20, 2022

Reference Issues/PRs

Fixes #16988

What does this implement/fix? Explain your changes.

The random_state argument accepts numpy.random.Generator.

Any other comments?

Context

Update: Please see this comment.

This is WIP and I discussed with @thomasjpfan that it would make sense to share the current progress to evaluate if the scope is sufficiently small for a single PR or if we need to split it.

Done
  • Added tests for estimators
  • Made tests for estimators pass (reverted)
Missing
  • tests for splitters
  • tests for other components, e.g. for creating random datasets (this will be difficult because those components need to be called, which is not possible to do in a generic way, unlike for estimators)
  • documentation
  • docstrings
  • SeedSequence use for n_jobs>1 is probably out of scope
Implementation

One difficulty is that Generator has a slightly different API than the existing RandomState class, namely that creating integers now happens through the integers method, not randint. We (Thomas and I) discussed 3 different approaches to support Generators:

  1. Use an adapter with the API of RandomState

If check_random_state sees a Generator, it returns an adapter that supports the randint method with the old signature. This would be backwards compatible with all existing code but locks sklearn into the "old way". Also, the appearance of this new class could be surprising to users.

  1. Use an adapter with the API of Generator

If check_random_state sees a RandomState, it returns an adapter that supports the integers method with the old signature. This would be forwards compatible with the "new way". However, it requires changing all existing calls to randint and the appearance of this new class could be surprising to users.

  1. Using a utility function that knows how to deal with both objects

This is the way that scipy approached the problem. It also requires to change all the calls to randint but it's more transparent than solution 2. One disadvantage is that all other sampling functions are method calls on the object, only integers require this function, which can be surprising.

In the end, we decided to go with option 3. because we assume that it worked well for scipy and should thus also serve sklearn well.

Another decision that I made while working on the feature is not to change randint method calls where the object is known to be a RandomState. E.g. there are many tests that go like:

random_state = RandomState(...)
i = random_state.randint(...)

or

random_state = check_random_state(0)
i = random_state.randint(...)

Therefore, grepping through the repo for randint still reveals many direct calls, but unless I overlooked something, they should all be safe.

Caveats

It's almost impossible to have a complete test coverage for this feature. The reason is that even though we check all estimators that support random_state, we don't know if the code path that actually uses random_state is being taken or not, since it might depend on hyper-parameters. A similar argument applies to splitters and other functions.

Done:

- Added tests for estimators
- Made tests for estimators pass

Missing

- splitters
- other components, e.g. for creating random datasets
- documentation
- docstrings

Caveats

It's almost impossible to have a complete test coverage for this
feature. The reason is that even though we check all estimators that
support random_state, we don't know if the code path that actually uses
random_state is being taken or not, since it might depend on
hyper-parameters.
@@ -21,6 +21,13 @@
from .._config import config_context, get_config
from ..externals._packaging.version import parse as parse_version

# below copied verbatim from scipy._lib._util.py to be used in rng_integers
try:
from numpy.random import Generator as Generator
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our minimum supported NumPy version is 1.17.3, so we can assume that Generators can be imported.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed that 👍

@@ -175,7 +176,7 @@
# NB: despite their names X_sparse_* are numpy arrays (and not sparse matrices)
X_sparse_pos = random_state.uniform(size=(20, 5))
X_sparse_pos[X_sparse_pos <= 0.8] = 0.0
y_random = random_state.randint(0, 4, size=(20,))
y_random = rng_integers(random_state, 0, 4, size=(20,))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make this PR smaller, I prefer to leave the test unchanged. Currently, the tests are always using a RandomState object.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -73,8 +74,8 @@ def make_sparse_random_data(n_samples, n_features, n_nonzeros, random_state=None
(
rng.randn(n_nonzeros),
(
rng.randint(n_samples, size=n_nonzeros),
rng.randint(n_features, size=n_nonzeros),
rng_integers(rng, n_samples, size=n_nonzeros),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here regarding not needing to change files in the main sklearn files.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

- Import can assume that Generator exists
- Revert rng_integers use where not necessary
Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The linting error and CI failure looks related to this PR.

@BenjaminBossan
Copy link
Contributor Author

The linting error and CI failure looks related to this PR.

If it's okay, I would address the linting problems later, before creating the non-draft PR.

@thomasjpfan
Copy link
Member

If it's okay, I would address the linting problems later, before creating the non-draft PR.

The CI does not fully run unless linting passes. This makes it harder to evaluate the PR even as a draft.

@BenjaminBossan
Copy link
Contributor Author

@thomasjpfan Regarding the failing tests, I think we have an interesting problem here. RandomState.randint returns int but Generator.integers returns np.int*. These tests directly check if the instance is int and thus fail. We could allow np.int*. But is it not also problematic that these estimators change the random_state attribute set by the user?

@thomasjpfan
Copy link
Member

This is WIP and I discussed with @thomasjpfan that it would make sense to share the current progress to evaluate if the scope is sufficiently small for a single PR or if we need to split it.

For me, it is about the scope. This PR turns on Generators on everywhere, which touches a lot of estimators all at once.

I prefer to incrementally turn on Generator support for each estimator and have a common test skip estimators that is not yet supported. This way we can be deliberate about using Generator specific features such as SeedSequences or dtype support in Generator.standard_normal (RandomState.standard_normal does not have dtype support).

- Remove rng_integers call where not necessary
- Isinstance check for random ints also accepts np.int_
- Black formatting
@BenjaminBossan
Copy link
Contributor Author

These tests directly check if the instance is int and thus fail.

I changed the tests to also accept np.int_, LMK if that's not the preferred solution.

I prefer to incrementally turn on Generator support for each estimator and have a common test skip estimators that is not yet supported.

This is certainly feasible. A disadvantage would be that random_state, which so far has been a very standardized argument, suddenly works differently for different classes and functions, which would be very surprising for the user. In a sense, even if we allow Generators everywhere, it is already opt-in since the user has to explicitly pass a Generator for it to be used. IMHO activating it on in steps would make sense if we're afraid that something breaks or if we want to later change the behavior when using Generators (as discussed could be possible for SeedSequence, which we considered adding in an accompanying PR).

@BenjaminBossan
Copy link
Contributor Author

There are still problems stemming from the integer dtype, e.g. here:

https://dev.azure.com/scikit-learn/scikit-learn/_build/results?buildId=44874&view=logs&j=aabdcdc3-bb64-5414-b357-ed024fe8659e&t=b7b3ba55-d585-563b-a032-f235636c22b0&l=1574

I believe those go back to the problem I mentioned earlier:

RandomState.randint returns int but Generator.integers returns np.int*

There are some possible solutions to that problem but I'm not sure which one to take.

@thomasjpfan
Copy link
Member

thomasjpfan commented Jul 21, 2022

A disadvantage would be that random_state, which so far has been a very standardized argument, suddenly works differently for different classes and functions, which would be very surprising for the user.

I am okay with that as long as we document which estimator supports generators in their docstrings. We can incrementally update the docstrings of random_state docs as we add support for Generators.

In a sense, even if we allow Generators everywhere, it is already opt-in since the user has to explicitly pass a Generator for it to be used.

I'm thinking more about estimators opting into to Generator support and not about user opt-in. Let's say I want MDS to opt into generators then we change the parameter constraints to:

    "random_state": ["random_state", np.random.Generator],

and include it in the common test. During review, we can look at MDS's code to make sure estimator is configured in a way that actually uses the random state.

With this PR turning Generator support everywhere, it is confirm that all estimators is configured to actually use the generator. For me, this makes it harder to review.

believe those go back to the problem I mentioned earlier:

Pass dtype=int to rng_integers to match the default dtype for RandomState.randint?

We decided to opt in estimators (and all the rest) step by step into
using Generators. Therefore, I reverted all the changes in the actual
estimators that were necessary to accomodate Generators, which comes
down to the use of the rng_integers for now.

The common test has been adjusted to have a long list of excluded
estimators -- currently containing all estimators -- that are skipped
for testing. The idea is that if a new PR comes along that opts an
estimator in, it should be as easy as crossing that estimator off the
list to be able to check if it still works.

Note for future developers. The "random_state" variable is sometimes
also referred to "rng" or "rnd" (and maybe others that I missed), so a
simple grep for "random_state" is not enough.
@BenjaminBossan
Copy link
Contributor Author

BenjaminBossan commented Jul 26, 2022

Updated status

After discussion with maintainers, we decided that estimators, splitters, and other functions should be opted in step by step into allowing Generators as random_state. Therefore, I reverted all changes required to make estimators work with Generators, which came down to removing the usage of rng_integers and some resulting dtype checks.

Currently, this PR thus contains tests for estimators that check if they can be fitted with Generators, as well as if they can call predict, predict_proba, decision_function, or transform (if they have those methods). There is also a big list of estimators excluded from these tests, which, at the moment, contains all estimators.

Guide how to opt an estimator into allowing Generators

[WIP]

I think we should provide a guide for others to opt an estimator in, additionally to what the standard steps for an sklearn PR. OTOH:

  1. Remove the estimator from the _estimators_excluded_from_check_random_state list in sklearn/tests/test_common.py
  2. Update the estimator's docstring to include Generators as a possible type for random_state. TODO we should provide a standard text here.
  3. Check the actual code being run to see if, and under what circumstances, the random state is being used. This is important to actually cover those code paths in the tests, otherwise they could be missed. If the standard test_check_random_state_type does not cover that code path, add a specific test for that estimator in that estimator's test module.
  4. Allow Generators in the parameter constraints, e.g. by setting _parameter_constraints = {..., "random_state": ["random_state", np.random.Generator]}.
  5. If a method on random_state is used whose API has changed, e.g. RandomState.randint, use a compatibility function that supports both new and old methods. For randint, it is already provided as rng_integers, thus the change would be:
from sklearn.utils.fixes import rng_integers
...
- i = random_state.randint(*args **kwargs)
+ i = rng_integers(random_state, *args, **kwargs) 

(with *args, **kwargs replaced by the actual arguments being used)

  1. If parallelism is being used, make use of SeedSequence (docs); how exactly is yet to be determined.

Tip: When grepping for the use of random_state, note that the variable is also sometimes referred to rng, rs, rnd and maybe other names.

Please let me know if the steps should be updated and how to proceed for splitters and other functions using random state.

No actual changes to how the code works, since KBinsDiscretizer only
uses the 'choice' method, which is backwards compatible.
@BenjaminBossan
Copy link
Contributor Author

BenjaminBossan commented Jul 29, 2022

TODOs

Here is a list of classes and functions I could find that use a random_state argument. This includes estimators that currently don't have a random_state argument. LMK if I missed something and if this list should be put somewhere else:

Estimators

  • sklearn.cluster._affinity_propagation.AffinityPropagation
  • sklearn.cluster._bicluster.SpectralBiclustering
  • sklearn.cluster._bicluster.SpectralCoclustering
  • sklearn.cluster._bisect_k_means.BisectingKMeans
  • sklearn.cluster._kmeans.KMeans
  • sklearn.cluster._kmeans.MiniBatchKMeans
  • sklearn.cluster._spectral.SpectralClustering
  • sklearn.covariance._elliptic_envelope.EllipticEnvelope
  • sklearn.covariance._robust_covariance.MinCovDet
  • sklearn.decomposition._dict_learning.DictionaryLearning
  • sklearn.decomposition._dict_learning.MiniBatchDictionaryLearning
  • sklearn.decomposition._factor_analysis.FactorAnalysis
  • sklearn.decomposition._fastica.FastICA
  • sklearn.decomposition._kernel_pca.KernelPCA
  • sklearn.decomposition._lda.LatentDirichletAllocation
  • sklearn.decomposition._nmf.MiniBatchNMF
  • sklearn.decomposition._nmf.NMF
  • sklearn.decomposition._pca.PCA
  • sklearn.decomposition._sparse_pca.MiniBatchSparsePCA
  • sklearn.decomposition._sparse_pca.SparsePCA
  • sklearn.decomposition._truncated_svd.TruncatedSVD
  • sklearn.dummy.DummyClassifier
  • sklearn.ensemble._bagging.BaggingClassifier
  • sklearn.ensemble._bagging.BaggingRegressor
  • sklearn.ensemble._forest.ExtraTreesClassifier
  • sklearn.ensemble._forest.ExtraTreesRegressor
  • sklearn.ensemble._forest.RandomForestClassifier
  • sklearn.ensemble._forest.RandomForestRegressor
  • sklearn.ensemble._forest.RandomTreesEmbedding
  • sklearn.ensemble._gb.GradientBoostingClassifier
  • sklearn.ensemble._gb.GradientBoostingRegressor
  • sklearn.ensemble._hist_gradient_boosting.gradient_boosting.HistGradientBoostingClassifier
  • sklearn.ensemble._hist_gradient_boosting.gradient_boosting.HistGradientBoostingRegressor
  • sklearn.ensemble._iforest.IsolationForest
  • sklearn.ensemble._weight_boosting.AdaBoostClassifier
  • sklearn.ensemble._weight_boosting.AdaBoostRegressor
  • sklearn.feature_extraction.image.PatchExtractor
  • sklearn.gaussian_process._gpc.GaussianProcessClassifier
  • sklearn.gaussian_process._gpr.GaussianProcessRegressor
  • sklearn.impute._iterative.IterativeImputer
  • sklearn.kernel_approximation.Nystroem
  • sklearn.kernel_approximation.PolynomialCountSketch
  • sklearn.kernel_approximation.RBFSampler
  • sklearn.kernel_approximation.SkewedChi2Sampler
  • sklearn.linear_model._coordinate_descent.ElasticNet
  • sklearn.linear_model._coordinate_descent.ElasticNetCV
  • sklearn.linear_model._coordinate_descent.Lasso
  • sklearn.linear_model._coordinate_descent.LassoCV
  • sklearn.linear_model._coordinate_descent.MultiTaskElasticNet
  • sklearn.linear_model._coordinate_descent.MultiTaskElasticNetCV
  • sklearn.linear_model._coordinate_descent.MultiTaskLasso
  • sklearn.linear_model._coordinate_descent.MultiTaskLassoCV
  • sklearn.linear_model._least_angle.Lars
  • sklearn.linear_model._least_angle.LassoLars
  • sklearn.linear_model._logistic.LogisticRegression
  • sklearn.linear_model._logistic.LogisticRegressionCV
  • sklearn.linear_model._passive_aggressive.PassiveAggressiveClassifier
  • sklearn.linear_model._passive_aggressive.PassiveAggressiveRegressor
  • sklearn.linear_model._perceptron.Perceptron
  • sklearn.linear_model._ransac.RANSACRegressor
  • sklearn.linear_model._ridge.Ridge
  • sklearn.linear_model._ridge.RidgeClassifier
  • sklearn.linear_model._stochastic_gradient.SGDClassifier
  • sklearn.linear_model._stochastic_gradient.SGDOneClassSVM
  • sklearn.linear_model._stochastic_gradient.SGDRegressor
  • sklearn.linear_model._theil_sen.TheilSenRegressor
  • sklearn.manifold._locally_linear.LocallyLinearEmbedding
  • sklearn.manifold._mds.MDS
  • sklearn.manifold._spectral_embedding.SpectralEmbedding
  • sklearn.manifold._t_sne.TSNE
  • sklearn.mixture._bayesian_mixture.BayesianGaussianMixture
  • sklearn.mixture._gaussian_mixture.GaussianMixture
  • sklearn.multiclass.OutputCodeClassifier
  • sklearn.multioutput.ClassifierChain
  • sklearn.multioutput.RegressorChain
  • sklearn.neighbors._nca.NeighborhoodComponentsAnalysis
  • sklearn.neural_network._multilayer_perceptron.MLPClassifier
  • sklearn.neural_network._multilayer_perceptron.MLPRegressor
  • sklearn.neural_network._rbm.BernoulliRBM
  • sklearn.preprocessing._data.QuantileTransformer
  • sklearn.preprocessing._discretization.KBinsDiscretizer
  • sklearn.random_projection.GaussianRandomProjection
  • sklearn.random_projection.SparseRandomProjection
  • sklearn.svm._classes.LinearSVC
  • sklearn.svm._classes.LinearSVR
  • sklearn.svm._classes.NuSVC
  • sklearn.svm._classes.SVC
  • sklearn.tree._classes.DecisionTreeClassifier
  • sklearn.tree._classes.DecisionTreeRegressor
  • sklearn.tree._classes.ExtraTreeClassifier
  • sklearn.tree._classes.ExtraTreeRegressor

Splitters

  • sklearn.model_selection.GroupShuffleSplit
  • sklearn.model_selection.KFold
  • sklearn.model_selection.RepeatedKFold
  • sklearn.model_selection.RepeatedStratifiedKFold
  • sklearn.model_selection.ShuffleSplit
  • sklearn.model_selection.StratifiedGroupKFold
  • sklearn.model_selection.StratifiedKFold
  • sklearn.model_selection.StratifiedShuffleSplit

Rest

  • sklearn.covariance._robust_covariance.fast_mcd
  • sklearn.datasets._base.load_files
  • sklearn.datasets._covtype.fetch_covtype
  • sklearn.datasets._kddcup99.fetch_kddcup99
  • sklearn.datasets._olivetti_faces.fetch_olivetti_faces
  • sklearn.datasets._rcv1.fetch_rcv1
  • sklearn.datasets._samples_generator.make_biclusters
  • sklearn.datasets._samples_generator.make_blobs
  • sklearn.datasets._samples_generator.make_checkerboard
  • sklearn.datasets._samples_generator.make_circles
  • sklearn.datasets._samples_generator.make_classification
  • sklearn.datasets._samples_generator.make_friedman1
  • sklearn.datasets._samples_generator.make_friedman2
  • sklearn.datasets._samples_generator.make_friedman3
  • sklearn.datasets._samples_generator.make_gaussian_quantiles
  • sklearn.datasets._samples_generator.make_hastie_10_2
  • sklearn.datasets._samples_generator.make_low_rank_matrix
  • sklearn.datasets._samples_generator.make_moons
  • sklearn.datasets._samples_generator.make_multilabel_classification
  • sklearn.datasets._samples_generator.make_regression
  • sklearn.datasets._samples_generator.make_s_curve
  • sklearn.datasets._samples_generator.make_sparse_coded_signal
  • sklearn.datasets._samples_generator.make_sparse_spd_matrix
  • sklearn.datasets._samples_generator.make_sparse_uncorrelated
  • sklearn.datasets._samples_generator.make_spd_matrix
  • sklearn.datasets._samples_generator.make_swiss_roll
  • sklearn.datasets._twenty_newsgroups.fetch_20newsgroups
  • sklearn.decomposition._dict_learning.dict_learning
  • sklearn.decomposition._dict_learning.dict_learning_online
  • sklearn.decomposition._fastica.fastica
  • sklearn.decomposition._nmf.non_negative_factorization
  • sklearn.linear_model._ridge.ridge_regression
  • sklearn.metrics.cluster._unsupervised.silhouette_score
  • sklearn.model_selection._search.ParameterSampler
  • sklearn.model_selection._split.train_test_split
  • sklearn.model_selection._validation.learning_curve
  • sklearn.model_selection._validation.permutation_test_score
  • sklearn.preprocessing._data.quantile_transform
  • sklearn.utils.extmath.randomized_svd
  • sklearn.utils.resample
  • sklearn.utils.shuffle

Documentation

Besides the individual docstrings of the classes/functions mentioned above, the documentation should be adjusted here:

@BenjaminBossan
Copy link
Contributor Author

@thomasjpfan as discussed, I changed the PR to only test a single estimator to decrease the review burden. That estimator is KBinsDiscretizer, which was very easy to test.

As for the updated docstring, for now I went with this very simple change:

- random_state : int, RandomState instance or None, default=None
+ random_state : int, RandomState/Generator instance or None, default=None

The reason is that this line is already quite long and from what I can tell, it is not desired to have very long lines for parameter types (or line breaks for that matter). The body itself has not been altered. LMK if we want to do that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support numpy.random.Generator and/or BitGenerator for random number generation
2 participants