diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 00000000000..3e0dd6b4f5b --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1 @@ +github: dhardy diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 00000000000..093433dffb3 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,20 @@ +--- +name: Bug report +about: Something doesn't work as expected +title: '' +labels: X-bug +assignees: '' + +--- + +## Summary + +A clear and concise description of what the bug is. + +What behaviour is expected, and why? + +## Code sample + +```rust +// Code demonstrating the problem +``` diff --git a/.github/workflows/benches.yml b/.github/workflows/benches.yml new file mode 100644 index 00000000000..368023aba9d --- /dev/null +++ b/.github/workflows/benches.yml @@ -0,0 +1,42 @@ +name: Benches + +on: + push: + branches: [ master ] + paths-ignore: + - "**.md" + - "examples/**" + pull_request: + branches: [ master ] + paths-ignore: + - "**.md" + - "examples/**" + +defaults: + run: + working-directory: ./benches + +jobs: + clippy-fmt: + name: "Benches: Check Clippy and rustfmt" + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: stable + components: clippy, rustfmt + - name: Rustfmt + run: cargo fmt -- --check + - name: Clippy + run: cargo clippy --all-targets -- -D warnings + benches: + name: "Benches: Test" + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: nightly + - name: Test + run: RUSTFLAGS=-Dwarnings cargo test --benches diff --git a/.github/workflows/gh-pages.yml b/.github/workflows/gh-pages.yml index 6c78ff56baf..1d83a77bd7f 100644 --- a/.github/workflows/gh-pages.yml +++ b/.github/workflows/gh-pages.yml @@ -30,12 +30,12 @@ jobs: RUSTDOCFLAGS: --cfg doc_cfg # --all builds all crates, but with default features for other crates (okay in this case) run: | - cargo doc --all --features nightly,serde1,getrandom,small_rng + cargo doc --all --all-features --no-deps cp utils/redirect.html target/doc/index.html rm target/doc/.lock - name: Setup Pages - uses: actions/configure-pages@v4 + uses: actions/configure-pages@v5 - name: Upload artifact uses: actions/upload-pages-artifact@v3 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9a4e860aec0..ad0cf1425cc 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,28 +1,54 @@ -name: Tests +name: Main tests on: push: branches: [ master, '0.[0-9]+' ] + paths-ignore: + - "**.md" + - "benches/**" pull_request: branches: [ master, '0.[0-9]+' ] + paths-ignore: + - "**.md" + - "benches/**" permissions: contents: read # to fetch code (actions/checkout) jobs: + clippy-fmt: + name: Check Clippy and rustfmt + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: stable + components: clippy, rustfmt + - name: Check Clippy + run: cargo clippy --workspace -- -D warnings + - name: Check rustfmt + run: cargo fmt --all -- --check + check-doc: name: Check doc runs-on: ubuntu-latest + env: + RUSTDOCFLAGS: "-Dwarnings --cfg docsrs -Zunstable-options --generate-link-to-definition" steps: - uses: actions/checkout@v4 - name: Install toolchain - uses: dtolnay/rust-toolchain@nightly - - run: cargo install cargo-deadlinks - - name: doc (rand) - env: - RUSTDOCFLAGS: --cfg doc_cfg - # --all builds all crates, but with default features for other crates (okay in this case) - run: cargo deadlinks --ignore-fragments -- --all --features nightly,serde1,getrandom,small_rng + uses: dtolnay/rust-toolchain@master + with: + toolchain: nightly + - name: rand + run: cargo doc --all-features --no-deps + - name: rand_core + run: cargo doc --all-features --package rand_core --no-deps + - name: rand_chacha + run: cargo doc --all-features --package rand_chacha --no-deps + - name: rand_pcg + run: cargo doc --all-features --package rand_pcg --no-deps test: runs-on: ${{ matrix.os }} @@ -47,7 +73,7 @@ jobs: - os: ubuntu-latest target: x86_64-unknown-linux-gnu variant: MSRV - toolchain: 1.60.0 + toolchain: 1.63.0 - os: ubuntu-latest deps: sudo apt-get update ; sudo apt install gcc-multilib target: i686-unknown-linux-gnu @@ -72,39 +98,30 @@ jobs: if: ${{ matrix.variant == 'minimal_versions' }} run: | cargo generate-lockfile -Z minimal-versions - # Overrides for dependencies with incorrect requirements (may need periodic updating) - cargo update -p regex --precise 1.5.1 - name: Maybe nightly if: ${{ matrix.toolchain == 'nightly' }} run: | cargo test --target ${{ matrix.target }} --features=nightly cargo test --target ${{ matrix.target }} --all-features - cargo test --target ${{ matrix.target }} --benches --features=small_rng,nightly - cargo test --target ${{ matrix.target }} --manifest-path rand_distr/Cargo.toml --benches cargo test --target ${{ matrix.target }} --lib --tests --no-default-features - name: Test rand run: | cargo test --target ${{ matrix.target }} --lib --tests --no-default-features - cargo build --target ${{ matrix.target }} --no-default-features --features alloc,getrandom,small_rng,unbiased - cargo test --target ${{ matrix.target }} --lib --tests --no-default-features --features=alloc,getrandom,small_rng + cargo build --target ${{ matrix.target }} --no-default-features --features alloc,os_rng,small_rng,unbiased + cargo test --target ${{ matrix.target }} --lib --tests --no-default-features --features=alloc,os_rng,small_rng cargo test --target ${{ matrix.target }} --examples - name: Test rand (all stable features) run: | - cargo test --target ${{ matrix.target }} --features=serde1,log,small_rng + cargo test --target ${{ matrix.target }} --features=serde,log,small_rng - name: Test rand_core run: | cargo test --target ${{ matrix.target }} --manifest-path rand_core/Cargo.toml cargo test --target ${{ matrix.target }} --manifest-path rand_core/Cargo.toml --no-default-features - cargo test --target ${{ matrix.target }} --manifest-path rand_core/Cargo.toml --no-default-features --features=alloc,getrandom - - name: Test rand_distr - run: | - cargo test --target ${{ matrix.target }} --manifest-path rand_distr/Cargo.toml --features=serde1 - cargo test --target ${{ matrix.target }} --manifest-path rand_distr/Cargo.toml --no-default-features - cargo test --target ${{ matrix.target }} --manifest-path rand_distr/Cargo.toml --no-default-features --features=std,std_math + cargo test --target ${{ matrix.target }} --manifest-path rand_core/Cargo.toml --no-default-features --features=os_rng - name: Test rand_pcg - run: cargo test --target ${{ matrix.target }} --manifest-path rand_pcg/Cargo.toml --features=serde1 + run: cargo test --target ${{ matrix.target }} --manifest-path rand_pcg/Cargo.toml --features=serde - name: Test rand_chacha - run: cargo test --target ${{ matrix.target }} --manifest-path rand_chacha/Cargo.toml + run: cargo test --target ${{ matrix.target }} --manifest-path rand_chacha/Cargo.toml --features=serde test-cross: runs-on: ${{ matrix.os }} @@ -133,11 +150,10 @@ jobs: - name: Test run: | # all stable features: - cross test --no-fail-fast --target ${{ matrix.target }} --features=serde1,log,small_rng + cross test --no-fail-fast --target ${{ matrix.target }} --features=serde,log,small_rng cross test --no-fail-fast --target ${{ matrix.target }} --examples cross test --no-fail-fast --target ${{ matrix.target }} --manifest-path rand_core/Cargo.toml - cross test --no-fail-fast --target ${{ matrix.target }} --manifest-path rand_distr/Cargo.toml --features=serde1 - cross test --no-fail-fast --target ${{ matrix.target }} --manifest-path rand_pcg/Cargo.toml --features=serde1 + cross test --no-fail-fast --target ${{ matrix.target }} --manifest-path rand_pcg/Cargo.toml --features=serde cross test --no-fail-fast --target ${{ matrix.target }} --manifest-path rand_chacha/Cargo.toml test-miri: @@ -154,10 +170,9 @@ jobs: cargo miri test --no-default-features --lib --tests cargo miri test --features=log,small_rng cargo miri test --manifest-path rand_core/Cargo.toml - cargo miri test --manifest-path rand_core/Cargo.toml --features=serde1 + cargo miri test --manifest-path rand_core/Cargo.toml --features=serde cargo miri test --manifest-path rand_core/Cargo.toml --no-default-features - #cargo miri test --manifest-path rand_distr/Cargo.toml # no unsafe and lots of slow tests - cargo miri test --manifest-path rand_pcg/Cargo.toml --features=serde1 + cargo miri test --manifest-path rand_pcg/Cargo.toml --features=serde cargo miri test --manifest-path rand_chacha/Cargo.toml --no-default-features test-no-std: diff --git a/CHANGELOG.md b/CHANGELOG.md index aaccde2ce79..891db26a9f9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,44 +8,125 @@ A [separate changelog is kept for rand_core](rand_core/CHANGELOG.md). You may also find the [Upgrade Guide](https://rust-random.github.io/book/update.html) useful. -## [0.9.0-alpha.0] - 2024-02-18 -This is a pre-release. To depend on this version, use `rand = "=0.9.0-alpha.0"` to prevent automatic updates (which can be expected to include breaking changes). +## [Unreleased] +### Deprecated +- Deprecate `rand::rngs::mock` module and `StepRng` generator (#1634) -### Generators -- Change `SmallRng::seed_from_u64` implementation (#1203) -- Replace `SeedableRng` impl for `SmallRng` with inherent methods, excluding `fn from_seed` (#1368) +## [0.9.1] - 2025-04-17 +### Security and unsafe +- Revise "not a crypto library" policy again (#1565) +- Remove `zerocopy` dependency from `rand` (#1579) -### Sequences -- Simpler and faster implementation of Floyd's F2 (#1277). This - changes some outputs from `rand::seq::index::sample` and - `rand::seq::SliceRandom::choose_multiple`. -- New, faster algorithms for `IteratorRandom::choose` and `choose_stable` (#1268) -- New, faster algorithms for `SliceRandom::shuffle` and `partial_shuffle` (#1272) -- Re-introduce `Rng::gen_iter` (#1305) +### Fixes +- Fix feature `simd_support` for recent nightly rust (#1586) + +### Changes +- Allow `fn rand::seq::index::sample_weighted` and `fn IndexedRandom::choose_multiple_weighted` to return fewer than `amount` results (#1623), reverting an undocumented change (#1382) to the previous release. + +### Additions +- Add `rand::distr::Alphabetic` distribution. (#1587) +- Re-export `rand_core` (#1604) + +## [0.9.0] - 2025-01-27 +### Security and unsafe +- Policy: "rand is not a crypto library" (#1514) +- Remove fork-protection from `ReseedingRng` and `ThreadRng`. Instead, it is recommended to call `ThreadRng::reseed` on fork. (#1379) +- Use `zerocopy` to replace some `unsafe` code (#1349, #1393, #1446, #1502) + +### Dependencies +- Bump the MSRV to 1.63.0 (#1207, #1246, #1269, #1341, #1416, #1536); note that 1.60.0 may work for dependents when using `--ignore-rust-version` +- Update to `rand_core` v0.9.0 (#1558) + +### Features +- Support `std` feature without `getrandom` or `rand_chacha` (#1354) +- Enable feature `small_rng` by default (#1455) +- Remove implicit feature `rand_chacha`; use `std_rng` instead. (#1473) +- Rename feature `serde1` to `serde` (#1477) +- Rename feature `getrandom` to `os_rng` (#1537) +- Add feature `thread_rng` (#1547) + +### API changes: rand_core traits +- Add fn `RngCore::read_adapter` implementing `std::io::Read` (#1267) +- Add trait `CryptoBlockRng: BlockRngCore`; make `trait CryptoRng: RngCore` (#1273) +- Add traits `TryRngCore`, `TryCryptoRng` (#1424, #1499) +- Rename `fn SeedableRng::from_rng` -> `try_from_rng` and add infallible variant `fn from_rng` (#1424) +- Rename `fn SeedableRng::from_entropy` -> `from_os_rng` and add fallible variant `fn try_from_os_rng` (#1424) +- Add bounds `Clone` and `AsRef` to associated type `SeedableRng::Seed` (#1491) + +### API changes: Rng trait and top-level fns +- Rename fn `rand::thread_rng()` to `rand::rng()` and remove from the prelude (#1506) +- Remove fn `rand::random()` from the prelude (#1506) +- Add top-level fns `random_iter`, `random_range`, `random_bool`, `random_ratio`, `fill` (#1488) +- Re-introduce fn `Rng::gen_iter` as `random_iter` (#1305, #1500) +- Rename fn `Rng::gen` to `random` to avoid conflict with the new `gen` keyword in Rust 2024 (#1438) +- Rename fns `Rng::gen_range` to `random_range`, `gen_bool` to `random_bool`, `gen_ratio` to `random_ratio` (#1505) +- Annotate panicking methods with `#[track_caller]` (#1442, #1447) + +### API changes: RNGs +- Fix `::Seed` size to 256 bits (#1455) +- Remove first parameter (`rng`) of `ReseedingRng::new` (#1533) + +### API changes: Sequences - Split trait `SliceRandom` into `IndexedRandom`, `IndexedMutRandom`, `SliceRandom` (#1382) +- Add `IndexedRandom::choose_multiple_array`, `index::sample_array` (#1453, #1469) -### Distributions -- `{Uniform, UniformSampler}::{new, new_inclusive}` return a `Result` (instead of potentially panicking) (#1229) -- `Uniform` implements `TryFrom` instead of `From` for ranges (#1229) -- `Uniform` now uses Canon's method (single sampling) / Lemire's method (distribution sampling) for faster sampling (breaks value stability; #1287) -- Relax `Sized` bound on `Distribution for &D` (#1278) -- Explicit impl of `sample_single_inclusive` (+~20% perf) (#1289) -- Impl `DistString` for `Slice` and `Uniform` (#1315) -- Let `Standard` support all `NonZero*` types (#1332) -- Add `trait Weight`, allowing `WeightedIndex` to trap overflow (#1353) -- Rename `WeightedError` to `WeightError`, revising variants (#1382) +### API changes: Distributions: renames +- Rename module `rand::distributions` to `rand::distr` (#1470) +- Rename distribution `Standard` to `StandardUniform` (#1526) +- Move `distr::Slice` -> `distr::slice::Choose`, `distr::EmptySlice` -> `distr::slice::Empty` (#1548) +- Rename trait `distr::DistString` -> `distr::SampleString` (#1548) +- Rename `distr::DistIter` -> `distr::Iter`, `distr::DistMap` -> `distr::Map` (#1548) -### SIMD +### API changes: Distributions +- Relax `Sized` bound on `Distribution for &D` (#1278) +- Remove impl of `Distribution>` for `StandardUniform` (#1526) +- Let distribution `StandardUniform` support all `NonZero*` types (#1332) +- Fns `{Uniform, UniformSampler}::{new, new_inclusive}` return a `Result` (instead of potentially panicking) (#1229) +- Distribution `Uniform` implements `TryFrom` instead of `From` for ranges (#1229) +- Add `UniformUsize` (#1487) +- Remove support for generating `isize` and `usize` values with `StandardUniform`, `Uniform` (except via `UniformUsize`) and `Fill` and usage as a `WeightedAliasIndex` weight (#1487) +- Add impl `DistString` for distributions `Slice` and `Uniform` (#1315) +- Add fn `Slice::num_choices` (#1402) +- Add fn `p()` for distribution `Bernoulli` to access probability (#1481) + +### API changes: Weighted distributions +- Add `pub` module `rand::distr::weighted`, moving `WeightedIndex` there (#1548) +- Add trait `weighted::Weight`, allowing `WeightedIndex` to trap overflow (#1353) +- Add fns `weight, weights, total_weight` to distribution `WeightedIndex` (#1420) +- Rename enum `WeightedError` to `weighted::Error`, revising variants (#1382) and mark as `#[non_exhaustive]` (#1480) + +### API changes: SIMD - Switch to `std::simd`, expand SIMD & docs (#1239) -- Optimise SIMD widening multipy (#1247) + +### Reproducibility-breaking changes +- Make `ReseedingRng::reseed` discard remaining data from the last block generated (#1379) +- Change fn `SmallRng::seed_from_u64` implementation (#1203) +- Allow `UniformFloat::new` samples and `UniformFloat::sample_single` to yield `high` (#1462) +- Fix portability of distribution `Slice` (#1469) +- Make `Uniform` for `usize` portable via `UniformUsize` (#1487) +- Fix `IndexdRandom::choose_multiple_weighted` for very small seeds and optimize for large input length / low memory (#1530) + +### Reproducibility-breaking optimisations +- Optimize fn `sample_floyd`, affecting output of `rand::seq::index::sample` and `rand::seq::SliceRandom::choose_multiple` (#1277) +- New, faster algorithms for `IteratorRandom::choose` and `choose_stable` (#1268) +- New, faster algorithms for `SliceRandom::shuffle` and `partial_shuffle` (#1272) +- Optimize distribution `Uniform`: use Canon's method (single sampling) / Lemire's method (distribution sampling) for faster sampling (breaks value stability; #1287) +- Optimize fn `sample_single_inclusive` for floats (+~20% perf) (#1289) + +### Other optimisations +- Improve `SmallRng` initialization performance (#1482) +- Optimise SIMD widening multiply (#1247) ### Other -- Bump MSRV to 1.60.0 (#1207, #1246, #1269, #1341) -- Improve `thread_rng` related docs (#1257) - Add `Cargo.lock.msrv` file (#1275) +- Reformat with `rustfmt` and enforce (#1448) +- Apply Clippy suggestions and enforce (#1448, #1474) +- Move all benchmarks to new `benches` crate (#1329, #1439) and migrate to Criterion (#1490) + +### Documentation +- Improve `ThreadRng` related docs (#1257) - Docs: enable experimental `--generate-link-to-definition` feature (#1327) -- Use `zerocopy` to replace some `unsafe` code (#1349) -- Support `std` feature without `getrandom` or `rand_chacha` (#1354) +- Better doc of crate features, use `doc_auto_cfg` (#1411, #1450) ## [0.8.5] - 2021-08-20 ### Fixes diff --git a/Cargo.lock.msrv b/Cargo.lock.msrv index b173eb6d866..66921820c1e 100644 --- a/Cargo.lock.msrv +++ b/Cargo.lock.msrv @@ -3,20 +3,18 @@ version = 3 [[package]] -name = "anes" -version = "0.1.6" +name = "android-tzdata" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" [[package]] -name = "atty" -version = "0.2.14" +name = "android_system_properties" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" dependencies = [ - "hermit-abi", "libc", - "winapi", ] [[package]] @@ -27,15 +25,21 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "average" -version = "0.13.1" +version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "843ec791d3f24503bbf72bbd5e49a3ab4dbb4bcd0a8ef6b0c908efa73caa27b1" +checksum = "3a237a6822e1c3c98e700b6db5b293eb341b7524dcb8d227941245702b7431dc" dependencies = [ "easy-cast", "float-ord", "num-traits", ] +[[package]] +name = "base64" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" + [[package]] name = "bincode" version = "1.3.3" @@ -45,23 +49,17 @@ dependencies = [ "serde", ] -[[package]] -name = "bitflags" -version = "1.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" - [[package]] name = "bumpalo" -version = "3.11.1" +version = "3.15.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "572f695136211188308f16ad2ca5c851a712c464060ae6974944458eb83880ba" +checksum = "7ff69b9dd49fd426c69a0db9fc04dd934cdb6645ff000864d98f7e2af8830eaa" [[package]] -name = "cast" -version = "0.3.0" +name = "cc" +version = "1.0.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" +checksum = "8cd6604a82acf3039f1144f54b8eb34e91ffba622051189e71b781822d5ee1f5" [[package]] name = "cfg-if" @@ -70,146 +68,126 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] -name = "ciborium" -version = "0.2.0" +name = "chrono" +version = "0.4.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0c137568cc60b904a7724001b35ce2630fd00d5d84805fbb608ab89509d788f" +checksum = "8eaf5903dcbc0a39312feb77df2ff4c76387d591b9fc7b04a238dcf8bb62639a" dependencies = [ - "ciborium-io", - "ciborium-ll", + "android-tzdata", + "iana-time-zone", + "num-traits", "serde", + "windows-targets", ] [[package]] -name = "ciborium-io" -version = "0.2.0" +name = "core-foundation-sys" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "346de753af073cc87b52b2083a506b38ac176a44cfb05497b622e27be899b369" +checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" [[package]] -name = "ciborium-ll" -version = "0.2.0" +name = "crossbeam-channel" +version = "0.5.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "213030a2b5a4e0c0892b6652260cf6ccac84827b83a85a534e178e3906c4cf1b" +checksum = "ab3db02a9c5b5121e1e42fbdb1aeb65f5e02624cc58c43f2884c6ccac0b82f95" dependencies = [ - "ciborium-io", - "half", + "crossbeam-utils", ] [[package]] -name = "clap" -version = "3.2.5" +name = "crossbeam-deque" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d53da17d37dba964b9b3ecb5c5a1f193a2762c700e6829201e645b9381c99dc7" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" dependencies = [ - "bitflags", - "clap_lex", - "indexmap", - "textwrap", + "crossbeam-epoch", + "crossbeam-utils", ] [[package]] -name = "clap_lex" -version = "0.2.2" +name = "crossbeam-epoch" +version = "0.9.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5538cd660450ebeb4234cfecf8f2284b844ffc4c50531e66d584ad5b91293613" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" dependencies = [ - "os_str_bytes", + "crossbeam-utils", ] [[package]] -name = "criterion" -version = "0.4.0" +name = "crossbeam-utils" +version = "0.8.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7c76e09c1aae2bc52b3d2f29e13c6572553b30c4aa1b8a49fd70de6412654cb" -dependencies = [ - "anes", - "atty", - "cast", - "ciborium", - "clap", - "criterion-plot", - "itertools", - "lazy_static", - "num-traits", - "oorandom", - "plotters", - "rayon", - "regex", - "serde", - "serde_derive", - "serde_json", - "tinytemplate", - "walkdir", -] +checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" [[package]] -name = "criterion-plot" -version = "0.5.0" +name = "darling" +version = "0.20.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +checksum = "54e36fcd13ed84ffdfda6f5be89b31287cbb80c439841fe69e04841435464391" dependencies = [ - "cast", - "itertools", + "darling_core", + "darling_macro", ] [[package]] -name = "crossbeam-channel" -version = "0.5.6" +name = "darling_core" +version = "0.20.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2dd04ddaf88237dc3b8d8f9a3c1004b506b54b3313403944054d23c0870c521" +checksum = "9c2cf1c23a687a1feeb728783b993c4e1ad83d99f351801977dd809b48d0a70f" dependencies = [ - "cfg-if", - "crossbeam-utils", + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn", ] [[package]] -name = "crossbeam-deque" -version = "0.8.2" +name = "darling_macro" +version = "0.20.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "715e8152b692bba2d374b53d4875445368fdf21a94751410af607a5ac677d1fc" +checksum = "a668eda54683121533a393014d8692171709ff57a7d61f187b6e782719f8933f" dependencies = [ - "cfg-if", - "crossbeam-epoch", - "crossbeam-utils", + "darling_core", + "quote", + "syn", ] [[package]] -name = "crossbeam-epoch" -version = "0.9.13" +name = "deranged" +version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01a9af1f4c2ef74bb8aa1f7e19706bc72d03598c8a570bb5de72243c7a9d9d5a" +checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" dependencies = [ - "autocfg", - "cfg-if", - "crossbeam-utils", - "memoffset", - "scopeguard", -] - -[[package]] -name = "crossbeam-utils" -version = "0.8.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fb766fa798726286dbbb842f174001dab8abc7b627a1dd86e0b7222a95d929f" -dependencies = [ - "cfg-if", + "powerfmt", + "serde", ] [[package]] name = "easy-cast" -version = "0.4.4" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bd102ee8c418348759919b83b81cdbdc933ffe29740b903df448b4bafaa348e" +checksum = "10936778145f3bea71fd9bf61332cce28c28e96a380714f7ab34838b80733fd6" dependencies = [ "libm", ] [[package]] name = "either" -version = "1.8.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90e5c1c8368803113bf0c9584fc495a58b86dc8a29edbf8fe877d21d9507e797" +checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a" + +[[package]] +name = "fast_polynomial" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62eea6ee590b08a5f8b1139f4d6caee195b646d0c07e4b1808fbd5c4dea4829a" +dependencies = [ + "num-traits", +] [[package]] name = "float-ord" @@ -217,23 +195,23 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ce81f49ae8a0482e4c55ea62ebbd7e5a686af544c00b9d090bba3ff9be97b3d" +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + [[package]] name = "getrandom" -version = "0.2.8" +version = "0.2.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31" +checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" dependencies = [ "cfg-if", "libc", "wasi", ] -[[package]] -name = "half" -version = "1.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7" - [[package]] name = "hashbrown" version = "0.12.3" @@ -242,88 +220,110 @@ checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" [[package]] name = "hermit-abi" -version = "0.1.19" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" + +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + +[[package]] +name = "iana-time-zone" +version = "0.1.60" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" +checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" dependencies = [ - "libc", + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "windows-core", ] [[package]] -name = "indexmap" -version = "1.9.2" +name = "iana-time-zone-haiku" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1885e79c1fc4b10f0e172c475f458b7f7b93061064d98c3293e98c5ba0c8b399" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" dependencies = [ - "autocfg", - "hashbrown", + "cc", ] [[package]] -name = "itertools" -version = "0.10.5" +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + +[[package]] +name = "indexmap" +version = "1.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" dependencies = [ - "either", + "autocfg", + "hashbrown", + "serde", ] [[package]] name = "itoa" -version = "1.0.4" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4217ad341ebadf8d8e724e264f13e593e0648f5b3e94b3896a5df283be015ecc" +checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" [[package]] name = "js-sys" -version = "0.3.60" +version = "0.3.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49409df3e3bf0856b916e2ceaca09ee28e6871cf7d9ce97a692cacfdb2a25a47" +checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" dependencies = [ "wasm-bindgen", ] [[package]] -name = "lazy_static" -version = "1.4.0" +name = "lambert_w" +version = "0.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +checksum = "dd8852c2190439a46c77861aca230080cc9db4064be7f9de8ee81816d6c72c25" +dependencies = [ + "fast_polynomial", + "libm", +] [[package]] name = "libc" -version = "0.2.138" +version = "0.2.153" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db6d7e329c562c5dfab7a46a2afabc8b987ab9a4834c9d1ca04dc54c1546cef8" +checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" [[package]] name = "libm" -version = "0.2.6" +version = "0.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "348108ab3fba42ec82ff6e9564fc4ca0247bdccdc68dd8af9764bbc79c3c8ffb" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" [[package]] name = "log" -version = "0.4.17" +version = "0.4.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "abb12e687cfb44aa40f41fc3978ef76448f9b6038cad6aef4259d3c095a2382e" -dependencies = [ - "cfg-if", -] +checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" [[package]] -name = "memoffset" -version = "0.7.1" +name = "num-conv" +version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5de893c32cde5f383baa4c04c5d6dbdd735cfd4a794b0debdb2bb1b421da5ff4" -dependencies = [ - "autocfg", -] +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" [[package]] name = "num-traits" -version = "0.2.15" +version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" +checksum = "da0df0e5185db44f69b44f26786fe401b6c293d1907744beaa7fa62b2e5a517a" dependencies = [ "autocfg", "libm", @@ -331,9 +331,9 @@ dependencies = [ [[package]] name = "num_cpus" -version = "1.14.0" +version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6058e64324c71e02bc2b150e4f3bc8286db6c83092132ffa3f6b1eab0f9def5" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" dependencies = [ "hermit-abi", "libc", @@ -341,49 +341,15 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86f0b0d4bf799edbc74508c1e8bf170ff5f41238e5f8225603ca7caaae2b7860" - -[[package]] -name = "oorandom" -version = "11.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" - -[[package]] -name = "os_str_bytes" -version = "6.1.0" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21326818e99cfe6ce1e524c2a805c189a99b5ae555a35d19f9a284b427d86afa" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] -name = "plotters" -version = "0.3.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2538b639e642295546c50fcd545198c9d64ee2a38620a628724a3b266d5fbf97" -dependencies = [ - "num-traits", - "plotters-backend", - "plotters-svg", - "wasm-bindgen", - "web-sys", -] - -[[package]] -name = "plotters-backend" -version = "0.3.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "193228616381fecdc1224c62e96946dfbc73ff4384fba576e052ff8c1bea8142" - -[[package]] -name = "plotters-svg" -version = "0.3.3" +name = "powerfmt" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9a81d2759aae1dae668f783c308bc5c8ebd191ff4184aaa1b37f65a6ae5a56f" -dependencies = [ - "plotters-backend", -] +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" [[package]] name = "ppv-lite86" @@ -393,40 +359,39 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.47" +version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ea3d908b0e36316caf9e9e2c4625cdde190a7e6f440d794667ed17a1855e725" +checksum = "e835ff2298f5721608eb1a980ecaee1aef2c132bf95ecc026a11b7bf3c01c02e" dependencies = [ "unicode-ident", ] [[package]] name = "quote" -version = "1.0.21" +version = "1.0.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbe448f377a7d6961e30f5955f9b8d106c3f5e449d493ee1b125c1d43c2b5179" +checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" dependencies = [ "proc-macro2", ] [[package]] name = "rand" -version = "0.9.0" +version = "0.9.0-beta.0" dependencies = [ "bincode", - "criterion", - "libc", "log", "rand_chacha", "rand_core", "rand_pcg", "rayon", "serde", + "zerocopy", ] [[package]] name = "rand_chacha" -version = "0.4.0" +version = "0.9.0-beta.0" dependencies = [ "ppv-lite86", "rand_core", @@ -436,27 +401,29 @@ dependencies = [ [[package]] name = "rand_core" -version = "0.7.0" +version = "0.9.0-beta.0" dependencies = [ "getrandom", "serde", + "zerocopy", ] [[package]] name = "rand_distr" -version = "0.5.0" +version = "0.5.0-beta.0" dependencies = [ "average", "num-traits", "rand", "rand_pcg", "serde", + "serde_with", "special", ] [[package]] name = "rand_pcg" -version = "0.4.0" +version = "0.9.0-beta.0" dependencies = [ "bincode", "rand_core", @@ -465,9 +432,9 @@ dependencies = [ [[package]] name = "rayon" -version = "1.6.1" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6db3a213adf02b3bcfd2d3846bb41cb22857d131789e01df434fb7e7bc0759b7" +checksum = "1d2df5196e37bcc87abebc0053e20787d73847bb33134a69841207dd0a47f03b" dependencies = [ "either", "rayon-core", @@ -475,9 +442,9 @@ dependencies = [ [[package]] name = "rayon-core" -version = "1.10.1" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cac410af5d00ab6884528b4ab69d1e8e146e8d471201800fa1b4524126de6ad3" +checksum = "4b8f95bd6966f5c87776639160a66bd8ab9895d9d4ab01ddba9fc60661aebe8d" dependencies = [ "crossbeam-channel", "crossbeam-deque", @@ -485,56 +452,26 @@ dependencies = [ "num_cpus", ] -[[package]] -name = "regex" -version = "1.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e076559ef8e241f2ae3479e36f97bd5741c0330689e217ad51ce2c76808b868a" -dependencies = [ - "regex-syntax", -] - -[[package]] -name = "regex-syntax" -version = "0.6.28" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "456c603be3e8d448b072f410900c09faf164fbce2d480456f50eea6e25f9c848" - [[package]] name = "ryu" -version = "1.0.11" +version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4501abdff3ae82a1c1b477a17252eb69cee9e66eb915c1abaa4f44d873df9f09" - -[[package]] -name = "same-file" -version = "1.0.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" -dependencies = [ - "winapi-util", -] - -[[package]] -name = "scopeguard" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" +checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1" [[package]] name = "serde" -version = "1.0.149" +version = "1.0.197" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "256b9932320c590e707b94576e3cc1f7c9024d0ee6612dfbcf1cb106cbe8e055" +checksum = "3fb1c873e1b9b056a4dc4c0c198b24c3ffa059243875552b2bd0933b1aee4ce2" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.149" +version = "1.0.197" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4eae9b04cbffdfd550eb462ed33bc6a1b68c935127d008b27444d08380f94e4" +checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" dependencies = [ "proc-macro2", "quote", @@ -543,9 +480,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.89" +version = "1.0.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "020ff22c755c2ed3f8cf162dbb41a7268d934702f3ed3631656ea597e08fc3db" +checksum = "c5f09b1bd632ef549eaa9f60a1f8de742bdbc698e6cee2095fc84dde5f549ae0" dependencies = [ "itoa", "ryu", @@ -553,58 +490,97 @@ dependencies = [ ] [[package]] -name = "special" -version = "0.8.1" +name = "serde_with" +version = "3.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24a65e074159b75dcf173a4733ab2188baac24967b5c8ec9ed87ae15fcbc7636" +checksum = "9f02d8aa6e3c385bf084924f660ce2a3a6bd333ba55b35e8590b321f35d88513" dependencies = [ - "libc", + "base64", + "chrono", + "hex", + "indexmap", + "serde", + "serde_json", + "serde_with_macros", + "time", ] [[package]] -name = "syn" -version = "1.0.105" +name = "serde_with_macros" +version = "3.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60b9b43d45702de4c839cb9b51d9f529c5dd26a4aff255b42b1ebc03e88ee908" +checksum = "edc7d5d3932fb12ce722ee5e64dd38c504efba37567f0c402f6ca728c3b8b070" dependencies = [ + "darling", "proc-macro2", "quote", - "unicode-ident", + "syn", ] [[package]] -name = "textwrap" -version = "0.15.2" +name = "special" +version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b7b3e525a49ec206798b40326a44121291b530c963cfb01018f63e135bac543d" +checksum = "98d279079c3ddec4e7851337070c1055a18b8f606bba0b1aeb054bc059fc2e27" +dependencies = [ + "lambert_w", + "libm", +] [[package]] -name = "tinytemplate" -version = "1.2.1" +name = "strsim" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" + +[[package]] +name = "syn" +version = "2.0.53" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7383cd0e49fff4b6b90ca5670bfd3e9d6a733b3f90c686605aa7eec8c4996032" dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "time" +version = "0.3.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8248b6521bb14bc45b4067159b9b6ad792e2d6d754d6c41fb50e29fefe38749" +dependencies = [ + "deranged", + "itoa", + "num-conv", + "powerfmt", "serde", - "serde_json", + "time-core", + "time-macros", ] [[package]] -name = "unicode-ident" -version = "1.0.5" +name = "time-core" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ceab39d59e4c9499d4e5a8ee0e2735b891bb7308ac83dfb4e80cad195c9f6f3" +checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" [[package]] -name = "walkdir" -version = "2.3.2" +name = "time-macros" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "808cf2735cd4b6866113f648b791c6adc5714537bc222d9347bb203386ffda56" +checksum = "7ba3a3ef41e6672a2f0f001392bb5dcd3ff0a9992d618ca761a11c3121547774" dependencies = [ - "same-file", - "winapi", - "winapi-util", + "num-conv", + "time-core", ] +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -613,9 +589,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.83" +version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eaf9f5aceeec8be17c128b2e93e031fb8a4d469bb9c4ae2d7dc1888b26887268" +checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" dependencies = [ "cfg-if", "wasm-bindgen-macro", @@ -623,9 +599,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.83" +version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c8ffb332579b0557b52d268b91feab8df3615f265d5270fec2a8c95b17c1142" +checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" dependencies = [ "bumpalo", "log", @@ -638,9 +614,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.83" +version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "052be0f94026e6cbc75cdefc9bae13fd6052cdcaf532fa6c45e7ae33a1e6c810" +checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -648,9 +624,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.83" +version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07bc0c051dc5f23e307b13285f9d75df86bfdf816c5721e573dec1f9b8aa193c" +checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", @@ -661,47 +637,92 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.83" +version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c38c045535d93ec4f0b4defec448e4291638ee608530863b1e2ba115d4fff7f" +checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" [[package]] -name = "web-sys" -version = "0.3.60" +name = "windows-core" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bcda906d8be16e728fd5adc5b729afad4e444e106ab28cd1c7256e54fa61510f" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" dependencies = [ - "js-sys", - "wasm-bindgen", + "windows-targets", ] [[package]] -name = "winapi" -version = "0.3.9" +name = "windows-targets" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +checksum = "7dd37b7e5ab9018759f893a1952c9420d060016fc19a472b4bb20d1bdd694d1b" dependencies = [ - "winapi-i686-pc-windows-gnu", - "winapi-x86_64-pc-windows-gnu", + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", ] [[package]] -name = "winapi-i686-pc-windows-gnu" -version = "0.4.0" +name = "windows_aarch64_gnullvm" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" +checksum = "bcf46cf4c365c6f2d1cc93ce535f2c8b244591df96ceee75d8e83deb70a9cac9" [[package]] -name = "winapi-util" -version = "0.1.5" +name = "windows_aarch64_msvc" +version = "0.52.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da9f259dd3bcf6990b55bffd094c4f7235817ba4ceebde8e6d11cd0c5633b675" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b474d8268f99e0995f25b9f095bc7434632601028cf86590aea5c8a5cb7801d3" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1515e9a29e5bed743cb4415a9ecf5dfca648ce85ee42e15873c3cd8610ff8e02" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5eee091590e89cc02ad514ffe3ead9eb6b660aedca2183455434b93546371a03" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77ca79f2451b49fa9e2af39f0747fe999fcda4f5e241b2898624dca97a1f2177" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8" + +[[package]] +name = "zerocopy" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178" +checksum = "a65238aacd5fb83fb03fcaf94823e71643e937000ec03c46e7da94234b10c870" dependencies = [ - "winapi", + "zerocopy-derive", ] [[package]] -name = "winapi-x86_64-pc-windows-gnu" -version = "0.4.0" +name = "zerocopy-derive" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +checksum = "3ca22c4ad176b37bd81a565f66635bde3d654fe6832730c3e52e1018ae1655ee" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] diff --git a/Cargo.toml b/Cargo.toml index 1d8cfbe057c..523c8d3c867 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rand" -version = "0.9.0-alpha.0" +version = "0.9.1" authors = ["The Rand Project Developers", "The Rust Project Developers"] license = "MIT OR Apache-2.0" readme = "README.md" @@ -14,90 +14,70 @@ keywords = ["random", "rng"] categories = ["algorithms", "no-std"] autobenches = true edition = "2021" -rust-version = "1.60" +rust-version = "1.63" include = ["src/", "LICENSE-*", "README.md", "CHANGELOG.md", "COPYRIGHT"] [package.metadata.docs.rs] # To build locally: -# RUSTDOCFLAGS="--cfg doc_cfg" cargo +nightly doc --all-features --no-deps --generate-link-to-definition --open +# RUSTDOCFLAGS="--cfg docsrs -Zunstable-options --generate-link-to-definition" cargo +nightly doc --all --all-features --no-deps --open all-features = true -rustdoc-args = ["--cfg", "doc_cfg", "--generate-link-to-definition"] +rustdoc-args = ["--generate-link-to-definition"] [package.metadata.playground] -features = ["small_rng", "serde1"] +features = ["small_rng", "serde"] [features] # Meta-features: -default = ["std", "std_rng", "getrandom"] +default = ["std", "std_rng", "os_rng", "small_rng", "thread_rng"] nightly = [] # some additions requiring nightly Rust -serde1 = ["serde", "rand_core/serde1"] +serde = ["dep:serde", "rand_core/serde"] # Option (enabled by default): without "std" rand uses libcore; this option # enables functionality expected to be available on a standard platform. -std = ["rand_core/std", "rand_chacha?/std", "alloc", "libc"] +std = ["rand_core/std", "rand_chacha?/std", "alloc"] # Option: "alloc" enables support for Vec and Box when not using "std" -alloc = ["rand_core/alloc"] +alloc = [] -# Option: use getrandom package for seeding -getrandom = ["rand_core/getrandom"] +# Option: enable OsRng +os_rng = ["rand_core/os_rng"] # Option (requires nightly Rust): experimental SIMD support -simd_support = ["zerocopy/simd-nightly"] +simd_support = [] # Option (enabled by default): enable StdRng -std_rng = ["rand_chacha"] +std_rng = ["dep:rand_chacha"] # Option: enable SmallRng small_rng = [] +# Option: enable ThreadRng and rng() +thread_rng = ["std", "std_rng", "os_rng"] + # Option: use unbiased sampling for algorithms supporting this option: Uniform distribution. # By default, bias affecting no more than one in 2^48 samples is accepted. # Note: enabling this option is expected to affect reproducibility of results. unbiased = [] +# Option: enable logging +log = ["dep:log"] + [workspace] members = [ "rand_core", - "rand_distr", "rand_chacha", "rand_pcg", ] +exclude = ["benches", "distr_test"] [dependencies] -rand_core = { path = "rand_core", version = "=0.9.0-alpha.0", default-features = false } +rand_core = { path = "rand_core", version = "0.9.0", default-features = false } log = { version = "0.4.4", optional = true } serde = { version = "1.0.103", features = ["derive"], optional = true } -rand_chacha = { path = "rand_chacha", version = "=0.9.0-alpha.0", default-features = false, optional = true } -zerocopy = { version = "=0.8.0-alpha.5", default-features = false, features = ["simd"] } - -[target.'cfg(unix)'.dependencies] -# Used for fork protection (reseeding.rs) -libc = { version = "0.2.22", optional = true, default-features = false } +rand_chacha = { path = "rand_chacha", version = "0.9.0", default-features = false, optional = true } [dev-dependencies] -rand_pcg = { path = "rand_pcg", version = "=0.9.0-alpha.0" } -# Only to test serde1 +rand_pcg = { path = "rand_pcg", version = "0.9.0" } +# Only to test serde bincode = "1.2.1" -rayon = "1.5.3" -criterion = { version = "0.4" } - -[[bench]] -name = "uniform" -path = "benches/uniform.rs" -harness = false - -[[bench]] -name = "seq_choose" -path = "benches/seq_choose.rs" -harness = false - -[[bench]] -name = "shuffle" -path = "benches/shuffle.rs" -harness = false - -[[bench]] -name = "uniform_float" -path = "benches/uniform_float.rs" -harness = false +rayon = "1.7" diff --git a/README.md b/README.md index 30c697922bf..e8b6fe3d337 100644 --- a/README.md +++ b/README.md @@ -1,42 +1,49 @@ # Rand -[![Test Status](https://github.com/rust-random/rand/workflows/Tests/badge.svg?event=push)](https://github.com/rust-random/rand/actions) +[![Test Status](https://github.com/rust-random/rand/actions/workflows/test.yml/badge.svg?event=push)](https://github.com/rust-random/rand/actions) [![Crate](https://img.shields.io/crates/v/rand.svg)](https://crates.io/crates/rand) [![Book](https://img.shields.io/badge/book-master-yellow.svg)](https://rust-random.github.io/book/) [![API](https://img.shields.io/badge/api-master-yellow.svg)](https://rust-random.github.io/rand/rand) [![API](https://docs.rs/rand/badge.svg)](https://docs.rs/rand) -[![Minimum rustc version](https://img.shields.io/badge/rustc-1.60+-lightgray.svg)](https://github.com/rust-random/rand#rust-version-requirements) - -A Rust library for random number generation, featuring: - -- Easy random value generation and usage via the [`Rng`](https://docs.rs/rand/*/rand/trait.Rng.html), - [`SliceRandom`](https://docs.rs/rand/*/rand/seq/trait.SliceRandom.html) and - [`IteratorRandom`](https://docs.rs/rand/*/rand/seq/trait.IteratorRandom.html) traits -- Secure seeding via the [`getrandom` crate](https://crates.io/crates/getrandom) - and fast, convenient generation via [`thread_rng`](https://docs.rs/rand/*/rand/fn.thread_rng.html) -- A modular design built over [`rand_core`](https://crates.io/crates/rand_core) - ([see the book](https://rust-random.github.io/book/crates.html)) -- Fast implementations of the best-in-class [cryptographic](https://rust-random.github.io/book/guide-rngs.html#cryptographically-secure-pseudo-random-number-generators-csprngs) and - [non-cryptographic](https://rust-random.github.io/book/guide-rngs.html#basic-pseudo-random-number-generators-prngs) generators -- A flexible [`distributions`](https://docs.rs/rand/*/rand/distributions/index.html) module -- Samplers for a large number of random number distributions via our own + +Rand is a set of crates supporting (pseudo-)random generators: + +- Built over a standard RNG trait: [`rand_core::RngCore`](https://docs.rs/rand_core/latest/rand_core/trait.RngCore.html) +- With fast implementations of both [strong](https://rust-random.github.io/book/guide-rngs.html#cryptographically-secure-pseudo-random-number-generators-csprngs) and + [small](https://rust-random.github.io/book/guide-rngs.html#basic-pseudo-random-number-generators-prngs) generators: [`rand::rngs`](https://docs.rs/rand/latest/rand/rngs/index.html), and more RNGs: [`rand_chacha`](https://docs.rs/rand_chacha), [`rand_xoshiro`](https://docs.rs/rand_xoshiro/), [`rand_pcg`](https://docs.rs/rand_pcg/), [rngs repo](https://github.com/rust-random/rngs/) +- [`rand::rng`](https://docs.rs/rand/latest/rand/fn.rng.html) is an asymptotically-fast, automatically-seeded and reasonably strong generator available on all `std` targets +- Direct support for seeding generators from the [getrandom] crate + +With broad support for random value generation and random processes: + +- [`StandardUniform`](https://docs.rs/rand/latest/rand/distr/struct.StandardUniform.html) random value sampling, + [`Uniform`](https://docs.rs/rand/latest/rand/distr/struct.Uniform.html)-ranged value sampling + and [more](https://docs.rs/rand/latest/rand/distr/index.html) +- Samplers for a large number of non-uniform random number distributions via our own [`rand_distr`](https://docs.rs/rand_distr) and via - the [`statrs`](https://docs.rs/statrs/0.13.0/statrs/) + the [`statrs`](https://docs.rs/statrs) +- Random processes (mostly choose and shuffle) via [`rand::seq`](https://docs.rs/rand/latest/rand/seq/index.html) traits + +All with: + - [Portably reproducible output](https://rust-random.github.io/book/portability.html) - `#[no_std]` compatibility (partial) -- *Many* performance optimisations +- *Many* performance optimisations thanks to contributions from the wide + user-base -It's also worth pointing out what `rand` *is not*: +Rand **is not**: -- Small. Most low-level crates are small, but the higher-level `rand` and - `rand_distr` each contain a lot of functionality. +- Small (LoC). Most low-level crates are small, but the higher-level `rand` + and `rand_distr` each contain a lot of functionality. - Simple (implementation). We have a strong focus on correctness, speed and flexibility, but not simplicity. If you prefer a small-and-simple library, there are alternatives including [fastrand](https://crates.io/crates/fastrand) and [oorandom](https://crates.io/crates/oorandom). -- Slow. We take performance seriously, with considerations also for set-up - time of new distributions, commonly-used parameters, and parameters of the - current sampler. +- Primarily a cryptographic library. `rand` does provide some generators which + aim to support unpredictable value generation under certain constraints; + see [SECURITY.md](https://github.com/rust-random/rand/blob/master/SECURITY.md) for details. + Users are expected to determine for themselves + whether `rand`'s functionality meets their own security requirements. Documentation: @@ -45,60 +52,14 @@ Documentation: - [API reference (docs.rs)](https://docs.rs/rand) -## Usage - -Add this to your `Cargo.toml`: - -```toml -[dependencies] -rand = "0.8.5" -``` - -To get started using Rand, see [The Book](https://rust-random.github.io/book). - - ## Versions Rand is *mature* (suitable for general usage, with infrequent breaking releases -which minimise breakage) but not yet at 1.0. We maintain compatibility with -pinned versions of the Rust compiler (see below). - -Current Rand versions are: - -- Version 0.7 was released in June 2019, moving most non-uniform distributions - to an external crate, moving `from_entropy` to `SeedableRng`, and many small - changes and fixes. -- Version 0.8 was released in December 2020 with many small changes. - -A detailed [changelog](CHANGELOG.md) is available for releases. - -When upgrading to the next minor series (especially 0.4 → 0.5), we recommend -reading the [Upgrade Guide](https://rust-random.github.io/book/update.html). +which minimise breakage) but not yet at 1.0. Current `MAJOR.MINOR` versions are: -Rand has not yet reached 1.0 implying some breaking changes may arrive in the -future ([SemVer](https://semver.org/) allows each 0.x.0 release to include -breaking changes), but is considered *mature*: breaking changes are minimised -and breaking releases are infrequent. +- Version 0.9 was released in January 2025. -Rand libs have inter-dependencies and make use of the -[semver trick](https://github.com/dtolnay/semver-trick/) in order to make traits -compatible across crate versions. (This is especially important for `RngCore` -and `SeedableRng`.) A few crate releases are thus compatibility shims, -depending on the *next* lib version (e.g. `rand_core` versions `0.2.2` and -`0.3.1`). This means, for example, that `rand_core_0_4_0::SeedableRng` and -`rand_core_0_3_0::SeedableRng` are distinct, incompatible traits, which can -cause build errors. Usually, running `cargo update` is enough to fix any issues. - -### Yanked versions - -Some versions of Rand crates have been yanked ("unreleased"). Where this occurs, -the crate's CHANGELOG *should* be updated with a rationale, and a search on the -issue tracker with the keyword `yank` *should* uncover the motivation. - -### Rust version requirements - -The Minimum Supported Rust Version (MSRV) is `rustc >= 1.60.0`. -Older releases may work (depending on feature configuration) but are untested. +See the [CHANGELOG](https://github.com/rust-random/rand/blob/master/CHANGELOG.md) or [Upgrade Guide](https://rust-random.github.io/book/update.html) for more details. ## Crate Features @@ -106,47 +67,54 @@ Rand is built with these features enabled by default: - `std` enables functionality dependent on the `std` lib - `alloc` (implied by `std`) enables functionality requiring an allocator -- `getrandom` (implied by `std`) is an optional dependency providing the code - behind `rngs::OsRng` -- `std_rng` enables inclusion of `StdRng`, `thread_rng` and `random` - (the latter two *also* require that `std` be enabled) +- `os_rng` (implied by `std`) enables `rngs::OsRng`, using the [getrandom] crate +- `std_rng` enables inclusion of `StdRng`, `ThreadRng` +- `small_rng` enables inclusion of the `SmallRng` PRNG Optionally, the following dependencies can be enabled: -- `log` enables logging via the `log` crate +- `log` enables logging via [log](https://crates.io/crates/log) Additionally, these features configure Rand: -- `small_rng` enables inclusion of the `SmallRng` PRNG - `nightly` includes some additions requiring nightly Rust - `simd_support` (experimental) enables sampling of SIMD values (uniformly random SIMD integers and floats), requiring nightly Rust +- `unbiased` use unbiased sampling for algorithms supporting this option: Uniform distribution. + + (By default, bias affecting no more than one in 2^48 samples is accepted.) + + Note: enabling this option is expected to affect reproducibility of results. Note that nightly features are not stable and therefore not all library and compiler versions will be compatible. This is especially true of Rand's experimental `simd_support` feature. Rand supports limited functionality in `no_std` mode (enabled via -`default-features = false`). In this case, `OsRng` and `from_entropy` are -unavailable (unless `getrandom` is enabled), large parts of `seq` are -unavailable (unless `alloc` is enabled), and `thread_rng` and `random` are -unavailable. - -### WASM support - -Seeding entropy from OS on WASM target `wasm32-unknown-unknown` is not -*automatically* supported by `rand` or `getrandom`. If you are fine with -seeding the generator manually, you can disable the `getrandom` feature -and use the methods on the `SeedableRng` trait. To enable seeding from OS, -either use a different target such as `wasm32-wasi` or add a direct -dependency on `getrandom` with the `js` feature (if the target supports -JavaScript). See -[getrandom#WebAssembly support](https://docs.rs/getrandom/latest/getrandom/#webassembly-support). +`default-features = false`). In this case, `OsRng` and `from_os_rng` are +unavailable (unless `os_rng` is enabled), large parts of `seq` are +unavailable (unless `alloc` is enabled), and `ThreadRng` is unavailable. + +## Portability and platform support + +Many (but not all) algorithms are intended to have reproducible output. Read more in the book: [Portability](https://rust-random.github.io/book/portability.html). + +The Rand library supports a variety of CPU architectures. Platform integration is outsourced to [getrandom]. + +### WebAssembly support + +The [WASI](https://github.com/WebAssembly/WASI/tree/main) and Emscripten +targets are directly supported. The `wasm32-unknown-unknown` target is not +*automatically* supported. To enable support for this target, refer to the +[`getrandom` documentation for WebAssembly](https://docs.rs/getrandom/latest/getrandom/#webassembly-support). +Alternatively, the `os_rng` feature may be disabled. # License Rand is distributed under the terms of both the MIT license and the Apache License (Version 2.0). -See [LICENSE-APACHE](LICENSE-APACHE) and [LICENSE-MIT](LICENSE-MIT), and -[COPYRIGHT](COPYRIGHT) for details. +See [LICENSE-APACHE](https://github.com/rust-random/rand/blob/master/LICENSE-APACHE) and [LICENSE-MIT](https://github.com/rust-random/rand/blob/master/LICENSE-MIT), and +[COPYRIGHT](https://github.com/rust-random/rand/blob/master/COPYRIGHT) for details. + +[getrandom]: https://crates.io/crates/getrandom diff --git a/SECURITY.md b/SECURITY.md index a31b4e23fd3..f1a61b0d208 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -1,69 +1,89 @@ # Security Policy -## No guarantees +## Disclaimer -Support is provided on a best-effort bases only. -No binding guarantees can be provided. +Rand is a community project and cannot provide legally-binding guarantees of +security. ## Security premises -Rand provides the trait `rand_core::CryptoRng` aka `rand::CryptoRng` as a marker -trait. Generators implementing `RngCore` *and* `CryptoRng`, and given the -additional constraints that: +### Marker traits -- Instances of seedable RNGs (those implementing `SeedableRng`) are - constructed with cryptographically secure seed values -- The state (memory) of the RNG and its seed value are not be exposed +Rand provides the marker traits `CryptoRng`, `TryCryptoRng` and +`CryptoBlockRng`. Generators (RNGs) implementing one of these traits which are +used according to these additional constraints: + +- The generator may be constructed using `std::default::Default` where the + generator supports this trait. Note that generators should *only* support + `Default` where the `default()` instance is appropriately seeded: for + example `OsRng` has no state and thus has a trivial `default()` instance + while `ThreadRng::default()` returns a handle to a thread-local instance + seeded using `OsRng`. +- The generator may be constructed using `rand_core::SeedableRng` in any of + the following ways where the generator supports this trait: + + - Via `SeedableRng::from_seed` using a cryptographically secure seed value + - Via `SeedableRng::from_rng` or `try_from_rng` using a cryptographically + secure source `rng` + - Via `SeedableRng::from_os_rng` or `try_from_os_rng` +- The state (memory) of the generator and its seed value (or source `rng`) are + not exposed are expected to provide the following: -- An attacker can gain no advantage over chance (50% for each bit) in - predicting the RNG output, even with full knowledge of all prior outputs. +- An attacker cannot predict the output with more accuracy than what would be + expected through pure chance since each possible output value of any method + under the above traits which generates output bytes (including + `RngCore::next_u32`, `RngCore::next_u64`, `RngCore::fill_bytes`, + `TryRngCore::try_next_u32`, `TryRngCore::try_next_u64`, + `TryRngCore::try_fill_bytes` and `BlockRngCore::generate`) should be equally + likely +- Knowledge of prior outputs from the generator does not aid an attacker in + predicting future outputs -For some RNGs, notably `OsRng`, `ThreadRng` and those wrapped by `ReseedingRng`, -we provide limited mitigations against side-channel attacks: +### Specific generators -- After a process fork on Unix, there is an upper-bound on the number of bits - output by the RNG before the processes diverge, after which outputs from - each process's RNG are uncorrelated -- After the state (memory) of an RNG is leaked, there is an upper-bound on the - number of bits of output by the RNG before prediction of output by an - observer again becomes computationally-infeasible +`OsRng` is a stateless "generator" implemented via [getrandom]. As such, it has +no possible state to leak and cannot be improperly seeded. -Additionally, derivations from such an RNG (including the `Rng` trait, -implementations of the `Distribution` trait, and `seq` algorithms) should not -introduce significant bias other than that expected from the operation in -question (e.g. bias from a weighted distribution). +`StdRng` is a `CryptoRng` and `SeedableRng` using a pseudo-random algorithm +selected for good security and performance qualities. Since it does not offer +reproducibility of output, its algorithm may be changed in any release version. -## Supported Versions +`ChaCha12Rng` and `ChaCha20Rng` are selected pseudo-random generators +distributed by the `rand` project which meet the requirements of the `CryptoRng` +trait and implement `SeedableRng` with a commitment to reproducibility of +results. -We will attempt to uphold these premises in the following crate versions, -provided that only the latest patch version is used, and with potential -exceptions for theoretical issues without a known exploit: +`ThreadRng` is a conveniently-packaged generator over `StdRng` offering +automatic seeding from `OsRng`, periodic reseeding and thread locality. +This random source is intended to offer a good compromise between cryptographic +security, fast generation with reasonably low memory and initialization cost +overheads, and robustness against misuse. -| Crate | Versions | Exceptions | -| ----- | -------- | ---------- | -| `rand` | 0.8 | | -| `rand` | 0.7 | | -| `rand` | 0.5, 0.6 | Jitter | -| `rand` | 0.4 | Jitter, ISAAC | -| `rand_core` | 0.2 - 0.6 | | -| `rand_chacha` | 0.1 - 0.3 | | +[getrandom]: https://crates.io/crates/getrandom -Explanation of exceptions: +### Distributions -- Jitter: `JitterRng` is used as an entropy source when the primary source - fails; this source may not be secure against side-channel attacks, see #699. -- ISAAC: the [ISAAC](https://burtleburtle.net/bob/rand/isaacafa.html) RNG used - to implement `thread_rng` is difficult to analyse and thus cannot provide - strong assertions of security. +Methods of the `Rng` trait, functionality of the `rand::seq` module and +implementators of the `Distribution` trait are expected, while using a +cryptographically secure `CryptoRng` instance meeting the above constraints, +to not introduce significant bias to their operation beyond what would be +expected of the operation. Note that the usage of 'significant' here permits +some bias, as noted for example in the documentation of the `Uniform` +distribution. -## Known issues +## Supported Versions -In `rand` version 0.3 (0.3.18 and later), if `OsRng` fails, `thread_rng` is -seeded from the system time in an insecure manner. +We aim to provide security fixes in the form of a new patch version for the +latest release version of `rand` and its dependencies `rand_core` and +`rand_chacha`, as well as for prior major and minor releases which were, at some +time during the previous 12 months, the latest release version. ## Reporting a Vulnerability -To report a vulnerability, [open a new issue](https://github.com/rust-random/rand/issues/new). -Once the issue is resolved, the vulnerability should be [reported to RustSec](https://github.com/RustSec/advisory-db/blob/master/CONTRIBUTING.md). +If you have discovered a security vulnerability in this project, please report it privately. **Do not disclose it as a public issue.** This gives us time to work with you to fix the issue before public exposure, reducing the chance that the exploit will be used before a patch is released. + +Please disclose it at [security advisory](https://github.com/rust-random/rand/security/advisories/new). + +This project is maintained by a team of volunteers on a reasonable-effort basis. As such, please give us at least 90 days to work on a fix before public exposure. diff --git a/benches/Cargo.toml b/benches/Cargo.toml new file mode 100644 index 00000000000..adb9aadd84b --- /dev/null +++ b/benches/Cargo.toml @@ -0,0 +1,58 @@ +[package] +name = "benches" +version = "0.1.0" +edition = "2021" +publish = false + +[features] +# Option (requires nightly Rust): experimental SIMD support +simd_support = ["rand/simd_support"] + +[dependencies] + +[dev-dependencies] +rand = { path = "..", features = ["small_rng", "nightly"] } +rand_pcg = { path = "../rand_pcg" } +rand_chacha = { path = "../rand_chacha" } +criterion = "0.5" +criterion-cycles-per-byte = "0.6" + +[[bench]] +name = "array" +harness = false + +[[bench]] +name = "bool" +harness = false + +[[bench]] +name = "generators" +harness = false + +[[bench]] +name = "seq_choose" +harness = false + +[[bench]] +name = "shuffle" +harness = false + +[[bench]] +name = "simd" +harness = false + +[[bench]] +name = "standard" +harness = false + +[[bench]] +name = "uniform" +harness = false + +[[bench]] +name = "uniform_float" +harness = false + +[[bench]] +name = "weighted" +harness = false diff --git a/benches/benches/array.rs b/benches/benches/array.rs new file mode 100644 index 00000000000..063516337bf --- /dev/null +++ b/benches/benches/array.rs @@ -0,0 +1,94 @@ +// Copyright 2018-2023 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Generating/filling arrays and iterators of output + +use criterion::{criterion_group, criterion_main, Criterion}; +use rand::distr::StandardUniform; +use rand::prelude::*; +use rand_pcg::Pcg64Mcg; + +criterion_group!( + name = benches; + config = Criterion::default(); + targets = bench +); +criterion_main!(benches); + +pub fn bench(c: &mut Criterion) { + let mut g = c.benchmark_group("random_1kb"); + g.throughput(criterion::Throughput::Bytes(1024)); + + g.bench_function("u16_iter_repeat", |b| { + use core::iter; + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + b.iter(|| { + let v: Vec = iter::repeat(()).map(|()| rng.random()).take(512).collect(); + v + }); + }); + + g.bench_function("u16_sample_iter", |b| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + b.iter(|| { + let v: Vec = StandardUniform.sample_iter(&mut rng).take(512).collect(); + v + }); + }); + + g.bench_function("u16_gen_array", |b| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + b.iter(|| { + let v: [u16; 512] = rng.random(); + v + }); + }); + + g.bench_function("u16_fill", |b| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + let mut buf = [0u16; 512]; + b.iter(|| { + rng.fill(&mut buf[..]); + buf + }); + }); + + g.bench_function("u64_iter_repeat", |b| { + use core::iter; + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + b.iter(|| { + let v: Vec = iter::repeat(()).map(|()| rng.random()).take(128).collect(); + v + }); + }); + + g.bench_function("u64_sample_iter", |b| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + b.iter(|| { + let v: Vec = StandardUniform.sample_iter(&mut rng).take(128).collect(); + v + }); + }); + + g.bench_function("u64_gen_array", |b| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + b.iter(|| { + let v: [u64; 128] = rng.random(); + v + }); + }); + + g.bench_function("u64_fill", |b| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + let mut buf = [0u64; 128]; + b.iter(|| { + rng.fill(&mut buf[..]); + buf + }); + }); +} diff --git a/benches/benches/bool.rs b/benches/benches/bool.rs new file mode 100644 index 00000000000..8ff8c676024 --- /dev/null +++ b/benches/benches/bool.rs @@ -0,0 +1,69 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Generating/filling arrays and iterators of output + +use criterion::{criterion_group, criterion_main, Criterion}; +use rand::distr::Bernoulli; +use rand::prelude::*; +use rand_pcg::Pcg32; + +criterion_group!( + name = benches; + config = Criterion::default(); + targets = bench +); +criterion_main!(benches); + +pub fn bench(c: &mut Criterion) { + let mut g = c.benchmark_group("random_bool"); + g.sample_size(1000); + g.warm_up_time(core::time::Duration::from_millis(500)); + g.measurement_time(core::time::Duration::from_millis(1000)); + + g.bench_function("standard", |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + b.iter(|| rng.sample::(rand::distr::StandardUniform)) + }); + + g.bench_function("const", |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + b.iter(|| rng.random_bool(0.18)) + }); + + g.bench_function("var", |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + let p = rng.random(); + b.iter(|| rng.random_bool(p)) + }); + + g.bench_function("ratio_const", |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + b.iter(|| rng.random_ratio(2, 3)) + }); + + g.bench_function("ratio_var", |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + let d = rng.random_range(1..=100); + let n = rng.random_range(0..=d); + b.iter(|| rng.random_ratio(n, d)); + }); + + g.bench_function("bernoulli_const", |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + let d = Bernoulli::new(0.18).unwrap(); + b.iter(|| rng.sample(d)) + }); + + g.bench_function("bernoulli_var", |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + let p = rng.random(); + let d = Bernoulli::new(p).unwrap(); + b.iter(|| rng.sample(d)) + }); +} diff --git a/benches/benches/generators.rs b/benches/benches/generators.rs new file mode 100644 index 00000000000..31f08a02408 --- /dev/null +++ b/benches/benches/generators.rs @@ -0,0 +1,218 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use core::time::Duration; +use criterion::measurement::WallTime; +use criterion::{black_box, criterion_group, criterion_main, BenchmarkGroup, Criterion}; +use rand::prelude::*; +use rand::rngs::OsRng; +use rand::rngs::ReseedingRng; +use rand_chacha::rand_core::UnwrapErr; +use rand_chacha::{ChaCha12Rng, ChaCha20Core, ChaCha20Rng, ChaCha8Rng}; +use rand_pcg::{Pcg32, Pcg64, Pcg64Dxsm, Pcg64Mcg}; + +criterion_group!( + name = benches; + config = Criterion::default(); + targets = random_bytes, random_u32, random_u64, init_gen, init_from_u64, init_from_seed, reseeding_bytes +); +criterion_main!(benches); + +pub fn random_bytes(c: &mut Criterion) { + let mut g = c.benchmark_group("random_bytes"); + g.warm_up_time(Duration::from_millis(500)); + g.measurement_time(Duration::from_millis(1000)); + g.throughput(criterion::Throughput::Bytes(1024)); + + fn bench(g: &mut BenchmarkGroup, name: &str, mut rng: impl Rng) { + g.bench_function(name, |b| { + let mut buf = [0u8; 1024]; + b.iter(|| { + rng.fill_bytes(&mut buf); + black_box(buf); + }); + }); + } + + bench(&mut g, "pcg32", Pcg32::from_rng(&mut rand::rng())); + bench(&mut g, "pcg64", Pcg64::from_rng(&mut rand::rng())); + bench(&mut g, "pcg64mcg", Pcg64Mcg::from_rng(&mut rand::rng())); + bench(&mut g, "pcg64dxsm", Pcg64Dxsm::from_rng(&mut rand::rng())); + bench(&mut g, "chacha8", ChaCha8Rng::from_rng(&mut rand::rng())); + bench(&mut g, "chacha12", ChaCha12Rng::from_rng(&mut rand::rng())); + bench(&mut g, "chacha20", ChaCha20Rng::from_rng(&mut rand::rng())); + bench(&mut g, "std", StdRng::from_rng(&mut rand::rng())); + bench(&mut g, "small", SmallRng::from_rng(&mut rand::rng())); + bench(&mut g, "os", UnwrapErr(OsRng)); + bench(&mut g, "thread", rand::rng()); + + g.finish() +} + +pub fn random_u32(c: &mut Criterion) { + let mut g = c.benchmark_group("random_u32"); + g.sample_size(1000); + g.warm_up_time(Duration::from_millis(500)); + g.measurement_time(Duration::from_millis(1000)); + g.throughput(criterion::Throughput::Bytes(4)); + + fn bench(g: &mut BenchmarkGroup, name: &str, mut rng: impl Rng) { + g.bench_function(name, |b| { + b.iter(|| rng.random::()); + }); + } + + bench(&mut g, "pcg32", Pcg32::from_rng(&mut rand::rng())); + bench(&mut g, "pcg64", Pcg64::from_rng(&mut rand::rng())); + bench(&mut g, "pcg64mcg", Pcg64Mcg::from_rng(&mut rand::rng())); + bench(&mut g, "pcg64dxsm", Pcg64Dxsm::from_rng(&mut rand::rng())); + bench(&mut g, "chacha8", ChaCha8Rng::from_rng(&mut rand::rng())); + bench(&mut g, "chacha12", ChaCha12Rng::from_rng(&mut rand::rng())); + bench(&mut g, "chacha20", ChaCha20Rng::from_rng(&mut rand::rng())); + bench(&mut g, "std", StdRng::from_rng(&mut rand::rng())); + bench(&mut g, "small", SmallRng::from_rng(&mut rand::rng())); + bench(&mut g, "os", UnwrapErr(OsRng)); + bench(&mut g, "thread", rand::rng()); + + g.finish() +} + +pub fn random_u64(c: &mut Criterion) { + let mut g = c.benchmark_group("random_u64"); + g.sample_size(1000); + g.warm_up_time(Duration::from_millis(500)); + g.measurement_time(Duration::from_millis(1000)); + g.throughput(criterion::Throughput::Bytes(8)); + + fn bench(g: &mut BenchmarkGroup, name: &str, mut rng: impl Rng) { + g.bench_function(name, |b| { + b.iter(|| rng.random::()); + }); + } + + bench(&mut g, "pcg32", Pcg32::from_rng(&mut rand::rng())); + bench(&mut g, "pcg64", Pcg64::from_rng(&mut rand::rng())); + bench(&mut g, "pcg64mcg", Pcg64Mcg::from_rng(&mut rand::rng())); + bench(&mut g, "pcg64dxsm", Pcg64Dxsm::from_rng(&mut rand::rng())); + bench(&mut g, "chacha8", ChaCha8Rng::from_rng(&mut rand::rng())); + bench(&mut g, "chacha12", ChaCha12Rng::from_rng(&mut rand::rng())); + bench(&mut g, "chacha20", ChaCha20Rng::from_rng(&mut rand::rng())); + bench(&mut g, "std", StdRng::from_rng(&mut rand::rng())); + bench(&mut g, "small", SmallRng::from_rng(&mut rand::rng())); + bench(&mut g, "os", UnwrapErr(OsRng)); + bench(&mut g, "thread", rand::rng()); + + g.finish() +} + +pub fn init_gen(c: &mut Criterion) { + let mut g = c.benchmark_group("init_gen"); + g.warm_up_time(Duration::from_millis(500)); + g.measurement_time(Duration::from_millis(1000)); + + fn bench(g: &mut BenchmarkGroup, name: &str) { + g.bench_function(name, |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + b.iter(|| R::from_rng(&mut rng)); + }); + } + + bench::(&mut g, "pcg32"); + bench::(&mut g, "pcg64"); + bench::(&mut g, "pcg64mcg"); + bench::(&mut g, "pcg64dxsm"); + bench::(&mut g, "chacha8"); + bench::(&mut g, "chacha12"); + bench::(&mut g, "chacha20"); + bench::(&mut g, "std"); + bench::(&mut g, "small"); + + g.finish() +} + +pub fn init_from_u64(c: &mut Criterion) { + let mut g = c.benchmark_group("init_from_u64"); + g.warm_up_time(Duration::from_millis(500)); + g.measurement_time(Duration::from_millis(1000)); + + fn bench(g: &mut BenchmarkGroup, name: &str) { + g.bench_function(name, |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + let seed = rng.random(); + b.iter(|| R::seed_from_u64(black_box(seed))); + }); + } + + bench::(&mut g, "pcg32"); + bench::(&mut g, "pcg64"); + bench::(&mut g, "pcg64mcg"); + bench::(&mut g, "pcg64dxsm"); + bench::(&mut g, "chacha8"); + bench::(&mut g, "chacha12"); + bench::(&mut g, "chacha20"); + bench::(&mut g, "std"); + bench::(&mut g, "small"); + + g.finish() +} + +pub fn init_from_seed(c: &mut Criterion) { + let mut g = c.benchmark_group("init_from_seed"); + g.warm_up_time(Duration::from_millis(500)); + g.measurement_time(Duration::from_millis(1000)); + + fn bench(g: &mut BenchmarkGroup, name: &str) + where + rand::distr::StandardUniform: Distribution<::Seed>, + { + g.bench_function(name, |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + let seed = rng.random(); + b.iter(|| R::from_seed(black_box(seed.clone()))); + }); + } + + bench::(&mut g, "pcg32"); + bench::(&mut g, "pcg64"); + bench::(&mut g, "pcg64mcg"); + bench::(&mut g, "pcg64dxsm"); + bench::(&mut g, "chacha8"); + bench::(&mut g, "chacha12"); + bench::(&mut g, "chacha20"); + bench::(&mut g, "std"); + bench::(&mut g, "small"); + + g.finish() +} + +pub fn reseeding_bytes(c: &mut Criterion) { + let mut g = c.benchmark_group("reseeding_bytes"); + g.warm_up_time(Duration::from_millis(500)); + g.throughput(criterion::Throughput::Bytes(1024 * 1024)); + + fn bench(g: &mut BenchmarkGroup, thresh: u64) { + let name = format!("chacha20_{}k", thresh); + g.bench_function(name.as_str(), |b| { + let mut rng = ReseedingRng::::new(thresh * 1024, OsRng).unwrap(); + let mut buf = [0u8; 1024 * 1024]; + b.iter(|| { + rng.fill_bytes(&mut buf); + black_box(&buf); + }); + }); + } + + bench(&mut g, 4); + bench(&mut g, 16); + bench(&mut g, 32); + bench(&mut g, 64); + bench(&mut g, 256); + bench(&mut g, 1024); + + g.finish() +} diff --git a/benches/benches/seq_choose.rs b/benches/benches/seq_choose.rs new file mode 100644 index 00000000000..56223dd0a62 --- /dev/null +++ b/benches/benches/seq_choose.rs @@ -0,0 +1,180 @@ +// Copyright 2018-2023 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use rand::prelude::*; +use rand::SeedableRng; +use rand_pcg::Pcg32; + +criterion_group!( + name = benches; + config = Criterion::default(); + targets = bench +); +criterion_main!(benches); + +pub fn bench(c: &mut Criterion) { + c.bench_function("seq_slice_choose_1_of_100", |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + let mut buf = [0i32; 100]; + rng.fill(&mut buf); + let x = black_box(&mut buf); + + b.iter(|| x.choose(&mut rng).unwrap()); + }); + + let lens = [(1, 1000), (950, 1000), (10, 100), (90, 100)]; + for (amount, len) in lens { + let name = format!("seq_slice_choose_multiple_{}_of_{}", amount, len); + c.bench_function(name.as_str(), |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + let mut buf = [0i32; 1000]; + rng.fill(&mut buf); + let x = black_box(&buf[..len]); + + let mut results_buf = [0i32; 950]; + let y = black_box(&mut results_buf[..amount]); + let amount = black_box(amount); + + b.iter(|| { + // Collect full result to prevent unwanted shortcuts getting + // first element (in case sample_indices returns an iterator). + for (slot, sample) in y.iter_mut().zip(x.choose_multiple(&mut rng, amount)) { + *slot = *sample; + } + y[amount - 1] + }) + }); + } + + let lens = [(1, 1000), (950, 1000), (10, 100), (90, 100)]; + for (amount, len) in lens { + let name = format!("seq_slice_choose_multiple_weighted_{}_of_{}", amount, len); + c.bench_function(name.as_str(), |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + let mut buf = [0i32; 1000]; + rng.fill(&mut buf); + let x = black_box(&buf[..len]); + + let mut results_buf = [0i32; 950]; + let y = black_box(&mut results_buf[..amount]); + let amount = black_box(amount); + + b.iter(|| { + // Collect full result to prevent unwanted shortcuts getting + // first element (in case sample_indices returns an iterator). + let samples_iter = x.choose_multiple_weighted(&mut rng, amount, |_| 1.0).unwrap(); + for (slot, sample) in y.iter_mut().zip(samples_iter) { + *slot = *sample; + } + y[amount - 1] + }) + }); + } + + c.bench_function("seq_iter_choose_multiple_10_of_100", |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + let mut buf = [0i32; 100]; + rng.fill(&mut buf); + let x = black_box(&buf); + b.iter(|| x.iter().cloned().choose_multiple(&mut rng, 10)) + }); + + c.bench_function("seq_iter_choose_multiple_fill_10_of_100", |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + let mut buf = [0i32; 100]; + rng.fill(&mut buf); + let x = black_box(&buf); + let mut buf = [0; 10]; + b.iter(|| x.iter().cloned().choose_multiple_fill(&mut rng, &mut buf)) + }); + + bench_rng::(c, "ChaCha20"); + bench_rng::(c, "Pcg32"); + bench_rng::(c, "Pcg64"); +} + +fn bench_rng(c: &mut Criterion, rng_name: &'static str) { + for length in [1, 2, 3, 10, 100, 1000].map(black_box) { + let name = format!("choose_size-hinted_from_{length}_{rng_name}"); + c.bench_function(name.as_str(), |b| { + let mut rng = Rng::seed_from_u64(123); + b.iter(|| choose_size_hinted(length, &mut rng)) + }); + + let name = format!("choose_stable_from_{length}_{rng_name}"); + c.bench_function(name.as_str(), |b| { + let mut rng = Rng::seed_from_u64(123); + b.iter(|| choose_stable(length, &mut rng)) + }); + + let name = format!("choose_unhinted_from_{length}_{rng_name}"); + c.bench_function(name.as_str(), |b| { + let mut rng = Rng::seed_from_u64(123); + b.iter(|| choose_unhinted(length, &mut rng)) + }); + + let name = format!("choose_windowed_from_{length}_{rng_name}"); + c.bench_function(name.as_str(), |b| { + let mut rng = Rng::seed_from_u64(123); + b.iter(|| choose_windowed(length, 7, &mut rng)) + }); + } +} + +fn choose_size_hinted(max: usize, rng: &mut R) -> Option { + let iterator = 0..max; + iterator.choose(rng) +} + +fn choose_stable(max: usize, rng: &mut R) -> Option { + let iterator = 0..max; + iterator.choose_stable(rng) +} + +fn choose_unhinted(max: usize, rng: &mut R) -> Option { + let iterator = UnhintedIterator { iter: (0..max) }; + iterator.choose(rng) +} + +fn choose_windowed(max: usize, window_size: usize, rng: &mut R) -> Option { + let iterator = WindowHintedIterator { + iter: (0..max), + window_size, + }; + iterator.choose(rng) +} + +#[derive(Clone)] +struct UnhintedIterator { + iter: I, +} +impl Iterator for UnhintedIterator { + type Item = I::Item; + + fn next(&mut self) -> Option { + self.iter.next() + } +} + +#[derive(Clone)] +struct WindowHintedIterator { + iter: I, + window_size: usize, +} +impl Iterator for WindowHintedIterator { + type Item = I::Item; + + fn next(&mut self) -> Option { + self.iter.next() + } + + fn size_hint(&self) -> (usize, Option) { + (core::cmp::min(self.iter.len(), self.window_size), None) + } +} diff --git a/benches/shuffle.rs b/benches/benches/shuffle.rs similarity index 60% rename from benches/shuffle.rs rename to benches/benches/shuffle.rs index 4d6e31fa38c..c2f37daaeab 100644 --- a/benches/shuffle.rs +++ b/benches/benches/shuffle.rs @@ -5,18 +5,31 @@ // , at your // option. This file may not be copied, modified, or distributed // except according to those terms. + use criterion::{black_box, criterion_group, criterion_main, Criterion}; use rand::prelude::*; use rand::SeedableRng; +use rand_pcg::Pcg32; criterion_group!( -name = benches; -config = Criterion::default(); -targets = bench + name = benches; + config = Criterion::default(); + targets = bench ); criterion_main!(benches); pub fn bench(c: &mut Criterion) { + c.bench_function("seq_shuffle_100", |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + let mut buf = [0i32; 100]; + rng.fill(&mut buf); + let x = black_box(&mut buf); + b.iter(|| { + x.shuffle(&mut rng); + x[0] + }) + }); + bench_rng::(c, "ChaCha12"); bench_rng::(c, "Pcg32"); bench_rng::(c, "Pcg64"); @@ -34,17 +47,15 @@ fn bench_rng(c: &mut Criterion, rng_name: &'static s }); if length >= 10 { - c.bench_function( - format!("partial_shuffle_{length}_{rng_name}").as_str(), - |b| { - let mut rng = Rng::seed_from_u64(123); - let mut vec: Vec = (0..length).collect(); - b.iter(|| { - vec.partial_shuffle(&mut rng, length / 2); - vec[0] - }) - }, - ); + let name = format!("partial_shuffle_{length}_{rng_name}"); + c.bench_function(name.as_str(), |b| { + let mut rng = Rng::seed_from_u64(123); + let mut vec: Vec = (0..length).collect(); + b.iter(|| { + vec.partial_shuffle(&mut rng, length / 2); + vec[0] + }) + }); } } } diff --git a/benches/benches/simd.rs b/benches/benches/simd.rs new file mode 100644 index 00000000000..f1723245977 --- /dev/null +++ b/benches/benches/simd.rs @@ -0,0 +1,76 @@ +// Copyright 2018-2023 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Generating SIMD / wide types + +#![cfg_attr(feature = "simd_support", feature(portable_simd))] + +use criterion::{criterion_group, criterion_main, Criterion}; + +criterion_group!( + name = benches; + config = Criterion::default(); + targets = simd +); +criterion_main!(benches); + +#[cfg(not(feature = "simd_support"))] +pub fn simd(_: &mut Criterion) {} + +#[cfg(feature = "simd_support")] +pub fn simd(c: &mut Criterion) { + use rand::prelude::*; + use rand_pcg::Pcg64Mcg; + + let mut g = c.benchmark_group("random_simd"); + + g.bench_function("u128", |b| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + b.iter(|| rng.random::()); + }); + + g.bench_function("m128i", |b| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + b.iter(|| rng.random::()); + }); + + g.bench_function("m256i", |b| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + b.iter(|| rng.random::()); + }); + + g.bench_function("m512i", |b| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + b.iter(|| rng.random::()); + }); + + g.bench_function("u64x2", |b| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + b.iter(|| rng.random::()); + }); + + g.bench_function("u32x4", |b| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + b.iter(|| rng.random::()); + }); + + g.bench_function("u32x8", |b| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + b.iter(|| rng.random::()); + }); + + g.bench_function("u16x8", |b| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + b.iter(|| rng.random::()); + }); + + g.bench_function("u8x16", |b| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + b.iter(|| rng.random::()); + }); +} diff --git a/benches/benches/standard.rs b/benches/benches/standard.rs new file mode 100644 index 00000000000..de95fb5ba69 --- /dev/null +++ b/benches/benches/standard.rs @@ -0,0 +1,64 @@ +// Copyright 2019 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use core::time::Duration; +use criterion::measurement::WallTime; +use criterion::{criterion_group, criterion_main, BenchmarkGroup, Criterion}; +use rand::distr::{Alphabetic, Alphanumeric, Open01, OpenClosed01, StandardUniform}; +use rand::prelude::*; +use rand_pcg::Pcg64Mcg; + +criterion_group!( + name = benches; + config = Criterion::default(); + targets = bench +); +criterion_main!(benches); + +fn bench_ty(g: &mut BenchmarkGroup, name: &str) +where + D: Distribution + Default, +{ + g.throughput(criterion::Throughput::Bytes(core::mem::size_of::() as u64)); + g.bench_function(name, |b| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + + b.iter(|| rng.sample::(D::default())); + }); +} + +pub fn bench(c: &mut Criterion) { + let mut g = c.benchmark_group("StandardUniform"); + g.sample_size(1000); + g.warm_up_time(Duration::from_millis(500)); + g.measurement_time(Duration::from_millis(1000)); + + macro_rules! do_ty { + ($t:ty) => { + bench_ty::<$t, StandardUniform>(&mut g, stringify!($t)); + }; + ($t:ty, $($tt:ty),*) => { + do_ty!($t); + do_ty!($($tt),*); + }; + } + + do_ty!(i8, i16, i32, i64, i128); + do_ty!(f32, f64); + do_ty!(char); + + bench_ty::(&mut g, "Alphabetic"); + bench_ty::(&mut g, "Alphanumeric"); + + bench_ty::(&mut g, "Open01/f32"); + bench_ty::(&mut g, "Open01/f64"); + bench_ty::(&mut g, "OpenClosed01/f32"); + bench_ty::(&mut g, "OpenClosed01/f64"); + + g.finish(); +} diff --git a/benches/benches/uniform.rs b/benches/benches/uniform.rs new file mode 100644 index 00000000000..1f1ed49681d --- /dev/null +++ b/benches/benches/uniform.rs @@ -0,0 +1,126 @@ +// Copyright 2021 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Implement benchmarks for uniform distributions over integer types + +#![cfg_attr(feature = "simd_support", feature(portable_simd))] + +use core::time::Duration; +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use rand::distr::uniform::{SampleRange, Uniform}; +use rand::prelude::*; +use rand_chacha::ChaCha8Rng; +use rand_pcg::{Pcg32, Pcg64}; +#[cfg(feature = "simd_support")] +use std::simd::{num::SimdUint, Simd}; + +const WARM_UP_TIME: Duration = Duration::from_millis(1000); +const MEASUREMENT_TIME: Duration = Duration::from_secs(3); +const SAMPLE_SIZE: usize = 100_000; +const N_RESAMPLES: usize = 10_000; + +macro_rules! sample { + (@range $T:ty, $U:ty, 1, $rng:ident) => {{ + assert_eq!(<$T>::BITS, <$U>::BITS); + let bits = (<$T>::BITS / 2); + let mask = (1 as $U).wrapping_neg() >> bits; + let x = $rng.random::<$U>(); + ((x >> bits) * (x & mask)) as $T + }}; + + (@range $T:ty, $U:ty, $len:tt, $rng:ident) => {{ + let bits = (<$T>::BITS / 2); + let mask = Simd::splat((1 as $U).wrapping_neg() >> bits); + let bits = Simd::splat(bits as $U); + let x = $rng.random::>(); + ((x >> bits) * (x & mask)).cast() + }}; + + (@MIN $T:ty, 1) => { + <$T>::MIN + }; + + (@MIN $T:ty, $len:tt) => { + Simd::<$T, $len>::splat(<$T>::MIN) + }; + + (@wrapping_add $lhs:expr, $rhs:expr, 1) => { + $lhs.wrapping_add($rhs) + }; + + (@wrapping_add $lhs:expr, $rhs:expr, $len:tt) => { + ($lhs + $rhs) + }; + + ($R:ty, $T:ty, $U:ty, $len:tt, $g:expr) => { + $g.bench_function(BenchmarkId::new(stringify!($R), "single"), |b| { + let mut rng = <$R>::from_rng(&mut rand::rng()); + let range = sample!(@range $T, $U, $len, rng); + let low = sample!(@MIN $T, $len); + let high = sample!(@wrapping_add low, range, $len); + + b.iter(|| (low..=high).sample_single(&mut rng)); + }); + + $g.bench_function(BenchmarkId::new(stringify!($R), "distr"), |b| { + let mut rng = <$R>::from_rng(&mut rand::rng()); + let range = sample!(@range $T, $U, $len, rng); + let low = sample!(@MIN $T, $len); + let high = sample!(@wrapping_add low, range, $len); + let dist = Uniform::new_inclusive(low, high).unwrap(); + + b.iter(|| dist.sample(&mut rng)); + }); + }; + + // Entrypoint: + // $T is the output type (integer) + // $U is the unsigned version of the output type + // $len is the width for SIMD or 1 for non-SIMD + ($c:expr, $T:ty, $U:ty, $len:tt) => {{ + let mut g = $c.benchmark_group(concat!("sample_", stringify!($T), "x", stringify!($len))); + g.sample_size(SAMPLE_SIZE); + g.warm_up_time(WARM_UP_TIME); + g.measurement_time(MEASUREMENT_TIME); + g.nresamples(N_RESAMPLES); + sample!(SmallRng, $T, $U, $len, g); + sample!(ChaCha8Rng, $T, $U, $len, g); + sample!(Pcg32, $T, $U, $len, g); + sample!(Pcg64, $T, $U, $len, g); + g.finish(); + }}; +} + +fn sample(c: &mut Criterion) { + sample!(c, i8, u8, 1); + sample!(c, i16, u16, 1); + sample!(c, i32, u32, 1); + sample!(c, i64, u64, 1); + sample!(c, i128, u128, 1); + #[cfg(feature = "simd_support")] + sample!(c, u8, u8, 8); + #[cfg(feature = "simd_support")] + sample!(c, u8, u8, 16); + #[cfg(feature = "simd_support")] + sample!(c, u8, u8, 32); + #[cfg(feature = "simd_support")] + sample!(c, u8, u8, 64); + #[cfg(feature = "simd_support")] + sample!(c, i16, u16, 8); + #[cfg(feature = "simd_support")] + sample!(c, i16, u16, 16); + #[cfg(feature = "simd_support")] + sample!(c, i16, u16, 32); +} + +criterion_group! { + name = benches; + config = Criterion::default(); + targets = sample +} +criterion_main!(benches); diff --git a/benches/uniform_float.rs b/benches/benches/uniform_float.rs similarity index 87% rename from benches/uniform_float.rs rename to benches/benches/uniform_float.rs index 957ff1b8ecf..03a434fc228 100644 --- a/benches/uniform_float.rs +++ b/benches/benches/uniform_float.rs @@ -14,7 +14,7 @@ use core::time::Duration; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; -use rand::distributions::uniform::{SampleUniform, Uniform, UniformSampler}; +use rand::distr::uniform::{SampleUniform, Uniform, UniformSampler}; use rand::prelude::*; use rand_chacha::ChaCha8Rng; use rand_pcg::{Pcg32, Pcg64}; @@ -27,11 +27,11 @@ const N_RESAMPLES: usize = 10_000; macro_rules! single_random { ($R:ty, $T:ty, $g:expr) => { $g.bench_function(BenchmarkId::new(stringify!($T), stringify!($R)), |b| { - let mut rng = <$R>::from_rng(thread_rng()).unwrap(); + let mut rng = <$R>::from_rng(&mut rand::rng()); let (mut low, mut high); loop { - low = <$T>::from_bits(rng.gen()); - high = <$T>::from_bits(rng.gen()); + low = <$T>::from_bits(rng.random()); + high = <$T>::from_bits(rng.random()); if (low < high) && (high - low).is_normal() { break; } @@ -63,10 +63,10 @@ fn single_random(c: &mut Criterion) { macro_rules! distr_random { ($R:ty, $T:ty, $g:expr) => { $g.bench_function(BenchmarkId::new(stringify!($T), stringify!($R)), |b| { - let mut rng = <$R>::from_rng(thread_rng()).unwrap(); + let mut rng = <$R>::from_rng(&mut rand::rng()); let dist = loop { - let low = <$T>::from_bits(rng.gen()); - let high = <$T>::from_bits(rng.gen()); + let low = <$T>::from_bits(rng.random()); + let high = <$T>::from_bits(rng.random()); if let Ok(dist) = Uniform::<$T>::new_inclusive(low, high) { break dist; } diff --git a/benches/benches/weighted.rs b/benches/benches/weighted.rs new file mode 100644 index 00000000000..69576b3608d --- /dev/null +++ b/benches/benches/weighted.rs @@ -0,0 +1,60 @@ +// Copyright 2019 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use rand::distr::weighted::WeightedIndex; +use rand::prelude::*; +use rand::seq::index::sample_weighted; + +criterion_group!( + name = benches; + config = Criterion::default(); + targets = bench +); +criterion_main!(benches); + +pub fn bench(c: &mut Criterion) { + c.bench_function("weighted_index_creation", |b| { + let mut rng = rand::rng(); + let weights = black_box([1u32, 2, 4, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 7]); + b.iter(|| { + let distr = WeightedIndex::new(weights.to_vec()).unwrap(); + rng.sample(distr) + }) + }); + + c.bench_function("weighted_index_modification", |b| { + let mut rng = rand::rng(); + let weights = black_box([1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]); + let mut distr = WeightedIndex::new(weights.to_vec()).unwrap(); + b.iter(|| { + distr.update_weights(&[(2, &4), (5, &1)]).unwrap(); + rng.sample(&distr) + }) + }); + + let lens = [ + (1, 1000, "1k"), + (10, 1000, "1k"), + (100, 1000, "1k"), + (100, 1_000_000, "1M"), + (200, 1_000_000, "1M"), + (400, 1_000_000, "1M"), + (600, 1_000_000, "1M"), + (1000, 1_000_000, "1M"), + ]; + for (amount, length, len_name) in lens { + let name = format!("weighted_sample_indices_{}_of_{}", amount, len_name); + c.bench_function(name.as_str(), |b| { + let length = black_box(length); + let amount = black_box(amount); + let mut rng = SmallRng::from_rng(&mut rand::rng()); + b.iter(|| sample_weighted(&mut rng, length, |idx| (1 + (idx % 100)) as u32, amount)) + }); + } +} diff --git a/benches/distributions.rs b/benches/distributions.rs deleted file mode 100644 index f637fe4ae47..00000000000 --- a/benches/distributions.rs +++ /dev/null @@ -1,440 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -#![feature(custom_inner_attributes)] -#![feature(test)] - -// Rustfmt splits macro invocations to shorten lines; in this case longer-lines are more readable -#![rustfmt::skip] - -extern crate test; - -const RAND_BENCH_N: u64 = 1000; - -use rand::distributions::{Alphanumeric, Open01, OpenClosed01, Standard, Uniform}; -use rand::distributions::uniform::{UniformInt, UniformSampler}; -use core::mem::size_of; -use core::num::{NonZeroU128, NonZeroU16, NonZeroU32, NonZeroU64, NonZeroU8}; -use core::time::Duration; -use test::{Bencher, black_box}; - -use rand::prelude::*; - -// At this time, distributions are optimised for 64-bit platforms. -use rand_pcg::Pcg64Mcg; - -macro_rules! distr_int { - ($fnn:ident, $ty:ty, $distr:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_entropy(); - let distr = $distr; - - b.iter(|| { - let mut accum = 0 as $ty; - for _ in 0..RAND_BENCH_N { - let x: $ty = distr.sample(&mut rng); - accum = accum.wrapping_add(x); - } - accum - }); - b.bytes = size_of::<$ty>() as u64 * RAND_BENCH_N; - } - }; -} - -macro_rules! distr_nz_int { - ($fnn:ident, $tynz:ty, $ty:ty, $distr:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_entropy(); - let distr = $distr; - - b.iter(|| { - let mut accum = 0 as $ty; - for _ in 0..RAND_BENCH_N { - let x: $tynz = distr.sample(&mut rng); - accum = accum.wrapping_add(x.get()); - } - accum - }); - b.bytes = size_of::<$ty>() as u64 * RAND_BENCH_N; - } - }; -} - -macro_rules! distr_float { - ($fnn:ident, $ty:ty, $distr:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_entropy(); - let distr = $distr; - - b.iter(|| { - let mut accum = 0.0; - for _ in 0..RAND_BENCH_N { - let x: $ty = distr.sample(&mut rng); - accum += x; - } - accum - }); - b.bytes = size_of::<$ty>() as u64 * RAND_BENCH_N; - } - }; -} - -macro_rules! distr_duration { - ($fnn:ident, $distr:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_entropy(); - let distr = $distr; - - b.iter(|| { - let mut accum = Duration::new(0, 0); - for _ in 0..RAND_BENCH_N { - let x: Duration = distr.sample(&mut rng); - accum = accum - .checked_add(x) - .unwrap_or(Duration::new(u64::max_value(), 999_999_999)); - } - accum - }); - b.bytes = size_of::() as u64 * RAND_BENCH_N; - } - }; -} - -macro_rules! distr { - ($fnn:ident, $ty:ty, $distr:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_entropy(); - let distr = $distr; - - b.iter(|| { - let mut accum = 0u32; - for _ in 0..RAND_BENCH_N { - let x: $ty = distr.sample(&mut rng); - accum = accum.wrapping_add(x as u32); - } - accum - }); - b.bytes = size_of::<$ty>() as u64 * RAND_BENCH_N; - } - }; -} - -// uniform -distr_int!(distr_uniform_i8, i8, Uniform::new(20i8, 100).unwrap()); -distr_int!(distr_uniform_i16, i16, Uniform::new(-500i16, 2000).unwrap()); -distr_int!(distr_uniform_i32, i32, Uniform::new(-200_000_000i32, 800_000_000).unwrap()); -distr_int!(distr_uniform_i64, i64, Uniform::new(3i64, 123_456_789_123).unwrap()); -distr_int!(distr_uniform_i128, i128, Uniform::new(-123_456_789_123i128, 123_456_789_123_456_789).unwrap()); -distr_int!(distr_uniform_usize16, usize, Uniform::new(0usize, 0xb9d7).unwrap()); -distr_int!(distr_uniform_usize32, usize, Uniform::new(0usize, 0x548c0f43).unwrap()); -#[cfg(target_pointer_width = "64")] -distr_int!(distr_uniform_usize64, usize, Uniform::new(0usize, 0x3a42714f2bf927a8).unwrap()); -distr_int!(distr_uniform_isize, isize, Uniform::new(-1060478432isize, 1858574057).unwrap()); - -distr_float!(distr_uniform_f32, f32, Uniform::new(2.26f32, 2.319).unwrap()); -distr_float!(distr_uniform_f64, f64, Uniform::new(2.26f64, 2.319).unwrap()); - -const LARGE_SEC: u64 = u64::max_value() / 1000; - -distr_duration!(distr_uniform_duration_largest, - Uniform::new_inclusive(Duration::new(0, 0), Duration::new(u64::max_value(), 999_999_999)).unwrap() -); -distr_duration!(distr_uniform_duration_large, - Uniform::new(Duration::new(0, 0), Duration::new(LARGE_SEC, 1_000_000_000 / 2)).unwrap() -); -distr_duration!(distr_uniform_duration_one, - Uniform::new(Duration::new(0, 0), Duration::new(1, 0)).unwrap() -); -distr_duration!(distr_uniform_duration_variety, - Uniform::new(Duration::new(10000, 423423), Duration::new(200000, 6969954)).unwrap() -); -distr_duration!(distr_uniform_duration_edge, - Uniform::new_inclusive(Duration::new(LARGE_SEC, 999_999_999), Duration::new(LARGE_SEC + 1, 1)).unwrap() -); - -// standard -distr_int!(distr_standard_i8, i8, Standard); -distr_int!(distr_standard_i16, i16, Standard); -distr_int!(distr_standard_i32, i32, Standard); -distr_int!(distr_standard_i64, i64, Standard); -distr_int!(distr_standard_i128, i128, Standard); -distr_nz_int!(distr_standard_nz8, NonZeroU8, u8, Standard); -distr_nz_int!(distr_standard_nz16, NonZeroU16, u16, Standard); -distr_nz_int!(distr_standard_nz32, NonZeroU32, u32, Standard); -distr_nz_int!(distr_standard_nz64, NonZeroU64, u64, Standard); -distr_nz_int!(distr_standard_nz128, NonZeroU128, u128, Standard); - -distr!(distr_standard_bool, bool, Standard); -distr!(distr_standard_alphanumeric, u8, Alphanumeric); -distr!(distr_standard_codepoint, char, Standard); - -distr_float!(distr_standard_f32, f32, Standard); -distr_float!(distr_standard_f64, f64, Standard); -distr_float!(distr_open01_f32, f32, Open01); -distr_float!(distr_open01_f64, f64, Open01); -distr_float!(distr_openclosed01_f32, f32, OpenClosed01); -distr_float!(distr_openclosed01_f64, f64, OpenClosed01); - -// construct and sample from a range -macro_rules! gen_range_int { - ($fnn:ident, $ty:ident, $low:expr, $high:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_entropy(); - - b.iter(|| { - let mut high = $high; - let mut accum: $ty = 0; - for _ in 0..RAND_BENCH_N { - accum = accum.wrapping_add(rng.gen_range($low..high)); - // force recalculation of range each time - high = high.wrapping_add(1) & core::$ty::MAX; - } - accum - }); - b.bytes = size_of::<$ty>() as u64 * RAND_BENCH_N; - } - }; -} - -// Algorithms such as Fisher–Yates shuffle often require uniform values from an -// incrementing range 0..n. We use -1..n here to prevent wrapping in the test -// from generating a 0-sized range. -gen_range_int!(gen_range_i8_low, i8, -1i8, 0); -gen_range_int!(gen_range_i16_low, i16, -1i16, 0); -gen_range_int!(gen_range_i32_low, i32, -1i32, 0); -gen_range_int!(gen_range_i64_low, i64, -1i64, 0); -gen_range_int!(gen_range_i128_low, i128, -1i128, 0); - -// These were the initially tested ranges. They are likely to see fewer -// rejections than the low tests. -gen_range_int!(gen_range_i8_high, i8, -20i8, 100); -gen_range_int!(gen_range_i16_high, i16, -500i16, 2000); -gen_range_int!(gen_range_i32_high, i32, -200_000_000i32, 800_000_000); -gen_range_int!(gen_range_i64_high, i64, 3i64, 123_456_789_123); -gen_range_int!(gen_range_i128_high, i128, -12345678901234i128, 123_456_789_123_456_789); - -// construct and sample from a floating-point range -macro_rules! gen_range_float { - ($fnn:ident, $ty:ident, $low:expr, $high:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_entropy(); - - b.iter(|| { - let mut high = $high; - let mut low = $low; - let mut accum: $ty = 0.0; - for _ in 0..RAND_BENCH_N { - accum += rng.gen_range(low..high); - // force recalculation of range each time - low += 0.9; - high += 1.1; - } - accum - }); - b.bytes = size_of::<$ty>() as u64 * RAND_BENCH_N; - } - }; -} - -gen_range_float!(gen_range_f32, f32, -20000.0f32, 100000.0); -gen_range_float!(gen_range_f64, f64, 123.456f64, 7890.12); - - -// In src/distributions/uniform.rs, we say: -// Implementation of [`uniform_single`] is optional, and is only useful when -// the implementation can be faster than `Self::new(low, high).sample(rng)`. - -// `UniformSampler::uniform_single` compromises on the rejection range to be -// faster. This benchmark demonstrates both the speed gain of doing this, and -// the worst case behavior. - -/// Sample random values from a pre-existing distribution. This uses the -/// half open `new` to be equivalent to the behavior of `uniform_single`. -macro_rules! uniform_sample { - ($fnn:ident, $type:ident, $low:expr, $high:expr, $count:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_entropy(); - let low = black_box($low); - let high = black_box($high); - b.iter(|| { - for _ in 0..10 { - let dist = UniformInt::<$type>::new(low, high).unwrap(); - for _ in 0..$count { - black_box(dist.sample(&mut rng)); - } - } - }); - } - }; -} - -macro_rules! uniform_inclusive { - ($fnn:ident, $type:ident, $low:expr, $high:expr, $count:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_entropy(); - let low = black_box($low); - let high = black_box($high); - b.iter(|| { - for _ in 0..10 { - let dist = UniformInt::<$type>::new_inclusive(low, high).unwrap(); - for _ in 0..$count { - black_box(dist.sample(&mut rng)); - } - } - }); - } - }; -} - -/// Use `uniform_single` to create a one-off random value -macro_rules! uniform_single { - ($fnn:ident, $type:ident, $low:expr, $high:expr, $count:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_entropy(); - let low = black_box($low); - let high = black_box($high); - b.iter(|| { - for _ in 0..(10 * $count) { - black_box(UniformInt::<$type>::sample_single(low, high, &mut rng).unwrap()); - } - }); - } - }; -} - - -// Benchmark: -// n: can use the full generated range -// (n-1): only the max value is rejected: expect this to be fast -// n/2+1: almost half of the values are rejected, and we can do no better -// n/2: approximation rejects half the values but powers of 2 could have no rejection -// n/2-1: only a few values are rejected: expect this to be fast -// 6: approximation rejects 25% of values but could be faster. However modulo by -// low numbers is typically more expensive - -// With the use of u32 as the minimum generated width, the worst-case u16 range -// (32769) will only reject 32769 / 4294967296 samples. -const HALF_16_BIT_UNSIGNED: u16 = 1 << 15; - -uniform_sample!(uniform_u16x1_allm1_new, u16, 0, u16::max_value(), 1); -uniform_sample!(uniform_u16x1_halfp1_new, u16, 0, HALF_16_BIT_UNSIGNED + 1, 1); -uniform_sample!(uniform_u16x1_half_new, u16, 0, HALF_16_BIT_UNSIGNED, 1); -uniform_sample!(uniform_u16x1_halfm1_new, u16, 0, HALF_16_BIT_UNSIGNED - 1, 1); -uniform_sample!(uniform_u16x1_6_new, u16, 0, 6u16, 1); - -uniform_single!(uniform_u16x1_allm1_single, u16, 0, u16::max_value(), 1); -uniform_single!(uniform_u16x1_halfp1_single, u16, 0, HALF_16_BIT_UNSIGNED + 1, 1); -uniform_single!(uniform_u16x1_half_single, u16, 0, HALF_16_BIT_UNSIGNED, 1); -uniform_single!(uniform_u16x1_halfm1_single, u16, 0, HALF_16_BIT_UNSIGNED - 1, 1); -uniform_single!(uniform_u16x1_6_single, u16, 0, 6u16, 1); - -uniform_inclusive!(uniform_u16x10_all_new_inclusive, u16, 0, u16::max_value(), 10); -uniform_sample!(uniform_u16x10_allm1_new, u16, 0, u16::max_value(), 10); -uniform_sample!(uniform_u16x10_halfp1_new, u16, 0, HALF_16_BIT_UNSIGNED + 1, 10); -uniform_sample!(uniform_u16x10_half_new, u16, 0, HALF_16_BIT_UNSIGNED, 10); -uniform_sample!(uniform_u16x10_halfm1_new, u16, 0, HALF_16_BIT_UNSIGNED - 1, 10); -uniform_sample!(uniform_u16x10_6_new, u16, 0, 6u16, 10); - -uniform_single!(uniform_u16x10_allm1_single, u16, 0, u16::max_value(), 10); -uniform_single!(uniform_u16x10_halfp1_single, u16, 0, HALF_16_BIT_UNSIGNED + 1, 10); -uniform_single!(uniform_u16x10_half_single, u16, 0, HALF_16_BIT_UNSIGNED, 10); -uniform_single!(uniform_u16x10_halfm1_single, u16, 0, HALF_16_BIT_UNSIGNED - 1, 10); -uniform_single!(uniform_u16x10_6_single, u16, 0, 6u16, 10); - - -const HALF_32_BIT_UNSIGNED: u32 = 1 << 31; - -uniform_sample!(uniform_u32x1_allm1_new, u32, 0, u32::max_value(), 1); -uniform_sample!(uniform_u32x1_halfp1_new, u32, 0, HALF_32_BIT_UNSIGNED + 1, 1); -uniform_sample!(uniform_u32x1_half_new, u32, 0, HALF_32_BIT_UNSIGNED, 1); -uniform_sample!(uniform_u32x1_halfm1_new, u32, 0, HALF_32_BIT_UNSIGNED - 1, 1); -uniform_sample!(uniform_u32x1_6_new, u32, 0, 6u32, 1); - -uniform_single!(uniform_u32x1_allm1_single, u32, 0, u32::max_value(), 1); -uniform_single!(uniform_u32x1_halfp1_single, u32, 0, HALF_32_BIT_UNSIGNED + 1, 1); -uniform_single!(uniform_u32x1_half_single, u32, 0, HALF_32_BIT_UNSIGNED, 1); -uniform_single!(uniform_u32x1_halfm1_single, u32, 0, HALF_32_BIT_UNSIGNED - 1, 1); -uniform_single!(uniform_u32x1_6_single, u32, 0, 6u32, 1); - -uniform_inclusive!(uniform_u32x10_all_new_inclusive, u32, 0, u32::max_value(), 10); -uniform_sample!(uniform_u32x10_allm1_new, u32, 0, u32::max_value(), 10); -uniform_sample!(uniform_u32x10_halfp1_new, u32, 0, HALF_32_BIT_UNSIGNED + 1, 10); -uniform_sample!(uniform_u32x10_half_new, u32, 0, HALF_32_BIT_UNSIGNED, 10); -uniform_sample!(uniform_u32x10_halfm1_new, u32, 0, HALF_32_BIT_UNSIGNED - 1, 10); -uniform_sample!(uniform_u32x10_6_new, u32, 0, 6u32, 10); - -uniform_single!(uniform_u32x10_allm1_single, u32, 0, u32::max_value(), 10); -uniform_single!(uniform_u32x10_halfp1_single, u32, 0, HALF_32_BIT_UNSIGNED + 1, 10); -uniform_single!(uniform_u32x10_half_single, u32, 0, HALF_32_BIT_UNSIGNED, 10); -uniform_single!(uniform_u32x10_halfm1_single, u32, 0, HALF_32_BIT_UNSIGNED - 1, 10); -uniform_single!(uniform_u32x10_6_single, u32, 0, 6u32, 10); - -const HALF_64_BIT_UNSIGNED: u64 = 1 << 63; - -uniform_sample!(uniform_u64x1_allm1_new, u64, 0, u64::max_value(), 1); -uniform_sample!(uniform_u64x1_halfp1_new, u64, 0, HALF_64_BIT_UNSIGNED + 1, 1); -uniform_sample!(uniform_u64x1_half_new, u64, 0, HALF_64_BIT_UNSIGNED, 1); -uniform_sample!(uniform_u64x1_halfm1_new, u64, 0, HALF_64_BIT_UNSIGNED - 1, 1); -uniform_sample!(uniform_u64x1_6_new, u64, 0, 6u64, 1); - -uniform_single!(uniform_u64x1_allm1_single, u64, 0, u64::max_value(), 1); -uniform_single!(uniform_u64x1_halfp1_single, u64, 0, HALF_64_BIT_UNSIGNED + 1, 1); -uniform_single!(uniform_u64x1_half_single, u64, 0, HALF_64_BIT_UNSIGNED, 1); -uniform_single!(uniform_u64x1_halfm1_single, u64, 0, HALF_64_BIT_UNSIGNED - 1, 1); -uniform_single!(uniform_u64x1_6_single, u64, 0, 6u64, 1); - -uniform_inclusive!(uniform_u64x10_all_new_inclusive, u64, 0, u64::max_value(), 10); -uniform_sample!(uniform_u64x10_allm1_new, u64, 0, u64::max_value(), 10); -uniform_sample!(uniform_u64x10_halfp1_new, u64, 0, HALF_64_BIT_UNSIGNED + 1, 10); -uniform_sample!(uniform_u64x10_half_new, u64, 0, HALF_64_BIT_UNSIGNED, 10); -uniform_sample!(uniform_u64x10_halfm1_new, u64, 0, HALF_64_BIT_UNSIGNED - 1, 10); -uniform_sample!(uniform_u64x10_6_new, u64, 0, 6u64, 10); - -uniform_single!(uniform_u64x10_allm1_single, u64, 0, u64::max_value(), 10); -uniform_single!(uniform_u64x10_halfp1_single, u64, 0, HALF_64_BIT_UNSIGNED + 1, 10); -uniform_single!(uniform_u64x10_half_single, u64, 0, HALF_64_BIT_UNSIGNED, 10); -uniform_single!(uniform_u64x10_halfm1_single, u64, 0, HALF_64_BIT_UNSIGNED - 1, 10); -uniform_single!(uniform_u64x10_6_single, u64, 0, 6u64, 10); - -const HALF_128_BIT_UNSIGNED: u128 = 1 << 127; - -uniform_sample!(uniform_u128x1_allm1_new, u128, 0, u128::max_value(), 1); -uniform_sample!(uniform_u128x1_halfp1_new, u128, 0, HALF_128_BIT_UNSIGNED + 1, 1); -uniform_sample!(uniform_u128x1_half_new, u128, 0, HALF_128_BIT_UNSIGNED, 1); -uniform_sample!(uniform_u128x1_halfm1_new, u128, 0, HALF_128_BIT_UNSIGNED - 1, 1); -uniform_sample!(uniform_u128x1_6_new, u128, 0, 6u128, 1); - -uniform_single!(uniform_u128x1_allm1_single, u128, 0, u128::max_value(), 1); -uniform_single!(uniform_u128x1_halfp1_single, u128, 0, HALF_128_BIT_UNSIGNED + 1, 1); -uniform_single!(uniform_u128x1_half_single, u128, 0, HALF_128_BIT_UNSIGNED, 1); -uniform_single!(uniform_u128x1_halfm1_single, u128, 0, HALF_128_BIT_UNSIGNED - 1, 1); -uniform_single!(uniform_u128x1_6_single, u128, 0, 6u128, 1); - -uniform_inclusive!(uniform_u128x10_all_new_inclusive, u128, 0, u128::max_value(), 10); -uniform_sample!(uniform_u128x10_allm1_new, u128, 0, u128::max_value(), 10); -uniform_sample!(uniform_u128x10_halfp1_new, u128, 0, HALF_128_BIT_UNSIGNED + 1, 10); -uniform_sample!(uniform_u128x10_half_new, u128, 0, HALF_128_BIT_UNSIGNED, 10); -uniform_sample!(uniform_u128x10_halfm1_new, u128, 0, HALF_128_BIT_UNSIGNED - 1, 10); -uniform_sample!(uniform_u128x10_6_new, u128, 0, 6u128, 10); - -uniform_single!(uniform_u128x10_allm1_single, u128, 0, u128::max_value(), 10); -uniform_single!(uniform_u128x10_halfp1_single, u128, 0, HALF_128_BIT_UNSIGNED + 1, 10); -uniform_single!(uniform_u128x10_half_single, u128, 0, HALF_128_BIT_UNSIGNED, 10); -uniform_single!(uniform_u128x10_halfm1_single, u128, 0, HALF_128_BIT_UNSIGNED - 1, 10); -uniform_single!(uniform_u128x10_6_single, u128, 0, 6u128, 10); diff --git a/benches/generators.rs b/benches/generators.rs deleted file mode 100644 index 12b8460f0b5..00000000000 --- a/benches/generators.rs +++ /dev/null @@ -1,164 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -#![feature(test)] -#![allow(non_snake_case)] - -extern crate test; - -const RAND_BENCH_N: u64 = 1000; -const BYTES_LEN: usize = 1024; - -use core::mem::size_of; -use test::{black_box, Bencher}; - -use rand::prelude::*; -use rand::rngs::adapter::ReseedingRng; -use rand::rngs::{mock::StepRng, OsRng}; -use rand_chacha::{ChaCha12Rng, ChaCha20Core, ChaCha20Rng, ChaCha8Rng}; -use rand_pcg::{Pcg32, Pcg64, Pcg64Mcg, Pcg64Dxsm}; - -macro_rules! gen_bytes { - ($fnn:ident, $gen:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = $gen; - let mut buf = [0u8; BYTES_LEN]; - b.iter(|| { - for _ in 0..RAND_BENCH_N { - rng.fill_bytes(&mut buf); - black_box(buf); - } - }); - b.bytes = BYTES_LEN as u64 * RAND_BENCH_N; - } - }; -} - -gen_bytes!(gen_bytes_step, StepRng::new(0, 1)); -gen_bytes!(gen_bytes_pcg32, Pcg32::from_entropy()); -gen_bytes!(gen_bytes_pcg64, Pcg64::from_entropy()); -gen_bytes!(gen_bytes_pcg64mcg, Pcg64Mcg::from_entropy()); -gen_bytes!(gen_bytes_pcg64dxsm, Pcg64Dxsm::from_entropy()); -gen_bytes!(gen_bytes_chacha8, ChaCha8Rng::from_entropy()); -gen_bytes!(gen_bytes_chacha12, ChaCha12Rng::from_entropy()); -gen_bytes!(gen_bytes_chacha20, ChaCha20Rng::from_entropy()); -gen_bytes!(gen_bytes_std, StdRng::from_entropy()); -#[cfg(feature = "small_rng")] -gen_bytes!(gen_bytes_small, SmallRng::from_thread_rng()); -gen_bytes!(gen_bytes_os, OsRng); - -macro_rules! gen_uint { - ($fnn:ident, $ty:ty, $gen:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = $gen; - b.iter(|| { - let mut accum: $ty = 0; - for _ in 0..RAND_BENCH_N { - accum = accum.wrapping_add(rng.gen::<$ty>()); - } - accum - }); - b.bytes = size_of::<$ty>() as u64 * RAND_BENCH_N; - } - }; -} - -gen_uint!(gen_u32_step, u32, StepRng::new(0, 1)); -gen_uint!(gen_u32_pcg32, u32, Pcg32::from_entropy()); -gen_uint!(gen_u32_pcg64, u32, Pcg64::from_entropy()); -gen_uint!(gen_u32_pcg64mcg, u32, Pcg64Mcg::from_entropy()); -gen_uint!(gen_u32_pcg64dxsm, u32, Pcg64Dxsm::from_entropy()); -gen_uint!(gen_u32_chacha8, u32, ChaCha8Rng::from_entropy()); -gen_uint!(gen_u32_chacha12, u32, ChaCha12Rng::from_entropy()); -gen_uint!(gen_u32_chacha20, u32, ChaCha20Rng::from_entropy()); -gen_uint!(gen_u32_std, u32, StdRng::from_entropy()); -#[cfg(feature = "small_rng")] -gen_uint!(gen_u32_small, u32, SmallRng::from_thread_rng()); -gen_uint!(gen_u32_os, u32, OsRng); - -gen_uint!(gen_u64_step, u64, StepRng::new(0, 1)); -gen_uint!(gen_u64_pcg32, u64, Pcg32::from_entropy()); -gen_uint!(gen_u64_pcg64, u64, Pcg64::from_entropy()); -gen_uint!(gen_u64_pcg64mcg, u64, Pcg64Mcg::from_entropy()); -gen_uint!(gen_u64_pcg64dxsm, u64, Pcg64Dxsm::from_entropy()); -gen_uint!(gen_u64_chacha8, u64, ChaCha8Rng::from_entropy()); -gen_uint!(gen_u64_chacha12, u64, ChaCha12Rng::from_entropy()); -gen_uint!(gen_u64_chacha20, u64, ChaCha20Rng::from_entropy()); -gen_uint!(gen_u64_std, u64, StdRng::from_entropy()); -#[cfg(feature = "small_rng")] -gen_uint!(gen_u64_small, u64, SmallRng::from_thread_rng()); -gen_uint!(gen_u64_os, u64, OsRng); - -macro_rules! init_gen { - ($fnn:ident, $gen:ident) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = Pcg32::from_entropy(); - b.iter(|| { - let r2 = $gen::from_rng(&mut rng).unwrap(); - r2 - }); - } - }; -} - -init_gen!(init_pcg32, Pcg32); -init_gen!(init_pcg64, Pcg64); -init_gen!(init_pcg64mcg, Pcg64Mcg); -init_gen!(init_pcg64dxsm, Pcg64Dxsm); -init_gen!(init_chacha, ChaCha20Rng); - -const RESEEDING_BYTES_LEN: usize = 1024 * 1024; -const RESEEDING_BENCH_N: u64 = 16; - -macro_rules! reseeding_bytes { - ($fnn:ident, $thresh:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = ReseedingRng::new(ChaCha20Core::from_entropy(), $thresh * 1024, OsRng); - let mut buf = [0u8; RESEEDING_BYTES_LEN]; - b.iter(|| { - for _ in 0..RESEEDING_BENCH_N { - rng.fill_bytes(&mut buf); - black_box(&buf); - } - }); - b.bytes = RESEEDING_BYTES_LEN as u64 * RESEEDING_BENCH_N; - } - }; -} - -reseeding_bytes!(reseeding_chacha20_4k, 4); -reseeding_bytes!(reseeding_chacha20_16k, 16); -reseeding_bytes!(reseeding_chacha20_32k, 32); -reseeding_bytes!(reseeding_chacha20_64k, 64); -reseeding_bytes!(reseeding_chacha20_256k, 256); -reseeding_bytes!(reseeding_chacha20_1M, 1024); - - -macro_rules! threadrng_uint { - ($fnn:ident, $ty:ty) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = thread_rng(); - b.iter(|| { - let mut accum: $ty = 0; - for _ in 0..RAND_BENCH_N { - accum = accum.wrapping_add(rng.gen::<$ty>()); - } - accum - }); - b.bytes = size_of::<$ty>() as u64 * RAND_BENCH_N; - } - }; -} - -threadrng_uint!(thread_rng_u32, u32); -threadrng_uint!(thread_rng_u64, u64); diff --git a/benches/misc.rs b/benches/misc.rs deleted file mode 100644 index f0b761f99ed..00000000000 --- a/benches/misc.rs +++ /dev/null @@ -1,183 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -#![feature(test)] - -extern crate test; - -const RAND_BENCH_N: u64 = 1000; - -use test::Bencher; - -use rand::distributions::{Bernoulli, Distribution, Standard}; -use rand::prelude::*; -use rand_pcg::{Pcg32, Pcg64Mcg}; - -#[bench] -fn misc_gen_bool_const(b: &mut Bencher) { - let mut rng = Pcg32::from_rng(&mut thread_rng()).unwrap(); - b.iter(|| { - let mut accum = true; - for _ in 0..crate::RAND_BENCH_N { - accum ^= rng.gen_bool(0.18); - } - accum - }) -} - -#[bench] -fn misc_gen_bool_var(b: &mut Bencher) { - let mut rng = Pcg32::from_rng(&mut thread_rng()).unwrap(); - b.iter(|| { - let mut accum = true; - let mut p = 0.18; - for _ in 0..crate::RAND_BENCH_N { - accum ^= rng.gen_bool(p); - p += 0.0001; - } - accum - }) -} - -#[bench] -fn misc_gen_ratio_const(b: &mut Bencher) { - let mut rng = Pcg32::from_rng(&mut thread_rng()).unwrap(); - b.iter(|| { - let mut accum = true; - for _ in 0..crate::RAND_BENCH_N { - accum ^= rng.gen_ratio(2, 3); - } - accum - }) -} - -#[bench] -fn misc_gen_ratio_var(b: &mut Bencher) { - let mut rng = Pcg32::from_rng(&mut thread_rng()).unwrap(); - b.iter(|| { - let mut accum = true; - for i in 2..(crate::RAND_BENCH_N as u32 + 2) { - accum ^= rng.gen_ratio(i, i + 1); - } - accum - }) -} - -#[bench] -fn misc_bernoulli_const(b: &mut Bencher) { - let mut rng = Pcg32::from_rng(&mut thread_rng()).unwrap(); - b.iter(|| { - let d = rand::distributions::Bernoulli::new(0.18).unwrap(); - let mut accum = true; - for _ in 0..crate::RAND_BENCH_N { - accum ^= rng.sample(d); - } - accum - }) -} - -#[bench] -fn misc_bernoulli_var(b: &mut Bencher) { - let mut rng = Pcg32::from_rng(&mut thread_rng()).unwrap(); - b.iter(|| { - let mut accum = true; - let mut p = 0.18; - for _ in 0..crate::RAND_BENCH_N { - let d = Bernoulli::new(p).unwrap(); - accum ^= rng.sample(d); - p += 0.0001; - } - accum - }) -} - -#[bench] -fn gen_1kb_u16_iter_repeat(b: &mut Bencher) { - use core::iter; - let mut rng = Pcg64Mcg::from_rng(&mut thread_rng()).unwrap(); - b.iter(|| { - let v: Vec = iter::repeat(()).map(|()| rng.gen()).take(512).collect(); - v - }); - b.bytes = 1024; -} - -#[bench] -fn gen_1kb_u16_sample_iter(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_rng(&mut thread_rng()).unwrap(); - b.iter(|| { - let v: Vec = Standard.sample_iter(&mut rng).take(512).collect(); - v - }); - b.bytes = 1024; -} - -#[bench] -fn gen_1kb_u16_gen_array(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_rng(&mut thread_rng()).unwrap(); - b.iter(|| { - // max supported array length is 32! - let v: [[u16; 32]; 16] = rng.gen(); - v - }); - b.bytes = 1024; -} - -#[bench] -fn gen_1kb_u16_fill(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_rng(&mut thread_rng()).unwrap(); - let mut buf = [0u16; 512]; - b.iter(|| { - rng.fill(&mut buf[..]); - buf - }); - b.bytes = 1024; -} - -#[bench] -fn gen_1kb_u64_iter_repeat(b: &mut Bencher) { - use core::iter; - let mut rng = Pcg64Mcg::from_rng(&mut thread_rng()).unwrap(); - b.iter(|| { - let v: Vec = iter::repeat(()).map(|()| rng.gen()).take(128).collect(); - v - }); - b.bytes = 1024; -} - -#[bench] -fn gen_1kb_u64_sample_iter(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_rng(&mut thread_rng()).unwrap(); - b.iter(|| { - let v: Vec = Standard.sample_iter(&mut rng).take(128).collect(); - v - }); - b.bytes = 1024; -} - -#[bench] -fn gen_1kb_u64_gen_array(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_rng(&mut thread_rng()).unwrap(); - b.iter(|| { - // max supported array length is 32! - let v: [[u64; 32]; 4] = rng.gen(); - v - }); - b.bytes = 1024; -} - -#[bench] -fn gen_1kb_u64_fill(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_rng(&mut thread_rng()).unwrap(); - let mut buf = [0u64; 128]; - b.iter(|| { - rng.fill(&mut buf[..]); - buf - }); - b.bytes = 1024; -} diff --git a/benches/rustfmt.toml b/benches/rustfmt.toml new file mode 100644 index 00000000000..b64fd7ad0e6 --- /dev/null +++ b/benches/rustfmt.toml @@ -0,0 +1,2 @@ +max_width = 120 +fn_call_width = 108 diff --git a/benches/seq.rs b/benches/seq.rs deleted file mode 100644 index 3d57d4872e6..00000000000 --- a/benches/seq.rs +++ /dev/null @@ -1,130 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -#![feature(test)] -#![allow(non_snake_case)] - -extern crate test; - -use test::Bencher; - -use core::mem::size_of; -use rand::prelude::*; -use rand::seq::*; - -// We force use of 32-bit RNG since seq code is optimised for use with 32-bit -// generators on all platforms. -use rand_pcg::Pcg32 as SmallRng; - -const RAND_BENCH_N: u64 = 1000; - -#[bench] -fn seq_shuffle_100(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - let x: &mut [usize] = &mut [1; 100]; - b.iter(|| { - x.shuffle(&mut rng); - x[0] - }) -} - -#[bench] -fn seq_slice_choose_1_of_1000(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - let x: &mut [usize] = &mut [1; 1000]; - for (i, r) in x.iter_mut().enumerate() { - *r = i; - } - b.iter(|| { - let mut s = 0; - for _ in 0..RAND_BENCH_N { - s += x.choose(&mut rng).unwrap(); - } - s - }); - b.bytes = size_of::() as u64 * crate::RAND_BENCH_N; -} - -macro_rules! seq_slice_choose_multiple { - ($name:ident, $amount:expr, $length:expr) => { - #[bench] - fn $name(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - let x: &[i32] = &[$amount; $length]; - let mut result = [0i32; $amount]; - b.iter(|| { - // Collect full result to prevent unwanted shortcuts getting - // first element (in case sample_indices returns an iterator). - for (slot, sample) in result.iter_mut().zip(x.choose_multiple(&mut rng, $amount)) { - *slot = *sample; - } - result[$amount - 1] - }) - } - }; -} - -seq_slice_choose_multiple!(seq_slice_choose_multiple_1_of_1000, 1, 1000); -seq_slice_choose_multiple!(seq_slice_choose_multiple_950_of_1000, 950, 1000); -seq_slice_choose_multiple!(seq_slice_choose_multiple_10_of_100, 10, 100); -seq_slice_choose_multiple!(seq_slice_choose_multiple_90_of_100, 90, 100); - -#[bench] -fn seq_iter_choose_multiple_10_of_100(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - let x: &[usize] = &[1; 100]; - b.iter(|| x.iter().cloned().choose_multiple(&mut rng, 10)) -} - -#[bench] -fn seq_iter_choose_multiple_fill_10_of_100(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - let x: &[usize] = &[1; 100]; - let mut buf = [0; 10]; - b.iter(|| x.iter().cloned().choose_multiple_fill(&mut rng, &mut buf)) -} - -macro_rules! sample_indices { - ($name:ident, $fn:ident, $amount:expr, $length:expr) => { - #[bench] - fn $name(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - b.iter(|| index::$fn(&mut rng, $length, $amount)) - } - }; -} - -sample_indices!(misc_sample_indices_1_of_1k, sample, 1, 1000); -sample_indices!(misc_sample_indices_10_of_1k, sample, 10, 1000); -sample_indices!(misc_sample_indices_100_of_1k, sample, 100, 1000); -sample_indices!(misc_sample_indices_100_of_1M, sample, 100, 1_000_000); -sample_indices!(misc_sample_indices_100_of_1G, sample, 100, 1_000_000_000); -sample_indices!(misc_sample_indices_200_of_1G, sample, 200, 1_000_000_000); -sample_indices!(misc_sample_indices_400_of_1G, sample, 400, 1_000_000_000); -sample_indices!(misc_sample_indices_600_of_1G, sample, 600, 1_000_000_000); - -macro_rules! sample_indices_rand_weights { - ($name:ident, $amount:expr, $length:expr) => { - #[bench] - fn $name(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - b.iter(|| { - index::sample_weighted(&mut rng, $length, |idx| (1 + (idx % 100)) as u32, $amount) - }) - } - }; -} - -sample_indices_rand_weights!(misc_sample_weighted_indices_1_of_1k, 1, 1000); -sample_indices_rand_weights!(misc_sample_weighted_indices_10_of_1k, 10, 1000); -sample_indices_rand_weights!(misc_sample_weighted_indices_100_of_1k, 100, 1000); -sample_indices_rand_weights!(misc_sample_weighted_indices_100_of_1M, 100, 1_000_000); -sample_indices_rand_weights!(misc_sample_weighted_indices_200_of_1M, 200, 1_000_000); -sample_indices_rand_weights!(misc_sample_weighted_indices_400_of_1M, 400, 1_000_000); -sample_indices_rand_weights!(misc_sample_weighted_indices_600_of_1M, 600, 1_000_000); -sample_indices_rand_weights!(misc_sample_weighted_indices_1k_of_1M, 1000, 1_000_000); diff --git a/benches/seq_choose.rs b/benches/seq_choose.rs deleted file mode 100644 index ccf7e5825aa..00000000000 --- a/benches/seq_choose.rs +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright 2018-2023 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. -use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use rand::prelude::*; -use rand::SeedableRng; - -criterion_group!( -name = benches; -config = Criterion::default(); -targets = bench -); -criterion_main!(benches); - -pub fn bench(c: &mut Criterion) { - bench_rng::(c, "ChaCha20"); - bench_rng::(c, "Pcg32"); - bench_rng::(c, "Pcg64"); -} - -fn bench_rng(c: &mut Criterion, rng_name: &'static str) { - for length in [1, 2, 3, 10, 100, 1000].map(black_box) { - c.bench_function( - format!("choose_size-hinted_from_{length}_{rng_name}").as_str(), - |b| { - let mut rng = Rng::seed_from_u64(123); - b.iter(|| choose_size_hinted(length, &mut rng)) - }, - ); - - c.bench_function( - format!("choose_stable_from_{length}_{rng_name}").as_str(), - |b| { - let mut rng = Rng::seed_from_u64(123); - b.iter(|| choose_stable(length, &mut rng)) - }, - ); - - c.bench_function( - format!("choose_unhinted_from_{length}_{rng_name}").as_str(), - |b| { - let mut rng = Rng::seed_from_u64(123); - b.iter(|| choose_unhinted(length, &mut rng)) - }, - ); - - c.bench_function( - format!("choose_windowed_from_{length}_{rng_name}").as_str(), - |b| { - let mut rng = Rng::seed_from_u64(123); - b.iter(|| choose_windowed(length, 7, &mut rng)) - }, - ); - } -} - -fn choose_size_hinted(max: usize, rng: &mut R) -> Option { - let iterator = 0..max; - iterator.choose(rng) -} - -fn choose_stable(max: usize, rng: &mut R) -> Option { - let iterator = 0..max; - iterator.choose_stable(rng) -} - -fn choose_unhinted(max: usize, rng: &mut R) -> Option { - let iterator = UnhintedIterator { iter: (0..max) }; - iterator.choose(rng) -} - -fn choose_windowed(max: usize, window_size: usize, rng: &mut R) -> Option { - let iterator = WindowHintedIterator { - iter: (0..max), - window_size, - }; - iterator.choose(rng) -} - -#[derive(Clone)] -struct UnhintedIterator { - iter: I, -} -impl Iterator for UnhintedIterator { - type Item = I::Item; - - fn next(&mut self) -> Option { - self.iter.next() - } -} - -#[derive(Clone)] -struct WindowHintedIterator { - iter: I, - window_size: usize, -} -impl Iterator for WindowHintedIterator { - type Item = I::Item; - - fn next(&mut self) -> Option { - self.iter.next() - } - - fn size_hint(&self) -> (usize, Option) { - (core::cmp::min(self.iter.len(), self.window_size), None) - } -} diff --git a/benches/uniform.rs b/benches/uniform.rs deleted file mode 100644 index 0ed0f2cde40..00000000000 --- a/benches/uniform.rs +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2021 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! Implement benchmarks for uniform distributions over integer types - -use core::time::Duration; -use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; -use rand::distributions::uniform::{SampleRange, Uniform}; -use rand::prelude::*; -use rand_chacha::ChaCha8Rng; -use rand_pcg::{Pcg32, Pcg64}; - -const WARM_UP_TIME: Duration = Duration::from_millis(1000); -const MEASUREMENT_TIME: Duration = Duration::from_secs(3); -const SAMPLE_SIZE: usize = 100_000; -const N_RESAMPLES: usize = 10_000; - -macro_rules! sample { - ($R:ty, $T:ty, $U:ty, $g:expr) => { - $g.bench_function(BenchmarkId::new(stringify!($R), "single"), |b| { - let mut rng = <$R>::from_rng(thread_rng()).unwrap(); - let x = rng.gen::<$U>(); - let bits = (<$T>::BITS / 2); - let mask = (1 as $U).wrapping_neg() >> bits; - let range = (x >> bits) * (x & mask); - let low = <$T>::MIN; - let high = low.wrapping_add(range as $T); - - b.iter(|| (low..=high).sample_single(&mut rng)); - }); - - $g.bench_function(BenchmarkId::new(stringify!($R), "distr"), |b| { - let mut rng = <$R>::from_rng(thread_rng()).unwrap(); - let x = rng.gen::<$U>(); - let bits = (<$T>::BITS / 2); - let mask = (1 as $U).wrapping_neg() >> bits; - let range = (x >> bits) * (x & mask); - let low = <$T>::MIN; - let high = low.wrapping_add(range as $T); - let dist = Uniform::<$T>::new_inclusive(<$T>::MIN, high).unwrap(); - - b.iter(|| dist.sample(&mut rng)); - }); - }; - - ($c:expr, $T:ty, $U:ty) => {{ - let mut g = $c.benchmark_group(concat!("sample", stringify!($T))); - g.sample_size(SAMPLE_SIZE); - g.warm_up_time(WARM_UP_TIME); - g.measurement_time(MEASUREMENT_TIME); - g.nresamples(N_RESAMPLES); - sample!(SmallRng, $T, $U, g); - sample!(ChaCha8Rng, $T, $U, g); - sample!(Pcg32, $T, $U, g); - sample!(Pcg64, $T, $U, g); - g.finish(); - }}; -} - -fn sample(c: &mut Criterion) { - sample!(c, i8, u8); - sample!(c, i16, u16); - sample!(c, i32, u32); - sample!(c, i64, u64); - sample!(c, i128, u128); -} - -criterion_group! { - name = benches; - config = Criterion::default(); - targets = sample -} -criterion_main!(benches); diff --git a/benches/weighted.rs b/benches/weighted.rs deleted file mode 100644 index 68722908a9e..00000000000 --- a/benches/weighted.rs +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2019 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -#![feature(test)] - -extern crate test; - -use rand::distributions::WeightedIndex; -use rand::Rng; -use test::Bencher; - -#[bench] -fn weighted_index_creation(b: &mut Bencher) { - let mut rng = rand::thread_rng(); - let weights = [1u32, 2, 4, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 7]; - b.iter(|| { - let distr = WeightedIndex::new(weights.to_vec()).unwrap(); - rng.sample(distr) - }) -} - -#[bench] -fn weighted_index_modification(b: &mut Bencher) { - let mut rng = rand::thread_rng(); - let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; - let mut distr = WeightedIndex::new(weights.to_vec()).unwrap(); - b.iter(|| { - distr.update_weights(&[(2, &4), (5, &1)]).unwrap(); - rng.sample(&distr) - }) -} diff --git a/clippy.toml b/clippy.toml new file mode 100644 index 00000000000..14793c52048 --- /dev/null +++ b/clippy.toml @@ -0,0 +1,2 @@ +# Don't warn about these identifiers when using clippy::doc_markdown. +doc-valid-idents = ["ChaCha", "ChaCha12", "SplitMix64", "ZiB", ".."] diff --git a/examples/monte-carlo.rs b/examples/monte-carlo.rs index a72cc1e9f47..d5b898f17f0 100644 --- a/examples/monte-carlo.rs +++ b/examples/monte-carlo.rs @@ -23,14 +23,11 @@ //! We can use the above fact to estimate the value of π: pick many points in //! the square at random, calculate the fraction that fall within the circle, //! and multiply this fraction by 4. - -#![cfg(all(feature = "std", feature = "std_rng"))] - -use rand::distributions::{Distribution, Uniform}; +use rand::distr::{Distribution, Uniform}; fn main() { let range = Uniform::new(-1.0f64, 1.0).unwrap(); - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let total = 1_000_000; let mut in_circle = 0; diff --git a/examples/monty-hall.rs b/examples/monty-hall.rs index 7499193bcea..0a6d033739c 100644 --- a/examples/monty-hall.rs +++ b/examples/monty-hall.rs @@ -26,9 +26,7 @@ //! //! [Monty Hall Problem]: https://en.wikipedia.org/wiki/Monty_Hall_problem -#![cfg(all(feature = "std", feature = "std_rng"))] - -use rand::distributions::{Distribution, Uniform}; +use rand::distr::{Distribution, Uniform}; use rand::Rng; struct SimulationResult { @@ -47,7 +45,7 @@ fn simulate(random_door: &Uniform, rng: &mut R) -> SimulationResult let open = game_host_open(car, choice, rng); // Shall we switch? - let switch = rng.gen(); + let switch = rng.random(); if switch { choice = switch_door(choice, open); } @@ -79,7 +77,7 @@ fn main() { // The estimation will be more accurate with more simulations let num_simulations = 10000; - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let random_door = Uniform::new(0u32, 3).unwrap(); let (mut switch_wins, mut switch_losses) = (0, 0); diff --git a/examples/rayon-monte-carlo.rs b/examples/rayon-monte-carlo.rs index 7e703c01d2d..31d8e681067 100644 --- a/examples/rayon-monte-carlo.rs +++ b/examples/rayon-monte-carlo.rs @@ -38,9 +38,7 @@ //! over BATCH_SIZE trials. Manually batching also turns out to be faster //! for the nondeterministic version of this program as well. -#![cfg(all(feature = "std", feature = "std_rng"))] - -use rand::distributions::{Distribution, Uniform}; +use rand::distr::{Distribution, Uniform}; use rand_chacha::{rand_core::SeedableRng, ChaCha8Rng}; use rayon::prelude::*; diff --git a/rand_chacha/CHANGELOG.md b/rand_chacha/CHANGELOG.md index dcc9d2e688c..7965cf7640e 100644 --- a/rand_chacha/CHANGELOG.md +++ b/rand_chacha/CHANGELOG.md @@ -4,11 +4,16 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [0.9.0-alpha.0] - 2024-02-18 -This is a pre-release. To depend on this version, use `rand_chacha = "=0.9.0-alpha.0"` to prevent automatic updates (which can be expected to include breaking changes). - -- Made `rand_chacha` propagate the `std` feature down to `rand_core` (#1153) +## [0.9.0] - 2025-01-27 +### Dependencies and features +- Update to `rand_core` v0.9.0 (#1558) +- Feature `std` now implies feature `rand_core/std` (#1153) +- Rename feature `serde1` to `serde` (#1477) +- Rename feature `getrandom` to `os_rng` (#1537) + +### Other changes - Remove usage of `unsafe` in `fn generate` (#1181) then optimise for AVX2 (~4-7%) (#1192) +- Revise crate docs (#1454) ## [0.3.1] - 2021-06-09 - add getters corresponding to existing setters: `get_seed`, `get_stream` (#1124) diff --git a/rand_chacha/Cargo.toml b/rand_chacha/Cargo.toml index bcc09f61cc0..e2f313d2e8e 100644 --- a/rand_chacha/Cargo.toml +++ b/rand_chacha/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rand_chacha" -version = "0.9.0-alpha.0" +version = "0.9.0" authors = ["The Rand Project Developers", "The Rust Project Developers", "The CryptoCorrosion Contributors"] license = "MIT OR Apache-2.0" readme = "README.md" @@ -13,22 +13,24 @@ ChaCha random number generator keywords = ["random", "rng", "chacha"] categories = ["algorithms", "no-std"] edition = "2021" -rust-version = "1.60" +rust-version = "1.63" [package.metadata.docs.rs] +all-features = true rustdoc-args = ["--generate-link-to-definition"] [dependencies] -rand_core = { path = "../rand_core", version = "=0.9.0-alpha.0" } +rand_core = { path = "../rand_core", version = "0.9.0" } ppv-lite86 = { version = "0.2.14", default-features = false, features = ["simd"] } serde = { version = "1.0", features = ["derive"], optional = true } [dev-dependencies] -# Only to test serde1 -serde_json = "1.0" +# Only to test serde +serde_json = "1.0.120" +rand_core = { path = "../rand_core", version = "0.9.0", features = ["os_rng"] } [features] default = ["std"] +os_rng = ["rand_core/os_rng"] std = ["ppv-lite86/std", "rand_core/std"] -simd = [] # deprecated -serde1 = ["serde"] +serde = ["dep:serde"] diff --git a/rand_chacha/README.md b/rand_chacha/README.md index 0fd1b64c0d4..167417f85c8 100644 --- a/rand_chacha/README.md +++ b/rand_chacha/README.md @@ -1,11 +1,10 @@ # rand_chacha -[![Test Status](https://github.com/rust-random/rand/workflows/Tests/badge.svg?event=push)](https://github.com/rust-random/rand/actions) +[![Test Status](https://github.com/rust-random/rand/actions/workflows/test.yml/badge.svg?event=push)](https://github.com/rust-random/rand/actions) [![Latest version](https://img.shields.io/crates/v/rand_chacha.svg)](https://crates.io/crates/rand_chacha) [![Book](https://img.shields.io/badge/book-master-yellow.svg)](https://rust-random.github.io/book/) [![API](https://img.shields.io/badge/api-master-yellow.svg)](https://rust-random.github.io/rand/rand_chacha) [![API](https://docs.rs/rand_chacha/badge.svg)](https://docs.rs/rand_chacha) -[![Minimum rustc version](https://img.shields.io/badge/rustc-1.60+-lightgray.svg)](https://github.com/rust-random/rand#rust-version-requirements) A cryptographically secure random number generator that uses the ChaCha algorithm. @@ -37,7 +36,7 @@ Links: `rand_chacha` is `no_std` compatible when disabling default features; the `std` feature can be explicitly required to re-enable `std` support. Using `std` allows detection of CPU features and thus better optimisation. Using `std` -also enables `getrandom` functionality, such as `ChaCha20Rng::from_entropy()`. +also enables `os_rng` functionality, such as `ChaCha20Rng::from_os_rng()`. # License diff --git a/rand_chacha/src/chacha.rs b/rand_chacha/src/chacha.rs index ebc28a8ab04..91d3cd628d2 100644 --- a/rand_chacha/src/chacha.rs +++ b/rand_chacha/src/chacha.rs @@ -8,15 +8,13 @@ //! The ChaCha random number generator. -#[cfg(not(feature = "std"))] use core; -#[cfg(feature = "std")] use std as core; - -use self::core::fmt; use crate::guts::ChaCha; +use core::fmt; use rand_core::block::{BlockRng, BlockRngCore, CryptoBlockRng}; -use rand_core::{CryptoRng, Error, RngCore, SeedableRng}; +use rand_core::{CryptoRng, RngCore, SeedableRng}; -#[cfg(feature = "serde1")] use serde::{Serialize, Deserialize, Serializer, Deserializer}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Deserializer, Serialize, Serializer}; // NB. this must remain consistent with some currently hard-coded numbers in this module const BUF_BLOCKS: u8 = 4; @@ -26,7 +24,8 @@ const BLOCK_WORDS: u8 = 16; #[repr(transparent)] pub struct Array64([T; 64]); impl Default for Array64 -where T: Default +where + T: Default, { #[rustfmt::skip] fn default() -> Self { @@ -53,7 +52,8 @@ impl AsMut<[T]> for Array64 { } } impl Clone for Array64 -where T: Copy + Default +where + T: Copy + Default, { fn clone(&self) -> Self { let mut new = Self::default(); @@ -68,7 +68,7 @@ impl fmt::Debug for Array64 { } macro_rules! chacha_impl { - ($ChaChaXCore:ident, $ChaChaXRng:ident, $rounds:expr, $doc:expr, $abst:ident) => { + ($ChaChaXCore:ident, $ChaChaXRng:ident, $rounds:expr, $doc:expr, $abst:ident,) => { #[doc=$doc] #[derive(Clone, PartialEq, Eq)] pub struct $ChaChaXCore { @@ -85,6 +85,7 @@ macro_rules! chacha_impl { impl BlockRngCore for $ChaChaXCore { type Item = u32; type Results = Array64; + #[inline] fn generate(&mut self, r: &mut Self::Results) { self.state.refill4($rounds, &mut r.0); @@ -93,9 +94,12 @@ macro_rules! chacha_impl { impl SeedableRng for $ChaChaXCore { type Seed = [u8; 32]; + #[inline] fn from_seed(seed: Self::Seed) -> Self { - $ChaChaXCore { state: ChaCha::new(&seed, &[0u8; 8]) } + $ChaChaXCore { + state: ChaCha::new(&seed, &[0u8; 8]), + } } } @@ -146,6 +150,7 @@ macro_rules! chacha_impl { impl SeedableRng for $ChaChaXRng { type Seed = [u8; 32]; + #[inline] fn from_seed(seed: Self::Seed) -> Self { let core = $ChaChaXCore::from_seed(seed); @@ -160,18 +165,16 @@ macro_rules! chacha_impl { fn next_u32(&mut self) -> u32 { self.rng.next_u32() } + #[inline] fn next_u64(&mut self) -> u64 { self.rng.next_u64() } + #[inline] fn fill_bytes(&mut self, bytes: &mut [u8]) { self.rng.fill_bytes(bytes) } - #[inline] - fn try_fill_bytes(&mut self, bytes: &mut [u8]) -> Result<(), Error> { - self.rng.try_fill_bytes(bytes) - } } impl $ChaChaXRng { @@ -209,11 +212,9 @@ macro_rules! chacha_impl { #[inline] pub fn set_word_pos(&mut self, word_offset: u128) { let block = (word_offset / u128::from(BLOCK_WORDS)) as u64; + self.rng.core.state.set_block_pos(block); self.rng - .core - .state - .set_block_pos(block); - self.rng.generate_and_set((word_offset % u128::from(BLOCK_WORDS)) as usize); + .generate_and_set((word_offset % u128::from(BLOCK_WORDS)) as usize); } /// Set the stream number. @@ -229,10 +230,7 @@ macro_rules! chacha_impl { /// indirectly via `set_word_pos`), but this is not directly supported. #[inline] pub fn set_stream(&mut self, stream: u64) { - self.rng - .core - .state - .set_nonce(stream); + self.rng.core.state.set_nonce(stream); if self.rng.index() != 64 { let wp = self.get_word_pos(); self.set_word_pos(wp); @@ -242,19 +240,13 @@ macro_rules! chacha_impl { /// Get the stream number. #[inline] pub fn get_stream(&self) -> u64 { - self.rng - .core - .state - .get_nonce() + self.rng.core.state.get_nonce() } /// Get the seed. #[inline] pub fn get_seed(&self) -> [u8; 32] { - self.rng - .core - .state - .get_seed() + self.rng.core.state.get_seed() } } @@ -277,31 +269,34 @@ macro_rules! chacha_impl { } impl Eq for $ChaChaXRng {} - #[cfg(feature = "serde1")] + #[cfg(feature = "serde")] impl Serialize for $ChaChaXRng { fn serialize(&self, s: S) -> Result - where S: Serializer { + where + S: Serializer, + { $abst::$ChaChaXRng::from(self).serialize(s) } } - #[cfg(feature = "serde1")] + #[cfg(feature = "serde")] impl<'de> Deserialize<'de> for $ChaChaXRng { - fn deserialize(d: D) -> Result where D: Deserializer<'de> { + fn deserialize(d: D) -> Result + where + D: Deserializer<'de>, + { $abst::$ChaChaXRng::deserialize(d).map(|x| Self::from(&x)) } } mod $abst { - #[cfg(feature = "serde1")] use serde::{Serialize, Deserialize}; + #[cfg(feature = "serde")] + use serde::{Deserialize, Serialize}; // The abstract state of a ChaCha stream, independent of implementation choices. The // comparison and serialization of this object is considered a semver-covered part of // the API. #[derive(Debug, PartialEq, Eq)] - #[cfg_attr( - feature = "serde1", - derive(Serialize, Deserialize), - )] + #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub(crate) struct $ChaChaXRng { seed: [u8; 32], stream: u64, @@ -331,27 +326,46 @@ macro_rules! chacha_impl { } } } - } + }; } -chacha_impl!(ChaCha20Core, ChaCha20Rng, 10, "ChaCha with 20 rounds", abstract20); -chacha_impl!(ChaCha12Core, ChaCha12Rng, 6, "ChaCha with 12 rounds", abstract12); -chacha_impl!(ChaCha8Core, ChaCha8Rng, 4, "ChaCha with 8 rounds", abstract8); +chacha_impl!( + ChaCha20Core, + ChaCha20Rng, + 10, + "ChaCha with 20 rounds", + abstract20, +); +chacha_impl!( + ChaCha12Core, + ChaCha12Rng, + 6, + "ChaCha with 12 rounds", + abstract12, +); +chacha_impl!( + ChaCha8Core, + ChaCha8Rng, + 4, + "ChaCha with 8 rounds", + abstract8, +); #[cfg(test)] mod test { use rand_core::{RngCore, SeedableRng}; - #[cfg(feature = "serde1")] use super::{ChaCha20Rng, ChaCha12Rng, ChaCha8Rng}; + #[cfg(feature = "serde")] + use super::{ChaCha12Rng, ChaCha20Rng, ChaCha8Rng}; type ChaChaRng = super::ChaCha20Rng; - #[cfg(feature = "serde1")] + #[cfg(feature = "serde")] #[test] fn test_chacha_serde_roundtrip() { let seed = [ - 1, 0, 52, 0, 0, 0, 0, 0, 1, 0, 10, 0, 22, 32, 0, 0, 2, 0, 55, 49, 0, 11, 0, 0, 3, 0, 0, 0, 0, - 0, 2, 92, + 1, 0, 52, 0, 0, 0, 0, 0, 1, 0, 10, 0, 22, 32, 0, 0, 2, 0, 55, 49, 0, 11, 0, 0, 3, 0, 0, + 0, 0, 0, 2, 92, ]; let mut rng1 = ChaCha20Rng::from_seed(seed); let mut rng2 = ChaCha12Rng::from_seed(seed); @@ -384,11 +398,11 @@ mod test { // However testing for equivalence of serialized data is difficult, and there shouldn't be any // reason we need to violate the stronger-than-needed condition, e.g. by changing the field // definition order. - #[cfg(feature = "serde1")] + #[cfg(feature = "serde")] #[test] fn test_chacha_serde_format_stability() { let j = r#"{"seed":[4,8,15,16,23,42,4,8,15,16,23,42,4,8,15,16,23,42,4,8,15,16,23,42,4,8,15,16,23,42,4,8],"stream":27182818284,"word_pos":314159265359}"#; - let r: ChaChaRng = serde_json::from_str(&j).unwrap(); + let r: ChaChaRng = serde_json::from_str(j).unwrap(); let j1 = serde_json::to_string(&r).unwrap(); assert_eq!(j, j1); } @@ -402,7 +416,7 @@ mod test { let mut rng1 = ChaChaRng::from_seed(seed); assert_eq!(rng1.next_u32(), 137206642); - let mut rng2 = ChaChaRng::from_rng(rng1).unwrap(); + let mut rng2 = ChaChaRng::from_rng(&mut rng1); assert_eq!(rng2.next_u32(), 1325750369); } @@ -598,7 +612,7 @@ mod test { #[test] fn test_chacha_word_pos_wrap_exact() { - use super::{BUF_BLOCKS, BLOCK_WORDS}; + use super::{BLOCK_WORDS, BUF_BLOCKS}; let mut rng = ChaChaRng::from_seed(Default::default()); // refilling the buffer in set_word_pos will wrap the block counter to 0 let last_block = (1 << 68) - u128::from(BUF_BLOCKS * BLOCK_WORDS); diff --git a/rand_chacha/src/guts.rs b/rand_chacha/src/guts.rs index 797ded6fa73..d077225c625 100644 --- a/rand_chacha/src/guts.rs +++ b/rand_chacha/src/guts.rs @@ -12,7 +12,9 @@ use ppv_lite86::{dispatch, dispatch_light128}; pub use ppv_lite86::Machine; -use ppv_lite86::{vec128_storage, ArithOps, BitOps32, LaneWords4, MultiLane, StoreBytes, Vec4, Vec4Ext, Vector}; +use ppv_lite86::{ + vec128_storage, ArithOps, BitOps32, LaneWords4, MultiLane, StoreBytes, Vec4, Vec4Ext, Vector, +}; pub(crate) const BLOCK: usize = 16; pub(crate) const BLOCK64: u64 = BLOCK as u64; @@ -140,14 +142,18 @@ fn add_pos(m: Mach, d: Mach::u32x4, i: u64) -> Mach::u32x4 { #[cfg(target_endian = "little")] fn d0123(m: Mach, d: vec128_storage) -> Mach::u32x4x4 { let d0: Mach::u64x2 = m.unpack(d); - let incr = Mach::u64x2x4::from_lanes([m.vec([0, 0]), m.vec([1, 0]), m.vec([2, 0]), m.vec([3, 0])]); + let incr = + Mach::u64x2x4::from_lanes([m.vec([0, 0]), m.vec([1, 0]), m.vec([2, 0]), m.vec([3, 0])]); m.unpack((Mach::u64x2x4::from_lanes([d0, d0, d0, d0]) + incr).into()) } #[allow(clippy::many_single_char_names)] #[inline(always)] fn refill_wide_impl( - m: Mach, state: &mut ChaCha, drounds: u32, out: &mut [u32; BUFSZ], + m: Mach, + state: &mut ChaCha, + drounds: u32, + out: &mut [u32; BUFSZ], ) { let k = m.vec([0x6170_7865, 0x3320_646e, 0x7962_2d32, 0x6b20_6574]); let b = m.unpack(state.b); diff --git a/rand_chacha/src/lib.rs b/rand_chacha/src/lib.rs index f4b526b8f64..24ddd601d27 100644 --- a/rand_chacha/src/lib.rs +++ b/rand_chacha/src/lib.rs @@ -6,7 +6,79 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! The ChaCha random number generator. +//! The ChaCha random number generators. +//! +//! These are native Rust implementations of RNGs derived from the +//! [ChaCha stream ciphers] by D J Bernstein. +//! +//! ## Generators +//! +//! This crate provides 8-, 12- and 20-round variants of generators via a "core" +//! implementation (of [`BlockRngCore`]), each with an associated "RNG" type +//! (implementing [`RngCore`]). +//! +//! These generators are all deterministic and portable (see [Reproducibility] +//! in the book), with testing against reference vectors. +//! +//! ## Cryptographic (secure) usage +//! +//! Where secure unpredictable generators are required, it is suggested to use +//! [`ChaCha12Rng`] or [`ChaCha20Rng`] and to seed via +//! [`SeedableRng::from_os_rng`]. +//! +//! See also the [Security] chapter in the rand book. The crate is provided +//! "as is", without any form of guarantee, and without a security audit. +//! +//! ## Seeding (construction) +//! +//! Generators implement the [`SeedableRng`] trait. Any method may be used, +//! but note that `seed_from_u64` is not suitable for usage where security is +//! important. Some suggestions: +//! +//! 1. With a fresh seed, **direct from the OS** (implies a syscall): +//! ``` +//! # use {rand_core::SeedableRng, rand_chacha::ChaCha12Rng}; +//! let rng = ChaCha12Rng::from_os_rng(); +//! # let _: ChaCha12Rng = rng; +//! ``` +//! 2. **From a master generator.** This could be [`rand::rng`] +//! (effectively a fresh seed without the need for a syscall on each usage) +//! or a deterministic generator such as [`ChaCha20Rng`]. +//! Beware that should a weak master generator be used, correlations may be +//! detectable between the outputs of its child generators. +//! ```ignore +//! let rng = ChaCha12Rng::from_rng(&mut rand::rng()); +//! ``` +//! +//! See also [Seeding RNGs] in the book. +//! +//! ## Generation +//! +//! Generators implement [`RngCore`], whose methods may be used directly to +//! generate unbounded integer or byte values. +//! ``` +//! use rand_core::{SeedableRng, RngCore}; +//! use rand_chacha::ChaCha12Rng; +//! +//! let mut rng = ChaCha12Rng::from_seed(Default::default()); +//! let x = rng.next_u64(); +//! assert_eq!(x, 0x53f955076a9af49b); +//! ``` +//! +//! It is often more convenient to use the [`rand::Rng`] trait, which provides +//! further functionality. See also the [Random Values] chapter in the book. +//! +//! [ChaCha stream ciphers]: https://cr.yp.to/chacha.html +//! [Reproducibility]: https://rust-random.github.io/book/crate-reprod.html +//! [Seeding RNGs]: https://rust-random.github.io/book/guide-seeding.html +//! [Security]: https://rust-random.github.io/book/guide-rngs.html#security +//! [Random Values]: https://rust-random.github.io/book/guide-values.html +//! [`BlockRngCore`]: rand_core::block::BlockRngCore +//! [`RngCore`]: rand_core::RngCore +//! [`SeedableRng`]: rand_core::SeedableRng +//! [`SeedableRng::from_os_rng`]: rand_core::SeedableRng::from_os_rng +//! [`rand::rng`]: https://docs.rs/rand/latest/rand/fn.rng.html +//! [`rand::Rng`]: https://docs.rs/rand/latest/rand/trait.Rng.html #![doc( html_logo_url = "https://www.rust-lang.org/logos/rust-logo-128x128-blk.png", diff --git a/rand_core/CHANGELOG.md b/rand_core/CHANGELOG.md index b7f00895622..7318dffa878 100644 --- a/rand_core/CHANGELOG.md +++ b/rand_core/CHANGELOG.md @@ -4,14 +4,36 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [0.9.0-alpha.0] - 2024-02-18 -This is a pre-release. To depend on this version, use `rand_core = "=0.9.0-alpha.0"` to prevent automatic updates (which can be expected to include breaking changes). - -- Bump MSRV to 1.60.0 (#1207, #1246, #1269, #1341) +## [0.9.3] — 2025-02-29 +### Other +- Remove `zerocopy` dependency (#1607) +- Deprecate `rand_core::impls::fill_via_u32_chunks`, `fill_via_u64_chunks` (#1607) + +## [0.9.2] - 2025-02-22 +### API changes +- Relax `Sized` bound on impls of `TryRngCore`, `TryCryptoRng` and `UnwrapMut` (#1593) +- Add `UnwrapMut::re` to reborrow the inner rng with a tighter lifetime (#1595) + +## [0.9.1] - 2025-02-16 +### API changes +- Add `TryRngCore::unwrap_mut`, providing an impl of `RngCore` over `&mut rng` (#1589) + +## [0.9.0] - 2025-01-27 +### Dependencies and features +- Bump the MSRV to 1.63.0 (#1207, #1246, #1269, #1341, #1416, #1536); note that 1.60.0 may work for dependents when using `--ignore-rust-version` +- Update to `getrandom` v0.3.0 (#1558) +- Use `zerocopy` to replace some `unsafe` code (#1349, #1393, #1446, #1502) +- Rename feature `serde1` to `serde` (#1477) +- Rename feature `getrandom` to `os_rng` (#1537) + +### API changes - Allow `rand_core::impls::fill_via_u*_chunks` to mutate source (#1182) -- Add `fn RngCore::read_adapter` implementing `std::io::Read` (#1267) -- Add `trait CryptoBlockRng: BlockRngCore`; make `trait CryptoRng: RngCore` (#1273) -- Use `zerocopy` to replace some `unsafe` code (#1349, #1393) +- Add fn `RngCore::read_adapter` implementing `std::io::Read` (#1267) +- Add trait `CryptoBlockRng: BlockRngCore`; make `trait CryptoRng: RngCore` replacing `CryptoRngCore` (#1273) +- Add traits `TryRngCore`, `TryCryptoRng` (#1424, #1499) +- Rename `fn SeedableRng::from_rng` -> `try_from_rng` and add infallible variant `fn from_rng` (#1424) +- Rename `fn SeedableRng::from_entropy` -> `from_os_rng` and add fallible variant `fn try_from_os_rng` (#1424) +- Add bounds `Clone` and `AsRef` to associated type `SeedableRng::Seed` (#1491) ## [0.6.4] - 2022-09-15 - Fix unsoundness in `::next_u32` (#1160) diff --git a/rand_core/Cargo.toml b/rand_core/Cargo.toml index 28845f3d661..899c359554c 100644 --- a/rand_core/Cargo.toml +++ b/rand_core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rand_core" -version = "0.9.0-alpha.0" +version = "0.9.3" authors = ["The Rand Project Developers", "The Rust Project Developers"] license = "MIT OR Apache-2.0" readme = "README.md" @@ -13,23 +13,22 @@ Core random number generator traits and tools for implementation. keywords = ["random", "rng"] categories = ["algorithms", "no-std"] edition = "2021" -rust-version = "1.60" +rust-version = "1.63" [package.metadata.docs.rs] # To build locally: -# RUSTDOCFLAGS="--cfg doc_cfg" cargo +nightly doc --all-features --no-deps --open +# RUSTDOCFLAGS="--cfg docsrs" cargo +nightly doc --all-features --no-deps --open all-features = true -rustdoc-args = ["--cfg", "doc_cfg", "--generate-link-to-definition"] +rustdoc-args = ["--generate-link-to-definition"] [package.metadata.playground] all-features = true [features] -std = ["alloc", "getrandom?/std"] -alloc = [] # enables Vec and Box support without std -serde1 = ["serde"] # enables serde for BlockRng wrapper +std = ["getrandom?/std"] +os_rng = ["dep:getrandom"] +serde = ["dep:serde"] # enables serde for BlockRng wrapper [dependencies] serde = { version = "1", features = ["derive"], optional = true } -getrandom = { version = "0.2", optional = true } -zerocopy = { version = "=0.8.0-alpha.5", default-features = false } +getrandom = { version = "0.3.0", optional = true } diff --git a/rand_core/README.md b/rand_core/README.md index a08f7c99251..05d9fbf6cb0 100644 --- a/rand_core/README.md +++ b/rand_core/README.md @@ -1,11 +1,10 @@ # rand_core -[![Test Status](https://github.com/rust-random/rand/workflows/Tests/badge.svg?event=push)](https://github.com/rust-random/rand/actions) +[![Test Status](https://github.com/rust-random/rand/actions/workflows/test.yml/badge.svg?event=push)](https://github.com/rust-random/rand/actions) [![Latest version](https://img.shields.io/crates/v/rand_core.svg)](https://crates.io/crates/rand_core) [![Book](https://img.shields.io/badge/book-master-yellow.svg)](https://rust-random.github.io/book/) [![API](https://img.shields.io/badge/api-master-yellow.svg)](https://rust-random.github.io/rand/rand_core) [![API](https://docs.rs/rand_core/badge.svg)](https://docs.rs/rand_core) -[![Minimum rustc version](https://img.shields.io/badge/rustc-1.60+-lightgray.svg)](https://github.com/rust-random/rand#rust-version-requirements) Core traits and error types of the [rand] library, plus tools for implementing RNGs. @@ -42,34 +41,10 @@ The traits and error types are also available via `rand`. ## Versions The current version is: -``` -rand_core = "0.6.4" -``` -Rand libs have inter-dependencies and make use of the -[semver trick](https://github.com/dtolnay/semver-trick/) in order to make traits -compatible across crate versions. (This is especially important for `RngCore` -and `SeedableRng`.) A few crate releases are thus compatibility shims, -depending on the *next* lib version (e.g. `rand_core` versions `0.2.2` and -`0.3.1`). This means, for example, that `rand_core_0_4_0::SeedableRng` and -`rand_core_0_3_0::SeedableRng` are distinct, incompatible traits, which can -cause build errors. Usually, running `cargo update` is enough to fix any issues. - -## Crate Features - -`rand_core` supports `no_std` and `alloc`-only configurations, as well as full -`std` functionality. The differences between `no_std` and full `std` are small, -comprising `RngCore` support for `Box` types where `R: RngCore`, -`std::io::Read` support for types supporting `RngCore`, and -extensions to the `Error` type's functionality. - -The `std` feature is *not enabled by default*. This is primarily to avoid build -problems where one crate implicitly requires `rand_core` with `std` support and -another crate requires `rand` *without* `std` support. However, the `rand` crate -continues to enable `std` support by default, both for itself and `rand_core`. - -The `serde1` feature can be used to derive `Serialize` and `Deserialize` for RNG -implementations that use the `BlockRng` or `BlockRng64` wrappers. +```toml +rand_core = "0.9.3" +``` # License diff --git a/rand_core/src/.le.rs.kate-swp b/rand_core/src/.le.rs.kate-swp new file mode 100644 index 00000000000..0debd30bbe9 Binary files /dev/null and b/rand_core/src/.le.rs.kate-swp differ diff --git a/rand_core/src/block.rs b/rand_core/src/block.rs index a8cefc8e40c..667cc0bca6a 100644 --- a/rand_core/src/block.rs +++ b/rand_core/src/block.rs @@ -53,11 +53,10 @@ //! [`BlockRngCore`]: crate::block::BlockRngCore //! [`fill_bytes`]: RngCore::fill_bytes -use crate::impls::{fill_via_u32_chunks, fill_via_u64_chunks}; -use crate::{Error, CryptoRng, RngCore, SeedableRng}; -use core::convert::AsRef; +use crate::impls::fill_via_chunks; +use crate::{CryptoRng, RngCore, SeedableRng, TryRngCore}; use core::fmt; -#[cfg(feature = "serde1")] +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; /// A trait for RNGs which do not generate random numbers individually, but in @@ -81,7 +80,7 @@ pub trait BlockRngCore { /// supposed to be cryptographically secure. /// /// See [`CryptoRng`] docs for more information. -pub trait CryptoBlockRng: BlockRngCore { } +pub trait CryptoBlockRng: BlockRngCore {} /// A wrapper type implementing [`RngCore`] for some type implementing /// [`BlockRngCore`] with `u32` array buffer; i.e. this can be used to implement @@ -98,16 +97,15 @@ pub trait CryptoBlockRng: BlockRngCore { } /// `BlockRng` has heavily optimized implementations of the [`RngCore`] methods /// reading values from the results buffer, as well as /// calling [`BlockRngCore::generate`] directly on the output array when -/// [`fill_bytes`] / [`try_fill_bytes`] is called on a large array. These methods -/// also handle the bookkeeping of when to generate a new batch of values. +/// [`fill_bytes`] is called on a large array. These methods also handle +/// the bookkeeping of when to generate a new batch of values. /// /// No whole generated `u32` values are thrown away and all values are consumed /// in-order. [`next_u32`] simply takes the next available `u32` value. /// [`next_u64`] is implemented by combining two `u32` values, least -/// significant first. [`fill_bytes`] and [`try_fill_bytes`] consume a whole -/// number of `u32` values, converting each `u32` to a byte slice in -/// little-endian order. If the requested byte length is not a multiple of 4, -/// some bytes will be discarded. +/// significant first. [`fill_bytes`] consume a whole number of `u32` values, +/// converting each `u32` to a byte slice in little-endian order. If the requested byte +/// length is not a multiple of 4, some bytes will be discarded. /// /// See also [`BlockRng64`] which uses `u64` array buffers. Currently there is /// no direct support for other buffer types. @@ -117,16 +115,15 @@ pub trait CryptoBlockRng: BlockRngCore { } /// [`next_u32`]: RngCore::next_u32 /// [`next_u64`]: RngCore::next_u64 /// [`fill_bytes`]: RngCore::fill_bytes -/// [`try_fill_bytes`]: RngCore::try_fill_bytes #[derive(Clone)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr( - feature = "serde1", + feature = "serde", serde( - bound = "for<'x> R: Serialize + Deserialize<'x> + Sized, for<'x> R::Results: Serialize + Deserialize<'x>" + bound = "for<'x> R: Serialize + Deserialize<'x>, for<'x> R::Results: Serialize + Deserialize<'x>" ) )] -pub struct BlockRng { +pub struct BlockRng { results: R::Results, index: usize, /// The *core* part of the RNG, implementing the `generate` function. @@ -200,7 +197,7 @@ impl> RngCore for BlockRng { fn next_u64(&mut self) -> u64 { let read_u64 = |results: &[u32], index| { let data = &results[index..=index + 1]; - u64::from(data[1]) << 32 | u64::from(data[0]) + (u64::from(data[1]) << 32) | u64::from(data[0]) }; let len = self.results.as_ref().len(); @@ -229,18 +226,12 @@ impl> RngCore for BlockRng { self.generate_and_set(0); } let (consumed_u32, filled_u8) = - fill_via_u32_chunks(&mut self.results.as_mut()[self.index..], &mut dest[read_len..]); + fill_via_chunks(&self.results.as_mut()[self.index..], &mut dest[read_len..]); self.index += consumed_u32; read_len += filled_u8; } } - - #[inline(always)] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.fill_bytes(dest); - Ok(()) - } } impl SeedableRng for BlockRng { @@ -257,8 +248,13 @@ impl SeedableRng for BlockRng { } #[inline(always)] - fn from_rng(rng: S) -> Result { - Ok(Self::new(R::from_rng(rng)?)) + fn from_rng(rng: &mut impl RngCore) -> Self { + Self::new(R::from_rng(rng)) + } + + #[inline(always)] + fn try_from_rng(rng: &mut S) -> Result { + R::try_from_rng(rng).map(Self::new) } } @@ -278,16 +274,14 @@ impl> CryptoRng for BlockRng {} /// then the other half is then consumed, however both [`next_u64`] and /// [`fill_bytes`] discard the rest of any half-consumed `u64`s when called. /// -/// [`fill_bytes`] and [`try_fill_bytes`] consume a whole number of `u64` -/// values. If the requested length is not a multiple of 8, some bytes will be -/// discarded. +/// [`fill_bytes`] consumes a whole number of `u64` values. If the requested length +/// is not a multiple of 8, some bytes will be discarded. /// /// [`next_u32`]: RngCore::next_u32 /// [`next_u64`]: RngCore::next_u64 /// [`fill_bytes`]: RngCore::fill_bytes -/// [`try_fill_bytes`]: RngCore::try_fill_bytes #[derive(Clone)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct BlockRng64 { results: R::Results, index: usize, @@ -394,21 +388,13 @@ impl> RngCore for BlockRng64 { self.index = 0; } - let (consumed_u64, filled_u8) = fill_via_u64_chunks( - &mut self.results.as_mut()[self.index..], - &mut dest[read_len..], - ); + let (consumed_u64, filled_u8) = + fill_via_chunks(&self.results.as_mut()[self.index..], &mut dest[read_len..]); self.index += consumed_u64; read_len += filled_u8; } } - - #[inline(always)] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.fill_bytes(dest); - Ok(()) - } } impl SeedableRng for BlockRng64 { @@ -425,8 +411,13 @@ impl SeedableRng for BlockRng64 { } #[inline(always)] - fn from_rng(rng: S) -> Result { - Ok(Self::new(R::from_rng(rng)?)) + fn from_rng(rng: &mut impl RngCore) -> Self { + Self::new(R::from_rng(rng)) + } + + #[inline(always)] + fn try_from_rng(rng: &mut S) -> Result { + R::try_from_rng(rng).map(Self::new) } } @@ -434,8 +425,8 @@ impl> CryptoRng for BlockRng64 { #[cfg(test)] mod test { - use crate::{SeedableRng, RngCore}; use crate::block::{BlockRng, BlockRng64, BlockRngCore}; + use crate::{RngCore, SeedableRng}; #[derive(Debug, Clone)] struct DummyRng { @@ -444,7 +435,6 @@ mod test { impl BlockRngCore for DummyRng { type Item = u32; - type Results = [u32; 16]; fn generate(&mut self, results: &mut Self::Results) { @@ -459,7 +449,9 @@ mod test { type Seed = [u8; 4]; fn from_seed(seed: Self::Seed) -> Self { - DummyRng { counter: u32::from_le_bytes(seed) } + DummyRng { + counter: u32::from_le_bytes(seed), + } } } @@ -494,7 +486,6 @@ mod test { impl BlockRngCore for DummyRng64 { type Item = u64; - type Results = [u64; 8]; fn generate(&mut self, results: &mut Self::Results) { @@ -509,7 +500,9 @@ mod test { type Seed = [u8; 8]; fn from_seed(seed: Self::Seed) -> Self { - DummyRng64 { counter: u64::from_le_bytes(seed) } + DummyRng64 { + counter: u64::from_le_bytes(seed), + } } } diff --git a/rand_core/src/error.rs b/rand_core/src/error.rs deleted file mode 100644 index 1a5092fe82b..00000000000 --- a/rand_core/src/error.rs +++ /dev/null @@ -1,226 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! Error types - -use core::fmt; -use core::num::NonZeroU32; - -#[cfg(feature = "std")] use std::boxed::Box; - -/// Error type of random number generators -/// -/// In order to be compatible with `std` and `no_std`, this type has two -/// possible implementations: with `std` a boxed `Error` trait object is stored, -/// while with `no_std` we merely store an error code. -pub struct Error { - #[cfg(feature = "std")] - inner: Box, - #[cfg(not(feature = "std"))] - code: NonZeroU32, -} - -impl Error { - /// Codes at or above this point can be used by users to define their own - /// custom errors. - /// - /// This has a fixed value of `(1 << 31) + (1 << 30) = 0xC000_0000`, - /// therefore the number of values available for custom codes is `1 << 30`. - /// - /// This is identical to [`getrandom::Error::CUSTOM_START`](https://docs.rs/getrandom/latest/getrandom/struct.Error.html#associatedconstant.CUSTOM_START). - pub const CUSTOM_START: u32 = (1 << 31) + (1 << 30); - /// Codes below this point represent OS Errors (i.e. positive i32 values). - /// Codes at or above this point, but below [`Error::CUSTOM_START`] are - /// reserved for use by the `rand` and `getrandom` crates. - /// - /// This is identical to [`getrandom::Error::INTERNAL_START`](https://docs.rs/getrandom/latest/getrandom/struct.Error.html#associatedconstant.INTERNAL_START). - pub const INTERNAL_START: u32 = 1 << 31; - - /// Construct from any type supporting `std::error::Error` - /// - /// Available only when configured with `std`. - /// - /// See also `From`, which is available with and without `std`. - #[cfg(feature = "std")] - #[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] - #[inline] - pub fn new(err: E) -> Self - where E: Into> { - Error { inner: err.into() } - } - - /// Reference the inner error (`std` only) - /// - /// When configured with `std`, this is a trivial operation and never - /// panics. Without `std`, this method is simply unavailable. - #[cfg(feature = "std")] - #[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] - #[inline] - pub fn inner(&self) -> &(dyn std::error::Error + Send + Sync + 'static) { - &*self.inner - } - - /// Unwrap the inner error (`std` only) - /// - /// When configured with `std`, this is a trivial operation and never - /// panics. Without `std`, this method is simply unavailable. - #[cfg(feature = "std")] - #[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] - #[inline] - pub fn take_inner(self) -> Box { - self.inner - } - - /// Extract the raw OS error code (if this error came from the OS) - /// - /// This method is identical to `std::io::Error::raw_os_error()`, except - /// that it works in `no_std` contexts. If this method returns `None`, the - /// error value can still be formatted via the `Display` implementation. - #[inline] - pub fn raw_os_error(&self) -> Option { - #[cfg(feature = "std")] - { - if let Some(e) = self.inner.downcast_ref::() { - return e.raw_os_error(); - } - } - match self.code() { - Some(code) if u32::from(code) < Self::INTERNAL_START => Some(u32::from(code) as i32), - _ => None, - } - } - - /// Retrieve the error code, if any. - /// - /// If this `Error` was constructed via `From`, then this method - /// will return this `NonZeroU32` code (for `no_std` this is always the - /// case). Otherwise, this method will return `None`. - #[inline] - pub fn code(&self) -> Option { - #[cfg(feature = "std")] - { - self.inner.downcast_ref::().map(|c| c.0) - } - #[cfg(not(feature = "std"))] - { - Some(self.code) - } - } -} - -impl fmt::Debug for Error { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - #[cfg(feature = "std")] - { - write!(f, "Error {{ inner: {:?} }}", self.inner) - } - #[cfg(all(feature = "getrandom", not(feature = "std")))] - { - getrandom::Error::from(self.code).fmt(f) - } - #[cfg(not(any(feature = "getrandom", feature = "std")))] - { - write!(f, "Error {{ code: {} }}", self.code) - } - } -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - #[cfg(feature = "std")] - { - write!(f, "{}", self.inner) - } - #[cfg(all(feature = "getrandom", not(feature = "std")))] - { - getrandom::Error::from(self.code).fmt(f) - } - #[cfg(not(any(feature = "getrandom", feature = "std")))] - { - write!(f, "error code {}", self.code) - } - } -} - -impl From for Error { - #[inline] - fn from(code: NonZeroU32) -> Self { - #[cfg(feature = "std")] - { - Error { - inner: Box::new(ErrorCode(code)), - } - } - #[cfg(not(feature = "std"))] - { - Error { code } - } - } -} - -#[cfg(feature = "getrandom")] -impl From for Error { - #[inline] - fn from(error: getrandom::Error) -> Self { - #[cfg(feature = "std")] - { - Error { - inner: Box::new(error), - } - } - #[cfg(not(feature = "std"))] - { - Error { code: error.code() } - } - } -} - -#[cfg(feature = "std")] -impl std::error::Error for Error { - #[inline] - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - self.inner.source() - } -} - -#[cfg(feature = "std")] -impl From for std::io::Error { - #[inline] - fn from(error: Error) -> Self { - if let Some(code) = error.raw_os_error() { - std::io::Error::from_raw_os_error(code) - } else { - std::io::Error::new(std::io::ErrorKind::Other, error) - } - } -} - -#[cfg(feature = "std")] -#[derive(Debug, Copy, Clone)] -struct ErrorCode(NonZeroU32); - -#[cfg(feature = "std")] -impl fmt::Display for ErrorCode { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "error code {}", self.0) - } -} - -#[cfg(feature = "std")] -impl std::error::Error for ErrorCode {} - -#[cfg(test)] -mod test { - #[cfg(feature = "getrandom")] - #[test] - fn test_error_codes() { - // Make sure the values are the same as in `getrandom`. - assert_eq!(super::Error::CUSTOM_START, getrandom::Error::CUSTOM_START); - assert_eq!(super::Error::INTERNAL_START, getrandom::Error::INTERNAL_START); - } -} diff --git a/rand_core/src/impls.rs b/rand_core/src/impls.rs index a8fc1a7e8c6..4cced0d6f0b 100644 --- a/rand_core/src/impls.rs +++ b/rand_core/src/impls.rs @@ -18,8 +18,6 @@ //! non-reproducible sources (e.g. `OsRng`) need not bother with it. use crate::RngCore; -use core::cmp::min; -use zerocopy::{IntoBytes, NoCell}; /// Implement `next_u64` via `next_u32`, little-endian order. pub fn next_u64_via_u32(rng: &mut R) -> u64 { @@ -53,41 +51,52 @@ pub fn fill_bytes_via_next(rng: &mut R, dest: &mut [u8]) { } } -trait Observable: IntoBytes + NoCell + Copy { - fn to_le(self) -> Self; +pub(crate) trait Observable: Copy { + type Bytes: Sized + AsRef<[u8]>; + fn to_le_bytes(self) -> Self::Bytes; } impl Observable for u32 { - fn to_le(self) -> Self { - self.to_le() + type Bytes = [u8; 4]; + + fn to_le_bytes(self) -> Self::Bytes { + Self::to_le_bytes(self) } } impl Observable for u64 { - fn to_le(self) -> Self { - self.to_le() + type Bytes = [u8; 8]; + + fn to_le_bytes(self) -> Self::Bytes { + Self::to_le_bytes(self) } } /// Fill dest from src /// -/// Returns `(n, byte_len)`. `src[..n]` is consumed (and possibly mutated), +/// Returns `(n, byte_len)`. `src[..n]` is consumed, /// `dest[..byte_len]` is filled. `src[n..]` and `dest[byte_len..]` are left /// unaltered. -fn fill_via_chunks(src: &mut [T], dest: &mut [u8]) -> (usize, usize) { +pub(crate) fn fill_via_chunks(src: &[T], dest: &mut [u8]) -> (usize, usize) { let size = core::mem::size_of::(); - let byte_len = min(src.len() * size, dest.len()); - let num_chunks = (byte_len + size - 1) / size; - - // Byte-swap for portability of results. This must happen before copying - // since the size of dest is not guaranteed to be a multiple of T or to be - // sufficiently aligned. - if cfg!(target_endian = "big") { - for x in &mut src[..num_chunks] { - *x = x.to_le(); - } - } - dest[..byte_len].copy_from_slice(&<[T]>::as_bytes(&src[..num_chunks])[..byte_len]); + // Always use little endian for portability of results. + let mut dest = dest.chunks_exact_mut(size); + let mut src = src.iter(); + + let zipped = dest.by_ref().zip(src.by_ref()); + let num_chunks = zipped.len(); + zipped.for_each(|(dest, src)| dest.copy_from_slice(src.to_le_bytes().as_ref())); + + let byte_len = num_chunks * size; + if let Some(src) = src.next() { + // We have consumed all full chunks of dest, but not src. + let dest = dest.into_remainder(); + let n = dest.len(); + if n > 0 { + dest.copy_from_slice(&src.to_le_bytes().as_ref()[..n]); + return (num_chunks + 1, byte_len + n); + } + } (num_chunks, byte_len) } @@ -96,8 +105,8 @@ fn fill_via_chunks(src: &mut [T], dest: &mut [u8]) -> (usize, usi /// /// The return values are `(consumed_u32, filled_u8)`. /// -/// On big-endian systems, endianness of `src[..consumed_u32]` values is -/// swapped. No other adjustments to `src` are made. +/// `src` is not modified; it is taken as a `&mut` reference for backward +/// compatibility with previous versions that did change it. /// /// `filled_u8` is the number of filled bytes in `dest`, which may be less than /// the length of `dest`. @@ -124,6 +133,7 @@ fn fill_via_chunks(src: &mut [T], dest: &mut [u8]) -> (usize, usi /// } /// } /// ``` +#[deprecated(since = "0.9.3", note = "use BlockRng instead")] pub fn fill_via_u32_chunks(src: &mut [u32], dest: &mut [u8]) -> (usize, usize) { fill_via_chunks(src, dest) } @@ -133,8 +143,8 @@ pub fn fill_via_u32_chunks(src: &mut [u32], dest: &mut [u8]) -> (usize, usize) { /// /// The return values are `(consumed_u64, filled_u8)`. /// -/// On big-endian systems, endianness of `src[..consumed_u64]` values is -/// swapped. No other adjustments to `src` are made. +/// `src` is not modified; it is taken as a `&mut` reference for backward +/// compatibility with previous versions that did change it. /// /// `filled_u8` is the number of filled bytes in `dest`, which may be less than /// the length of `dest`. @@ -142,6 +152,7 @@ pub fn fill_via_u32_chunks(src: &mut [u32], dest: &mut [u8]) -> (usize, usize) { /// as `filled_u8 / 8` rounded up. /// /// See `fill_via_u32_chunks` for an example. +#[deprecated(since = "0.9.3", note = "use BlockRng64 instead")] pub fn fill_via_u64_chunks(src: &mut [u64], dest: &mut [u8]) -> (usize, usize) { fill_via_chunks(src, dest) } @@ -166,41 +177,41 @@ mod test { #[test] fn test_fill_via_u32_chunks() { - let src_orig = [1, 2, 3]; + let src_orig = [1u32, 2, 3]; let mut src = src_orig; let mut dst = [0u8; 11]; - assert_eq!(fill_via_u32_chunks(&mut src, &mut dst), (3, 11)); + assert_eq!(fill_via_chunks(&mut src, &mut dst), (3, 11)); assert_eq!(dst, [1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0]); let mut src = src_orig; let mut dst = [0u8; 13]; - assert_eq!(fill_via_u32_chunks(&mut src, &mut dst), (3, 12)); + assert_eq!(fill_via_chunks(&mut src, &mut dst), (3, 12)); assert_eq!(dst, [1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 0]); let mut src = src_orig; let mut dst = [0u8; 5]; - assert_eq!(fill_via_u32_chunks(&mut src, &mut dst), (2, 5)); + assert_eq!(fill_via_chunks(&mut src, &mut dst), (2, 5)); assert_eq!(dst, [1, 0, 0, 0, 2]); } #[test] fn test_fill_via_u64_chunks() { - let src_orig = [1, 2]; + let src_orig = [1u64, 2]; let mut src = src_orig; let mut dst = [0u8; 11]; - assert_eq!(fill_via_u64_chunks(&mut src, &mut dst), (2, 11)); + assert_eq!(fill_via_chunks(&mut src, &mut dst), (2, 11)); assert_eq!(dst, [1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0]); let mut src = src_orig; let mut dst = [0u8; 17]; - assert_eq!(fill_via_u64_chunks(&mut src, &mut dst), (2, 16)); + assert_eq!(fill_via_chunks(&mut src, &mut dst), (2, 16)); assert_eq!(dst, [1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0]); let mut src = src_orig; let mut dst = [0u8; 5]; - assert_eq!(fill_via_u64_chunks(&mut src, &mut dst), (1, 5)); + assert_eq!(fill_via_chunks(&mut src, &mut dst), (1, 5)); assert_eq!(dst, [1, 0, 0, 0, 0]); } } diff --git a/rand_core/src/le.rs b/rand_core/src/le.rs index ed42e57f478..6c4d7c82ad0 100644 --- a/rand_core/src/le.rs +++ b/rand_core/src/le.rs @@ -11,10 +11,16 @@ //! Little-Endian order has been chosen for internal usage; this makes some //! useful functions available. -use core::convert::TryInto; - -/// Reads unsigned 32 bit integers from `src` into `dst`. +/// Fills `dst: &mut [u32]` from `src` +/// +/// Reads use Little-Endian byte order, allowing portable reproduction of `dst` +/// from a byte slice. +/// +/// # Panics +/// +/// If `src` has insufficient length (if `src.len() < 4*dst.len()`). #[inline] +#[track_caller] pub fn read_u32_into(src: &[u8], dst: &mut [u32]) { assert!(src.len() >= 4 * dst.len()); for (out, chunk) in dst.iter_mut().zip(src.chunks_exact(4)) { @@ -22,8 +28,13 @@ pub fn read_u32_into(src: &[u8], dst: &mut [u32]) { } } -/// Reads unsigned 64 bit integers from `src` into `dst`. +/// Fills `dst: &mut [u64]` from `src` +/// +/// # Panics +/// +/// If `src` has insufficient length (if `src.len() < 8*dst.len()`). #[inline] +#[track_caller] pub fn read_u64_into(src: &[u8], dst: &mut [u64]) { assert!(src.len() >= 8 * dst.len()); for (out, chunk) in dst.iter_mut().zip(src.chunks_exact(8)) { diff --git a/rand_core/src/lib.rs b/rand_core/src/lib.rs index 292c57ffb8b..6c007797806 100644 --- a/rand_core/src/lib.rs +++ b/rand_core/src/lib.rs @@ -19,9 +19,6 @@ //! [`SeedableRng`] is an extension trait for construction from fixed seeds and //! other random number generators. //! -//! [`Error`] is provided for error-handling. It is safe to use in `no_std` -//! environments. -//! //! The [`impls`] and [`le`] sub-modules include a few small functions to assist //! implementation of [`RngCore`]. //! @@ -34,33 +31,30 @@ )] #![deny(missing_docs)] #![deny(missing_debug_implementations)] +#![deny(clippy::undocumented_unsafe_blocks)] #![doc(test(attr(allow(unused_variables), deny(warnings))))] -#![cfg_attr(doc_cfg, feature(doc_cfg))] +#![cfg_attr(docsrs, feature(doc_auto_cfg))] #![no_std] -use core::convert::AsMut; -use core::default::Default; - -#[cfg(feature = "alloc")] extern crate alloc; -#[cfg(feature = "std")] extern crate std; -#[cfg(feature = "alloc")] use alloc::boxed::Box; - -pub use error::Error; -#[cfg(feature = "getrandom")] pub use os::OsRng; +#[cfg(feature = "std")] +extern crate std; +use core::{fmt, ops::DerefMut}; pub mod block; -mod error; pub mod impls; pub mod le; -#[cfg(feature = "getrandom")] mod os; +#[cfg(feature = "os_rng")] +mod os; +#[cfg(feature = "os_rng")] +pub use os::{OsError, OsRng}; -/// The core of a random number generator. +/// Implementation-level interface for RNGs /// /// This trait encapsulates the low-level functionality common to all /// generators, and is the "back end", to be implemented by generators. -/// End users should normally use the `Rng` trait from the [`rand`] crate, +/// End users should normally use the [`rand::Rng`] trait /// which is automatically implemented for every type implementing `RngCore`. /// /// Three different methods for generating random data are provided since the @@ -71,11 +65,6 @@ pub mod le; /// [`next_u32`] and [`next_u64`] methods, implementations may discard some /// random bits for efficiency. /// -/// The [`try_fill_bytes`] method is a variant of [`fill_bytes`] allowing error -/// handling; it is not deemed sufficiently useful to add equivalents for -/// [`next_u32`] or [`next_u64`] since the latter methods are almost always used -/// with algorithmic generators (PRNGs), which are normally infallible. -/// /// Implementers should produce bits uniformly. Pathological RNGs (e.g. always /// returning the same value, or never setting certain bits) can break rejection /// sampling used by random distributions, and also break other RNGs when @@ -90,6 +79,10 @@ pub mod le; /// in this trait directly, then use the helper functions from the /// [`impls`] module to implement the other methods. /// +/// Note that implementors of [`RngCore`] also automatically implement +/// the [`TryRngCore`] trait with the `Error` associated type being +/// equal to [`Infallible`]. +/// /// It is recommended that implementations also implement: /// /// - `Debug` with a custom implementation which *does not* print any internal @@ -110,7 +103,7 @@ pub mod le; /// /// ``` /// #![allow(dead_code)] -/// use rand_core::{RngCore, Error, impls}; +/// use rand_core::{RngCore, impls}; /// /// struct CountingRng(u64); /// @@ -124,21 +117,17 @@ pub mod le; /// self.0 /// } /// -/// fn fill_bytes(&mut self, dest: &mut [u8]) { -/// impls::fill_bytes_via_next(self, dest) -/// } -/// -/// fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { -/// Ok(self.fill_bytes(dest)) +/// fn fill_bytes(&mut self, dst: &mut [u8]) { +/// impls::fill_bytes_via_next(self, dst) /// } /// } /// ``` /// -/// [`rand`]: https://docs.rs/rand -/// [`try_fill_bytes`]: RngCore::try_fill_bytes +/// [`rand::Rng`]: https://docs.rs/rand/latest/rand/trait.Rng.html /// [`fill_bytes`]: RngCore::fill_bytes /// [`next_u32`]: RngCore::next_u32 /// [`next_u64`]: RngCore::next_u64 +/// [`Infallible`]: core::convert::Infallible pub trait RngCore { /// Return the next random `u32`. /// @@ -158,68 +147,227 @@ pub trait RngCore { /// /// RNGs must implement at least one method from this trait directly. In /// the case this method is not implemented directly, it can be implemented - /// via [`impls::fill_bytes_via_next`] or - /// via [`RngCore::try_fill_bytes`]; if this generator can - /// fail the implementation must choose how best to handle errors here - /// (e.g. panic with a descriptive message or log a warning and retry a few - /// times). + /// via [`impls::fill_bytes_via_next`]. /// /// This method should guarantee that `dest` is entirely filled /// with new data, and may panic if this is impossible /// (e.g. reading past the end of a file that is being used as the /// source of randomness). - fn fill_bytes(&mut self, dest: &mut [u8]); + fn fill_bytes(&mut self, dst: &mut [u8]); +} +impl RngCore for T +where + T::Target: RngCore, +{ + #[inline] + fn next_u32(&mut self) -> u32 { + self.deref_mut().next_u32() + } + + #[inline] + fn next_u64(&mut self) -> u64 { + self.deref_mut().next_u64() + } + + #[inline] + fn fill_bytes(&mut self, dst: &mut [u8]) { + self.deref_mut().fill_bytes(dst); + } +} + +/// A marker trait over [`RngCore`] for securely unpredictable RNGs +/// +/// This marker trait indicates that the implementing generator is intended, +/// when correctly seeded and protected from side-channel attacks such as a +/// leaking of state, to be a cryptographically secure generator. This trait is +/// provided as a tool to aid review of cryptographic code, but does not by +/// itself guarantee suitability for cryptographic applications. +/// +/// Implementors of `CryptoRng` automatically implement the [`TryCryptoRng`] +/// trait. +/// +/// Implementors of `CryptoRng` should only implement [`Default`] if the +/// `default()` instances are themselves secure generators: for example if the +/// implementing type is a stateless interface over a secure external generator +/// (like [`OsRng`]) or if the `default()` instance uses a strong, fresh seed. +/// +/// Formally, a CSPRNG (Cryptographically Secure Pseudo-Random Number Generator) +/// should satisfy an additional property over other generators: assuming that +/// the generator has been appropriately seeded and has unknown state, then +/// given the first *k* bits of an algorithm's output +/// sequence, it should not be possible using polynomial-time algorithms to +/// predict the next bit with probability significantly greater than 50%. +/// +/// An optional property of CSPRNGs is backtracking resistance: if the CSPRNG's +/// state is revealed, it will not be computationally-feasible to reconstruct +/// prior output values. This property is not required by `CryptoRng`. +pub trait CryptoRng: RngCore {} + +impl CryptoRng for T where T::Target: CryptoRng {} + +/// A potentially fallible variant of [`RngCore`] +/// +/// This trait is a generalization of [`RngCore`] to support potentially- +/// fallible IO-based generators such as [`OsRng`]. +/// +/// All implementations of [`RngCore`] automatically support this `TryRngCore` +/// trait, using [`Infallible`][core::convert::Infallible] as the associated +/// `Error` type. +/// +/// An implementation of this trait may be made compatible with code requiring +/// an [`RngCore`] through [`TryRngCore::unwrap_err`]. The resulting RNG will +/// panic in case the underlying fallible RNG yields an error. +pub trait TryRngCore { + /// The type returned in the event of a RNG error. + type Error: fmt::Debug + fmt::Display; + + /// Return the next random `u32`. + fn try_next_u32(&mut self) -> Result; + /// Return the next random `u64`. + fn try_next_u64(&mut self) -> Result; /// Fill `dest` entirely with random data. - /// - /// This is the only method which allows an RNG to report errors while - /// generating random data thus making this the primary method implemented - /// by external (true) RNGs (e.g. `OsRng`) which can fail. It may be used - /// directly to generate keys and to seed (infallible) PRNGs. - /// - /// Other than error handling, this method is identical to [`RngCore::fill_bytes`]; - /// thus this may be implemented using `Ok(self.fill_bytes(dest))` or - /// `fill_bytes` may be implemented with - /// `self.try_fill_bytes(dest).unwrap()` or more specific error handling. - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error>; + fn try_fill_bytes(&mut self, dst: &mut [u8]) -> Result<(), Self::Error>; + + /// Wrap RNG with the [`UnwrapErr`] wrapper. + fn unwrap_err(self) -> UnwrapErr + where + Self: Sized, + { + UnwrapErr(self) + } + + /// Wrap RNG with the [`UnwrapMut`] wrapper. + fn unwrap_mut(&mut self) -> UnwrapMut<'_, Self> { + UnwrapMut(self) + } /// Convert an [`RngCore`] to a [`RngReadAdapter`]. #[cfg(feature = "std")] fn read_adapter(&mut self) -> RngReadAdapter<'_, Self> - where Self: Sized { + where + Self: Sized, + { RngReadAdapter { inner: self } } } -/// A marker trait used to indicate that an [`RngCore`] implementation is -/// supposed to be cryptographically secure. -/// -/// *Cryptographically secure generators*, also known as *CSPRNGs*, should -/// satisfy an additional properties over other generators: given the first -/// *k* bits of an algorithm's output -/// sequence, it should not be possible using polynomial-time algorithms to -/// predict the next bit with probability significantly greater than 50%. -/// -/// Some generators may satisfy an additional property, however this is not -/// required by this trait: if the CSPRNG's state is revealed, it should not be -/// computationally-feasible to reconstruct output prior to this. Some other -/// generators allow backwards-computation and are considered *reversible*. +// Note that, unfortunately, this blanket impl prevents us from implementing +// `TryRngCore` for types which can be dereferenced to `TryRngCore`, i.e. `TryRngCore` +// will not be automatically implemented for `&mut R`, `Box`, etc. +impl TryRngCore for R { + type Error = core::convert::Infallible; + + #[inline] + fn try_next_u32(&mut self) -> Result { + Ok(self.next_u32()) + } + + #[inline] + fn try_next_u64(&mut self) -> Result { + Ok(self.next_u64()) + } + + #[inline] + fn try_fill_bytes(&mut self, dst: &mut [u8]) -> Result<(), Self::Error> { + self.fill_bytes(dst); + Ok(()) + } +} + +/// A marker trait over [`TryRngCore`] for securely unpredictable RNGs /// -/// Note that this trait is provided for guidance only and cannot guarantee -/// suitability for cryptographic applications. In general it should only be -/// implemented for well-reviewed code implementing well-regarded algorithms. +/// This trait is like [`CryptoRng`] but for the trait [`TryRngCore`]. /// -/// Note also that use of a `CryptoRng` does not protect against other -/// weaknesses such as seeding from a weak entropy source or leaking state. +/// This marker trait indicates that the implementing generator is intended, +/// when correctly seeded and protected from side-channel attacks such as a +/// leaking of state, to be a cryptographically secure generator. This trait is +/// provided as a tool to aid review of cryptographic code, but does not by +/// itself guarantee suitability for cryptographic applications. /// -/// [`BlockRngCore`]: block::BlockRngCore -pub trait CryptoRng: RngCore {} +/// Implementors of `TryCryptoRng` should only implement [`Default`] if the +/// `default()` instances are themselves secure generators: for example if the +/// implementing type is a stateless interface over a secure external generator +/// (like [`OsRng`]) or if the `default()` instance uses a strong, fresh seed. +pub trait TryCryptoRng: TryRngCore {} + +impl TryCryptoRng for R {} + +/// Wrapper around [`TryRngCore`] implementation which implements [`RngCore`] +/// by panicking on potential errors. +#[derive(Debug, Default, Clone, Copy, Eq, PartialEq, Hash)] +pub struct UnwrapErr(pub R); + +impl RngCore for UnwrapErr { + #[inline] + fn next_u32(&mut self) -> u32 { + self.0.try_next_u32().unwrap() + } + + #[inline] + fn next_u64(&mut self) -> u64 { + self.0.try_next_u64().unwrap() + } + + #[inline] + fn fill_bytes(&mut self, dst: &mut [u8]) { + self.0.try_fill_bytes(dst).unwrap() + } +} + +impl CryptoRng for UnwrapErr {} + +/// Wrapper around [`TryRngCore`] implementation which implements [`RngCore`] +/// by panicking on potential errors. +#[derive(Debug, Eq, PartialEq, Hash)] +pub struct UnwrapMut<'r, R: TryRngCore + ?Sized>(pub &'r mut R); + +impl<'r, R: TryRngCore + ?Sized> UnwrapMut<'r, R> { + /// Reborrow with a new lifetime + /// + /// Rust allows references like `&T` or `&mut T` to be "reborrowed" through + /// coercion: essentially, the pointer is copied under a new, shorter, lifetime. + /// Until rfcs#1403 lands, reborrows on user types require a method call. + #[inline(always)] + pub fn re<'b>(&'b mut self) -> UnwrapMut<'b, R> + where + 'r: 'b, + { + UnwrapMut(self.0) + } +} + +impl RngCore for UnwrapMut<'_, R> { + #[inline] + fn next_u32(&mut self) -> u32 { + self.0.try_next_u32().unwrap() + } + + #[inline] + fn next_u64(&mut self) -> u64 { + self.0.try_next_u64().unwrap() + } + + #[inline] + fn fill_bytes(&mut self, dst: &mut [u8]) { + self.0.try_fill_bytes(dst).unwrap() + } +} + +impl CryptoRng for UnwrapMut<'_, R> {} /// A random number generator that can be explicitly seeded. /// /// This trait encapsulates the low-level functionality common to all /// pseudo-random number generators (PRNGs, or algorithmic generators). /// +/// A generator implementing `SeedableRng` will usually be deterministic, but +/// beware that portability and reproducibility of results **is not implied**. +/// Refer to documentation of the generator, noting that generators named after +/// a specific algorithm are usually tested for reproducibility against a +/// reference vector, while `SmallRng` and `StdRng` specifically opt out of +/// reproducibility guarantees. +/// /// [`rand`]: https://docs.rs/rand pub trait SeedableRng: Sized { /// Seed type, which is restricted to types mutably-dereferenceable as `u8` @@ -234,16 +382,15 @@ pub trait SeedableRng: Sized { /// /// # Implementing `SeedableRng` for RNGs with large seeds /// - /// Note that the required traits `core::default::Default` and - /// `core::convert::AsMut` are not implemented for large arrays - /// `[u8; N]` with `N` > 32. To be able to implement the traits required by - /// `SeedableRng` for RNGs with such large seeds, the newtype pattern can be - /// used: + /// Note that [`Default`] is not implemented for large arrays `[u8; N]` with + /// `N` > 32. To be able to implement the traits required by `SeedableRng` + /// for RNGs with such large seeds, the newtype pattern can be used: /// /// ``` /// use rand_core::SeedableRng; /// /// const N: usize = 64; + /// #[derive(Clone)] /// pub struct MyRngSeed(pub [u8; N]); /// # #[allow(dead_code)] /// pub struct MyRng(MyRngSeed); @@ -254,6 +401,12 @@ pub trait SeedableRng: Sized { /// } /// } /// + /// impl AsRef<[u8]> for MyRngSeed { + /// fn as_ref(&self) -> &[u8] { + /// &self.0 + /// } + /// } + /// /// impl AsMut<[u8]> for MyRngSeed { /// fn as_mut(&mut self) -> &mut [u8] { /// &mut self.0 @@ -268,7 +421,7 @@ pub trait SeedableRng: Sized { /// } /// } /// ``` - type Seed: Sized + Default + AsMut<[u8]>; + type Seed: Clone + Default + AsRef<[u8]> + AsMut<[u8]>; /// Create a new PRNG using the given seed. /// @@ -342,7 +495,7 @@ pub trait SeedableRng: Sized { Self::from_seed(seed) } - /// Create a new PRNG seeded from another `Rng`. + /// Create a new PRNG seeded from an infallible `Rng`. /// /// This may be useful when needing to rapidly seed many PRNGs from a master /// PRNG, and to allow forking of PRNGs. It may be considered deterministic. @@ -366,7 +519,16 @@ pub trait SeedableRng: Sized { /// (in prior versions this was not required). /// /// [`rand`]: https://docs.rs/rand - fn from_rng(mut rng: R) -> Result { + fn from_rng(rng: &mut impl RngCore) -> Self { + let mut seed = Self::Seed::default(); + rng.fill_bytes(seed.as_mut()); + Self::from_seed(seed) + } + + /// Create a new PRNG seeded from a potentially fallible `Rng`. + /// + /// See [`from_rng`][SeedableRng::from_rng] docs for more information. + fn try_from_rng(rng: &mut R) -> Result { let mut seed = Self::Seed::default(); rng.try_fill_bytes(seed.as_mut())?; Ok(Self::from_seed(seed)) @@ -377,74 +539,41 @@ pub trait SeedableRng: Sized { /// This method is the recommended way to construct non-deterministic PRNGs /// since it is convenient and secure. /// + /// Note that this method may panic on (extremely unlikely) [`getrandom`] errors. + /// If it's not desirable, use the [`try_from_os_rng`] method instead. + /// /// In case the overhead of using [`getrandom`] to seed *many* PRNGs is an /// issue, one may prefer to seed from a local PRNG, e.g. - /// `from_rng(thread_rng()).unwrap()`. + /// `from_rng(rand::rng()).unwrap()`. /// /// # Panics /// /// If [`getrandom`] is unable to provide secure entropy this method will panic. /// /// [`getrandom`]: https://docs.rs/getrandom - #[cfg(feature = "getrandom")] - #[cfg_attr(doc_cfg, doc(cfg(feature = "getrandom")))] - fn from_entropy() -> Self { - let mut seed = Self::Seed::default(); - if let Err(err) = getrandom::getrandom(seed.as_mut()) { - panic!("from_entropy failed: {}", err); + /// [`try_from_os_rng`]: SeedableRng::try_from_os_rng + #[cfg(feature = "os_rng")] + fn from_os_rng() -> Self { + match Self::try_from_os_rng() { + Ok(res) => res, + Err(err) => panic!("from_os_rng failed: {}", err), } - Self::from_seed(seed) - } -} - -// Implement `RngCore` for references to an `RngCore`. -// Force inlining all functions, so that it is up to the `RngCore` -// implementation and the optimizer to decide on inlining. -impl<'a, R: RngCore + ?Sized> RngCore for &'a mut R { - #[inline(always)] - fn next_u32(&mut self) -> u32 { - (**self).next_u32() } - #[inline(always)] - fn next_u64(&mut self) -> u64 { - (**self).next_u64() - } - - #[inline(always)] - fn fill_bytes(&mut self, dest: &mut [u8]) { - (**self).fill_bytes(dest) - } - - #[inline(always)] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - (**self).try_fill_bytes(dest) - } -} - -// Implement `RngCore` for boxed references to an `RngCore`. -// Force inlining all functions, so that it is up to the `RngCore` -// implementation and the optimizer to decide on inlining. -#[cfg(feature = "alloc")] -impl RngCore for Box { - #[inline(always)] - fn next_u32(&mut self) -> u32 { - (**self).next_u32() - } - - #[inline(always)] - fn next_u64(&mut self) -> u64 { - (**self).next_u64() - } - - #[inline(always)] - fn fill_bytes(&mut self, dest: &mut [u8]) { - (**self).fill_bytes(dest) - } - - #[inline(always)] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - (**self).try_fill_bytes(dest) + /// Creates a new instance of the RNG seeded via [`getrandom`] without unwrapping + /// potential [`getrandom`] errors. + /// + /// In case the overhead of using [`getrandom`] to seed *many* PRNGs is an + /// issue, one may prefer to seed from a local PRNG, e.g. + /// `from_rng(&mut rand::rng()).unwrap()`. + /// + /// [`getrandom`]: https://docs.rs/getrandom + #[cfg(feature = "os_rng")] + fn try_from_os_rng() -> Result { + let mut seed = Self::Seed::default(); + getrandom::fill(seed.as_mut())?; + let res = Self::from_seed(seed); + Ok(res) } } @@ -455,37 +584,33 @@ impl RngCore for Box { /// ```no_run /// # use std::{io, io::Read}; /// # use std::fs::File; -/// # use rand_core::{OsRng, RngCore}; +/// # use rand_core::{OsRng, TryRngCore}; /// /// io::copy(&mut OsRng.read_adapter().take(100), &mut File::create("/tmp/random.bytes").unwrap()).unwrap(); /// ``` #[cfg(feature = "std")] -pub struct RngReadAdapter<'a, R: RngCore + ?Sized> { +pub struct RngReadAdapter<'a, R: TryRngCore + ?Sized> { inner: &'a mut R, } #[cfg(feature = "std")] -impl std::io::Read for RngReadAdapter<'_, R> { +impl std::io::Read for RngReadAdapter<'_, R> { + #[inline] fn read(&mut self, buf: &mut [u8]) -> Result { - self.inner.try_fill_bytes(buf)?; + self.inner.try_fill_bytes(buf).map_err(|err| { + std::io::Error::new(std::io::ErrorKind::Other, std::format!("RNG error: {err}")) + })?; Ok(buf.len()) } } #[cfg(feature = "std")] -impl std::fmt::Debug for RngReadAdapter<'_, R> { +impl std::fmt::Debug for RngReadAdapter<'_, R> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("ReadAdapter").finish() } } -// Implement `CryptoRng` for references to a `CryptoRng`. -impl<'a, R: CryptoRng + ?Sized> CryptoRng for &'a mut R {} - -// Implement `CryptoRng` for boxed references to a `CryptoRng`. -#[cfg(feature = "alloc")] -impl CryptoRng for Box {} - #[cfg(test)] mod test { use super::*; @@ -530,4 +655,118 @@ mod test { // value-breakage test: assert_eq!(results[0], 5029875928683246316); } + + // A stub RNG. + struct SomeRng; + + impl RngCore for SomeRng { + fn next_u32(&mut self) -> u32 { + unimplemented!() + } + fn next_u64(&mut self) -> u64 { + unimplemented!() + } + fn fill_bytes(&mut self, _: &mut [u8]) { + unimplemented!() + } + } + + impl CryptoRng for SomeRng {} + + #[test] + fn dyn_rngcore_to_tryrngcore() { + // Illustrates the need for `+ ?Sized` bound in `impl TryRngCore for R`. + + // A method in another crate taking a fallible RNG + fn third_party_api(_rng: &mut (impl TryRngCore + ?Sized)) -> bool { + true + } + + // A method in our crate requiring an infallible RNG + fn my_api(rng: &mut dyn RngCore) -> bool { + // We want to call the method above + third_party_api(rng) + } + + assert!(my_api(&mut SomeRng)); + } + + #[test] + fn dyn_cryptorng_to_trycryptorng() { + // Illustrates the need for `+ ?Sized` bound in `impl TryCryptoRng for R`. + + // A method in another crate taking a fallible RNG + fn third_party_api(_rng: &mut (impl TryCryptoRng + ?Sized)) -> bool { + true + } + + // A method in our crate requiring an infallible RNG + fn my_api(rng: &mut dyn CryptoRng) -> bool { + // We want to call the method above + third_party_api(rng) + } + + assert!(my_api(&mut SomeRng)); + } + + #[test] + fn dyn_unwrap_mut_tryrngcore() { + // Illustrates the need for `+ ?Sized` bound in + // `impl RngCore for UnwrapMut<'_, R>`. + + fn third_party_api(_rng: &mut impl RngCore) -> bool { + true + } + + fn my_api(rng: &mut (impl TryRngCore + ?Sized)) -> bool { + let mut infallible_rng = rng.unwrap_mut(); + third_party_api(&mut infallible_rng) + } + + assert!(my_api(&mut SomeRng)); + } + + #[test] + fn dyn_unwrap_mut_trycryptorng() { + // Illustrates the need for `+ ?Sized` bound in + // `impl CryptoRng for UnwrapMut<'_, R>`. + + fn third_party_api(_rng: &mut impl CryptoRng) -> bool { + true + } + + fn my_api(rng: &mut (impl TryCryptoRng + ?Sized)) -> bool { + let mut infallible_rng = rng.unwrap_mut(); + third_party_api(&mut infallible_rng) + } + + assert!(my_api(&mut SomeRng)); + } + + #[test] + fn reborrow_unwrap_mut() { + struct FourRng; + + impl TryRngCore for FourRng { + type Error = core::convert::Infallible; + fn try_next_u32(&mut self) -> Result { + Ok(4) + } + fn try_next_u64(&mut self) -> Result { + unimplemented!() + } + fn try_fill_bytes(&mut self, _: &mut [u8]) -> Result<(), Self::Error> { + unimplemented!() + } + } + + let mut rng = FourRng; + let mut rng = rng.unwrap_mut(); + + assert_eq!(rng.next_u32(), 4); + let mut rng2 = rng.re(); + assert_eq!(rng2.next_u32(), 4); + drop(rng2); + assert_eq!(rng.next_u32(), 4); + } } diff --git a/rand_core/src/os.rs b/rand_core/src/os.rs index b43c9fdaf05..49111632d9f 100644 --- a/rand_core/src/os.rs +++ b/rand_core/src/os.rs @@ -8,19 +8,17 @@ //! Interface to the random number generator of the operating system. -use crate::{impls, CryptoRng, Error, RngCore}; -use getrandom::getrandom; +use crate::{TryCryptoRng, TryRngCore}; -/// A random number generator that retrieves randomness from the -/// operating system. +/// An interface over the operating-system's random data source /// -/// This is a zero-sized struct. It can be freely constructed with `OsRng`. +/// This is a zero-sized struct. It can be freely constructed with just `OsRng`. /// /// The implementation is provided by the [getrandom] crate. Refer to /// [getrandom] documentation for details. /// /// This struct is available as `rand_core::OsRng` and as `rand::rngs::OsRng`. -/// In both cases, this requires the crate feature `getrandom` or `std` +/// In both cases, this requires the crate feature `os_rng` or `std` /// (enabled by default in `rand` but not in `rand_core`). /// /// # Blocking and error handling @@ -32,55 +30,86 @@ use getrandom::getrandom; /// /// After the first successful call, it is highly unlikely that failures or /// significant delays will occur (although performance should be expected to -/// be much slower than a user-space PRNG). +/// be much slower than a user-space +/// [PRNG](https://rust-random.github.io/book/guide-gen.html#pseudo-random-number-generators)). /// /// # Usage example /// ``` -/// use rand_core::{RngCore, OsRng}; +/// use rand_core::{TryRngCore, OsRng}; /// /// let mut key = [0u8; 16]; -/// OsRng.fill_bytes(&mut key); -/// let random_u64 = OsRng.next_u64(); +/// OsRng.try_fill_bytes(&mut key).unwrap(); +/// let random_u64 = OsRng.try_next_u64().unwrap(); /// ``` /// /// [getrandom]: https://crates.io/crates/getrandom -#[cfg_attr(doc_cfg, doc(cfg(feature = "getrandom")))] #[derive(Clone, Copy, Debug, Default)] pub struct OsRng; -impl CryptoRng for OsRng {} +/// Error type of [`OsRng`] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct OsError(getrandom::Error); -impl RngCore for OsRng { - fn next_u32(&mut self) -> u32 { - impls::next_u32_via_fill(self) +impl core::fmt::Display for OsError { + #[inline] + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + self.0.fmt(f) } +} - fn next_u64(&mut self) -> u64 { - impls::next_u64_via_fill(self) +// NOTE: this can use core::error::Error from rustc 1.81.0 +#[cfg(feature = "std")] +impl std::error::Error for OsError { + #[inline] + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + std::error::Error::source(&self.0) } +} - fn fill_bytes(&mut self, dest: &mut [u8]) { - if let Err(e) = self.try_fill_bytes(dest) { - panic!("Error: {}", e); - } +impl OsError { + /// Extract the raw OS error code (if this error came from the OS) + /// + /// This method is identical to [`std::io::Error::raw_os_error()`][1], except + /// that it works in `no_std` contexts. If this method returns `None`, the + /// error value can still be formatted via the `Display` implementation. + /// + /// [1]: https://doc.rust-lang.org/std/io/struct.Error.html#method.raw_os_error + #[inline] + pub fn raw_os_error(self) -> Option { + self.0.raw_os_error() } +} - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - getrandom(dest)?; - Ok(()) +impl TryRngCore for OsRng { + type Error = OsError; + + #[inline] + fn try_next_u32(&mut self) -> Result { + getrandom::u32().map_err(OsError) + } + + #[inline] + fn try_next_u64(&mut self) -> Result { + getrandom::u64().map_err(OsError) + } + + #[inline] + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Self::Error> { + getrandom::fill(dest).map_err(OsError) } } +impl TryCryptoRng for OsRng {} + #[test] fn test_os_rng() { - let x = OsRng.next_u64(); - let y = OsRng.next_u64(); + let x = OsRng.try_next_u64().unwrap(); + let y = OsRng.try_next_u64().unwrap(); assert!(x != 0); assert!(x != y); } #[test] fn test_construction() { - let mut rng = OsRng::default(); - assert!(rng.next_u64() != 0); + assert!(OsRng.try_next_u64().unwrap() != 0); } diff --git a/rand_distr/CHANGELOG.md b/rand_distr/CHANGELOG.md deleted file mode 100644 index 24b43648ebd..00000000000 --- a/rand_distr/CHANGELOG.md +++ /dev/null @@ -1,85 +0,0 @@ -# Changelog -All notable changes to this project will be documented in this file. - -The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) -and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). - -## [0.5.0-alpha.0] - 2024-02-18 -This is a pre-release. To depend on this version, use `rand_distr = "=0.5.0-alpha.0"` to prevent automatic updates (which can be expected to include breaking changes). - -### Additions -- Make distributions comparable with `PartialEq` (#1218) -- Add `WeightedIndexTree` (#1372) - -### Changes -- Target `rand` version `0.9.0-alpha.0` -- Remove unused fields from `Gamma`, `NormalInverseGaussian` and `Zipf` distributions (#1184) - This breaks serialization compatibility with older versions. -- `Dirichlet` now uses `const` generics, which means that its size is required at compile time (#1292) -- The `Dirichlet::new_with_size` constructor was removed (#1292) - -### Fixes -- Fix Knuth's method so `Poisson` doesn't return -1.0 for small lambda (#1284) -- Fix `Poisson` distribution instantiation so it return an error if lambda is infinite (#1291) -- Fix Dirichlet sample for small alpha values to avoid NaN samples (#1209) -- Fix infinite loop in `Binomial` distribution (#1325) - -## [0.4.3] - 2021-12-30 -- Fix `no_std` build (#1208) - -## [0.4.2] - 2021-09-18 -- New `Zeta` and `Zipf` distributions (#1136) -- New `SkewNormal` distribution (#1149) -- New `Gumbel` and `Frechet` distributions (#1168, #1171) - -## [0.4.1] - 2021-06-15 -- Empirically test PDF of normal distribution (#1121) -- Correctly document `no_std` support (#1100) -- Add `std_math` feature to prefer `std` over `libm` for floating point math (#1100) -- Add mean and std_dev accessors to Normal (#1114) -- Make sure all distributions and their error types implement `Error`, `Display`, `Clone`, - `Copy`, `PartialEq` and `Eq` as appropriate (#1126) -- Port benchmarks to use Criterion crate (#1116) -- Support serde for distributions (#1141) - -## [0.4.0] - 2020-12-18 -- Bump `rand` to v0.8.0 -- New `Geometric`, `StandardGeometric` and `Hypergeometric` distributions (#1062) -- New `Beta` sampling algorithm for improved performance and accuracy (#1000) -- `Normal` and `LogNormal` now support `from_mean_cv` and `from_zscore` (#1044) -- Variants of `NormalError` changed (#1044) - -## [0.3.0] - 2020-08-25 -- Move alias method for `WeightedIndex` from `rand` (#945) -- Rename `WeightedIndex` to `WeightedAliasIndex` (#1008) -- Replace custom `Float` trait with `num-traits::Float` (#987) -- Enable `no_std` support via `num-traits` math functions (#987) -- Remove `Distribution` impl for `Poisson` (#987) -- Tweak `Dirichlet` and `alias_method` to use boxed slice instead of `Vec` (#987) -- Use whitelist for package contents, reducing size by 5kb (#983) -- Add case `lambda = 0` in the parametrization of `Exp` (#972) -- Implement inverse Gaussian distribution (#954) -- Reformatting and use of `rustfmt::skip` (#926) -- All error types now implement `std::error::Error` (#919) -- Re-exported `rand::distributions::BernoulliError` (#919) -- Add value stability tests for distributions (#891) - -## [0.2.2] - 2019-09-10 -- Fix version requirement on rand lib (#847) -- Clippy fixes & suppression (#840) - -## [0.2.1] - 2019-06-29 -- Update dependency to support Rand 0.7 -- Doc link fixes - -## [0.2.0] - 2019-06-06 -- Remove `new` constructors for zero-sized types -- Add Pert distribution -- Fix undefined behavior in `Poisson` -- Make all distributions return `Result`s instead of panicking -- Implement `f32` support for most distributions -- Rename `UnitSphereSurface` to `UnitSphere` -- Implement `UnitBall` and `UnitDisc` - -## [0.1.0] - 2019-06-06 -Initial release. This is equivalent to the code in `rand` 0.6.5. diff --git a/rand_distr/COPYRIGHT b/rand_distr/COPYRIGHT deleted file mode 100644 index 468d907caf9..00000000000 --- a/rand_distr/COPYRIGHT +++ /dev/null @@ -1,12 +0,0 @@ -Copyrights in the Rand project are retained by their contributors. No -copyright assignment is required to contribute to the Rand project. - -For full authorship information, see the version control history. - -Except as otherwise noted (below and/or in individual files), Rand is -licensed under the Apache License, Version 2.0 or - or the MIT license - or , at your option. - -The Rand project includes code from the Rust project -published under these same licenses. diff --git a/rand_distr/Cargo.toml b/rand_distr/Cargo.toml deleted file mode 100644 index 36533d46464..00000000000 --- a/rand_distr/Cargo.toml +++ /dev/null @@ -1,42 +0,0 @@ -[package] -name = "rand_distr" -version = "0.5.0-alpha.0" -authors = ["The Rand Project Developers"] -license = "MIT OR Apache-2.0" -readme = "README.md" -repository = "https://github.com/rust-random/rand" -documentation = "https://docs.rs/rand_distr" -homepage = "https://rust-random.github.io/book" -description = """ -Sampling from random number distributions -""" -keywords = ["random", "rng", "distribution", "probability"] -categories = ["algorithms", "no-std"] -edition = "2021" -rust-version = "1.60" -include = ["src/", "LICENSE-*", "README.md", "CHANGELOG.md", "COPYRIGHT"] - -[package.metadata.docs.rs] -rustdoc-args = ["--generate-link-to-definition"] - -[features] -default = ["std"] -std = ["alloc", "rand/std"] -alloc = ["rand/alloc"] -std_math = ["num-traits/std"] -serde1 = ["serde", "rand/serde1"] - -[dependencies] -rand = { path = "..", version = "=0.9.0-alpha.0", default-features = false } -num-traits = { version = "0.2", default-features = false, features = ["libm"] } -serde = { version = "1.0.103", features = ["derive"], optional = true } -serde_with = { version = "3.6.1", optional = true } - -[dev-dependencies] -rand_pcg = { version = "=0.9.0-alpha.0", path = "../rand_pcg" } -# For inline examples -rand = { path = "..", version = "=0.9.0-alpha.0", features = ["small_rng"] } -# Histogram implementation for testing uniformity -average = { version = "0.13", features = [ "std" ] } -# Special functions for testing distributions -special = "0.10.3" diff --git a/rand_distr/LICENSE-APACHE b/rand_distr/LICENSE-APACHE deleted file mode 100644 index 455787c2334..00000000000 --- a/rand_distr/LICENSE-APACHE +++ /dev/null @@ -1,187 +0,0 @@ - Apache License - Version 2.0, January 2004 - https://www.apache.org/licenses/ - -TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - -1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - -2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - -3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - -4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - -5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - -6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - -7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - -8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - -9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - -END OF TERMS AND CONDITIONS - -APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. diff --git a/rand_distr/LICENSE-MIT b/rand_distr/LICENSE-MIT deleted file mode 100644 index cf656074cbf..00000000000 --- a/rand_distr/LICENSE-MIT +++ /dev/null @@ -1,25 +0,0 @@ -Copyright 2018 Developers of the Rand project - -Permission is hereby granted, free of charge, to any -person obtaining a copy of this software and associated -documentation files (the "Software"), to deal in the -Software without restriction, including without -limitation the rights to use, copy, modify, merge, -publish, distribute, sublicense, and/or sell copies of -the Software, and to permit persons to whom the Software -is furnished to do so, subject to the following -conditions: - -The above copyright notice and this permission notice -shall be included in all copies or substantial portions -of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF -ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED -TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A -PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT -SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY -CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR -IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -DEALINGS IN THE SOFTWARE. diff --git a/rand_distr/README.md b/rand_distr/README.md deleted file mode 100644 index 016e8981d85..00000000000 --- a/rand_distr/README.md +++ /dev/null @@ -1,57 +0,0 @@ -# rand_distr - -[![Test Status](https://github.com/rust-random/rand/workflows/Tests/badge.svg?event=push)](https://github.com/rust-random/rand/actions) -[![Latest version](https://img.shields.io/crates/v/rand_distr.svg)](https://crates.io/crates/rand_distr) -[![Book](https://img.shields.io/badge/book-master-yellow.svg)](https://rust-random.github.io/book/) -[![API](https://img.shields.io/badge/api-master-yellow.svg)](https://rust-random.github.io/rand/rand_distr) -[![API](https://docs.rs/rand_distr/badge.svg)](https://docs.rs/rand_distr) -[![Minimum rustc version](https://img.shields.io/badge/rustc-1.60+-lightgray.svg)](https://github.com/rust-random/rand#rust-version-requirements) - -Implements a full suite of random number distribution sampling routines. - -This crate is a superset of the [rand::distributions] module, including support -for sampling from Beta, Binomial, Cauchy, ChiSquared, Dirichlet, Exponential, -FisherF, Gamma, Geometric, Hypergeometric, InverseGaussian, LogNormal, Normal, -Pareto, PERT, Poisson, StudentT, Triangular and Weibull distributions. Sampling -from the unit ball, unit circle, unit disc and unit sphere surfaces is also -supported. - -It is worth mentioning the [statrs] crate which provides similar functionality -along with various support functions, including PDF and CDF computation. In -contrast, this `rand_distr` crate focuses on sampling from distributions. - -## Portability and libm - -The floating point functions from `num_traits` and `libm` are used to support -`no_std` environments and ensure reproducibility. If the floating point -functions from `std` are preferred, which may provide better accuracy and -performance but may produce different random values, the `std_math` feature -can be enabled. - -## Crate features - -- `std` (enabled by default): `rand_distr` implements the `Error` trait for - its error types. Implies `alloc` and `rand/std`. -- `alloc` (enabled by default): required for some distributions when not using - `std` (in particular, `Dirichlet` and `WeightedAliasIndex`). -- `std_math`: see above on portability and libm -- `serde1`: implement (de)seriaialization using `serde` - -## Links - -- [API documentation (master)](https://rust-random.github.io/rand/rand_distr) -- [API documentation (docs.rs)](https://docs.rs/rand_distr) -- [Changelog](CHANGELOG.md) -- [The Rand project](https://github.com/rust-random/rand) - - -[statrs]: https://github.com/boxtown/statrs -[rand::distributions]: https://rust-random.github.io/rand/rand/distributions/index.html - -## License - -`rand_distr` is distributed under the terms of both the MIT license and the -Apache License (Version 2.0). - -See [LICENSE-APACHE](LICENSE-APACHE) and [LICENSE-MIT](LICENSE-MIT), and -[COPYRIGHT](COPYRIGHT) for details. diff --git a/rand_distr/benches/Cargo.toml b/rand_distr/benches/Cargo.toml deleted file mode 100644 index 2dd82c7973a..00000000000 --- a/rand_distr/benches/Cargo.toml +++ /dev/null @@ -1,23 +0,0 @@ -[package] -name = "benches" -version = "0.0.0" -authors = ["The Rand Project Developers"] -license = "MIT OR Apache-2.0" -description = "Criterion benchmarks of the rand_distr crate" -edition = "2021" -rust-version = "1.60" -publish = false - -[workspace] - -[dependencies] -criterion = { version = "0.3", features = ["html_reports"] } -criterion-cycles-per-byte = "0.1" -rand = { path = "../../" } -rand_distr = { path = "../" } -rand_pcg = { path = "../../rand_pcg/" } - -[[bench]] -name = "distributions" -path = "src/distributions.rs" -harness = false diff --git a/rand_distr/benches/src/distributions.rs b/rand_distr/benches/src/distributions.rs deleted file mode 100644 index 2677fca4812..00000000000 --- a/rand_distr/benches/src/distributions.rs +++ /dev/null @@ -1,232 +0,0 @@ -// Copyright 2018-2021 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -#![feature(custom_inner_attributes)] - -// Rustfmt splits macro invocations to shorten lines; in this case longer-lines are more readable -#![rustfmt::skip] - -const RAND_BENCH_N: u64 = 1000; - -use criterion::{criterion_group, criterion_main, Criterion, - Throughput}; -use criterion_cycles_per_byte::CyclesPerByte; - -use core::mem::size_of; - -use rand::prelude::*; -use rand_distr::*; - -// At this time, distributions are optimised for 64-bit platforms. -use rand_pcg::Pcg64Mcg; - -macro_rules! distr_int { - ($group:ident, $fnn:expr, $ty:ty, $distr:expr) => { - $group.throughput(Throughput::Bytes( - size_of::<$ty>() as u64 * RAND_BENCH_N)); - $group.bench_function($fnn, |c| { - let mut rng = Pcg64Mcg::from_entropy(); - let distr = $distr; - - c.iter(|| { - let mut accum: $ty = 0; - for _ in 0..RAND_BENCH_N { - let x: $ty = distr.sample(&mut rng); - accum = accum.wrapping_add(x); - } - accum - }); - }); - }; -} - -macro_rules! distr_float { - ($group:ident, $fnn:expr, $ty:ty, $distr:expr) => { - $group.throughput(Throughput::Bytes( - size_of::<$ty>() as u64 * RAND_BENCH_N)); - $group.bench_function($fnn, |c| { - let mut rng = Pcg64Mcg::from_entropy(); - let distr = $distr; - - c.iter(|| { - let mut accum = 0.; - for _ in 0..RAND_BENCH_N { - let x: $ty = distr.sample(&mut rng); - accum += x; - } - accum - }); - }); - }; -} - -macro_rules! distr { - ($group:ident, $fnn:expr, $ty:ty, $distr:expr) => { - $group.throughput(Throughput::Bytes( - size_of::<$ty>() as u64 * RAND_BENCH_N)); - $group.bench_function($fnn, |c| { - let mut rng = Pcg64Mcg::from_entropy(); - let distr = $distr; - - c.iter(|| { - let mut accum: u32 = 0; - for _ in 0..RAND_BENCH_N { - let x: $ty = distr.sample(&mut rng); - accum = accum.wrapping_add(x as u32); - } - accum - }); - }); - }; -} - -macro_rules! distr_arr { - ($group:ident, $fnn:expr, $ty:ty, $distr:expr) => { - $group.throughput(Throughput::Bytes( - size_of::<$ty>() as u64 * RAND_BENCH_N)); - $group.bench_function($fnn, |c| { - let mut rng = Pcg64Mcg::from_entropy(); - let distr = $distr; - - c.iter(|| { - let mut accum: u32 = 0; - for _ in 0..RAND_BENCH_N { - let x: $ty = distr.sample(&mut rng); - accum = accum.wrapping_add(x[0] as u32); - } - accum - }); - }); - }; -} - -macro_rules! sample_binomial { - ($group:ident, $name:expr, $n:expr, $p:expr) => { - distr_int!($group, $name, u64, Binomial::new($n, $p).unwrap()) - }; -} - -fn bench(c: &mut Criterion) { - { - let mut g = c.benchmark_group("exp"); - distr_float!(g, "exp", f64, Exp::new(1.23 * 4.56).unwrap()); - distr_float!(g, "exp1_specialized", f64, Exp1); - distr_float!(g, "exp1_general", f64, Exp::new(1.).unwrap()); - } - - { - let mut g = c.benchmark_group("normal"); - distr_float!(g, "normal", f64, Normal::new(-1.23, 4.56).unwrap()); - distr_float!(g, "standardnormal_specialized", f64, StandardNormal); - distr_float!(g, "standardnormal_general", f64, Normal::new(0., 1.).unwrap()); - distr_float!(g, "log_normal", f64, LogNormal::new(-1.23, 4.56).unwrap()); - g.throughput(Throughput::Bytes(size_of::() as u64 * RAND_BENCH_N)); - g.bench_function("iter", |c| { - let mut rng = Pcg64Mcg::from_entropy(); - let distr = Normal::new(-2.71828, 3.14159).unwrap(); - let mut iter = distr.sample_iter(&mut rng); - - c.iter(|| { - let mut accum = 0.0; - for _ in 0..RAND_BENCH_N { - accum += iter.next().unwrap(); - } - accum - }); - }); - } - - { - let mut g = c.benchmark_group("skew_normal"); - distr_float!(g, "shape_zero", f64, SkewNormal::new(0.0, 1.0, 0.0).unwrap()); - distr_float!(g, "shape_positive", f64, SkewNormal::new(0.0, 1.0, 100.0).unwrap()); - distr_float!(g, "shape_negative", f64, SkewNormal::new(0.0, 1.0, -100.0).unwrap()); - } - - { - let mut g = c.benchmark_group("gamma"); - distr_float!(g, "gamma_large_shape", f64, Gamma::new(10., 1.0).unwrap()); - distr_float!(g, "gamma_small_shape", f64, Gamma::new(0.1, 1.0).unwrap()); - distr_float!(g, "beta_small_param", f64, Beta::new(0.1, 0.1).unwrap()); - distr_float!(g, "beta_large_param_similar", f64, Beta::new(101., 95.).unwrap()); - distr_float!(g, "beta_large_param_different", f64, Beta::new(10., 1000.).unwrap()); - distr_float!(g, "beta_mixed_param", f64, Beta::new(0.5, 100.).unwrap()); - } - - { - let mut g = c.benchmark_group("cauchy"); - distr_float!(g, "cauchy", f64, Cauchy::new(4.2, 6.9).unwrap()); - } - - { - let mut g = c.benchmark_group("triangular"); - distr_float!(g, "triangular", f64, Triangular::new(0., 1., 0.9).unwrap()); - } - - { - let mut g = c.benchmark_group("geometric"); - distr_int!(g, "geometric", u64, Geometric::new(0.5).unwrap()); - distr_int!(g, "standard_geometric", u64, StandardGeometric); - } - - { - let mut g = c.benchmark_group("weighted"); - distr_int!(g, "weighted_i8", usize, WeightedIndex::new(&[1i8, 2, 3, 4, 12, 0, 2, 1]).unwrap()); - distr_int!(g, "weighted_u32", usize, WeightedIndex::new(&[1u32, 2, 3, 4, 12, 0, 2, 1]).unwrap()); - distr_int!(g, "weighted_f64", usize, WeightedIndex::new(&[1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap()); - distr_int!(g, "weighted_large_set", usize, WeightedIndex::new((0..10000).rev().chain(1..10001)).unwrap()); - distr_int!(g, "weighted_alias_method_i8", usize, WeightedAliasIndex::new(vec![1i8, 2, 3, 4, 12, 0, 2, 1]).unwrap()); - distr_int!(g, "weighted_alias_method_u32", usize, WeightedAliasIndex::new(vec![1u32, 2, 3, 4, 12, 0, 2, 1]).unwrap()); - distr_int!(g, "weighted_alias_method_f64", usize, WeightedAliasIndex::new(vec![1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap()); - distr_int!(g, "weighted_alias_method_large_set", usize, WeightedAliasIndex::new((0..10000).rev().chain(1..10001).collect()).unwrap()); - } - - { - let mut g = c.benchmark_group("binomial"); - sample_binomial!(g, "binomial", 20, 0.7); - sample_binomial!(g, "binomial_small", 1_000_000, 1e-30); - sample_binomial!(g, "binomial_1", 1, 0.9); - sample_binomial!(g, "binomial_10", 10, 0.9); - sample_binomial!(g, "binomial_100", 100, 0.99); - sample_binomial!(g, "binomial_1000", 1000, 0.01); - sample_binomial!(g, "binomial_1e12", 1000_000_000_000, 0.2); - } - - { - let mut g = c.benchmark_group("poisson"); - distr_float!(g, "poisson", f64, Poisson::new(4.0).unwrap()); - } - - { - let mut g = c.benchmark_group("zipf"); - distr_float!(g, "zipf", f64, Zipf::new(10, 1.5).unwrap()); - distr_float!(g, "zeta", f64, Zeta::new(1.5).unwrap()); - } - - { - let mut g = c.benchmark_group("bernoulli"); - distr!(g, "bernoulli", bool, Bernoulli::new(0.18).unwrap()); - } - - { - let mut g = c.benchmark_group("circle"); - distr_arr!(g, "circle", [f64; 2], UnitCircle); - } - - { - let mut g = c.benchmark_group("sphere"); - distr_arr!(g, "sphere", [f64; 3], UnitSphere); - } -} - -criterion_group!( - name = benches; - config = Criterion::default().with_measurement(CyclesPerByte); - targets = bench -); -criterion_main!(benches); diff --git a/rand_distr/src/binomial.rs b/rand_distr/src/binomial.rs deleted file mode 100644 index 2d380c64688..00000000000 --- a/rand_distr/src/binomial.rs +++ /dev/null @@ -1,380 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// Copyright 2016-2017 The Rust Project Developers. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! The binomial distribution. - -use crate::{Distribution, Uniform}; -use rand::Rng; -use core::fmt; -use core::cmp::Ordering; -#[allow(unused_imports)] -use num_traits::Float; - -/// The binomial distribution `Binomial(n, p)`. -/// -/// This distribution has density function: -/// `f(k) = n!/(k! (n-k)!) p^k (1-p)^(n-k)` for `k >= 0`. -/// -/// # Example -/// -/// ``` -/// use rand_distr::{Binomial, Distribution}; -/// -/// let bin = Binomial::new(20, 0.3).unwrap(); -/// let v = bin.sample(&mut rand::thread_rng()); -/// println!("{} is from a binomial distribution", v); -/// ``` -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct Binomial { - /// Number of trials. - n: u64, - /// Probability of success. - p: f64, -} - -/// Error type returned from `Binomial::new`. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum Error { - /// `p < 0` or `nan`. - ProbabilityTooSmall, - /// `p > 1`. - ProbabilityTooLarge, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Error::ProbabilityTooSmall => "p < 0 or is NaN in binomial distribution", - Error::ProbabilityTooLarge => "p > 1 in binomial distribution", - }) - } -} - -#[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] -impl std::error::Error for Error {} - -impl Binomial { - /// Construct a new `Binomial` with the given shape parameters `n` (number - /// of trials) and `p` (probability of success). - pub fn new(n: u64, p: f64) -> Result { - if !(p >= 0.0) { - return Err(Error::ProbabilityTooSmall); - } - if !(p <= 1.0) { - return Err(Error::ProbabilityTooLarge); - } - Ok(Binomial { n, p }) - } -} - -/// Convert a `f64` to an `i64`, panicking on overflow. -fn f64_to_i64(x: f64) -> i64 { - assert!(x < (core::i64::MAX as f64)); - x as i64 -} - -impl Distribution for Binomial { - #[allow(clippy::many_single_char_names)] // Same names as in the reference. - fn sample(&self, rng: &mut R) -> u64 { - // Handle these values directly. - if self.p == 0.0 { - return 0; - } else if self.p == 1.0 { - return self.n; - } - - // The binomial distribution is symmetrical with respect to p -> 1-p, - // k -> n-k switch p so that it is less than 0.5 - this allows for lower - // expected values we will just invert the result at the end - let p = if self.p <= 0.5 { self.p } else { 1.0 - self.p }; - - let result; - let q = 1. - p; - - // For small n * min(p, 1 - p), the BINV algorithm based on the inverse - // transformation of the binomial distribution is efficient. Otherwise, - // the BTPE algorithm is used. - // - // Voratas Kachitvichyanukul and Bruce W. Schmeiser. 1988. Binomial - // random variate generation. Commun. ACM 31, 2 (February 1988), - // 216-222. http://dx.doi.org/10.1145/42372.42381 - - // Threshold for preferring the BINV algorithm. The paper suggests 10, - // Ranlib uses 30, and GSL uses 14. - const BINV_THRESHOLD: f64 = 10.; - - // Same value as in GSL. - // It is possible for BINV to get stuck, so we break if x > BINV_MAX_X and try again. - // It would be safer to set BINV_MAX_X to self.n, but it is extremely unlikely to be relevant. - // When n*p < 10, so is n*p*q which is the variance, so a result > 110 would be 100 / sqrt(10) = 31 standard deviations away. - const BINV_MAX_X : u64 = 110; - - if (self.n as f64) * p < BINV_THRESHOLD && self.n <= (core::i32::MAX as u64) { - // Use the BINV algorithm. - let s = p / q; - let a = ((self.n + 1) as f64) * s; - - result = 'outer: loop { - let mut r = q.powi(self.n as i32); - let mut u: f64 = rng.gen(); - let mut x = 0; - - while u > r { - u -= r; - x += 1; - if x > BINV_MAX_X { - continue 'outer; - } - r *= a / (x as f64) - s; - } - break x; - - } - } else { - // Use the BTPE algorithm. - - // Threshold for using the squeeze algorithm. This can be freely - // chosen based on performance. Ranlib and GSL use 20. - const SQUEEZE_THRESHOLD: i64 = 20; - - // Step 0: Calculate constants as functions of `n` and `p`. - let n = self.n as f64; - let np = n * p; - let npq = np * q; - let f_m = np + p; - let m = f64_to_i64(f_m); - // radius of triangle region, since height=1 also area of region - let p1 = (2.195 * npq.sqrt() - 4.6 * q).floor() + 0.5; - // tip of triangle - let x_m = (m as f64) + 0.5; - // left edge of triangle - let x_l = x_m - p1; - // right edge of triangle - let x_r = x_m + p1; - let c = 0.134 + 20.5 / (15.3 + (m as f64)); - // p1 + area of parallelogram region - let p2 = p1 * (1. + 2. * c); - - fn lambda(a: f64) -> f64 { - a * (1. + 0.5 * a) - } - - let lambda_l = lambda((f_m - x_l) / (f_m - x_l * p)); - let lambda_r = lambda((x_r - f_m) / (x_r * q)); - // p1 + area of left tail - let p3 = p2 + c / lambda_l; - // p1 + area of right tail - let p4 = p3 + c / lambda_r; - - // return value - let mut y: i64; - - let gen_u = Uniform::new(0., p4).unwrap(); - let gen_v = Uniform::new(0., 1.).unwrap(); - - loop { - // Step 1: Generate `u` for selecting the region. If region 1 is - // selected, generate a triangularly distributed variate. - let u = gen_u.sample(rng); - let mut v = gen_v.sample(rng); - if !(u > p1) { - y = f64_to_i64(x_m - p1 * v + u); - break; - } - - if !(u > p2) { - // Step 2: Region 2, parallelograms. Check if region 2 is - // used. If so, generate `y`. - let x = x_l + (u - p1) / c; - v = v * c + 1.0 - (x - x_m).abs() / p1; - if v > 1. { - continue; - } else { - y = f64_to_i64(x); - } - } else if !(u > p3) { - // Step 3: Region 3, left exponential tail. - y = f64_to_i64(x_l + v.ln() / lambda_l); - if y < 0 { - continue; - } else { - v *= (u - p2) * lambda_l; - } - } else { - // Step 4: Region 4, right exponential tail. - y = f64_to_i64(x_r - v.ln() / lambda_r); - if y > 0 && (y as u64) > self.n { - continue; - } else { - v *= (u - p3) * lambda_r; - } - } - - // Step 5: Acceptance/rejection comparison. - - // Step 5.0: Test for appropriate method of evaluating f(y). - let k = (y - m).abs(); - if !(k > SQUEEZE_THRESHOLD && (k as f64) < 0.5 * npq - 1.) { - // Step 5.1: Evaluate f(y) via the recursive relationship. Start the - // search from the mode. - let s = p / q; - let a = s * (n + 1.); - let mut f = 1.0; - match m.cmp(&y) { - Ordering::Less => { - let mut i = m; - loop { - i += 1; - f *= a / (i as f64) - s; - if i == y { - break; - } - } - }, - Ordering::Greater => { - let mut i = y; - loop { - i += 1; - f /= a / (i as f64) - s; - if i == m { - break; - } - } - }, - Ordering::Equal => {}, - } - if v > f { - continue; - } else { - break; - } - } - - // Step 5.2: Squeezing. Check the value of ln(v) against upper and - // lower bound of ln(f(y)). - let k = k as f64; - let rho = (k / npq) * ((k * (k / 3. + 0.625) + 1. / 6.) / npq + 0.5); - let t = -0.5 * k * k / npq; - let alpha = v.ln(); - if alpha < t - rho { - break; - } - if alpha > t + rho { - continue; - } - - // Step 5.3: Final acceptance/rejection test. - let x1 = (y + 1) as f64; - let f1 = (m + 1) as f64; - let z = (f64_to_i64(n) + 1 - m) as f64; - let w = (f64_to_i64(n) - y + 1) as f64; - - fn stirling(a: f64) -> f64 { - let a2 = a * a; - (13860. - (462. - (132. - (99. - 140. / a2) / a2) / a2) / a2) / a / 166320. - } - - if alpha - > x_m * (f1 / x1).ln() - + (n - (m as f64) + 0.5) * (z / w).ln() - + ((y - m) as f64) * (w * p / (x1 * q)).ln() - // We use the signs from the GSL implementation, which are - // different than the ones in the reference. According to - // the GSL authors, the new signs were verified to be - // correct by one of the original designers of the - // algorithm. - + stirling(f1) - + stirling(z) - - stirling(x1) - - stirling(w) - { - continue; - } - - break; - } - assert!(y >= 0); - result = y as u64; - } - - // Invert the result for p < 0.5. - if p != self.p { - self.n - result - } else { - result - } - } -} - -#[cfg(test)] -mod test { - use super::Binomial; - use crate::Distribution; - use rand::Rng; - - fn test_binomial_mean_and_variance(n: u64, p: f64, rng: &mut R) { - let binomial = Binomial::new(n, p).unwrap(); - - let expected_mean = n as f64 * p; - let expected_variance = n as f64 * p * (1.0 - p); - - let mut results = [0.0; 1000]; - for i in results.iter_mut() { - *i = binomial.sample(rng) as f64; - } - - let mean = results.iter().sum::() / results.len() as f64; - assert!((mean - expected_mean).abs() < expected_mean / 50.0); - - let variance = - results.iter().map(|x| (x - mean) * (x - mean)).sum::() / results.len() as f64; - assert!((variance - expected_variance).abs() < expected_variance / 10.0); - } - - #[test] - fn test_binomial() { - let mut rng = crate::test::rng(351); - test_binomial_mean_and_variance(150, 0.1, &mut rng); - test_binomial_mean_and_variance(70, 0.6, &mut rng); - test_binomial_mean_and_variance(40, 0.5, &mut rng); - test_binomial_mean_and_variance(20, 0.7, &mut rng); - test_binomial_mean_and_variance(20, 0.5, &mut rng); - } - - #[test] - fn test_binomial_end_points() { - let mut rng = crate::test::rng(352); - assert_eq!(rng.sample(Binomial::new(20, 0.0).unwrap()), 0); - assert_eq!(rng.sample(Binomial::new(20, 1.0).unwrap()), 20); - } - - #[test] - #[should_panic] - fn test_binomial_invalid_lambda_neg() { - Binomial::new(20, -10.0).unwrap(); - } - - #[test] - fn binomial_distributions_can_be_compared() { - assert_eq!(Binomial::new(1, 1.0), Binomial::new(1, 1.0)); - } - - #[test] - fn binomial_avoid_infinite_loop() { - let dist = Binomial::new(16000000, 3.1444753148558566e-10).unwrap(); - let mut sum: u64 = 0; - let mut rng = crate::test::rng(742); - for _ in 0..100_000 { - sum = sum.wrapping_add(dist.sample(&mut rng)); - } - assert_ne!(sum, 0); - } -} diff --git a/rand_distr/src/cauchy.rs b/rand_distr/src/cauchy.rs deleted file mode 100644 index cd3e31b453f..00000000000 --- a/rand_distr/src/cauchy.rs +++ /dev/null @@ -1,172 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// Copyright 2016-2017 The Rust Project Developers. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! The Cauchy distribution. - -use num_traits::{Float, FloatConst}; -use crate::{Distribution, Standard}; -use rand::Rng; -use core::fmt; - -/// The Cauchy distribution `Cauchy(median, scale)`. -/// -/// This distribution has a density function: -/// `f(x) = 1 / (pi * scale * (1 + ((x - median) / scale)^2))` -/// -/// Note that at least for `f32`, results are not fully portable due to minor -/// differences in the target system's *tan* implementation, `tanf`. -/// -/// # Example -/// -/// ``` -/// use rand_distr::{Cauchy, Distribution}; -/// -/// let cau = Cauchy::new(2.0, 5.0).unwrap(); -/// let v = cau.sample(&mut rand::thread_rng()); -/// println!("{} is from a Cauchy(2, 5) distribution", v); -/// ``` -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct Cauchy -where F: Float + FloatConst, Standard: Distribution -{ - median: F, - scale: F, -} - -/// Error type returned from `Cauchy::new`. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum Error { - /// `scale <= 0` or `nan`. - ScaleTooSmall, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Error::ScaleTooSmall => "scale is not positive in Cauchy distribution", - }) - } -} - -#[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] -impl std::error::Error for Error {} - -impl Cauchy -where F: Float + FloatConst, Standard: Distribution -{ - /// Construct a new `Cauchy` with the given shape parameters - /// `median` the peak location and `scale` the scale factor. - pub fn new(median: F, scale: F) -> Result, Error> { - if !(scale > F::zero()) { - return Err(Error::ScaleTooSmall); - } - Ok(Cauchy { median, scale }) - } -} - -impl Distribution for Cauchy -where F: Float + FloatConst, Standard: Distribution -{ - fn sample(&self, rng: &mut R) -> F { - // sample from [0, 1) - let x = Standard.sample(rng); - // get standard cauchy random number - // note that π/2 is not exactly representable, even if x=0.5 the result is finite - let comp_dev = (F::PI() * x).tan(); - // shift and scale according to parameters - self.median + self.scale * comp_dev - } -} - -#[cfg(test)] -mod test { - use super::*; - - fn median(numbers: &mut [f64]) -> f64 { - sort(numbers); - let mid = numbers.len() / 2; - numbers[mid] - } - - fn sort(numbers: &mut [f64]) { - numbers.sort_by(|a, b| a.partial_cmp(b).unwrap()); - } - - #[test] - fn test_cauchy_averages() { - // NOTE: given that the variance and mean are undefined, - // this test does not have any rigorous statistical meaning. - let cauchy = Cauchy::new(10.0, 5.0).unwrap(); - let mut rng = crate::test::rng(123); - let mut numbers: [f64; 1000] = [0.0; 1000]; - let mut sum = 0.0; - for number in &mut numbers[..] { - *number = cauchy.sample(&mut rng); - sum += *number; - } - let median = median(&mut numbers); - #[cfg(feature = "std")] - std::println!("Cauchy median: {}", median); - assert!((median - 10.0).abs() < 0.4); // not 100% certain, but probable enough - let mean = sum / 1000.0; - #[cfg(feature = "std")] - std::println!("Cauchy mean: {}", mean); - // for a Cauchy distribution the mean should not converge - assert!((mean - 10.0).abs() > 0.4); // not 100% certain, but probable enough - } - - #[test] - #[should_panic] - fn test_cauchy_invalid_scale_zero() { - Cauchy::new(0.0, 0.0).unwrap(); - } - - #[test] - #[should_panic] - fn test_cauchy_invalid_scale_neg() { - Cauchy::new(0.0, -10.0).unwrap(); - } - - #[test] - fn value_stability() { - fn gen_samples(m: F, s: F, buf: &mut [F]) - where Standard: Distribution { - let distr = Cauchy::new(m, s).unwrap(); - let mut rng = crate::test::rng(353); - for x in buf { - *x = rng.sample(distr); - } - } - - let mut buf = [0.0; 4]; - gen_samples(100f64, 10.0, &mut buf); - assert_eq!(&buf, &[ - 77.93369152808678, - 90.1606912098641, - 125.31516221323625, - 86.10217834773925 - ]); - - // Unfortunately this test is not fully portable due to reliance on the - // system's implementation of tanf (see doc on Cauchy struct). - let mut buf = [0.0; 4]; - gen_samples(10f32, 7.0, &mut buf); - let expected = [15.023088, -5.446413, 3.7092876, 3.112482]; - for (a, b) in buf.iter().zip(expected.iter()) { - assert_almost_eq!(*a, *b, 1e-5); - } - } - - #[test] - fn cauchy_distributions_can_be_compared() { - assert_eq!(Cauchy::new(1.0, 2.0), Cauchy::new(1.0, 2.0)); - } -} diff --git a/rand_distr/src/dirichlet.rs b/rand_distr/src/dirichlet.rs deleted file mode 100644 index 413c00476ab..00000000000 --- a/rand_distr/src/dirichlet.rs +++ /dev/null @@ -1,444 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// Copyright 2013 The Rust Project Developers. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! The dirichlet distribution. -#![cfg(feature = "alloc")] -use crate::{Beta, Distribution, Exp1, Gamma, Open01, StandardNormal}; -use core::fmt; -use num_traits::{Float, NumCast}; -use rand::Rng; -#[cfg(feature = "serde_with")] use serde_with::serde_as; - -use alloc::{boxed::Box, vec, vec::Vec}; - -#[derive(Clone, Debug, PartialEq)] -#[cfg_attr(feature = "serde_with", serde_as)] -struct DirichletFromGamma -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - samplers: [Gamma; N], -} - -/// Error type returned from `DirchletFromGamma::new`. -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -enum DirichletFromGammaError { - /// Gamma::new(a, 1) failed. - GammmaNewFailed, - - /// gamma_dists.try_into() failed (in theory, this should not happen). - GammaArrayCreationFailed, -} - -impl DirichletFromGamma -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - /// Construct a new `DirichletFromGamma` with the given parameters `alpha`. - /// - /// This function is part of a private implementation detail. - /// It assumes that the input is correct, so no validation of alpha is done. - #[inline] - fn new(alpha: [F; N]) -> Result, DirichletFromGammaError> { - let mut gamma_dists = Vec::new(); - for a in alpha { - let dist = - Gamma::new(a, F::one()).map_err(|_| DirichletFromGammaError::GammmaNewFailed)?; - gamma_dists.push(dist); - } - Ok(DirichletFromGamma { - samplers: gamma_dists - .try_into() - .map_err(|_| DirichletFromGammaError::GammaArrayCreationFailed)?, - }) - } -} - -impl Distribution<[F; N]> for DirichletFromGamma -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - fn sample(&self, rng: &mut R) -> [F; N] { - let mut samples = [F::zero(); N]; - let mut sum = F::zero(); - - for (s, g) in samples.iter_mut().zip(self.samplers.iter()) { - *s = g.sample(rng); - sum = sum + *s; - } - let invacc = F::one() / sum; - for s in samples.iter_mut() { - *s = *s * invacc; - } - samples - } -} - -#[derive(Clone, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -struct DirichletFromBeta -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - samplers: Box<[Beta]>, -} - -/// Error type returned from `DirchletFromBeta::new`. -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -enum DirichletFromBetaError { - /// Beta::new(a, b) failed. - BetaNewFailed, -} - -impl DirichletFromBeta -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - /// Construct a new `DirichletFromBeta` with the given parameters `alpha`. - /// - /// This function is part of a private implementation detail. - /// It assumes that the input is correct, so no validation of alpha is done. - #[inline] - fn new(alpha: [F; N]) -> Result, DirichletFromBetaError> { - // `alpha_rev_csum` is the reverse of the cumulative sum of the - // reverse of `alpha[1..]`. E.g. if `alpha = [a0, a1, a2, a3]`, then - // `alpha_rev_csum` is `[a1 + a2 + a3, a2 + a3, a3]`. - // Note that instances of DirichletFromBeta will always have N >= 2, - // so the subtractions of 1, 2 and 3 from N in the following are safe. - let mut alpha_rev_csum = vec![alpha[N - 1]; N - 1]; - for k in 0..(N - 2) { - alpha_rev_csum[N - 3 - k] = alpha_rev_csum[N - 2 - k] + alpha[N - 2 - k]; - } - - // Zip `alpha[..(N-1)]` and `alpha_rev_csum`; for the example - // `alpha = [a0, a1, a2, a3]`, the zip result holds the tuples - // `[(a0, a1+a2+a3), (a1, a2+a3), (a2, a3)]`. - // Then pass each tuple to `Beta::new()` to create the `Beta` - // instances. - let mut beta_dists = Vec::new(); - for (&a, &b) in alpha[..(N - 1)].iter().zip(alpha_rev_csum.iter()) { - let dist = Beta::new(a, b).map_err(|_| DirichletFromBetaError::BetaNewFailed)?; - beta_dists.push(dist); - } - Ok(DirichletFromBeta { - samplers: beta_dists.into_boxed_slice(), - }) - } -} - -impl Distribution<[F; N]> for DirichletFromBeta -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - fn sample(&self, rng: &mut R) -> [F; N] { - let mut samples = [F::zero(); N]; - let mut acc = F::one(); - - for (s, beta) in samples.iter_mut().zip(self.samplers.iter()) { - let beta_sample = beta.sample(rng); - *s = acc * beta_sample; - acc = acc * (F::one() - beta_sample); - } - samples[N - 1] = acc; - samples - } -} - -#[derive(Clone, Debug, PartialEq)] -#[cfg_attr(feature = "serde_with", serde_as)] -enum DirichletRepr -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - /// Dirichlet distribution that generates samples using the Gamma distribution. - FromGamma(DirichletFromGamma), - - /// Dirichlet distribution that generates samples using the Beta distribution. - FromBeta(DirichletFromBeta), -} - -/// The Dirichlet distribution `Dirichlet(alpha)`. -/// -/// The Dirichlet distribution is a family of continuous multivariate -/// probability distributions parameterized by a vector alpha of positive reals. -/// It is a multivariate generalization of the beta distribution. -/// -/// # Example -/// -/// ``` -/// use rand::prelude::*; -/// use rand_distr::Dirichlet; -/// -/// let dirichlet = Dirichlet::new([1.0, 2.0, 3.0]).unwrap(); -/// let samples = dirichlet.sample(&mut rand::thread_rng()); -/// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples); -/// ``` -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] -#[cfg_attr(feature = "serde_with", serde_as)] -#[derive(Clone, Debug, PartialEq)] -pub struct Dirichlet -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - repr: DirichletRepr, -} - -/// Error type returned from `Dirchlet::new`. -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum Error { - /// `alpha.len() < 2`. - AlphaTooShort, - /// `alpha <= 0.0` or `nan`. - AlphaTooSmall, - /// `alpha` is subnormal. - /// Variate generation methods are not reliable with subnormal inputs. - AlphaSubnormal, - /// `alpha` is infinite. - AlphaInfinite, - /// Failed to create required Gamma distribution(s). - FailedToCreateGamma, - /// Failed to create required Beta distribition(s). - FailedToCreateBeta, - /// `size < 2`. - SizeTooSmall, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Error::AlphaTooShort | Error::SizeTooSmall => { - "less than 2 dimensions in Dirichlet distribution" - } - Error::AlphaTooSmall => "alpha is not positive in Dirichlet distribution", - Error::AlphaSubnormal => "alpha contains a subnormal value in Dirichlet distribution", - Error::AlphaInfinite => "alpha contains an infinite value in Dirichlet distribution", - Error::FailedToCreateGamma => { - "failed to create required Gamma distribution for Dirichlet distribution" - } - Error::FailedToCreateBeta => { - "failed to create required Beta distribition for Dirichlet distribution" - } - }) - } -} - -#[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] -impl std::error::Error for Error {} - -impl Dirichlet -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - /// Construct a new `Dirichlet` with the given alpha parameter `alpha`. - /// - /// Requires `alpha.len() >= 2`, and each value in `alpha` must be positive, - /// finite and not subnormal. - #[inline] - pub fn new(alpha: [F; N]) -> Result, Error> { - if N < 2 { - return Err(Error::AlphaTooShort); - } - for &ai in alpha.iter() { - if !(ai > F::zero()) { - // This also catches nan. - return Err(Error::AlphaTooSmall); - } - if ai.is_infinite() { - return Err(Error::AlphaInfinite); - } - if !ai.is_normal() { - return Err(Error::AlphaSubnormal); - } - } - - if alpha.iter().all(|&x| x <= NumCast::from(0.1).unwrap()) { - // Use the Beta method when all the alphas are less than 0.1 This - // threshold provides a reasonable compromise between using the faster - // Gamma method for as wide a range as possible while ensuring that - // the probability of generating nans is negligibly small. - let dist = DirichletFromBeta::new(alpha).map_err(|_| Error::FailedToCreateBeta)?; - Ok(Dirichlet { - repr: DirichletRepr::FromBeta(dist), - }) - } else { - let dist = DirichletFromGamma::new(alpha).map_err(|_| Error::FailedToCreateGamma)?; - Ok(Dirichlet { - repr: DirichletRepr::FromGamma(dist), - }) - } - } -} - -impl Distribution<[F; N]> for Dirichlet -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - fn sample(&self, rng: &mut R) -> [F; N] { - match &self.repr { - DirichletRepr::FromGamma(dirichlet) => dirichlet.sample(rng), - DirichletRepr::FromBeta(dirichlet) => dirichlet.sample(rng), - } - } -} - -#[cfg(test)] -mod test { - use super::*; - use alloc::vec::Vec; - - #[test] - fn test_dirichlet() { - let d = Dirichlet::new([1.0, 2.0, 3.0]).unwrap(); - let mut rng = crate::test::rng(221); - let samples = d.sample(&mut rng); - let _: Vec = samples - .into_iter() - .map(|x| { - assert!(x > 0.0); - x - }) - .collect(); - } - - #[test] - #[should_panic] - fn test_dirichlet_invalid_length() { - Dirichlet::new([0.5]).unwrap(); - } - - #[test] - #[should_panic] - fn test_dirichlet_alpha_zero() { - Dirichlet::new([0.1, 0.0, 0.3]).unwrap(); - } - - #[test] - #[should_panic] - fn test_dirichlet_alpha_negative() { - Dirichlet::new([0.1, -1.5, 0.3]).unwrap(); - } - - #[test] - #[should_panic] - fn test_dirichlet_alpha_nan() { - Dirichlet::new([0.5, f64::NAN, 0.25]).unwrap(); - } - - #[test] - #[should_panic] - fn test_dirichlet_alpha_subnormal() { - Dirichlet::new([0.5, 1.5e-321, 0.25]).unwrap(); - } - - #[test] - #[should_panic] - fn test_dirichlet_alpha_inf() { - Dirichlet::new([0.5, f64::INFINITY, 0.25]).unwrap(); - } - - #[test] - fn dirichlet_distributions_can_be_compared() { - assert_eq!(Dirichlet::new([1.0, 2.0]), Dirichlet::new([1.0, 2.0])); - } - - /// Check that the means of the components of n samples from - /// the Dirichlet distribution agree with the expected means - /// with a relative tolerance of rtol. - /// - /// This is a crude statistical test, but it will catch egregious - /// mistakes. It will also also fail if any samples contain nan. - fn check_dirichlet_means(alpha: [f64; N], n: i32, rtol: f64, seed: u64) { - let d = Dirichlet::new(alpha).unwrap(); - let mut rng = crate::test::rng(seed); - let mut sums = [0.0; N]; - for _ in 0..n { - let samples = d.sample(&mut rng); - for i in 0..N { - sums[i] += samples[i]; - } - } - let sample_mean = sums.map(|x| x / n as f64); - let alpha_sum: f64 = alpha.iter().sum(); - let expected_mean = alpha.map(|x| x / alpha_sum); - for i in 0..N { - assert_almost_eq!(sample_mean[i], expected_mean[i], rtol); - } - } - - #[test] - fn test_dirichlet_means() { - // Check the means of 20000 samples for several different alphas. - let n = 20000; - let rtol = 2e-2; - let seed = 1317624576693539401; - check_dirichlet_means([0.5, 0.25], n, rtol, seed); - check_dirichlet_means([123.0, 75.0], n, rtol, seed); - check_dirichlet_means([2.0, 2.5, 5.0, 7.0], n, rtol, seed); - check_dirichlet_means([0.1, 8.0, 1.0, 2.0, 2.0, 0.85, 0.05, 12.5], n, rtol, seed); - } - - #[test] - fn test_dirichlet_means_very_small_alpha() { - // With values of alpha that are all 0.001, check that the means of the - // components of 10000 samples are within 1% of the expected means. - // With the sampling method based on gamma variates, this test would - // fail, with about 10% of the samples containing nan. - let alpha = [0.001; 3]; - let n = 10000; - let rtol = 1e-2; - let seed = 1317624576693539401; - check_dirichlet_means(alpha, n, rtol, seed); - } - - #[test] - fn test_dirichlet_means_small_alpha() { - // With values of alpha that are all less than 0.1, check that the - // means of the components of 150000 samples are within 0.1% of the - // expected means. - let alpha = [0.05, 0.025, 0.075, 0.05]; - let n = 150000; - let rtol = 1e-3; - let seed = 1317624576693539401; - check_dirichlet_means(alpha, n, rtol, seed); - } -} diff --git a/rand_distr/src/exponential.rs b/rand_distr/src/exponential.rs deleted file mode 100644 index e3d2a8d1cf6..00000000000 --- a/rand_distr/src/exponential.rs +++ /dev/null @@ -1,186 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// Copyright 2013 The Rust Project Developers. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! The exponential distribution. - -use crate::utils::ziggurat; -use num_traits::Float; -use crate::{ziggurat_tables, Distribution}; -use rand::Rng; -use core::fmt; - -/// Samples floating-point numbers according to the exponential distribution, -/// with rate parameter `λ = 1`. This is equivalent to `Exp::new(1.0)` or -/// sampling with `-rng.gen::().ln()`, but faster. -/// -/// See `Exp` for the general exponential distribution. -/// -/// Implemented via the ZIGNOR variant[^1] of the Ziggurat method. The exact -/// description in the paper was adjusted to use tables for the exponential -/// distribution rather than normal. -/// -/// [^1]: Jurgen A. Doornik (2005). [*An Improved Ziggurat Method to -/// Generate Normal Random Samples*]( -/// https://www.doornik.com/research/ziggurat.pdf). -/// Nuffield College, Oxford -/// -/// # Example -/// ``` -/// use rand::prelude::*; -/// use rand_distr::Exp1; -/// -/// let val: f64 = thread_rng().sample(Exp1); -/// println!("{}", val); -/// ``` -#[derive(Clone, Copy, Debug)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct Exp1; - -impl Distribution for Exp1 { - #[inline] - fn sample(&self, rng: &mut R) -> f32 { - // TODO: use optimal 32-bit implementation - let x: f64 = self.sample(rng); - x as f32 - } -} - -// This could be done via `-rng.gen::().ln()` but that is slower. -impl Distribution for Exp1 { - #[inline] - fn sample(&self, rng: &mut R) -> f64 { - #[inline] - fn pdf(x: f64) -> f64 { - (-x).exp() - } - #[inline] - fn zero_case(rng: &mut R, _u: f64) -> f64 { - ziggurat_tables::ZIG_EXP_R - rng.gen::().ln() - } - - ziggurat( - rng, - false, - &ziggurat_tables::ZIG_EXP_X, - &ziggurat_tables::ZIG_EXP_F, - pdf, - zero_case, - ) - } -} - -/// The exponential distribution `Exp(lambda)`. -/// -/// This distribution has density function: `f(x) = lambda * exp(-lambda * x)` -/// for `x > 0`, when `lambda > 0`. For `lambda = 0`, all samples yield infinity. -/// -/// Note that [`Exp1`](crate::Exp1) is an optimised implementation for `lambda = 1`. -/// -/// # Example -/// -/// ``` -/// use rand_distr::{Exp, Distribution}; -/// -/// let exp = Exp::new(2.0).unwrap(); -/// let v = exp.sample(&mut rand::thread_rng()); -/// println!("{} is from a Exp(2) distribution", v); -/// ``` -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct Exp -where F: Float, Exp1: Distribution -{ - /// `lambda` stored as `1/lambda`, since this is what we scale by. - lambda_inverse: F, -} - -/// Error type returned from `Exp::new`. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum Error { - /// `lambda < 0` or `nan`. - LambdaTooSmall, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Error::LambdaTooSmall => "lambda is negative or NaN in exponential distribution", - }) - } -} - -#[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] -impl std::error::Error for Error {} - -impl Exp -where F: Float, Exp1: Distribution -{ - /// Construct a new `Exp` with the given shape parameter - /// `lambda`. - /// - /// # Remarks - /// - /// For custom types `N` implementing the [`Float`] trait, - /// the case `lambda = 0` is handled as follows: each sample corresponds - /// to a sample from an `Exp1` multiplied by `1 / 0`. Primitive types - /// yield infinity, since `1 / 0 = infinity`. - #[inline] - pub fn new(lambda: F) -> Result, Error> { - if !(lambda >= F::zero()) { - return Err(Error::LambdaTooSmall); - } - Ok(Exp { - lambda_inverse: F::one() / lambda, - }) - } -} - -impl Distribution for Exp -where F: Float, Exp1: Distribution -{ - fn sample(&self, rng: &mut R) -> F { - rng.sample(Exp1) * self.lambda_inverse - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_exp() { - let exp = Exp::new(10.0).unwrap(); - let mut rng = crate::test::rng(221); - for _ in 0..1000 { - assert!(exp.sample(&mut rng) >= 0.0); - } - } - #[test] - fn test_zero() { - let d = Exp::new(0.0).unwrap(); - assert_eq!(d.sample(&mut crate::test::rng(21)), f64::infinity()); - } - #[test] - #[should_panic] - fn test_exp_invalid_lambda_neg() { - Exp::new(-10.0).unwrap(); - } - - #[test] - #[should_panic] - fn test_exp_invalid_lambda_nan() { - Exp::new(f64::nan()).unwrap(); - } - - #[test] - fn exponential_distributions_can_be_compared() { - assert_eq!(Exp::new(1.0), Exp::new(1.0)); - } -} diff --git a/rand_distr/src/frechet.rs b/rand_distr/src/frechet.rs deleted file mode 100644 index 63205b40cbd..00000000000 --- a/rand_distr/src/frechet.rs +++ /dev/null @@ -1,190 +0,0 @@ -// Copyright 2021 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! The Fréchet distribution. - -use crate::{Distribution, OpenClosed01}; -use core::fmt; -use num_traits::Float; -use rand::Rng; - -/// Samples floating-point numbers according to the Fréchet distribution -/// -/// This distribution has density function: -/// `f(x) = [(x - μ) / σ]^(-1 - α) exp[-(x - μ) / σ]^(-α) α / σ`, -/// where `μ` is the location parameter, `σ` the scale parameter, and `α` the shape parameter. -/// -/// # Example -/// ``` -/// use rand::prelude::*; -/// use rand_distr::Frechet; -/// -/// let val: f64 = thread_rng().sample(Frechet::new(0.0, 1.0, 1.0).unwrap()); -/// println!("{}", val); -/// ``` -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct Frechet -where - F: Float, - OpenClosed01: Distribution, -{ - location: F, - scale: F, - shape: F, -} - -/// Error type returned from `Frechet::new`. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum Error { - /// location is infinite or NaN - LocationNotFinite, - /// scale is not finite positive number - ScaleNotPositive, - /// shape is not finite positive number - ShapeNotPositive, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Error::LocationNotFinite => "location is not finite in Frechet distribution", - Error::ScaleNotPositive => "scale is not positive and finite in Frechet distribution", - Error::ShapeNotPositive => "shape is not positive and finite in Frechet distribution", - }) - } -} - -#[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] -impl std::error::Error for Error {} - -impl Frechet -where - F: Float, - OpenClosed01: Distribution, -{ - /// Construct a new `Frechet` distribution with given `location`, `scale`, and `shape`. - pub fn new(location: F, scale: F, shape: F) -> Result, Error> { - if scale <= F::zero() || scale.is_infinite() || scale.is_nan() { - return Err(Error::ScaleNotPositive); - } - if shape <= F::zero() || shape.is_infinite() || shape.is_nan() { - return Err(Error::ShapeNotPositive); - } - if location.is_infinite() || location.is_nan() { - return Err(Error::LocationNotFinite); - } - Ok(Frechet { - location, - scale, - shape, - }) - } -} - -impl Distribution for Frechet -where - F: Float, - OpenClosed01: Distribution, -{ - fn sample(&self, rng: &mut R) -> F { - let x: F = rng.sample(OpenClosed01); - self.location + self.scale * (-x.ln()).powf(-self.shape.recip()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - #[should_panic] - fn test_zero_scale() { - Frechet::new(0.0, 0.0, 1.0).unwrap(); - } - - #[test] - #[should_panic] - fn test_infinite_scale() { - Frechet::new(0.0, core::f64::INFINITY, 1.0).unwrap(); - } - - #[test] - #[should_panic] - fn test_nan_scale() { - Frechet::new(0.0, core::f64::NAN, 1.0).unwrap(); - } - - #[test] - #[should_panic] - fn test_zero_shape() { - Frechet::new(0.0, 1.0, 0.0).unwrap(); - } - - #[test] - #[should_panic] - fn test_infinite_shape() { - Frechet::new(0.0, 1.0, core::f64::INFINITY).unwrap(); - } - - #[test] - #[should_panic] - fn test_nan_shape() { - Frechet::new(0.0, 1.0, core::f64::NAN).unwrap(); - } - - #[test] - #[should_panic] - fn test_infinite_location() { - Frechet::new(core::f64::INFINITY, 1.0, 1.0).unwrap(); - } - - #[test] - #[should_panic] - fn test_nan_location() { - Frechet::new(core::f64::NAN, 1.0, 1.0).unwrap(); - } - - #[test] - fn test_sample_against_cdf() { - fn quantile_function(x: f64) -> f64 { - (-x.ln()).recip() - } - let location = 0.0; - let scale = 1.0; - let shape = 1.0; - let iterations = 100_000; - let increment = 1.0 / iterations as f64; - let probabilities = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]; - let mut quantiles = [0.0; 9]; - for (i, p) in probabilities.iter().enumerate() { - quantiles[i] = quantile_function(*p); - } - let mut proportions = [0.0; 9]; - let d = Frechet::new(location, scale, shape).unwrap(); - let mut rng = crate::test::rng(1); - for _ in 0..iterations { - let replicate = d.sample(&mut rng); - for (i, q) in quantiles.iter().enumerate() { - if replicate < *q { - proportions[i] += increment; - } - } - } - assert!(proportions - .iter() - .zip(&probabilities) - .all(|(p_hat, p)| (p_hat - p).abs() < 0.003)) - } - - #[test] - fn frechet_distributions_can_be_compared() { - assert_eq!(Frechet::new(1.0, 2.0, 3.0), Frechet::new(1.0, 2.0, 3.0)); - } -} diff --git a/rand_distr/src/gamma.rs b/rand_distr/src/gamma.rs deleted file mode 100644 index 1a575bd6a9f..00000000000 --- a/rand_distr/src/gamma.rs +++ /dev/null @@ -1,838 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// Copyright 2013 The Rust Project Developers. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! The Gamma and derived distributions. - -// We use the variable names from the published reference, therefore this -// warning is not helpful. -#![allow(clippy::many_single_char_names)] - -use self::ChiSquaredRepr::*; -use self::GammaRepr::*; - -use crate::normal::StandardNormal; -use num_traits::Float; -use crate::{Distribution, Exp, Exp1, Open01}; -use rand::Rng; -use core::fmt; -#[cfg(feature = "serde1")] -use serde::{Serialize, Deserialize}; - -/// The Gamma distribution `Gamma(shape, scale)` distribution. -/// -/// The density function of this distribution is -/// -/// ```text -/// f(x) = x^(k - 1) * exp(-x / θ) / (Γ(k) * θ^k) -/// ``` -/// -/// where `Γ` is the Gamma function, `k` is the shape and `θ` is the -/// scale and both `k` and `θ` are strictly positive. -/// -/// The algorithm used is that described by Marsaglia & Tsang 2000[^1], -/// falling back to directly sampling from an Exponential for `shape -/// == 1`, and using the boosting technique described in that paper for -/// `shape < 1`. -/// -/// # Example -/// -/// ``` -/// use rand_distr::{Distribution, Gamma}; -/// -/// let gamma = Gamma::new(2.0, 5.0).unwrap(); -/// let v = gamma.sample(&mut rand::thread_rng()); -/// println!("{} is from a Gamma(2, 5) distribution", v); -/// ``` -/// -/// [^1]: George Marsaglia and Wai Wan Tsang. 2000. "A Simple Method for -/// Generating Gamma Variables" *ACM Trans. Math. Softw.* 26, 3 -/// (September 2000), 363-372. -/// DOI:[10.1145/358407.358414](https://doi.acm.org/10.1145/358407.358414) -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub struct Gamma -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - repr: GammaRepr, -} - -/// Error type returned from `Gamma::new`. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum Error { - /// `shape <= 0` or `nan`. - ShapeTooSmall, - /// `scale <= 0` or `nan`. - ScaleTooSmall, - /// `1 / scale == 0`. - ScaleTooLarge, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Error::ShapeTooSmall => "shape is not positive in gamma distribution", - Error::ScaleTooSmall => "scale is not positive in gamma distribution", - Error::ScaleTooLarge => "scale is infinity in gamma distribution", - }) - } -} - -#[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] -impl std::error::Error for Error {} - -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -enum GammaRepr -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - Large(GammaLargeShape), - One(Exp), - Small(GammaSmallShape), -} - -// These two helpers could be made public, but saving the -// match-on-Gamma-enum branch from using them directly (e.g. if one -// knows that the shape is always > 1) doesn't appear to be much -// faster. - -/// Gamma distribution where the shape parameter is less than 1. -/// -/// Note, samples from this require a compulsory floating-point `pow` -/// call, which makes it significantly slower than sampling from a -/// gamma distribution where the shape parameter is greater than or -/// equal to 1. -/// -/// See `Gamma` for sampling from a Gamma distribution with general -/// shape parameters. -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -struct GammaSmallShape -where - F: Float, - StandardNormal: Distribution, - Open01: Distribution, -{ - inv_shape: F, - large_shape: GammaLargeShape, -} - -/// Gamma distribution where the shape parameter is larger than 1. -/// -/// See `Gamma` for sampling from a Gamma distribution with general -/// shape parameters. -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -struct GammaLargeShape -where - F: Float, - StandardNormal: Distribution, - Open01: Distribution, -{ - scale: F, - c: F, - d: F, -} - -impl Gamma -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - /// Construct an object representing the `Gamma(shape, scale)` - /// distribution. - #[inline] - pub fn new(shape: F, scale: F) -> Result, Error> { - if !(shape > F::zero()) { - return Err(Error::ShapeTooSmall); - } - if !(scale > F::zero()) { - return Err(Error::ScaleTooSmall); - } - - let repr = if shape == F::one() { - One(Exp::new(F::one() / scale).map_err(|_| Error::ScaleTooLarge)?) - } else if shape < F::one() { - Small(GammaSmallShape::new_raw(shape, scale)) - } else { - Large(GammaLargeShape::new_raw(shape, scale)) - }; - Ok(Gamma { repr }) - } -} - -impl GammaSmallShape -where - F: Float, - StandardNormal: Distribution, - Open01: Distribution, -{ - fn new_raw(shape: F, scale: F) -> GammaSmallShape { - GammaSmallShape { - inv_shape: F::one() / shape, - large_shape: GammaLargeShape::new_raw(shape + F::one(), scale), - } - } -} - -impl GammaLargeShape -where - F: Float, - StandardNormal: Distribution, - Open01: Distribution, -{ - fn new_raw(shape: F, scale: F) -> GammaLargeShape { - let d = shape - F::from(1. / 3.).unwrap(); - GammaLargeShape { - scale, - c: F::one() / (F::from(9.).unwrap() * d).sqrt(), - d, - } - } -} - -impl Distribution for Gamma -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - fn sample(&self, rng: &mut R) -> F { - match self.repr { - Small(ref g) => g.sample(rng), - One(ref g) => g.sample(rng), - Large(ref g) => g.sample(rng), - } - } -} -impl Distribution for GammaSmallShape -where - F: Float, - StandardNormal: Distribution, - Open01: Distribution, -{ - fn sample(&self, rng: &mut R) -> F { - let u: F = rng.sample(Open01); - - self.large_shape.sample(rng) * u.powf(self.inv_shape) - } -} -impl Distribution for GammaLargeShape -where - F: Float, - StandardNormal: Distribution, - Open01: Distribution, -{ - fn sample(&self, rng: &mut R) -> F { - // Marsaglia & Tsang method, 2000 - loop { - let x: F = rng.sample(StandardNormal); - let v_cbrt = F::one() + self.c * x; - if v_cbrt <= F::zero() { - // a^3 <= 0 iff a <= 0 - continue; - } - - let v = v_cbrt * v_cbrt * v_cbrt; - let u: F = rng.sample(Open01); - - let x_sqr = x * x; - if u < F::one() - F::from(0.0331).unwrap() * x_sqr * x_sqr - || u.ln() < F::from(0.5).unwrap() * x_sqr + self.d * (F::one() - v + v.ln()) - { - return self.d * v * self.scale; - } - } - } -} - -/// The chi-squared distribution `χ²(k)`, where `k` is the degrees of -/// freedom. -/// -/// For `k > 0` integral, this distribution is the sum of the squares -/// of `k` independent standard normal random variables. For other -/// `k`, this uses the equivalent characterisation -/// `χ²(k) = Gamma(k/2, 2)`. -/// -/// # Example -/// -/// ``` -/// use rand_distr::{ChiSquared, Distribution}; -/// -/// let chi = ChiSquared::new(11.0).unwrap(); -/// let v = chi.sample(&mut rand::thread_rng()); -/// println!("{} is from a χ²(11) distribution", v) -/// ``` -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub struct ChiSquared -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - repr: ChiSquaredRepr, -} - -/// Error type returned from `ChiSquared::new` and `StudentT::new`. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub enum ChiSquaredError { - /// `0.5 * k <= 0` or `nan`. - DoFTooSmall, -} - -impl fmt::Display for ChiSquaredError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - ChiSquaredError::DoFTooSmall => { - "degrees-of-freedom k is not positive in chi-squared distribution" - } - }) - } -} - -#[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] -impl std::error::Error for ChiSquaredError {} - -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -enum ChiSquaredRepr -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - // k == 1, Gamma(alpha, ..) is particularly slow for alpha < 1, - // e.g. when alpha = 1/2 as it would be for this case, so special- - // casing and using the definition of N(0,1)^2 is faster. - DoFExactlyOne, - DoFAnythingElse(Gamma), -} - -impl ChiSquared -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - /// Create a new chi-squared distribution with degrees-of-freedom - /// `k`. - pub fn new(k: F) -> Result, ChiSquaredError> { - let repr = if k == F::one() { - DoFExactlyOne - } else { - if !(F::from(0.5).unwrap() * k > F::zero()) { - return Err(ChiSquaredError::DoFTooSmall); - } - DoFAnythingElse(Gamma::new(F::from(0.5).unwrap() * k, F::from(2.0).unwrap()).unwrap()) - }; - Ok(ChiSquared { repr }) - } -} -impl Distribution for ChiSquared -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - fn sample(&self, rng: &mut R) -> F { - match self.repr { - DoFExactlyOne => { - // k == 1 => N(0,1)^2 - let norm: F = rng.sample(StandardNormal); - norm * norm - } - DoFAnythingElse(ref g) => g.sample(rng), - } - } -} - -/// The Fisher F distribution `F(m, n)`. -/// -/// This distribution is equivalent to the ratio of two normalised -/// chi-squared distributions, that is, `F(m,n) = (χ²(m)/m) / -/// (χ²(n)/n)`. -/// -/// # Example -/// -/// ``` -/// use rand_distr::{FisherF, Distribution}; -/// -/// let f = FisherF::new(2.0, 32.0).unwrap(); -/// let v = f.sample(&mut rand::thread_rng()); -/// println!("{} is from an F(2, 32) distribution", v) -/// ``` -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub struct FisherF -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - numer: ChiSquared, - denom: ChiSquared, - // denom_dof / numer_dof so that this can just be a straight - // multiplication, rather than a division. - dof_ratio: F, -} - -/// Error type returned from `FisherF::new`. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub enum FisherFError { - /// `m <= 0` or `nan`. - MTooSmall, - /// `n <= 0` or `nan`. - NTooSmall, -} - -impl fmt::Display for FisherFError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - FisherFError::MTooSmall => "m is not positive in Fisher F distribution", - FisherFError::NTooSmall => "n is not positive in Fisher F distribution", - }) - } -} - -#[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] -impl std::error::Error for FisherFError {} - -impl FisherF -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - /// Create a new `FisherF` distribution, with the given parameter. - pub fn new(m: F, n: F) -> Result, FisherFError> { - let zero = F::zero(); - if !(m > zero) { - return Err(FisherFError::MTooSmall); - } - if !(n > zero) { - return Err(FisherFError::NTooSmall); - } - - Ok(FisherF { - numer: ChiSquared::new(m).unwrap(), - denom: ChiSquared::new(n).unwrap(), - dof_ratio: n / m, - }) - } -} -impl Distribution for FisherF -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - fn sample(&self, rng: &mut R) -> F { - self.numer.sample(rng) / self.denom.sample(rng) * self.dof_ratio - } -} - -/// The Student t distribution, `t(nu)`, where `nu` is the degrees of -/// freedom. -/// -/// # Example -/// -/// ``` -/// use rand_distr::{StudentT, Distribution}; -/// -/// let t = StudentT::new(11.0).unwrap(); -/// let v = t.sample(&mut rand::thread_rng()); -/// println!("{} is from a t(11) distribution", v) -/// ``` -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub struct StudentT -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - chi: ChiSquared, - dof: F, -} - -impl StudentT -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - /// Create a new Student t distribution with `n` degrees of - /// freedom. - pub fn new(n: F) -> Result, ChiSquaredError> { - Ok(StudentT { - chi: ChiSquared::new(n)?, - dof: n, - }) - } -} -impl Distribution for StudentT -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - fn sample(&self, rng: &mut R) -> F { - let norm: F = rng.sample(StandardNormal); - norm * (self.dof / self.chi.sample(rng)).sqrt() - } -} - -/// The algorithm used for sampling the Beta distribution. -/// -/// Reference: -/// -/// R. C. H. Cheng (1978). -/// Generating beta variates with nonintegral shape parameters. -/// Communications of the ACM 21, 317-322. -/// https://doi.org/10.1145/359460.359482 -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -enum BetaAlgorithm { - BB(BB), - BC(BC), -} - -/// Algorithm BB for `min(alpha, beta) > 1`. -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -struct BB { - alpha: N, - beta: N, - gamma: N, -} - -/// Algorithm BC for `min(alpha, beta) <= 1`. -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -struct BC { - alpha: N, - beta: N, - kappa1: N, - kappa2: N, -} - -/// The Beta distribution with shape parameters `alpha` and `beta`. -/// -/// # Example -/// -/// ``` -/// use rand_distr::{Distribution, Beta}; -/// -/// let beta = Beta::new(2.0, 5.0).unwrap(); -/// let v = beta.sample(&mut rand::thread_rng()); -/// println!("{} is from a Beta(2, 5) distribution", v); -/// ``` -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub struct Beta -where - F: Float, - Open01: Distribution, -{ - a: F, b: F, switched_params: bool, - algorithm: BetaAlgorithm, -} - -/// Error type returned from `Beta::new`. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub enum BetaError { - /// `alpha <= 0` or `nan`. - AlphaTooSmall, - /// `beta <= 0` or `nan`. - BetaTooSmall, -} - -impl fmt::Display for BetaError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - BetaError::AlphaTooSmall => "alpha is not positive in beta distribution", - BetaError::BetaTooSmall => "beta is not positive in beta distribution", - }) - } -} - -#[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] -impl std::error::Error for BetaError {} - -impl Beta -where - F: Float, - Open01: Distribution, -{ - /// Construct an object representing the `Beta(alpha, beta)` - /// distribution. - pub fn new(alpha: F, beta: F) -> Result, BetaError> { - if !(alpha > F::zero()) { - return Err(BetaError::AlphaTooSmall); - } - if !(beta > F::zero()) { - return Err(BetaError::BetaTooSmall); - } - // From now on, we use the notation from the reference, - // i.e. `alpha` and `beta` are renamed to `a0` and `b0`. - let (a0, b0) = (alpha, beta); - let (a, b, switched_params) = if a0 < b0 { - (a0, b0, false) - } else { - (b0, a0, true) - }; - if a > F::one() { - // Algorithm BB - let alpha = a + b; - let beta = ((alpha - F::from(2.).unwrap()) - / (F::from(2.).unwrap()*a*b - alpha)).sqrt(); - let gamma = a + F::one() / beta; - - Ok(Beta { - a, b, switched_params, - algorithm: BetaAlgorithm::BB(BB { - alpha, beta, gamma, - }) - }) - } else { - // Algorithm BC - // - // Here `a` is the maximum instead of the minimum. - let (a, b, switched_params) = (b, a, !switched_params); - let alpha = a + b; - let beta = F::one() / b; - let delta = F::one() + a - b; - let kappa1 = delta - * (F::from(1. / 18. / 4.).unwrap() + F::from(3. / 18. / 4.).unwrap()*b) - / (a*beta - F::from(14. / 18.).unwrap()); - let kappa2 = F::from(0.25).unwrap() - + (F::from(0.5).unwrap() + F::from(0.25).unwrap()/delta)*b; - - Ok(Beta { - a, b, switched_params, - algorithm: BetaAlgorithm::BC(BC { - alpha, beta, kappa1, kappa2, - }) - }) - } - } -} - -impl Distribution for Beta -where - F: Float, - Open01: Distribution, -{ - fn sample(&self, rng: &mut R) -> F { - let mut w; - match self.algorithm { - BetaAlgorithm::BB(algo) => { - loop { - // 1. - let u1 = rng.sample(Open01); - let u2 = rng.sample(Open01); - let v = algo.beta * (u1 / (F::one() - u1)).ln(); - w = self.a * v.exp(); - let z = u1*u1 * u2; - let r = algo.gamma * v - F::from(4.).unwrap().ln(); - let s = self.a + r - w; - // 2. - if s + F::one() + F::from(5.).unwrap().ln() - >= F::from(5.).unwrap() * z { - break; - } - // 3. - let t = z.ln(); - if s >= t { - break; - } - // 4. - if !(r + algo.alpha * (algo.alpha / (self.b + w)).ln() < t) { - break; - } - } - }, - BetaAlgorithm::BC(algo) => { - loop { - let z; - // 1. - let u1 = rng.sample(Open01); - let u2 = rng.sample(Open01); - if u1 < F::from(0.5).unwrap() { - // 2. - let y = u1 * u2; - z = u1 * y; - if F::from(0.25).unwrap() * u2 + z - y >= algo.kappa1 { - continue; - } - } else { - // 3. - z = u1 * u1 * u2; - if z <= F::from(0.25).unwrap() { - let v = algo.beta * (u1 / (F::one() - u1)).ln(); - w = self.a * v.exp(); - break; - } - // 4. - if z >= algo.kappa2 { - continue; - } - } - // 5. - let v = algo.beta * (u1 / (F::one() - u1)).ln(); - w = self.a * v.exp(); - if !(algo.alpha * ((algo.alpha / (self.b + w)).ln() + v) - - F::from(4.).unwrap().ln() < z.ln()) { - break; - }; - } - }, - }; - // 5. for BB, 6. for BC - if !self.switched_params { - if w == F::infinity() { - // Assuming `b` is finite, for large `w`: - return F::one(); - } - w / (self.b + w) - } else { - self.b / (self.b + w) - } - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_chi_squared_one() { - let chi = ChiSquared::new(1.0).unwrap(); - let mut rng = crate::test::rng(201); - for _ in 0..1000 { - chi.sample(&mut rng); - } - } - #[test] - fn test_chi_squared_small() { - let chi = ChiSquared::new(0.5).unwrap(); - let mut rng = crate::test::rng(202); - for _ in 0..1000 { - chi.sample(&mut rng); - } - } - #[test] - fn test_chi_squared_large() { - let chi = ChiSquared::new(30.0).unwrap(); - let mut rng = crate::test::rng(203); - for _ in 0..1000 { - chi.sample(&mut rng); - } - } - #[test] - #[should_panic] - fn test_chi_squared_invalid_dof() { - ChiSquared::new(-1.0).unwrap(); - } - - #[test] - fn test_f() { - let f = FisherF::new(2.0, 32.0).unwrap(); - let mut rng = crate::test::rng(204); - for _ in 0..1000 { - f.sample(&mut rng); - } - } - - #[test] - fn test_t() { - let t = StudentT::new(11.0).unwrap(); - let mut rng = crate::test::rng(205); - for _ in 0..1000 { - t.sample(&mut rng); - } - } - - #[test] - fn test_beta() { - let beta = Beta::new(1.0, 2.0).unwrap(); - let mut rng = crate::test::rng(201); - for _ in 0..1000 { - beta.sample(&mut rng); - } - } - - #[test] - #[should_panic] - fn test_beta_invalid_dof() { - Beta::new(0., 0.).unwrap(); - } - - #[test] - fn test_beta_small_param() { - let beta = Beta::::new(1e-3, 1e-3).unwrap(); - let mut rng = crate::test::rng(206); - for i in 0..1000 { - assert!(!beta.sample(&mut rng).is_nan(), "failed at i={}", i); - } - } - - #[test] - fn gamma_distributions_can_be_compared() { - assert_eq!(Gamma::new(1.0, 2.0), Gamma::new(1.0, 2.0)); - } - - #[test] - fn beta_distributions_can_be_compared() { - assert_eq!(Beta::new(1.0, 2.0), Beta::new(1.0, 2.0)); - } - - #[test] - fn chi_squared_distributions_can_be_compared() { - assert_eq!(ChiSquared::new(1.0), ChiSquared::new(1.0)); - } - - #[test] - fn fisher_f_distributions_can_be_compared() { - assert_eq!(FisherF::new(1.0, 2.0), FisherF::new(1.0, 2.0)); - } - - #[test] - fn student_t_distributions_can_be_compared() { - assert_eq!(StudentT::new(1.0), StudentT::new(1.0)); - } -} diff --git a/rand_distr/src/geometric.rs b/rand_distr/src/geometric.rs deleted file mode 100644 index 6ee64a77d98..00000000000 --- a/rand_distr/src/geometric.rs +++ /dev/null @@ -1,244 +0,0 @@ -//! The geometric distribution. - -use crate::Distribution; -use rand::Rng; -use core::fmt; -#[allow(unused_imports)] -use num_traits::Float; - -/// The geometric distribution `Geometric(p)` bounded to `[0, u64::MAX]`. -/// -/// This is the probability distribution of the number of failures before the -/// first success in a series of Bernoulli trials. It has the density function -/// `f(k) = (1 - p)^k p` for `k >= 0`, where `p` is the probability of success -/// on each trial. -/// -/// This is the discrete analogue of the [exponential distribution](crate::Exp). -/// -/// Note that [`StandardGeometric`](crate::StandardGeometric) is an optimised -/// implementation for `p = 0.5`. -/// -/// # Example -/// -/// ``` -/// use rand_distr::{Geometric, Distribution}; -/// -/// let geo = Geometric::new(0.25).unwrap(); -/// let v = geo.sample(&mut rand::thread_rng()); -/// println!("{} is from a Geometric(0.25) distribution", v); -/// ``` -#[derive(Copy, Clone, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct Geometric -{ - p: f64, - pi: f64, - k: u64 -} - -/// Error type returned from `Geometric::new`. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum Error { - /// `p < 0 || p > 1` or `nan` - InvalidProbability, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Error::InvalidProbability => "p is NaN or outside the interval [0, 1] in geometric distribution", - }) - } -} - -#[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] -impl std::error::Error for Error {} - -impl Geometric { - /// Construct a new `Geometric` with the given shape parameter `p` - /// (probability of success on each trial). - pub fn new(p: f64) -> Result { - if !p.is_finite() || !(0.0..=1.0).contains(&p) { - Err(Error::InvalidProbability) - } else if p == 0.0 || p >= 2.0 / 3.0 { - Ok(Geometric { p, pi: p, k: 0 }) - } else { - let (pi, k) = { - // choose smallest k such that pi = (1 - p)^(2^k) <= 0.5 - let mut k = 1; - let mut pi = (1.0 - p).powi(2); - while pi > 0.5 { - k += 1; - pi = pi * pi; - } - (pi, k) - }; - - Ok(Geometric { p, pi, k }) - } - } -} - -impl Distribution for Geometric -{ - fn sample(&self, rng: &mut R) -> u64 { - if self.p >= 2.0 / 3.0 { - // use the trivial algorithm: - let mut failures = 0; - loop { - let u = rng.gen::(); - if u <= self.p { break; } - failures += 1; - } - return failures; - } - - if self.p == 0.0 { return core::u64::MAX; } - - let Geometric { p, pi, k } = *self; - - // Based on the algorithm presented in section 3 of - // Karl Bringmann and Tobias Friedrich (July 2013) - Exact and Efficient - // Generation of Geometric Random Variates and Random Graphs, published - // in International Colloquium on Automata, Languages and Programming - // (pp.267-278) - // https://people.mpi-inf.mpg.de/~kbringma/paper/2013ICALP-1.pdf - - // Use the trivial algorithm to sample D from Geo(pi) = Geo(p) / 2^k: - let d = { - let mut failures = 0; - while rng.gen::() < pi { - failures += 1; - } - failures - }; - - // Use rejection sampling for the remainder M from Geo(p) % 2^k: - // choose M uniformly from [0, 2^k), but reject with probability (1 - p)^M - // NOTE: The paper suggests using bitwise sampling here, which is - // currently unsupported, but should improve performance by requiring - // fewer iterations on average. ~ October 28, 2020 - let m = loop { - let m = rng.gen::() & ((1 << k) - 1); - let p_reject = if m <= core::i32::MAX as u64 { - (1.0 - p).powi(m as i32) - } else { - (1.0 - p).powf(m as f64) - }; - - let u = rng.gen::(); - if u < p_reject { - break m; - } - }; - - (d << k) + m - } -} - -/// Samples integers according to the geometric distribution with success -/// probability `p = 0.5`. This is equivalent to `Geometeric::new(0.5)`, -/// but faster. -/// -/// See [`Geometric`](crate::Geometric) for the general geometric distribution. -/// -/// Implemented via iterated -/// [`Rng::gen::().leading_zeros()`](Rng::gen::().leading_zeros()). -/// -/// # Example -/// ``` -/// use rand::prelude::*; -/// use rand_distr::StandardGeometric; -/// -/// let v = StandardGeometric.sample(&mut thread_rng()); -/// println!("{} is from a Geometric(0.5) distribution", v); -/// ``` -#[derive(Copy, Clone, Debug)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct StandardGeometric; - -impl Distribution for StandardGeometric { - fn sample(&self, rng: &mut R) -> u64 { - let mut result = 0; - loop { - let x = rng.gen::().leading_zeros() as u64; - result += x; - if x < 64 { break; } - } - result - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_geo_invalid_p() { - assert!(Geometric::new(core::f64::NAN).is_err()); - assert!(Geometric::new(core::f64::INFINITY).is_err()); - assert!(Geometric::new(core::f64::NEG_INFINITY).is_err()); - - assert!(Geometric::new(-0.5).is_err()); - assert!(Geometric::new(0.0).is_ok()); - assert!(Geometric::new(1.0).is_ok()); - assert!(Geometric::new(2.0).is_err()); - } - - fn test_geo_mean_and_variance(p: f64, rng: &mut R) { - let distr = Geometric::new(p).unwrap(); - - let expected_mean = (1.0 - p) / p; - let expected_variance = (1.0 - p) / (p * p); - - let mut results = [0.0; 10000]; - for i in results.iter_mut() { - *i = distr.sample(rng) as f64; - } - - let mean = results.iter().sum::() / results.len() as f64; - assert!((mean - expected_mean).abs() < expected_mean / 40.0); - - let variance = - results.iter().map(|x| (x - mean) * (x - mean)).sum::() / results.len() as f64; - assert!((variance - expected_variance).abs() < expected_variance / 10.0); - } - - #[test] - fn test_geometric() { - let mut rng = crate::test::rng(12345); - - test_geo_mean_and_variance(0.10, &mut rng); - test_geo_mean_and_variance(0.25, &mut rng); - test_geo_mean_and_variance(0.50, &mut rng); - test_geo_mean_and_variance(0.75, &mut rng); - test_geo_mean_and_variance(0.90, &mut rng); - } - - #[test] - fn test_standard_geometric() { - let mut rng = crate::test::rng(654321); - - let distr = StandardGeometric; - let expected_mean = 1.0; - let expected_variance = 2.0; - - let mut results = [0.0; 1000]; - for i in results.iter_mut() { - *i = distr.sample(&mut rng) as f64; - } - - let mean = results.iter().sum::() / results.len() as f64; - assert!((mean - expected_mean).abs() < expected_mean / 50.0); - - let variance = - results.iter().map(|x| (x - mean) * (x - mean)).sum::() / results.len() as f64; - assert!((variance - expected_variance).abs() < expected_variance / 10.0); - } - - #[test] - fn geometric_distributions_can_be_compared() { - assert_eq!(Geometric::new(1.0), Geometric::new(1.0)); - } -} diff --git a/rand_distr/src/gumbel.rs b/rand_distr/src/gumbel.rs deleted file mode 100644 index b254919f3b8..00000000000 --- a/rand_distr/src/gumbel.rs +++ /dev/null @@ -1,160 +0,0 @@ -// Copyright 2021 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! The Gumbel distribution. - -use crate::{Distribution, OpenClosed01}; -use core::fmt; -use num_traits::Float; -use rand::Rng; - -/// Samples floating-point numbers according to the Gumbel distribution -/// -/// This distribution has density function: -/// `f(x) = exp(-(z + exp(-z))) / σ`, where `z = (x - μ) / σ`, -/// `μ` is the location parameter, and `σ` the scale parameter. -/// -/// # Example -/// ``` -/// use rand::prelude::*; -/// use rand_distr::Gumbel; -/// -/// let val: f64 = thread_rng().sample(Gumbel::new(0.0, 1.0).unwrap()); -/// println!("{}", val); -/// ``` -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct Gumbel -where - F: Float, - OpenClosed01: Distribution, -{ - location: F, - scale: F, -} - -/// Error type returned from `Gumbel::new`. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum Error { - /// location is infinite or NaN - LocationNotFinite, - /// scale is not finite positive number - ScaleNotPositive, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Error::ScaleNotPositive => "scale is not positive and finite in Gumbel distribution", - Error::LocationNotFinite => "location is not finite in Gumbel distribution", - }) - } -} - -#[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] -impl std::error::Error for Error {} - -impl Gumbel -where - F: Float, - OpenClosed01: Distribution, -{ - /// Construct a new `Gumbel` distribution with given `location` and `scale`. - pub fn new(location: F, scale: F) -> Result, Error> { - if scale <= F::zero() || scale.is_infinite() || scale.is_nan() { - return Err(Error::ScaleNotPositive); - } - if location.is_infinite() || location.is_nan() { - return Err(Error::LocationNotFinite); - } - Ok(Gumbel { location, scale }) - } -} - -impl Distribution for Gumbel -where - F: Float, - OpenClosed01: Distribution, -{ - fn sample(&self, rng: &mut R) -> F { - let x: F = rng.sample(OpenClosed01); - self.location - self.scale * (-x.ln()).ln() - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - #[should_panic] - fn test_zero_scale() { - Gumbel::new(0.0, 0.0).unwrap(); - } - - #[test] - #[should_panic] - fn test_infinite_scale() { - Gumbel::new(0.0, core::f64::INFINITY).unwrap(); - } - - #[test] - #[should_panic] - fn test_nan_scale() { - Gumbel::new(0.0, core::f64::NAN).unwrap(); - } - - #[test] - #[should_panic] - fn test_infinite_location() { - Gumbel::new(core::f64::INFINITY, 1.0).unwrap(); - } - - #[test] - #[should_panic] - fn test_nan_location() { - Gumbel::new(core::f64::NAN, 1.0).unwrap(); - } - - #[test] - fn test_sample_against_cdf() { - fn neg_log_log(x: f64) -> f64 { - -(-x.ln()).ln() - } - let location = 0.0; - let scale = 1.0; - let iterations = 100_000; - let increment = 1.0 / iterations as f64; - let probabilities = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]; - let mut quantiles = [0.0; 9]; - for (i, p) in probabilities.iter().enumerate() { - quantiles[i] = neg_log_log(*p); - } - let mut proportions = [0.0; 9]; - let d = Gumbel::new(location, scale).unwrap(); - let mut rng = crate::test::rng(1); - for _ in 0..iterations { - let replicate = d.sample(&mut rng); - for (i, q) in quantiles.iter().enumerate() { - if replicate < *q { - proportions[i] += increment; - } - } - } - assert!(proportions - .iter() - .zip(&probabilities) - .all(|(p_hat, p)| (p_hat - p).abs() < 0.003)) - } - - #[test] - fn gumbel_distributions_can_be_compared() { - assert_eq!(Gumbel::new(1.0, 2.0), Gumbel::new(1.0, 2.0)); - } -} diff --git a/rand_distr/src/hypergeometric.rs b/rand_distr/src/hypergeometric.rs deleted file mode 100644 index 73a8e91c75e..00000000000 --- a/rand_distr/src/hypergeometric.rs +++ /dev/null @@ -1,427 +0,0 @@ -//! The hypergeometric distribution. - -use crate::Distribution; -use rand::Rng; -use rand::distributions::uniform::Uniform; -use core::fmt; -#[allow(unused_imports)] -use num_traits::Float; - -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -enum SamplingMethod { - InverseTransform{ initial_p: f64, initial_x: i64 }, - RejectionAcceptance{ - m: f64, - a: f64, - lambda_l: f64, - lambda_r: f64, - x_l: f64, - x_r: f64, - p1: f64, - p2: f64, - p3: f64 - }, -} - -/// The hypergeometric distribution `Hypergeometric(N, K, n)`. -/// -/// This is the distribution of successes in samples of size `n` drawn without -/// replacement from a population of size `N` containing `K` success states. -/// It has the density function: -/// `f(k) = binomial(K, k) * binomial(N-K, n-k) / binomial(N, n)`, -/// where `binomial(a, b) = a! / (b! * (a - b)!)`. -/// -/// The [binomial distribution](crate::Binomial) is the analogous distribution -/// for sampling with replacement. It is a good approximation when the population -/// size is much larger than the sample size. -/// -/// # Example -/// -/// ``` -/// use rand_distr::{Distribution, Hypergeometric}; -/// -/// let hypergeo = Hypergeometric::new(60, 24, 7).unwrap(); -/// let v = hypergeo.sample(&mut rand::thread_rng()); -/// println!("{} is from a hypergeometric distribution", v); -/// ``` -#[derive(Copy, Clone, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct Hypergeometric { - n1: u64, - n2: u64, - k: u64, - offset_x: i64, - sign_x: i64, - sampling_method: SamplingMethod, -} - -/// Error type returned from `Hypergeometric::new`. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum Error { - /// `total_population_size` is too large, causing floating point underflow. - PopulationTooLarge, - /// `population_with_feature > total_population_size`. - ProbabilityTooLarge, - /// `sample_size > total_population_size`. - SampleSizeTooLarge, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Error::PopulationTooLarge => "total_population_size is too large causing underflow in geometric distribution", - Error::ProbabilityTooLarge => "population_with_feature > total_population_size in geometric distribution", - Error::SampleSizeTooLarge => "sample_size > total_population_size in geometric distribution", - }) - } -} - -#[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] -impl std::error::Error for Error {} - -// evaluate fact(numerator.0)*fact(numerator.1) / fact(denominator.0)*fact(denominator.1) -fn fraction_of_products_of_factorials(numerator: (u64, u64), denominator: (u64, u64)) -> f64 { - let min_top = u64::min(numerator.0, numerator.1); - let min_bottom = u64::min(denominator.0, denominator.1); - // the factorial of this will cancel out: - let min_all = u64::min(min_top, min_bottom); - - let max_top = u64::max(numerator.0, numerator.1); - let max_bottom = u64::max(denominator.0, denominator.1); - let max_all = u64::max(max_top, max_bottom); - - let mut result = 1.0; - for i in (min_all + 1)..=max_all { - if i <= min_top { - result *= i as f64; - } - - if i <= min_bottom { - result /= i as f64; - } - - if i <= max_top { - result *= i as f64; - } - - if i <= max_bottom { - result /= i as f64; - } - } - - result -} - -fn ln_of_factorial(v: f64) -> f64 { - // the paper calls for ln(v!), but also wants to pass in fractions, - // so we need to use Stirling's approximation to fill in the gaps: - v * v.ln() - v -} - -impl Hypergeometric { - /// Constructs a new `Hypergeometric` with the shape parameters - /// `N = total_population_size`, - /// `K = population_with_feature`, - /// `n = sample_size`. - #[allow(clippy::many_single_char_names)] // Same names as in the reference. - pub fn new(total_population_size: u64, population_with_feature: u64, sample_size: u64) -> Result { - if population_with_feature > total_population_size { - return Err(Error::ProbabilityTooLarge); - } - - if sample_size > total_population_size { - return Err(Error::SampleSizeTooLarge); - } - - // set-up constants as function of original parameters - let n = total_population_size; - let (mut sign_x, mut offset_x) = (1, 0); - let (n1, n2) = { - // switch around success and failure states if necessary to ensure n1 <= n2 - let population_without_feature = n - population_with_feature; - if population_with_feature > population_without_feature { - sign_x = -1; - offset_x = sample_size as i64; - (population_without_feature, population_with_feature) - } else { - (population_with_feature, population_without_feature) - } - }; - // when sampling more than half the total population, take the smaller - // group as sampled instead (we can then return n1-x instead). - // - // Note: the boundary condition given in the paper is `sample_size < n / 2`; - // we're deviating here, because when n is even, it doesn't matter whether - // we switch here or not, but when n is odd `n/2 < n - n/2`, so switching - // when `k == n/2`, we'd actually be taking the _larger_ group as sampled. - let k = if sample_size <= n / 2 { - sample_size - } else { - offset_x += n1 as i64 * sign_x; - sign_x *= -1; - n - sample_size - }; - - // Algorithm H2PE has bounded runtime only if `M - max(0, k-n2) >= 10`, - // where `M` is the mode of the distribution. - // Use algorithm HIN for the remaining parameter space. - // - // Voratas Kachitvichyanukul and Bruce W. Schmeiser. 1985. Computer - // generation of hypergeometric random variates. - // J. Statist. Comput. Simul. Vol.22 (August 1985), 127-145 - // https://www.researchgate.net/publication/233212638 - const HIN_THRESHOLD: f64 = 10.0; - let m = ((k + 1) as f64 * (n1 + 1) as f64 / (n + 2) as f64).floor(); - let sampling_method = if m - f64::max(0.0, k as f64 - n2 as f64) < HIN_THRESHOLD { - let (initial_p, initial_x) = if k < n2 { - (fraction_of_products_of_factorials((n2, n - k), (n, n2 - k)), 0) - } else { - (fraction_of_products_of_factorials((n1, k), (n, k - n2)), (k - n2) as i64) - }; - - if initial_p <= 0.0 || !initial_p.is_finite() { - return Err(Error::PopulationTooLarge); - } - - SamplingMethod::InverseTransform { initial_p, initial_x } - } else { - let a = ln_of_factorial(m) + - ln_of_factorial(n1 as f64 - m) + - ln_of_factorial(k as f64 - m) + - ln_of_factorial((n2 - k) as f64 + m); - - let numerator = (n - k) as f64 * k as f64 * n1 as f64 * n2 as f64; - let denominator = (n - 1) as f64 * n as f64 * n as f64; - let d = 1.5 * (numerator / denominator).sqrt() + 0.5; - - let x_l = m - d + 0.5; - let x_r = m + d + 0.5; - - let k_l = f64::exp(a - - ln_of_factorial(x_l) - - ln_of_factorial(n1 as f64 - x_l) - - ln_of_factorial(k as f64 - x_l) - - ln_of_factorial((n2 - k) as f64 + x_l)); - let k_r = f64::exp(a - - ln_of_factorial(x_r - 1.0) - - ln_of_factorial(n1 as f64 - x_r + 1.0) - - ln_of_factorial(k as f64 - x_r + 1.0) - - ln_of_factorial((n2 - k) as f64 + x_r - 1.0)); - - let numerator = x_l * ((n2 - k) as f64 + x_l); - let denominator = (n1 as f64 - x_l + 1.0) * (k as f64 - x_l + 1.0); - let lambda_l = -((numerator / denominator).ln()); - - let numerator = (n1 as f64 - x_r + 1.0) * (k as f64 - x_r + 1.0); - let denominator = x_r * ((n2 - k) as f64 + x_r); - let lambda_r = -((numerator / denominator).ln()); - - // the paper literally gives `p2 + kL/lambdaL` where it (probably) - // should have been `p2 <- p1 + kL/lambdaL`; another print error?! - let p1 = 2.0 * d; - let p2 = p1 + k_l / lambda_l; - let p3 = p2 + k_r / lambda_r; - - SamplingMethod::RejectionAcceptance { - m, a, lambda_l, lambda_r, x_l, x_r, p1, p2, p3 - } - }; - - Ok(Hypergeometric { n1, n2, k, offset_x, sign_x, sampling_method }) - } -} - -impl Distribution for Hypergeometric { - #[allow(clippy::many_single_char_names)] // Same names as in the reference. - fn sample(&self, rng: &mut R) -> u64 { - use SamplingMethod::*; - - let Hypergeometric { n1, n2, k, sign_x, offset_x, sampling_method } = *self; - let x = match sampling_method { - InverseTransform { initial_p: mut p, initial_x: mut x } => { - let mut u = rng.gen::(); - while u > p && x < k as i64 { // the paper erroneously uses `until n < p`, which doesn't make any sense - u -= p; - p *= ((n1 as i64 - x) * (k as i64 - x)) as f64; - p /= ((x + 1) * (n2 as i64 - k as i64 + 1 + x)) as f64; - x += 1; - } - x - }, - RejectionAcceptance { m, a, lambda_l, lambda_r, x_l, x_r, p1, p2, p3 } => { - let distr_region_select = Uniform::new(0.0, p3).unwrap(); - loop { - let (y, v) = loop { - let u = distr_region_select.sample(rng); - let v = rng.gen::(); // for the accept/reject decision - - if u <= p1 { - // Region 1, central bell - let y = (x_l + u).floor(); - break (y, v); - } else if u <= p2 { - // Region 2, left exponential tail - let y = (x_l + v.ln() / lambda_l).floor(); - if y as i64 >= i64::max(0, k as i64 - n2 as i64) { - let v = v * (u - p1) * lambda_l; - break (y, v); - } - } else { - // Region 3, right exponential tail - let y = (x_r - v.ln() / lambda_r).floor(); - if y as u64 <= u64::min(n1, k) { - let v = v * (u - p2) * lambda_r; - break (y, v); - } - } - }; - - // Step 4: Acceptance/Rejection Comparison - if m < 100.0 || y <= 50.0 { - // Step 4.1: evaluate f(y) via recursive relationship - let mut f = 1.0; - if m < y { - for i in (m as u64 + 1)..=(y as u64) { - f *= (n1 - i + 1) as f64 * (k - i + 1) as f64; - f /= i as f64 * (n2 - k + i) as f64; - } - } else { - for i in (y as u64 + 1)..=(m as u64) { - f *= i as f64 * (n2 - k + i) as f64; - f /= (n1 - i) as f64 * (k - i) as f64; - } - } - - if v <= f { break y as i64; } - } else { - // Step 4.2: Squeezing - let y1 = y + 1.0; - let ym = y - m; - let yn = n1 as f64 - y + 1.0; - let yk = k as f64 - y + 1.0; - let nk = n2 as f64 - k as f64 + y1; - let r = -ym / y1; - let s = ym / yn; - let t = ym / yk; - let e = -ym / nk; - let g = yn * yk / (y1 * nk) - 1.0; - let dg = if g < 0.0 { - 1.0 + g - } else { - 1.0 - }; - let gu = g * (1.0 + g * (-0.5 + g / 3.0)); - let gl = gu - g.powi(4) / (4.0 * dg); - let xm = m + 0.5; - let xn = n1 as f64 - m + 0.5; - let xk = k as f64 - m + 0.5; - let nm = n2 as f64 - k as f64 + xm; - let ub = xm * r * (1.0 + r * (-0.5 + r / 3.0)) + - xn * s * (1.0 + s * (-0.5 + s / 3.0)) + - xk * t * (1.0 + t * (-0.5 + t / 3.0)) + - nm * e * (1.0 + e * (-0.5 + e / 3.0)) + - y * gu - m * gl + 0.0034; - let av = v.ln(); - if av > ub { continue; } - let dr = if r < 0.0 { - xm * r.powi(4) / (1.0 + r) - } else { - xm * r.powi(4) - }; - let ds = if s < 0.0 { - xn * s.powi(4) / (1.0 + s) - } else { - xn * s.powi(4) - }; - let dt = if t < 0.0 { - xk * t.powi(4) / (1.0 + t) - } else { - xk * t.powi(4) - }; - let de = if e < 0.0 { - nm * e.powi(4) / (1.0 + e) - } else { - nm * e.powi(4) - }; - - if av < ub - 0.25*(dr + ds + dt + de) + (y + m)*(gl - gu) - 0.0078 { - break y as i64; - } - - // Step 4.3: Final Acceptance/Rejection Test - let av_critical = a - - ln_of_factorial(y) - - ln_of_factorial(n1 as f64 - y) - - ln_of_factorial(k as f64 - y) - - ln_of_factorial((n2 - k) as f64 + y); - if v.ln() <= av_critical { - break y as i64; - } - } - } - } - }; - - (offset_x + sign_x * x) as u64 - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_hypergeometric_invalid_params() { - assert!(Hypergeometric::new(100, 101, 5).is_err()); - assert!(Hypergeometric::new(100, 10, 101).is_err()); - assert!(Hypergeometric::new(100, 101, 101).is_err()); - assert!(Hypergeometric::new(100, 10, 5).is_ok()); - } - - fn test_hypergeometric_mean_and_variance(n: u64, k: u64, s: u64, rng: &mut R) - { - let distr = Hypergeometric::new(n, k, s).unwrap(); - - let expected_mean = s as f64 * k as f64 / n as f64; - let expected_variance = { - let numerator = (s * k * (n - k) * (n - s)) as f64; - let denominator = (n * n * (n - 1)) as f64; - numerator / denominator - }; - - let mut results = [0.0; 1000]; - for i in results.iter_mut() { - *i = distr.sample(rng) as f64; - } - - let mean = results.iter().sum::() / results.len() as f64; - assert!((mean - expected_mean).abs() < expected_mean / 50.0); - - let variance = - results.iter().map(|x| (x - mean) * (x - mean)).sum::() / results.len() as f64; - assert!((variance - expected_variance).abs() < expected_variance / 10.0); - } - - #[test] - fn test_hypergeometric() { - let mut rng = crate::test::rng(737); - - // exercise algorithm HIN: - test_hypergeometric_mean_and_variance(500, 400, 30, &mut rng); - test_hypergeometric_mean_and_variance(250, 200, 230, &mut rng); - test_hypergeometric_mean_and_variance(100, 20, 6, &mut rng); - test_hypergeometric_mean_and_variance(50, 10, 47, &mut rng); - - // exercise algorithm H2PE - test_hypergeometric_mean_and_variance(5000, 2500, 500, &mut rng); - test_hypergeometric_mean_and_variance(10100, 10000, 1000, &mut rng); - test_hypergeometric_mean_and_variance(100100, 100, 10000, &mut rng); - } - - #[test] - fn hypergeometric_distributions_can_be_compared() { - assert_eq!(Hypergeometric::new(1, 2, 3), Hypergeometric::new(1, 2, 3)); - } -} diff --git a/rand_distr/src/inverse_gaussian.rs b/rand_distr/src/inverse_gaussian.rs deleted file mode 100644 index ba845fd1505..00000000000 --- a/rand_distr/src/inverse_gaussian.rs +++ /dev/null @@ -1,117 +0,0 @@ -use crate::{Distribution, Standard, StandardNormal}; -use num_traits::Float; -use rand::Rng; -use core::fmt; - -/// Error type returned from `InverseGaussian::new` -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum Error { - /// `mean <= 0` or `nan`. - MeanNegativeOrNull, - /// `shape <= 0` or `nan`. - ShapeNegativeOrNull, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Error::MeanNegativeOrNull => "mean <= 0 or is NaN in inverse Gaussian distribution", - Error::ShapeNegativeOrNull => "shape <= 0 or is NaN in inverse Gaussian distribution", - }) - } -} - -#[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] -impl std::error::Error for Error {} - -/// The [inverse Gaussian distribution](https://en.wikipedia.org/wiki/Inverse_Gaussian_distribution) -#[derive(Debug, Clone, Copy, PartialEq)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct InverseGaussian -where - F: Float, - StandardNormal: Distribution, - Standard: Distribution, -{ - mean: F, - shape: F, -} - -impl InverseGaussian -where - F: Float, - StandardNormal: Distribution, - Standard: Distribution, -{ - /// Construct a new `InverseGaussian` distribution with the given mean and - /// shape. - pub fn new(mean: F, shape: F) -> Result, Error> { - let zero = F::zero(); - if !(mean > zero) { - return Err(Error::MeanNegativeOrNull); - } - - if !(shape > zero) { - return Err(Error::ShapeNegativeOrNull); - } - - Ok(Self { mean, shape }) - } -} - -impl Distribution for InverseGaussian -where - F: Float, - StandardNormal: Distribution, - Standard: Distribution, -{ - #[allow(clippy::many_single_char_names)] - fn sample(&self, rng: &mut R) -> F - where R: Rng + ?Sized { - let mu = self.mean; - let l = self.shape; - - let v: F = rng.sample(StandardNormal); - let y = mu * v * v; - - let mu_2l = mu / (F::from(2.).unwrap() * l); - - let x = mu + mu_2l * (y - (F::from(4.).unwrap() * l * y + y * y).sqrt()); - - let u: F = rng.gen(); - - if u <= mu / (mu + x) { - return x; - } - - mu * mu / x - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_inverse_gaussian() { - let inv_gauss = InverseGaussian::new(1.0, 1.0).unwrap(); - let mut rng = crate::test::rng(210); - for _ in 0..1000 { - inv_gauss.sample(&mut rng); - } - } - - #[test] - fn test_inverse_gaussian_invalid_param() { - assert!(InverseGaussian::new(-1.0, 1.0).is_err()); - assert!(InverseGaussian::new(-1.0, -1.0).is_err()); - assert!(InverseGaussian::new(1.0, -1.0).is_err()); - assert!(InverseGaussian::new(1.0, 1.0).is_ok()); - } - - #[test] - fn inverse_gaussian_distributions_can_be_compared() { - assert_eq!(InverseGaussian::new(1.0, 2.0), InverseGaussian::new(1.0, 2.0)); - } -} diff --git a/rand_distr/src/lib.rs b/rand_distr/src/lib.rs deleted file mode 100644 index dc155bb5d5d..00000000000 --- a/rand_distr/src/lib.rs +++ /dev/null @@ -1,221 +0,0 @@ -// Copyright 2019 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -#![doc( - html_logo_url = "https://www.rust-lang.org/logos/rust-logo-128x128-blk.png", - html_favicon_url = "https://www.rust-lang.org/favicon.ico", - html_root_url = "https://rust-random.github.io/rand/" -)] -#![forbid(unsafe_code)] -#![deny(missing_docs)] -#![deny(missing_debug_implementations)] -#![allow( - clippy::excessive_precision, - clippy::float_cmp, - clippy::unreadable_literal -)] -#![allow(clippy::neg_cmp_op_on_partial_ord)] // suggested fix too verbose -#![no_std] -#![cfg_attr(doc_cfg, feature(doc_cfg))] - -//! Generating random samples from probability distributions. -//! -//! ## Re-exports -//! -//! This crate is a super-set of the [`rand::distributions`] module. See the -//! [`rand::distributions`] module documentation for an overview of the core -//! [`Distribution`] trait and implementations. -//! -//! The following are re-exported: -//! -//! - The [`Distribution`] trait and [`DistIter`] helper type -//! - The [`Standard`], [`Alphanumeric`], [`Uniform`], [`OpenClosed01`], -//! [`Open01`], [`Bernoulli`], and [`WeightedIndex`] distributions -//! -//! ## Distributions -//! -//! This crate provides the following probability distributions: -//! -//! - Related to real-valued quantities that grow linearly -//! (e.g. errors, offsets): -//! - [`Normal`] distribution, and [`StandardNormal`] as a primitive -//! - [`SkewNormal`] distribution -//! - [`Cauchy`] distribution -//! - Related to Bernoulli trials (yes/no events, with a given probability): -//! - [`Binomial`] distribution -//! - [`Geometric`] distribution -//! - [`Hypergeometric`] distribution -//! - Related to positive real-valued quantities that grow exponentially -//! (e.g. prices, incomes, populations): -//! - [`LogNormal`] distribution -//! - Related to the occurrence of independent events at a given rate: -//! - [`Pareto`] distribution -//! - [`Poisson`] distribution -//! - [`Exp`]onential distribution, and [`Exp1`] as a primitive -//! - [`Weibull`] distribution -//! - [`Gumbel`] distribution -//! - [`Frechet`] distribution -//! - [`Zeta`] distribution -//! - [`Zipf`] distribution -//! - Gamma and derived distributions: -//! - [`Gamma`] distribution -//! - [`ChiSquared`] distribution -//! - [`StudentT`] distribution -//! - [`FisherF`] distribution -//! - Triangular distribution: -//! - [`Beta`] distribution -//! - [`Triangular`] distribution -//! - Multivariate probability distributions -//! - [`Dirichlet`] distribution -//! - [`UnitSphere`] distribution -//! - [`UnitBall`] distribution -//! - [`UnitCircle`] distribution -//! - [`UnitDisc`] distribution -//! - Alternative implementations for weighted index sampling -//! - [`WeightedAliasIndex`] distribution -//! - [`WeightedTreeIndex`] distribution -//! - Misc. distributions -//! - [`InverseGaussian`] distribution -//! - [`NormalInverseGaussian`] distribution - -#[cfg(feature = "alloc")] -extern crate alloc; - -#[cfg(feature = "std")] -extern crate std; - -// This is used for doc links: -#[allow(unused)] -use rand::Rng; - -pub use rand::distributions::{ - uniform, Alphanumeric, Bernoulli, BernoulliError, DistIter, Distribution, Open01, OpenClosed01, - Standard, Uniform, -}; - -pub use self::binomial::{Binomial, Error as BinomialError}; -pub use self::cauchy::{Cauchy, Error as CauchyError}; -#[cfg(feature = "alloc")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] -pub use self::dirichlet::{Dirichlet, Error as DirichletError}; -pub use self::exponential::{Error as ExpError, Exp, Exp1}; -pub use self::frechet::{Error as FrechetError, Frechet}; -pub use self::gamma::{ - Beta, BetaError, ChiSquared, ChiSquaredError, Error as GammaError, FisherF, FisherFError, - Gamma, StudentT, -}; -pub use self::geometric::{Error as GeoError, Geometric, StandardGeometric}; -pub use self::gumbel::{Error as GumbelError, Gumbel}; -pub use self::hypergeometric::{Error as HyperGeoError, Hypergeometric}; -pub use self::inverse_gaussian::{Error as InverseGaussianError, InverseGaussian}; -pub use self::normal::{Error as NormalError, LogNormal, Normal, StandardNormal}; -pub use self::normal_inverse_gaussian::{ - Error as NormalInverseGaussianError, NormalInverseGaussian, -}; -pub use self::pareto::{Error as ParetoError, Pareto}; -pub use self::pert::{Pert, PertError}; -pub use self::poisson::{Error as PoissonError, Poisson}; -pub use self::skew_normal::{Error as SkewNormalError, SkewNormal}; -pub use self::triangular::{Triangular, TriangularError}; -pub use self::unit_ball::UnitBall; -pub use self::unit_circle::UnitCircle; -pub use self::unit_disc::UnitDisc; -pub use self::unit_sphere::UnitSphere; -pub use self::weibull::{Error as WeibullError, Weibull}; -pub use self::zipf::{Zeta, ZetaError, Zipf, ZipfError}; -#[cfg(feature = "alloc")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] -pub use rand::distributions::{WeightError, WeightedIndex}; -#[cfg(feature = "alloc")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] -pub use weighted_alias::WeightedAliasIndex; -#[cfg(feature = "alloc")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] -pub use weighted_tree::WeightedTreeIndex; - -pub use num_traits; - -#[cfg(test)] -#[macro_use] -mod test { - // Notes on testing - // - // Testing random number distributions correctly is hard. The following - // testing is desired: - // - // - Construction: test initialisation with a few valid parameter sets. - // - Erroneous usage: test that incorrect usage generates an error. - // - Vector: test that usage with fixed inputs (including RNG) generates a - // fixed output sequence on all platforms. - // - Correctness at fixed points (optional): using a specific mock RNG, - // check that specific values are sampled (e.g. end-points and median of - // distribution). - // - Correctness of PDF (extra): generate a histogram of samples within a - // certain range, and check this approximates the PDF. These tests are - // expected to be expensive, and should be behind a feature-gate. - // - // TODO: Vector and correctness tests are largely absent so far. - // NOTE: Some distributions have tests checking only that samples can be - // generated. This is redundant with vector and correctness tests. - - /// Construct a deterministic RNG with the given seed - pub fn rng(seed: u64) -> impl rand::RngCore { - // For tests, we want a statistically good, fast, reproducible RNG. - // PCG32 will do fine, and will be easy to embed if we ever need to. - const INC: u64 = 11634580027462260723; - rand_pcg::Pcg32::new(seed, INC) - } - - /// Assert that two numbers are almost equal to each other. - /// - /// On panic, this macro will print the values of the expressions with their - /// debug representations. - macro_rules! assert_almost_eq { - ($a:expr, $b:expr, $prec:expr) => { - let diff = ($a - $b).abs(); - assert!(diff <= $prec, - "assertion failed: `abs(left - right) = {:.1e} < {:e}`, \ - (left: `{}`, right: `{}`)", - diff, $prec, $a, $b - ); - }; - } -} - -#[cfg(feature = "alloc")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] -pub mod weighted_alias; -#[cfg(feature = "alloc")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] -pub mod weighted_tree; - -mod binomial; -mod cauchy; -mod dirichlet; -mod exponential; -mod frechet; -mod gamma; -mod geometric; -mod gumbel; -mod hypergeometric; -mod inverse_gaussian; -mod normal; -mod normal_inverse_gaussian; -mod pareto; -mod pert; -mod poisson; -mod skew_normal; -mod triangular; -mod unit_ball; -mod unit_circle; -mod unit_disc; -mod unit_sphere; -mod utils; -mod weibull; -mod ziggurat_tables; -mod zipf; diff --git a/rand_distr/src/normal.rs b/rand_distr/src/normal.rs deleted file mode 100644 index b3b801dfed9..00000000000 --- a/rand_distr/src/normal.rs +++ /dev/null @@ -1,381 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// Copyright 2013 The Rust Project Developers. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! The normal and derived distributions. - -use crate::utils::ziggurat; -use num_traits::Float; -use crate::{ziggurat_tables, Distribution, Open01}; -use rand::Rng; -use core::fmt; - -/// Samples floating-point numbers according to the normal distribution -/// `N(0, 1)` (a.k.a. a standard normal, or Gaussian). This is equivalent to -/// `Normal::new(0.0, 1.0)` but faster. -/// -/// See `Normal` for the general normal distribution. -/// -/// Implemented via the ZIGNOR variant[^1] of the Ziggurat method. -/// -/// [^1]: Jurgen A. Doornik (2005). [*An Improved Ziggurat Method to -/// Generate Normal Random Samples*]( -/// https://www.doornik.com/research/ziggurat.pdf). -/// Nuffield College, Oxford -/// -/// # Example -/// ``` -/// use rand::prelude::*; -/// use rand_distr::StandardNormal; -/// -/// let val: f64 = thread_rng().sample(StandardNormal); -/// println!("{}", val); -/// ``` -#[derive(Clone, Copy, Debug)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct StandardNormal; - -impl Distribution for StandardNormal { - #[inline] - fn sample(&self, rng: &mut R) -> f32 { - // TODO: use optimal 32-bit implementation - let x: f64 = self.sample(rng); - x as f32 - } -} - -impl Distribution for StandardNormal { - fn sample(&self, rng: &mut R) -> f64 { - #[inline] - fn pdf(x: f64) -> f64 { - (-x * x / 2.0).exp() - } - #[inline] - fn zero_case(rng: &mut R, u: f64) -> f64 { - // compute a random number in the tail by hand - - // strange initial conditions, because the loop is not - // do-while, so the condition should be true on the first - // run, they get overwritten anyway (0 < 1, so these are - // good). - let mut x = 1.0f64; - let mut y = 0.0f64; - - while -2.0 * y < x * x { - let x_: f64 = rng.sample(Open01); - let y_: f64 = rng.sample(Open01); - - x = x_.ln() / ziggurat_tables::ZIG_NORM_R; - y = y_.ln(); - } - - if u < 0.0 { - x - ziggurat_tables::ZIG_NORM_R - } else { - ziggurat_tables::ZIG_NORM_R - x - } - } - - ziggurat( - rng, - true, // this is symmetric - &ziggurat_tables::ZIG_NORM_X, - &ziggurat_tables::ZIG_NORM_F, - pdf, - zero_case, - ) - } -} - -/// The normal distribution `N(mean, std_dev**2)`. -/// -/// This uses the ZIGNOR variant of the Ziggurat method, see [`StandardNormal`] -/// for more details. -/// -/// Note that [`StandardNormal`] is an optimised implementation for mean 0, and -/// standard deviation 1. -/// -/// # Example -/// -/// ``` -/// use rand_distr::{Normal, Distribution}; -/// -/// // mean 2, standard deviation 3 -/// let normal = Normal::new(2.0, 3.0).unwrap(); -/// let v = normal.sample(&mut rand::thread_rng()); -/// println!("{} is from a N(2, 9) distribution", v) -/// ``` -/// -/// [`StandardNormal`]: crate::StandardNormal -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct Normal -where F: Float, StandardNormal: Distribution -{ - mean: F, - std_dev: F, -} - -/// Error type returned from `Normal::new` and `LogNormal::new`. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum Error { - /// The mean value is too small (log-normal samples must be positive) - MeanTooSmall, - /// The standard deviation or other dispersion parameter is not finite. - BadVariance, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Error::MeanTooSmall => "mean < 0 or NaN in log-normal distribution", - Error::BadVariance => "variation parameter is non-finite in (log)normal distribution", - }) - } -} - -#[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] -impl std::error::Error for Error {} - -impl Normal -where F: Float, StandardNormal: Distribution -{ - /// Construct, from mean and standard deviation - /// - /// Parameters: - /// - /// - mean (`μ`, unrestricted) - /// - standard deviation (`σ`, must be finite) - #[inline] - pub fn new(mean: F, std_dev: F) -> Result, Error> { - if !std_dev.is_finite() { - return Err(Error::BadVariance); - } - Ok(Normal { mean, std_dev }) - } - - /// Construct, from mean and coefficient of variation - /// - /// Parameters: - /// - /// - mean (`μ`, unrestricted) - /// - coefficient of variation (`cv = abs(σ / μ)`) - #[inline] - pub fn from_mean_cv(mean: F, cv: F) -> Result, Error> { - if !cv.is_finite() || cv < F::zero() { - return Err(Error::BadVariance); - } - let std_dev = cv * mean; - Ok(Normal { mean, std_dev }) - } - - /// Sample from a z-score - /// - /// This may be useful for generating correlated samples `x1` and `x2` - /// from two different distributions, as follows. - /// ``` - /// # use rand::prelude::*; - /// # use rand_distr::{Normal, StandardNormal}; - /// let mut rng = thread_rng(); - /// let z = StandardNormal.sample(&mut rng); - /// let x1 = Normal::new(0.0, 1.0).unwrap().from_zscore(z); - /// let x2 = Normal::new(2.0, -3.0).unwrap().from_zscore(z); - /// ``` - #[inline] - pub fn from_zscore(&self, zscore: F) -> F { - self.mean + self.std_dev * zscore - } - - /// Returns the mean (`μ`) of the distribution. - pub fn mean(&self) -> F { - self.mean - } - - /// Returns the standard deviation (`σ`) of the distribution. - pub fn std_dev(&self) -> F { - self.std_dev - } -} - -impl Distribution for Normal -where F: Float, StandardNormal: Distribution -{ - fn sample(&self, rng: &mut R) -> F { - self.from_zscore(rng.sample(StandardNormal)) - } -} - - -/// The log-normal distribution `ln N(mean, std_dev**2)`. -/// -/// If `X` is log-normal distributed, then `ln(X)` is `N(mean, std_dev**2)` -/// distributed. -/// -/// # Example -/// -/// ``` -/// use rand_distr::{LogNormal, Distribution}; -/// -/// // mean 2, standard deviation 3 -/// let log_normal = LogNormal::new(2.0, 3.0).unwrap(); -/// let v = log_normal.sample(&mut rand::thread_rng()); -/// println!("{} is from an ln N(2, 9) distribution", v) -/// ``` -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct LogNormal -where F: Float, StandardNormal: Distribution -{ - norm: Normal, -} - -impl LogNormal -where F: Float, StandardNormal: Distribution -{ - /// Construct, from (log-space) mean and standard deviation - /// - /// Parameters are the "standard" log-space measures (these are the mean - /// and standard deviation of the logarithm of samples): - /// - /// - `mu` (`μ`, unrestricted) is the mean of the underlying distribution - /// - `sigma` (`σ`, must be finite) is the standard deviation of the - /// underlying Normal distribution - #[inline] - pub fn new(mu: F, sigma: F) -> Result, Error> { - let norm = Normal::new(mu, sigma)?; - Ok(LogNormal { norm }) - } - - /// Construct, from (linear-space) mean and coefficient of variation - /// - /// Parameters are linear-space measures: - /// - /// - mean (`μ > 0`) is the (real) mean of the distribution - /// - coefficient of variation (`cv = σ / μ`, requiring `cv ≥ 0`) is a - /// standardized measure of dispersion - /// - /// As a special exception, `μ = 0, cv = 0` is allowed (samples are `-inf`). - #[inline] - pub fn from_mean_cv(mean: F, cv: F) -> Result, Error> { - if cv == F::zero() { - let mu = mean.ln(); - let norm = Normal::new(mu, F::zero()).unwrap(); - return Ok(LogNormal { norm }); - } - if !(mean > F::zero()) { - return Err(Error::MeanTooSmall); - } - if !(cv >= F::zero()) { - return Err(Error::BadVariance); - } - - // Using X ~ lognormal(μ, σ), CV² = Var(X) / E(X)² - // E(X) = exp(μ + σ² / 2) = exp(μ) × exp(σ² / 2) - // Var(X) = exp(2μ + σ²)(exp(σ²) - 1) = E(X)² × (exp(σ²) - 1) - // but Var(X) = (CV × E(X))² so CV² = exp(σ²) - 1 - // thus σ² = log(CV² + 1) - // and exp(μ) = E(X) / exp(σ² / 2) = E(X) / sqrt(CV² + 1) - let a = F::one() + cv * cv; // e - let mu = F::from(0.5).unwrap() * (mean * mean / a).ln(); - let sigma = a.ln().sqrt(); - let norm = Normal::new(mu, sigma)?; - Ok(LogNormal { norm }) - } - - /// Sample from a z-score - /// - /// This may be useful for generating correlated samples `x1` and `x2` - /// from two different distributions, as follows. - /// ``` - /// # use rand::prelude::*; - /// # use rand_distr::{LogNormal, StandardNormal}; - /// let mut rng = thread_rng(); - /// let z = StandardNormal.sample(&mut rng); - /// let x1 = LogNormal::from_mean_cv(3.0, 1.0).unwrap().from_zscore(z); - /// let x2 = LogNormal::from_mean_cv(2.0, 4.0).unwrap().from_zscore(z); - /// ``` - #[inline] - pub fn from_zscore(&self, zscore: F) -> F { - self.norm.from_zscore(zscore).exp() - } -} - -impl Distribution for LogNormal -where F: Float, StandardNormal: Distribution -{ - #[inline] - fn sample(&self, rng: &mut R) -> F { - self.norm.sample(rng).exp() - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_normal() { - let norm = Normal::new(10.0, 10.0).unwrap(); - let mut rng = crate::test::rng(210); - for _ in 0..1000 { - norm.sample(&mut rng); - } - } - #[test] - fn test_normal_cv() { - let norm = Normal::from_mean_cv(1024.0, 1.0 / 256.0).unwrap(); - assert_eq!((norm.mean, norm.std_dev), (1024.0, 4.0)); - } - #[test] - fn test_normal_invalid_sd() { - assert!(Normal::from_mean_cv(10.0, -1.0).is_err()); - } - - #[test] - fn test_log_normal() { - let lnorm = LogNormal::new(10.0, 10.0).unwrap(); - let mut rng = crate::test::rng(211); - for _ in 0..1000 { - lnorm.sample(&mut rng); - } - } - #[test] - fn test_log_normal_cv() { - let lnorm = LogNormal::from_mean_cv(0.0, 0.0).unwrap(); - assert_eq!((lnorm.norm.mean, lnorm.norm.std_dev), (-core::f64::INFINITY, 0.0)); - - let lnorm = LogNormal::from_mean_cv(1.0, 0.0).unwrap(); - assert_eq!((lnorm.norm.mean, lnorm.norm.std_dev), (0.0, 0.0)); - - let e = core::f64::consts::E; - let lnorm = LogNormal::from_mean_cv(e.sqrt(), (e - 1.0).sqrt()).unwrap(); - assert_almost_eq!(lnorm.norm.mean, 0.0, 2e-16); - assert_almost_eq!(lnorm.norm.std_dev, 1.0, 2e-16); - - let lnorm = LogNormal::from_mean_cv(e.powf(1.5), (e - 1.0).sqrt()).unwrap(); - assert_almost_eq!(lnorm.norm.mean, 1.0, 1e-15); - assert_eq!(lnorm.norm.std_dev, 1.0); - } - #[test] - fn test_log_normal_invalid_sd() { - assert!(LogNormal::from_mean_cv(-1.0, 1.0).is_err()); - assert!(LogNormal::from_mean_cv(0.0, 1.0).is_err()); - assert!(LogNormal::from_mean_cv(1.0, -1.0).is_err()); - } - - #[test] - fn normal_distributions_can_be_compared() { - assert_eq!(Normal::new(1.0, 2.0), Normal::new(1.0, 2.0)); - } - - #[test] - fn log_normal_distributions_can_be_compared() { - assert_eq!(LogNormal::new(1.0, 2.0), LogNormal::new(1.0, 2.0)); - } -} diff --git a/rand_distr/src/normal_inverse_gaussian.rs b/rand_distr/src/normal_inverse_gaussian.rs deleted file mode 100644 index 7c5ad971710..00000000000 --- a/rand_distr/src/normal_inverse_gaussian.rs +++ /dev/null @@ -1,110 +0,0 @@ -use crate::{Distribution, InverseGaussian, Standard, StandardNormal}; -use num_traits::Float; -use rand::Rng; -use core::fmt; - -/// Error type returned from `NormalInverseGaussian::new` -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum Error { - /// `alpha <= 0` or `nan`. - AlphaNegativeOrNull, - /// `|beta| >= alpha` or `nan`. - AbsoluteBetaNotLessThanAlpha, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Error::AlphaNegativeOrNull => "alpha <= 0 or is NaN in normal inverse Gaussian distribution", - Error::AbsoluteBetaNotLessThanAlpha => "|beta| >= alpha or is NaN in normal inverse Gaussian distribution", - }) - } -} - -#[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] -impl std::error::Error for Error {} - -/// The [normal-inverse Gaussian distribution](https://en.wikipedia.org/wiki/Normal-inverse_Gaussian_distribution) -#[derive(Debug, Clone, Copy, PartialEq)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct NormalInverseGaussian -where - F: Float, - StandardNormal: Distribution, - Standard: Distribution, -{ - beta: F, - inverse_gaussian: InverseGaussian, -} - -impl NormalInverseGaussian -where - F: Float, - StandardNormal: Distribution, - Standard: Distribution, -{ - /// Construct a new `NormalInverseGaussian` distribution with the given alpha (tail heaviness) and - /// beta (asymmetry) parameters. - pub fn new(alpha: F, beta: F) -> Result, Error> { - if !(alpha > F::zero()) { - return Err(Error::AlphaNegativeOrNull); - } - - if !(beta.abs() < alpha) { - return Err(Error::AbsoluteBetaNotLessThanAlpha); - } - - let gamma = (alpha * alpha - beta * beta).sqrt(); - - let mu = F::one() / gamma; - - let inverse_gaussian = InverseGaussian::new(mu, F::one()).unwrap(); - - Ok(Self { - beta, - inverse_gaussian, - }) - } -} - -impl Distribution for NormalInverseGaussian -where - F: Float, - StandardNormal: Distribution, - Standard: Distribution, -{ - fn sample(&self, rng: &mut R) -> F - where R: Rng + ?Sized { - let inv_gauss = rng.sample(self.inverse_gaussian); - - self.beta * inv_gauss + inv_gauss.sqrt() * rng.sample(StandardNormal) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_normal_inverse_gaussian() { - let norm_inv_gauss = NormalInverseGaussian::new(2.0, 1.0).unwrap(); - let mut rng = crate::test::rng(210); - for _ in 0..1000 { - norm_inv_gauss.sample(&mut rng); - } - } - - #[test] - fn test_normal_inverse_gaussian_invalid_param() { - assert!(NormalInverseGaussian::new(-1.0, 1.0).is_err()); - assert!(NormalInverseGaussian::new(-1.0, -1.0).is_err()); - assert!(NormalInverseGaussian::new(1.0, 2.0).is_err()); - assert!(NormalInverseGaussian::new(2.0, 1.0).is_ok()); - } - - #[test] - fn normal_inverse_gaussian_distributions_can_be_compared() { - assert_eq!(NormalInverseGaussian::new(1.0, 2.0), NormalInverseGaussian::new(1.0, 2.0)); - } -} diff --git a/rand_distr/src/pareto.rs b/rand_distr/src/pareto.rs deleted file mode 100644 index 25c8e0537dd..00000000000 --- a/rand_distr/src/pareto.rs +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! The Pareto distribution. - -use num_traits::Float; -use crate::{Distribution, OpenClosed01}; -use rand::Rng; -use core::fmt; - -/// Samples floating-point numbers according to the Pareto distribution -/// -/// # Example -/// ``` -/// use rand::prelude::*; -/// use rand_distr::Pareto; -/// -/// let val: f64 = thread_rng().sample(Pareto::new(1., 2.).unwrap()); -/// println!("{}", val); -/// ``` -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct Pareto -where F: Float, OpenClosed01: Distribution -{ - scale: F, - inv_neg_shape: F, -} - -/// Error type returned from `Pareto::new`. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum Error { - /// `scale <= 0` or `nan`. - ScaleTooSmall, - /// `shape <= 0` or `nan`. - ShapeTooSmall, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Error::ScaleTooSmall => "scale is not positive in Pareto distribution", - Error::ShapeTooSmall => "shape is not positive in Pareto distribution", - }) - } -} - -#[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] -impl std::error::Error for Error {} - -impl Pareto -where F: Float, OpenClosed01: Distribution -{ - /// Construct a new Pareto distribution with given `scale` and `shape`. - /// - /// In the literature, `scale` is commonly written as xm or k and - /// `shape` is often written as α. - pub fn new(scale: F, shape: F) -> Result, Error> { - let zero = F::zero(); - - if !(scale > zero) { - return Err(Error::ScaleTooSmall); - } - if !(shape > zero) { - return Err(Error::ShapeTooSmall); - } - Ok(Pareto { - scale, - inv_neg_shape: F::from(-1.0).unwrap() / shape, - }) - } -} - -impl Distribution for Pareto -where F: Float, OpenClosed01: Distribution -{ - fn sample(&self, rng: &mut R) -> F { - let u: F = OpenClosed01.sample(rng); - self.scale * u.powf(self.inv_neg_shape) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use core::fmt::{Debug, Display, LowerExp}; - - #[test] - #[should_panic] - fn invalid() { - Pareto::new(0., 0.).unwrap(); - } - - #[test] - fn sample() { - let scale = 1.0; - let shape = 2.0; - let d = Pareto::new(scale, shape).unwrap(); - let mut rng = crate::test::rng(1); - for _ in 0..1000 { - let r = d.sample(&mut rng); - assert!(r >= scale); - } - } - - #[test] - fn value_stability() { - fn test_samples>( - distr: D, thresh: F, expected: &[F], - ) { - let mut rng = crate::test::rng(213); - for v in expected { - let x = rng.sample(&distr); - assert_almost_eq!(x, *v, thresh); - } - } - - test_samples(Pareto::new(1f32, 1.0).unwrap(), 1e-6, &[ - 1.0423688, 2.1235929, 4.132709, 1.4679428, - ]); - test_samples(Pareto::new(2.0, 0.5).unwrap(), 1e-14, &[ - 9.019295276219136, - 4.3097126018270595, - 6.837815045397157, - 105.8826669383772, - ]); - } - - #[test] - fn pareto_distributions_can_be_compared() { - assert_eq!(Pareto::new(1.0, 2.0), Pareto::new(1.0, 2.0)); - } -} diff --git a/rand_distr/src/pert.rs b/rand_distr/src/pert.rs deleted file mode 100644 index 9ed79bf28ff..00000000000 --- a/rand_distr/src/pert.rs +++ /dev/null @@ -1,154 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. -//! The PERT distribution. - -use num_traits::Float; -use crate::{Beta, Distribution, Exp1, Open01, StandardNormal}; -use rand::Rng; -use core::fmt; - -/// The PERT distribution. -/// -/// Similar to the [`Triangular`] distribution, the PERT distribution is -/// parameterised by a range and a mode within that range. Unlike the -/// [`Triangular`] distribution, the probability density function of the PERT -/// distribution is smooth, with a configurable weighting around the mode. -/// -/// # Example -/// -/// ```rust -/// use rand_distr::{Pert, Distribution}; -/// -/// let d = Pert::new(0., 5., 2.5).unwrap(); -/// let v = d.sample(&mut rand::thread_rng()); -/// println!("{} is from a PERT distribution", v); -/// ``` -/// -/// [`Triangular`]: crate::Triangular -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct Pert -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - min: F, - range: F, - beta: Beta, -} - -/// Error type returned from [`Pert`] constructors. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum PertError { - /// `max < min` or `min` or `max` is NaN. - RangeTooSmall, - /// `mode < min` or `mode > max` or `mode` is NaN. - ModeRange, - /// `shape < 0` or `shape` is NaN - ShapeTooSmall, -} - -impl fmt::Display for PertError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - PertError::RangeTooSmall => "requirement min < max is not met in PERT distribution", - PertError::ModeRange => "mode is outside [min, max] in PERT distribution", - PertError::ShapeTooSmall => "shape < 0 or is NaN in PERT distribution", - }) - } -} - -#[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] -impl std::error::Error for PertError {} - -impl Pert -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - /// Set up the PERT distribution with defined `min`, `max` and `mode`. - /// - /// This is equivalent to calling `Pert::new_with_shape` with `shape == 4.0`. - #[inline] - pub fn new(min: F, max: F, mode: F) -> Result, PertError> { - Pert::new_with_shape(min, max, mode, F::from(4.).unwrap()) - } - - /// Set up the PERT distribution with defined `min`, `max`, `mode` and - /// `shape`. - pub fn new_with_shape(min: F, max: F, mode: F, shape: F) -> Result, PertError> { - if !(max > min) { - return Err(PertError::RangeTooSmall); - } - if !(mode >= min && max >= mode) { - return Err(PertError::ModeRange); - } - if !(shape >= F::from(0.).unwrap()) { - return Err(PertError::ShapeTooSmall); - } - - let range = max - min; - let mu = (min + max + shape * mode) / (shape + F::from(2.).unwrap()); - let v = if mu == mode { - shape * F::from(0.5).unwrap() + F::from(1.).unwrap() - } else { - (mu - min) * (F::from(2.).unwrap() * mode - min - max) / ((mode - mu) * (max - min)) - }; - let w = v * (max - mu) / (mu - min); - let beta = Beta::new(v, w).map_err(|_| PertError::RangeTooSmall)?; - Ok(Pert { min, range, beta }) - } -} - -impl Distribution for Pert -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - #[inline] - fn sample(&self, rng: &mut R) -> F { - self.beta.sample(rng) * self.range + self.min - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_pert() { - for &(min, max, mode) in &[ - (-1., 1., 0.), - (1., 2., 1.), - (5., 25., 25.), - ] { - let _distr = Pert::new(min, max, mode).unwrap(); - // TODO: test correctness - } - - for &(min, max, mode) in &[ - (-1., 1., 2.), - (-1., 1., -2.), - (2., 1., 1.), - ] { - assert!(Pert::new(min, max, mode).is_err()); - } - } - - #[test] - fn pert_distributions_can_be_compared() { - assert_eq!(Pert::new(1.0, 3.0, 2.0), Pert::new(1.0, 3.0, 2.0)); - } -} diff --git a/rand_distr/src/poisson.rs b/rand_distr/src/poisson.rs deleted file mode 100644 index 50d74298356..00000000000 --- a/rand_distr/src/poisson.rs +++ /dev/null @@ -1,203 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// Copyright 2016-2017 The Rust Project Developers. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! The Poisson distribution. - -use num_traits::{Float, FloatConst}; -use crate::{Cauchy, Distribution, Standard}; -use rand::Rng; -use core::fmt; - -/// The Poisson distribution `Poisson(lambda)`. -/// -/// This distribution has a density function: -/// `f(k) = lambda^k * exp(-lambda) / k!` for `k >= 0`. -/// -/// # Example -/// -/// ``` -/// use rand_distr::{Poisson, Distribution}; -/// -/// let poi = Poisson::new(2.0).unwrap(); -/// let v = poi.sample(&mut rand::thread_rng()); -/// println!("{} is from a Poisson(2) distribution", v); -/// ``` -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct Poisson -where F: Float + FloatConst, Standard: Distribution -{ - lambda: F, - // precalculated values - exp_lambda: F, - log_lambda: F, - sqrt_2lambda: F, - magic_val: F, -} - -/// Error type returned from `Poisson::new`. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum Error { - /// `lambda <= 0` - ShapeTooSmall, - /// `lambda = ∞` or `lambda = nan` - NonFinite, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Error::ShapeTooSmall => "lambda is not positive in Poisson distribution", - Error::NonFinite => "lambda is infinite or nan in Poisson distribution", - }) - } -} - -#[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] -impl std::error::Error for Error {} - -impl Poisson -where F: Float + FloatConst, Standard: Distribution -{ - /// Construct a new `Poisson` with the given shape parameter - /// `lambda`. - pub fn new(lambda: F) -> Result, Error> { - if !lambda.is_finite() { - return Err(Error::NonFinite); - } - if !(lambda > F::zero()) { - return Err(Error::ShapeTooSmall); - } - let log_lambda = lambda.ln(); - Ok(Poisson { - lambda, - exp_lambda: (-lambda).exp(), - log_lambda, - sqrt_2lambda: (F::from(2.0).unwrap() * lambda).sqrt(), - magic_val: lambda * log_lambda - crate::utils::log_gamma(F::one() + lambda), - }) - } -} - -impl Distribution for Poisson -where F: Float + FloatConst, Standard: Distribution -{ - #[inline] - fn sample(&self, rng: &mut R) -> F { - // using the algorithm from Numerical Recipes in C - - // for low expected values use the Knuth method - if self.lambda < F::from(12.0).unwrap() { - let mut result = F::one(); - let mut p = rng.gen::(); - while p > self.exp_lambda { - p = p*rng.gen::(); - result = result + F::one(); - } - result - F::one() - } - // high expected values - rejection method - else { - // we use the Cauchy distribution as the comparison distribution - // f(x) ~ 1/(1+x^2) - let cauchy = Cauchy::new(F::zero(), F::one()).unwrap(); - let mut result; - - loop { - let mut comp_dev; - - loop { - // draw from the Cauchy distribution - comp_dev = rng.sample(cauchy); - // shift the peak of the comparison distribution - result = self.sqrt_2lambda * comp_dev + self.lambda; - // repeat the drawing until we are in the range of possible values - if result >= F::zero() { - break; - } - } - // now the result is a random variable greater than 0 with Cauchy distribution - // the result should be an integer value - result = result.floor(); - - // this is the ratio of the Poisson distribution to the comparison distribution - // the magic value scales the distribution function to a range of approximately 0-1 - // since it is not exact, we multiply the ratio by 0.9 to avoid ratios greater than 1 - // this doesn't change the resulting distribution, only increases the rate of failed drawings - let check = F::from(0.9).unwrap() - * (F::one() + comp_dev * comp_dev) - * (result * self.log_lambda - - crate::utils::log_gamma(F::one() + result) - - self.magic_val) - .exp(); - - // check with uniform random value - if below the threshold, we are within the target distribution - if rng.gen::() <= check { - break; - } - } - result - } - } -} - -#[cfg(test)] -mod test { - use super::*; - - fn test_poisson_avg_gen(lambda: F, tol: F) - where Standard: Distribution - { - let poisson = Poisson::new(lambda).unwrap(); - let mut rng = crate::test::rng(123); - let mut sum = F::zero(); - for _ in 0..1000 { - sum = sum + poisson.sample(&mut rng); - } - let avg = sum / F::from(1000.0).unwrap(); - assert!((avg - lambda).abs() < tol); - } - - #[test] - fn test_poisson_avg() { - test_poisson_avg_gen::(10.0, 0.1); - test_poisson_avg_gen::(15.0, 0.1); - - test_poisson_avg_gen::(10.0, 0.1); - test_poisson_avg_gen::(15.0, 0.1); - - //Small lambda will use Knuth's method with exp_lambda == 1.0 - test_poisson_avg_gen::(0.00000000000000005, 0.1); - test_poisson_avg_gen::(0.00000000000000005, 0.1); - } - - #[test] - #[should_panic] - fn test_poisson_invalid_lambda_zero() { - Poisson::new(0.0).unwrap(); - } - - #[test] - #[should_panic] - fn test_poisson_invalid_lambda_infinity() { - Poisson::new(f64::INFINITY).unwrap(); - } - - #[test] - #[should_panic] - fn test_poisson_invalid_lambda_neg() { - Poisson::new(-10.0).unwrap(); - } - - #[test] - fn poisson_distributions_can_be_compared() { - assert_eq!(Poisson::new(1.0), Poisson::new(1.0)); - } -} diff --git a/rand_distr/src/skew_normal.rs b/rand_distr/src/skew_normal.rs deleted file mode 100644 index 29ba413a0ac..00000000000 --- a/rand_distr/src/skew_normal.rs +++ /dev/null @@ -1,261 +0,0 @@ -// Copyright 2021 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! The Skew Normal distribution. - -use crate::{Distribution, StandardNormal}; -use core::fmt; -use num_traits::Float; -use rand::Rng; - -/// The [skew normal distribution] `SN(location, scale, shape)`. -/// -/// The skew normal distribution is a generalization of the -/// [`Normal`] distribution to allow for non-zero skewness. -/// -/// It has the density function, for `scale > 0`, -/// `f(x) = 2 / scale * phi((x - location) / scale) * Phi(alpha * (x - location) / scale)` -/// where `phi` and `Phi` are the density and distribution of a standard normal variable. -/// -/// # Example -/// -/// ``` -/// use rand_distr::{SkewNormal, Distribution}; -/// -/// // location 2, scale 3, shape 1 -/// let skew_normal = SkewNormal::new(2.0, 3.0, 1.0).unwrap(); -/// let v = skew_normal.sample(&mut rand::thread_rng()); -/// println!("{} is from a SN(2, 3, 1) distribution", v) -/// ``` -/// -/// # Implementation details -/// -/// We are using the algorithm from [A Method to Simulate the Skew Normal Distribution]. -/// -/// [skew normal distribution]: https://en.wikipedia.org/wiki/Skew_normal_distribution -/// [`Normal`]: struct.Normal.html -/// [A Method to Simulate the Skew Normal Distribution]: https://dx.doi.org/10.4236/am.2014.513201 -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct SkewNormal -where - F: Float, - StandardNormal: Distribution, -{ - location: F, - scale: F, - shape: F, -} - -/// Error type returned from `SkewNormal::new`. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum Error { - /// The scale parameter is not finite or it is less or equal to zero. - ScaleTooSmall, - /// The shape parameter is not finite. - BadShape, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Error::ScaleTooSmall => { - "scale parameter is either non-finite or it is less or equal to zero in skew normal distribution" - } - Error::BadShape => "shape parameter is non-finite in skew normal distribution", - }) - } -} - -#[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] -impl std::error::Error for Error {} - -impl SkewNormal -where - F: Float, - StandardNormal: Distribution, -{ - /// Construct, from location, scale and shape. - /// - /// Parameters: - /// - /// - location (unrestricted) - /// - scale (must be finite and larger than zero) - /// - shape (must be finite) - #[inline] - pub fn new(location: F, scale: F, shape: F) -> Result, Error> { - if !scale.is_finite() || !(scale > F::zero()) { - return Err(Error::ScaleTooSmall); - } - if !shape.is_finite() { - return Err(Error::BadShape); - } - Ok(SkewNormal { - location, - scale, - shape, - }) - } - - /// Returns the location of the distribution. - pub fn location(&self) -> F { - self.location - } - - /// Returns the scale of the distribution. - pub fn scale(&self) -> F { - self.scale - } - - /// Returns the shape of the distribution. - pub fn shape(&self) -> F { - self.shape - } -} - -impl Distribution for SkewNormal -where - F: Float, - StandardNormal: Distribution, -{ - fn sample(&self, rng: &mut R) -> F { - let linear_map = |x: F| -> F { x * self.scale + self.location }; - let u_1: F = rng.sample(StandardNormal); - if self.shape == F::zero() { - linear_map(u_1) - } else { - let u_2 = rng.sample(StandardNormal); - let (u, v) = (u_1.max(u_2), u_1.min(u_2)); - if self.shape == -F::one() { - linear_map(v) - } else if self.shape == F::one() { - linear_map(u) - } else { - let normalized = ((F::one() + self.shape) * u + (F::one() - self.shape) * v) - / ((F::one() + self.shape * self.shape).sqrt() - * F::from(core::f64::consts::SQRT_2).unwrap()); - linear_map(normalized) - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - fn test_samples>( - distr: D, zero: F, expected: &[F], - ) { - let mut rng = crate::test::rng(213); - let mut buf = [zero; 4]; - for x in &mut buf { - *x = rng.sample(&distr); - } - assert_eq!(buf, expected); - } - - #[test] - #[should_panic] - fn invalid_scale_nan() { - SkewNormal::new(0.0, core::f64::NAN, 0.0).unwrap(); - } - - #[test] - #[should_panic] - fn invalid_scale_zero() { - SkewNormal::new(0.0, 0.0, 0.0).unwrap(); - } - - #[test] - #[should_panic] - fn invalid_scale_negative() { - SkewNormal::new(0.0, -1.0, 0.0).unwrap(); - } - - #[test] - #[should_panic] - fn invalid_scale_infinite() { - SkewNormal::new(0.0, core::f64::INFINITY, 0.0).unwrap(); - } - - #[test] - #[should_panic] - fn invalid_shape_nan() { - SkewNormal::new(0.0, 1.0, core::f64::NAN).unwrap(); - } - - #[test] - #[should_panic] - fn invalid_shape_infinite() { - SkewNormal::new(0.0, 1.0, core::f64::INFINITY).unwrap(); - } - - #[test] - fn valid_location_nan() { - SkewNormal::new(core::f64::NAN, 1.0, 0.0).unwrap(); - } - - #[test] - fn skew_normal_value_stability() { - test_samples( - SkewNormal::new(0.0, 1.0, 0.0).unwrap(), - 0f32, - &[-0.11844189, 0.781378, 0.06563994, -1.1932899], - ); - test_samples( - SkewNormal::new(0.0, 1.0, 0.0).unwrap(), - 0f64, - &[ - -0.11844188827977231, - 0.7813779637772346, - 0.06563993969580051, - -1.1932899004186373, - ], - ); - test_samples( - SkewNormal::new(core::f64::INFINITY, 1.0, 0.0).unwrap(), - 0f64, - &[ - core::f64::INFINITY, - core::f64::INFINITY, - core::f64::INFINITY, - core::f64::INFINITY, - ], - ); - test_samples( - SkewNormal::new(core::f64::NEG_INFINITY, 1.0, 0.0).unwrap(), - 0f64, - &[ - core::f64::NEG_INFINITY, - core::f64::NEG_INFINITY, - core::f64::NEG_INFINITY, - core::f64::NEG_INFINITY, - ], - ); - } - - #[test] - fn skew_normal_value_location_nan() { - let skew_normal = SkewNormal::new(core::f64::NAN, 1.0, 0.0).unwrap(); - let mut rng = crate::test::rng(213); - let mut buf = [0.0; 4]; - for x in &mut buf { - *x = rng.sample(skew_normal); - } - for value in buf.iter() { - assert!(value.is_nan()); - } - } - - #[test] - fn skew_normal_distributions_can_be_compared() { - assert_eq!(SkewNormal::new(1.0, 2.0, 3.0), SkewNormal::new(1.0, 2.0, 3.0)); - } -} diff --git a/rand_distr/src/triangular.rs b/rand_distr/src/triangular.rs deleted file mode 100644 index eef7d190133..00000000000 --- a/rand_distr/src/triangular.rs +++ /dev/null @@ -1,138 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. -//! The triangular distribution. - -use num_traits::Float; -use crate::{Distribution, Standard}; -use rand::Rng; -use core::fmt; - -/// The triangular distribution. -/// -/// A continuous probability distribution parameterised by a range, and a mode -/// (most likely value) within that range. -/// -/// The probability density function is triangular. For a similar distribution -/// with a smooth PDF, see the [`Pert`] distribution. -/// -/// # Example -/// -/// ```rust -/// use rand_distr::{Triangular, Distribution}; -/// -/// let d = Triangular::new(0., 5., 2.5).unwrap(); -/// let v = d.sample(&mut rand::thread_rng()); -/// println!("{} is from a triangular distribution", v); -/// ``` -/// -/// [`Pert`]: crate::Pert -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct Triangular -where F: Float, Standard: Distribution -{ - min: F, - max: F, - mode: F, -} - -/// Error type returned from [`Triangular::new`]. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum TriangularError { - /// `max < min` or `min` or `max` is NaN. - RangeTooSmall, - /// `mode < min` or `mode > max` or `mode` is NaN. - ModeRange, -} - -impl fmt::Display for TriangularError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - TriangularError::RangeTooSmall => { - "requirement min <= max is not met in triangular distribution" - } - TriangularError::ModeRange => "mode is outside [min, max] in triangular distribution", - }) - } -} - -#[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] -impl std::error::Error for TriangularError {} - -impl Triangular -where F: Float, Standard: Distribution -{ - /// Set up the Triangular distribution with defined `min`, `max` and `mode`. - #[inline] - pub fn new(min: F, max: F, mode: F) -> Result, TriangularError> { - if !(max >= min) { - return Err(TriangularError::RangeTooSmall); - } - if !(mode >= min && max >= mode) { - return Err(TriangularError::ModeRange); - } - Ok(Triangular { min, max, mode }) - } -} - -impl Distribution for Triangular -where F: Float, Standard: Distribution -{ - #[inline] - fn sample(&self, rng: &mut R) -> F { - let f: F = rng.sample(Standard); - let diff_mode_min = self.mode - self.min; - let range = self.max - self.min; - let f_range = f * range; - if f_range < diff_mode_min { - self.min + (f_range * diff_mode_min).sqrt() - } else { - self.max - ((range - f_range) * (self.max - self.mode)).sqrt() - } - } -} - -#[cfg(test)] -mod test { - use super::*; - use rand::{rngs::mock, Rng}; - - #[test] - fn test_triangular() { - let mut half_rng = mock::StepRng::new(0x8000_0000_0000_0000, 0); - assert_eq!(half_rng.gen::(), 0.5); - for &(min, max, mode, median) in &[ - (-1., 1., 0., 0.), - (1., 2., 1., 2. - 0.5f64.sqrt()), - (5., 25., 25., 5. + 200f64.sqrt()), - (1e-5, 1e5, 1e-3, 1e5 - 4999999949.5f64.sqrt()), - (0., 1., 0.9, 0.45f64.sqrt()), - (-4., -0.5, -2., -4.0 + 3.5f64.sqrt()), - ] { - #[cfg(feature = "std")] - std::println!("{} {} {} {}", min, max, mode, median); - let distr = Triangular::new(min, max, mode).unwrap(); - // Test correct value at median: - assert_eq!(distr.sample(&mut half_rng), median); - } - - for &(min, max, mode) in &[ - (-1., 1., 2.), - (-1., 1., -2.), - (2., 1., 1.), - ] { - assert!(Triangular::new(min, max, mode).is_err()); - } - } - - #[test] - fn triangular_distributions_can_be_compared() { - assert_eq!(Triangular::new(1.0, 3.0, 2.0), Triangular::new(1.0, 3.0, 2.0)); - } -} diff --git a/rand_distr/src/unit_ball.rs b/rand_distr/src/unit_ball.rs deleted file mode 100644 index 4d29612597f..00000000000 --- a/rand_distr/src/unit_ball.rs +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright 2019 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -use num_traits::Float; -use crate::{uniform::SampleUniform, Distribution, Uniform}; -use rand::Rng; - -/// Samples uniformly from the unit ball (surface and interior) in three -/// dimensions. -/// -/// Implemented via rejection sampling. -/// -/// -/// # Example -/// -/// ``` -/// use rand_distr::{UnitBall, Distribution}; -/// -/// let v: [f64; 3] = UnitBall.sample(&mut rand::thread_rng()); -/// println!("{:?} is from the unit ball.", v) -/// ``` -#[derive(Clone, Copy, Debug)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct UnitBall; - -impl Distribution<[F; 3]> for UnitBall { - #[inline] - fn sample(&self, rng: &mut R) -> [F; 3] { - let uniform = Uniform::new(F::from(-1.).unwrap(), F::from(1.).unwrap()).unwrap(); - let mut x1; - let mut x2; - let mut x3; - loop { - x1 = uniform.sample(rng); - x2 = uniform.sample(rng); - x3 = uniform.sample(rng); - if x1 * x1 + x2 * x2 + x3 * x3 <= F::from(1.).unwrap() { - break; - } - } - [x1, x2, x3] - } -} diff --git a/rand_distr/src/unit_circle.rs b/rand_distr/src/unit_circle.rs deleted file mode 100644 index f3dbe757aa9..00000000000 --- a/rand_distr/src/unit_circle.rs +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -use num_traits::Float; -use crate::{uniform::SampleUniform, Distribution, Uniform}; -use rand::Rng; - -/// Samples uniformly from the edge of the unit circle in two dimensions. -/// -/// Implemented via a method by von Neumann[^1]. -/// -/// -/// # Example -/// -/// ``` -/// use rand_distr::{UnitCircle, Distribution}; -/// -/// let v: [f64; 2] = UnitCircle.sample(&mut rand::thread_rng()); -/// println!("{:?} is from the unit circle.", v) -/// ``` -/// -/// [^1]: von Neumann, J. (1951) [*Various Techniques Used in Connection with -/// Random Digits.*](https://mcnp.lanl.gov/pdf_files/nbs_vonneumann.pdf) -/// NBS Appl. Math. Ser., No. 12. Washington, DC: U.S. Government Printing -/// Office, pp. 36-38. -#[derive(Clone, Copy, Debug)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct UnitCircle; - -impl Distribution<[F; 2]> for UnitCircle { - #[inline] - fn sample(&self, rng: &mut R) -> [F; 2] { - let uniform = Uniform::new(F::from(-1.).unwrap(), F::from(1.).unwrap()).unwrap(); - let mut x1; - let mut x2; - let mut sum; - loop { - x1 = uniform.sample(rng); - x2 = uniform.sample(rng); - sum = x1 * x1 + x2 * x2; - if sum < F::from(1.).unwrap() { - break; - } - } - let diff = x1 * x1 - x2 * x2; - [diff / sum, F::from(2.).unwrap() * x1 * x2 / sum] - } -} - -#[cfg(test)] -mod tests { - use super::UnitCircle; - use crate::Distribution; - - #[test] - fn norm() { - let mut rng = crate::test::rng(1); - for _ in 0..1000 { - let x: [f64; 2] = UnitCircle.sample(&mut rng); - assert_almost_eq!(x[0] * x[0] + x[1] * x[1], 1., 1e-15); - } - } -} diff --git a/rand_distr/src/unit_disc.rs b/rand_distr/src/unit_disc.rs deleted file mode 100644 index 5004217d5b7..00000000000 --- a/rand_distr/src/unit_disc.rs +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright 2019 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -use num_traits::Float; -use crate::{uniform::SampleUniform, Distribution, Uniform}; -use rand::Rng; - -/// Samples uniformly from the unit disc in two dimensions. -/// -/// Implemented via rejection sampling. -/// -/// -/// # Example -/// -/// ``` -/// use rand_distr::{UnitDisc, Distribution}; -/// -/// let v: [f64; 2] = UnitDisc.sample(&mut rand::thread_rng()); -/// println!("{:?} is from the unit Disc.", v) -/// ``` -#[derive(Clone, Copy, Debug)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct UnitDisc; - -impl Distribution<[F; 2]> for UnitDisc { - #[inline] - fn sample(&self, rng: &mut R) -> [F; 2] { - let uniform = Uniform::new(F::from(-1.).unwrap(), F::from(1.).unwrap()).unwrap(); - let mut x1; - let mut x2; - loop { - x1 = uniform.sample(rng); - x2 = uniform.sample(rng); - if x1 * x1 + x2 * x2 <= F::from(1.).unwrap() { - break; - } - } - [x1, x2] - } -} diff --git a/rand_distr/src/unit_sphere.rs b/rand_distr/src/unit_sphere.rs deleted file mode 100644 index 632275e3327..00000000000 --- a/rand_distr/src/unit_sphere.rs +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright 2018-2019 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -use num_traits::Float; -use crate::{uniform::SampleUniform, Distribution, Uniform}; -use rand::Rng; - -/// Samples uniformly from the surface of the unit sphere in three dimensions. -/// -/// Implemented via a method by Marsaglia[^1]. -/// -/// -/// # Example -/// -/// ``` -/// use rand_distr::{UnitSphere, Distribution}; -/// -/// let v: [f64; 3] = UnitSphere.sample(&mut rand::thread_rng()); -/// println!("{:?} is from the unit sphere surface.", v) -/// ``` -/// -/// [^1]: Marsaglia, George (1972). [*Choosing a Point from the Surface of a -/// Sphere.*](https://doi.org/10.1214/aoms/1177692644) -/// Ann. Math. Statist. 43, no. 2, 645--646. -#[derive(Clone, Copy, Debug)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct UnitSphere; - -impl Distribution<[F; 3]> for UnitSphere { - #[inline] - fn sample(&self, rng: &mut R) -> [F; 3] { - let uniform = Uniform::new(F::from(-1.).unwrap(), F::from(1.).unwrap()).unwrap(); - loop { - let (x1, x2) = (uniform.sample(rng), uniform.sample(rng)); - let sum = x1 * x1 + x2 * x2; - if sum >= F::from(1.).unwrap() { - continue; - } - let factor = F::from(2.).unwrap() * (F::one() - sum).sqrt(); - return [x1 * factor, x2 * factor, F::from(1.).unwrap() - F::from(2.).unwrap() * sum]; - } - } -} - -#[cfg(test)] -mod tests { - use super::UnitSphere; - use crate::Distribution; - - #[test] - fn norm() { - let mut rng = crate::test::rng(1); - for _ in 0..1000 { - let x: [f64; 3] = UnitSphere.sample(&mut rng); - assert_almost_eq!(x[0] * x[0] + x[1] * x[1] + x[2] * x[2], 1., 1e-15); - } - } -} diff --git a/rand_distr/src/utils.rs b/rand_distr/src/utils.rs deleted file mode 100644 index 4638e3623d2..00000000000 --- a/rand_distr/src/utils.rs +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! Math helper functions - -use crate::ziggurat_tables; -use rand::distributions::hidden_export::IntoFloat; -use rand::Rng; -use num_traits::Float; - -/// Calculates ln(gamma(x)) (natural logarithm of the gamma -/// function) using the Lanczos approximation. -/// -/// The approximation expresses the gamma function as: -/// `gamma(z+1) = sqrt(2*pi)*(z+g+0.5)^(z+0.5)*exp(-z-g-0.5)*Ag(z)` -/// `g` is an arbitrary constant; we use the approximation with `g=5`. -/// -/// Noting that `gamma(z+1) = z*gamma(z)` and applying `ln` to both sides: -/// `ln(gamma(z)) = (z+0.5)*ln(z+g+0.5)-(z+g+0.5) + ln(sqrt(2*pi)*Ag(z)/z)` -/// -/// `Ag(z)` is an infinite series with coefficients that can be calculated -/// ahead of time - we use just the first 6 terms, which is good enough -/// for most purposes. -pub(crate) fn log_gamma(x: F) -> F { - // precalculated 6 coefficients for the first 6 terms of the series - let coefficients: [F; 6] = [ - F::from(76.18009172947146).unwrap(), - F::from(-86.50532032941677).unwrap(), - F::from(24.01409824083091).unwrap(), - F::from(-1.231739572450155).unwrap(), - F::from(0.1208650973866179e-2).unwrap(), - F::from(-0.5395239384953e-5).unwrap(), - ]; - - // (x+0.5)*ln(x+g+0.5)-(x+g+0.5) - let tmp = x + F::from(5.5).unwrap(); - let log = (x + F::from(0.5).unwrap()) * tmp.ln() - tmp; - - // the first few terms of the series for Ag(x) - let mut a = F::from(1.000000000190015).unwrap(); - let mut denom = x; - for &coeff in &coefficients { - denom = denom + F::one(); - a = a + (coeff / denom); - } - - // get everything together - // a is Ag(x) - // 2.5066... is sqrt(2pi) - log + (F::from(2.5066282746310005).unwrap() * a / x).ln() -} - -/// Sample a random number using the Ziggurat method (specifically the -/// ZIGNOR variant from Doornik 2005). Most of the arguments are -/// directly from the paper: -/// -/// * `rng`: source of randomness -/// * `symmetric`: whether this is a symmetric distribution, or one-sided with P(x < 0) = 0. -/// * `X`: the $x_i$ abscissae. -/// * `F`: precomputed values of the PDF at the $x_i$, (i.e. $f(x_i)$) -/// * `F_DIFF`: precomputed values of $f(x_i) - f(x_{i+1})$ -/// * `pdf`: the probability density function -/// * `zero_case`: manual sampling from the tail when we chose the -/// bottom box (i.e. i == 0) - -// the perf improvement (25-50%) is definitely worth the extra code -// size from force-inlining. -#[inline(always)] -pub(crate) fn ziggurat( - rng: &mut R, - symmetric: bool, - x_tab: ziggurat_tables::ZigTable, - f_tab: ziggurat_tables::ZigTable, - mut pdf: P, - mut zero_case: Z -) -> f64 -where - P: FnMut(f64) -> f64, - Z: FnMut(&mut R, f64) -> f64, -{ - loop { - // As an optimisation we re-implement the conversion to a f64. - // From the remaining 12 most significant bits we use 8 to construct `i`. - // This saves us generating a whole extra random number, while the added - // precision of using 64 bits for f64 does not buy us much. - let bits = rng.next_u64(); - let i = bits as usize & 0xff; - - let u = if symmetric { - // Convert to a value in the range [2,4) and subtract to get [-1,1) - // We can't convert to an open range directly, that would require - // subtracting `3.0 - EPSILON`, which is not representable. - // It is possible with an extra step, but an open range does not - // seem necessary for the ziggurat algorithm anyway. - (bits >> 12).into_float_with_exponent(1) - 3.0 - } else { - // Convert to a value in the range [1,2) and subtract to get (0,1) - (bits >> 12).into_float_with_exponent(0) - (1.0 - core::f64::EPSILON / 2.0) - }; - let x = u * x_tab[i]; - - let test_x = if symmetric { x.abs() } else { x }; - - // algebraically equivalent to |u| < x_tab[i+1]/x_tab[i] (or u < x_tab[i+1]/x_tab[i]) - if test_x < x_tab[i + 1] { - return x; - } - if i == 0 { - return zero_case(rng, u); - } - // algebraically equivalent to f1 + DRanU()*(f0 - f1) < 1 - if f_tab[i + 1] + (f_tab[i] - f_tab[i + 1]) * rng.gen::() < pdf(x) { - return x; - } - } -} diff --git a/rand_distr/src/weibull.rs b/rand_distr/src/weibull.rs deleted file mode 100644 index fe45eff6613..00000000000 --- a/rand_distr/src/weibull.rs +++ /dev/null @@ -1,137 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! The Weibull distribution. - -use num_traits::Float; -use crate::{Distribution, OpenClosed01}; -use rand::Rng; -use core::fmt; - -/// Samples floating-point numbers according to the Weibull distribution -/// -/// # Example -/// ``` -/// use rand::prelude::*; -/// use rand_distr::Weibull; -/// -/// let val: f64 = thread_rng().sample(Weibull::new(1., 10.).unwrap()); -/// println!("{}", val); -/// ``` -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct Weibull -where F: Float, OpenClosed01: Distribution -{ - inv_shape: F, - scale: F, -} - -/// Error type returned from `Weibull::new`. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum Error { - /// `scale <= 0` or `nan`. - ScaleTooSmall, - /// `shape <= 0` or `nan`. - ShapeTooSmall, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Error::ScaleTooSmall => "scale is not positive in Weibull distribution", - Error::ShapeTooSmall => "shape is not positive in Weibull distribution", - }) - } -} - -#[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] -impl std::error::Error for Error {} - -impl Weibull -where F: Float, OpenClosed01: Distribution -{ - /// Construct a new `Weibull` distribution with given `scale` and `shape`. - pub fn new(scale: F, shape: F) -> Result, Error> { - if !(scale > F::zero()) { - return Err(Error::ScaleTooSmall); - } - if !(shape > F::zero()) { - return Err(Error::ShapeTooSmall); - } - Ok(Weibull { - inv_shape: F::from(1.).unwrap() / shape, - scale, - }) - } -} - -impl Distribution for Weibull -where F: Float, OpenClosed01: Distribution -{ - fn sample(&self, rng: &mut R) -> F { - let x: F = rng.sample(OpenClosed01); - self.scale * (-x.ln()).powf(self.inv_shape) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - #[should_panic] - fn invalid() { - Weibull::new(0., 0.).unwrap(); - } - - #[test] - fn sample() { - let scale = 1.0; - let shape = 2.0; - let d = Weibull::new(scale, shape).unwrap(); - let mut rng = crate::test::rng(1); - for _ in 0..1000 { - let r = d.sample(&mut rng); - assert!(r >= 0.); - } - } - - #[test] - fn value_stability() { - fn test_samples>( - distr: D, zero: F, expected: &[F], - ) { - let mut rng = crate::test::rng(213); - let mut buf = [zero; 4]; - for x in &mut buf { - *x = rng.sample(&distr); - } - assert_eq!(buf, expected); - } - - test_samples(Weibull::new(1.0, 1.0).unwrap(), 0f32, &[ - 0.041495778, - 0.7531094, - 1.4189332, - 0.38386202, - ]); - test_samples(Weibull::new(2.0, 0.5).unwrap(), 0f64, &[ - 1.1343478702739669, - 0.29470010050655226, - 0.7556151370284702, - 7.877212340241561, - ]); - } - - #[test] - fn weibull_distributions_can_be_compared() { - assert_eq!(Weibull::new(1.0, 2.0), Weibull::new(1.0, 2.0)); - } -} diff --git a/rand_distr/src/weighted_alias.rs b/rand_distr/src/weighted_alias.rs deleted file mode 100644 index 236e2ad734b..00000000000 --- a/rand_distr/src/weighted_alias.rs +++ /dev/null @@ -1,521 +0,0 @@ -// Copyright 2019 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! This module contains an implementation of alias method for sampling random -//! indices with probabilities proportional to a collection of weights. - -use super::WeightError; -use crate::{uniform::SampleUniform, Distribution, Uniform}; -use core::fmt; -use core::iter::Sum; -use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign}; -use rand::Rng; -use alloc::{boxed::Box, vec, vec::Vec}; -#[cfg(feature = "serde1")] -use serde::{Serialize, Deserialize}; - -/// A distribution using weighted sampling to pick a discretely selected item. -/// -/// Sampling a [`WeightedAliasIndex`] distribution returns the index of a randomly -/// selected element from the vector used to create the [`WeightedAliasIndex`]. -/// The chance of a given element being picked is proportional to the value of -/// the element. The weights can have any type `W` for which a implementation of -/// [`AliasableWeight`] exists. -/// -/// # Performance -/// -/// Given that `n` is the number of items in the vector used to create an -/// [`WeightedAliasIndex`], it will require `O(n)` amount of memory. -/// More specifically it takes up some constant amount of memory plus -/// the vector used to create it and a [`Vec`] with capacity `n`. -/// -/// Time complexity for the creation of a [`WeightedAliasIndex`] is `O(n)`. -/// Sampling is `O(1)`, it makes a call to [`Uniform::sample`] and a call -/// to [`Uniform::sample`]. -/// -/// # Example -/// -/// ``` -/// use rand_distr::WeightedAliasIndex; -/// use rand::prelude::*; -/// -/// let choices = vec!['a', 'b', 'c']; -/// let weights = vec![2, 1, 1]; -/// let dist = WeightedAliasIndex::new(weights).unwrap(); -/// let mut rng = thread_rng(); -/// for _ in 0..100 { -/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' -/// println!("{}", choices[dist.sample(&mut rng)]); -/// } -/// -/// let items = [('a', 0), ('b', 3), ('c', 7)]; -/// let dist2 = WeightedAliasIndex::new(items.iter().map(|item| item.1).collect()).unwrap(); -/// for _ in 0..100 { -/// // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c' -/// println!("{}", items[dist2.sample(&mut rng)].0); -/// } -/// ``` -/// -/// [`WeightedAliasIndex`]: WeightedAliasIndex -/// [`Vec`]: Vec -/// [`Uniform::sample`]: Distribution::sample -/// [`Uniform::sample`]: Distribution::sample -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "serde1", serde(bound(serialize = "W: Serialize, W::Sampler: Serialize")))] -#[cfg_attr(feature = "serde1", serde(bound(deserialize = "W: Deserialize<'de>, W::Sampler: Deserialize<'de>")))] -pub struct WeightedAliasIndex { - aliases: Box<[u32]>, - no_alias_odds: Box<[W]>, - uniform_index: Uniform, - uniform_within_weight_sum: Uniform, -} - -impl WeightedAliasIndex { - /// Creates a new [`WeightedAliasIndex`]. - /// - /// Error cases: - /// - [`WeightError::InvalidInput`] when `weights.len()` is zero or greater than `u32::MAX`. - /// - [`WeightError::InvalidWeight`] when a weight is not-a-number, - /// negative or greater than `max = W::MAX / weights.len()`. - /// - [`WeightError::InsufficientNonZero`] when the sum of all weights is zero. - pub fn new(weights: Vec) -> Result { - let n = weights.len(); - if n == 0 || n > ::core::u32::MAX as usize { - return Err(WeightError::InvalidInput); - } - let n = n as u32; - - let max_weight_size = W::try_from_u32_lossy(n) - .map(|n| W::MAX / n) - .unwrap_or(W::ZERO); - if !weights - .iter() - .all(|&w| W::ZERO <= w && w <= max_weight_size) - { - return Err(WeightError::InvalidWeight); - } - - // The sum of weights will represent 100% of no alias odds. - let weight_sum = AliasableWeight::sum(weights.as_slice()); - // Prevent floating point overflow due to rounding errors. - let weight_sum = if weight_sum > W::MAX { - W::MAX - } else { - weight_sum - }; - if weight_sum == W::ZERO { - return Err(WeightError::InsufficientNonZero); - } - - // `weight_sum` would have been zero if `try_from_lossy` causes an error here. - let n_converted = W::try_from_u32_lossy(n).unwrap(); - - let mut no_alias_odds = weights.into_boxed_slice(); - for odds in no_alias_odds.iter_mut() { - *odds *= n_converted; - // Prevent floating point overflow due to rounding errors. - *odds = if *odds > W::MAX { W::MAX } else { *odds }; - } - - /// This struct is designed to contain three data structures at once, - /// sharing the same memory. More precisely it contains two linked lists - /// and an alias map, which will be the output of this method. To keep - /// the three data structures from getting in each other's way, it must - /// be ensured that a single index is only ever in one of them at the - /// same time. - struct Aliases { - aliases: Box<[u32]>, - smalls_head: u32, - bigs_head: u32, - } - - impl Aliases { - fn new(size: u32) -> Self { - Aliases { - aliases: vec![0; size as usize].into_boxed_slice(), - smalls_head: ::core::u32::MAX, - bigs_head: ::core::u32::MAX, - } - } - - fn push_small(&mut self, idx: u32) { - self.aliases[idx as usize] = self.smalls_head; - self.smalls_head = idx; - } - - fn push_big(&mut self, idx: u32) { - self.aliases[idx as usize] = self.bigs_head; - self.bigs_head = idx; - } - - fn pop_small(&mut self) -> u32 { - let popped = self.smalls_head; - self.smalls_head = self.aliases[popped as usize]; - popped - } - - fn pop_big(&mut self) -> u32 { - let popped = self.bigs_head; - self.bigs_head = self.aliases[popped as usize]; - popped - } - - fn smalls_is_empty(&self) -> bool { - self.smalls_head == ::core::u32::MAX - } - - fn bigs_is_empty(&self) -> bool { - self.bigs_head == ::core::u32::MAX - } - - fn set_alias(&mut self, idx: u32, alias: u32) { - self.aliases[idx as usize] = alias; - } - } - - let mut aliases = Aliases::new(n); - - // Split indices into those with small weights and those with big weights. - for (index, &odds) in no_alias_odds.iter().enumerate() { - if odds < weight_sum { - aliases.push_small(index as u32); - } else { - aliases.push_big(index as u32); - } - } - - // Build the alias map by finding an alias with big weight for each index with - // small weight. - while !aliases.smalls_is_empty() && !aliases.bigs_is_empty() { - let s = aliases.pop_small(); - let b = aliases.pop_big(); - - aliases.set_alias(s, b); - no_alias_odds[b as usize] = - no_alias_odds[b as usize] - weight_sum + no_alias_odds[s as usize]; - - if no_alias_odds[b as usize] < weight_sum { - aliases.push_small(b); - } else { - aliases.push_big(b); - } - } - - // The remaining indices should have no alias odds of about 100%. This is due to - // numeric accuracy. Otherwise they would be exactly 100%. - while !aliases.smalls_is_empty() { - no_alias_odds[aliases.pop_small() as usize] = weight_sum; - } - while !aliases.bigs_is_empty() { - no_alias_odds[aliases.pop_big() as usize] = weight_sum; - } - - // Prepare distributions for sampling. Creating them beforehand improves - // sampling performance. - let uniform_index = Uniform::new(0, n).unwrap(); - let uniform_within_weight_sum = Uniform::new(W::ZERO, weight_sum).unwrap(); - - Ok(Self { - aliases: aliases.aliases, - no_alias_odds, - uniform_index, - uniform_within_weight_sum, - }) - } -} - -impl Distribution for WeightedAliasIndex { - fn sample(&self, rng: &mut R) -> usize { - let candidate = rng.sample(self.uniform_index); - if rng.sample(&self.uniform_within_weight_sum) < self.no_alias_odds[candidate as usize] { - candidate as usize - } else { - self.aliases[candidate as usize] as usize - } - } -} - -impl fmt::Debug for WeightedAliasIndex -where - W: fmt::Debug, - Uniform: fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("WeightedAliasIndex") - .field("aliases", &self.aliases) - .field("no_alias_odds", &self.no_alias_odds) - .field("uniform_index", &self.uniform_index) - .field("uniform_within_weight_sum", &self.uniform_within_weight_sum) - .finish() - } -} - -impl Clone for WeightedAliasIndex -where Uniform: Clone -{ - fn clone(&self) -> Self { - Self { - aliases: self.aliases.clone(), - no_alias_odds: self.no_alias_odds.clone(), - uniform_index: self.uniform_index, - uniform_within_weight_sum: self.uniform_within_weight_sum.clone(), - } - } -} - -/// Trait that must be implemented for weights, that are used with -/// [`WeightedAliasIndex`]. Currently no guarantees on the correctness of -/// [`WeightedAliasIndex`] are given for custom implementations of this trait. -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] -pub trait AliasableWeight: - Sized - + Copy - + SampleUniform - + PartialOrd - + Add - + AddAssign - + Sub - + SubAssign - + Mul - + MulAssign - + Div - + DivAssign - + Sum -{ - /// Maximum number representable by `Self`. - const MAX: Self; - - /// Element of `Self` equivalent to 0. - const ZERO: Self; - - /// Produce an instance of `Self` from a `u32` value, or return `None` if - /// out of range. Loss of precision (where `Self` is a floating point type) - /// is acceptable. - fn try_from_u32_lossy(n: u32) -> Option; - - /// Sums all values in slice `values`. - fn sum(values: &[Self]) -> Self { - values.iter().copied().sum() - } -} - -macro_rules! impl_weight_for_float { - ($T: ident) => { - impl AliasableWeight for $T { - const MAX: Self = ::core::$T::MAX; - const ZERO: Self = 0.0; - - fn try_from_u32_lossy(n: u32) -> Option { - Some(n as $T) - } - - fn sum(values: &[Self]) -> Self { - pairwise_sum(values) - } - } - }; -} - -/// In comparison to naive accumulation, the pairwise sum algorithm reduces -/// rounding errors when there are many floating point values. -fn pairwise_sum(values: &[T]) -> T { - if values.len() <= 32 { - values.iter().copied().sum() - } else { - let mid = values.len() / 2; - let (a, b) = values.split_at(mid); - pairwise_sum(a) + pairwise_sum(b) - } -} - -macro_rules! impl_weight_for_int { - ($T: ident) => { - impl AliasableWeight for $T { - const MAX: Self = ::core::$T::MAX; - const ZERO: Self = 0; - - fn try_from_u32_lossy(n: u32) -> Option { - let n_converted = n as Self; - if n_converted >= Self::ZERO && n_converted as u32 == n { - Some(n_converted) - } else { - None - } - } - } - }; -} - -impl_weight_for_float!(f64); -impl_weight_for_float!(f32); -impl_weight_for_int!(usize); -impl_weight_for_int!(u128); -impl_weight_for_int!(u64); -impl_weight_for_int!(u32); -impl_weight_for_int!(u16); -impl_weight_for_int!(u8); -impl_weight_for_int!(isize); -impl_weight_for_int!(i128); -impl_weight_for_int!(i64); -impl_weight_for_int!(i32); -impl_weight_for_int!(i16); -impl_weight_for_int!(i8); - -#[cfg(test)] -mod test { - use super::*; - - #[test] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_weighted_index_f32() { - test_weighted_index(f32::into); - - // Floating point special cases - assert_eq!( - WeightedAliasIndex::new(vec![::core::f32::INFINITY]).unwrap_err(), - WeightError::InvalidWeight - ); - assert_eq!( - WeightedAliasIndex::new(vec![-0_f32]).unwrap_err(), - WeightError::InsufficientNonZero - ); - assert_eq!( - WeightedAliasIndex::new(vec![-1_f32]).unwrap_err(), - WeightError::InvalidWeight - ); - assert_eq!( - WeightedAliasIndex::new(vec![-::core::f32::INFINITY]).unwrap_err(), - WeightError::InvalidWeight - ); - assert_eq!( - WeightedAliasIndex::new(vec![::core::f32::NAN]).unwrap_err(), - WeightError::InvalidWeight - ); - } - - #[test] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_weighted_index_u128() { - test_weighted_index(|x: u128| x as f64); - } - - #[test] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_weighted_index_i128() { - test_weighted_index(|x: i128| x as f64); - - // Signed integer special cases - assert_eq!( - WeightedAliasIndex::new(vec![-1_i128]).unwrap_err(), - WeightError::InvalidWeight - ); - assert_eq!( - WeightedAliasIndex::new(vec![::core::i128::MIN]).unwrap_err(), - WeightError::InvalidWeight - ); - } - - #[test] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_weighted_index_u8() { - test_weighted_index(u8::into); - } - - #[test] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_weighted_index_i8() { - test_weighted_index(i8::into); - - // Signed integer special cases - assert_eq!( - WeightedAliasIndex::new(vec![-1_i8]).unwrap_err(), - WeightError::InvalidWeight - ); - assert_eq!( - WeightedAliasIndex::new(vec![::core::i8::MIN]).unwrap_err(), - WeightError::InvalidWeight - ); - } - - fn test_weighted_index f64>(w_to_f64: F) - where WeightedAliasIndex: fmt::Debug { - const NUM_WEIGHTS: u32 = 10; - const ZERO_WEIGHT_INDEX: u32 = 3; - const NUM_SAMPLES: u32 = 15000; - let mut rng = crate::test::rng(0x9c9fa0b0580a7031); - - let weights = { - let mut weights = Vec::with_capacity(NUM_WEIGHTS as usize); - let random_weight_distribution = Uniform::new_inclusive( - W::ZERO, - W::MAX / W::try_from_u32_lossy(NUM_WEIGHTS).unwrap(), - ).unwrap(); - for _ in 0..NUM_WEIGHTS { - weights.push(rng.sample(&random_weight_distribution)); - } - weights[ZERO_WEIGHT_INDEX as usize] = W::ZERO; - weights - }; - let weight_sum = weights.iter().copied().sum::(); - let expected_counts = weights - .iter() - .map(|&w| w_to_f64(w) / w_to_f64(weight_sum) * NUM_SAMPLES as f64) - .collect::>(); - let weight_distribution = WeightedAliasIndex::new(weights).unwrap(); - - let mut counts = vec![0; NUM_WEIGHTS as usize]; - for _ in 0..NUM_SAMPLES { - counts[rng.sample(&weight_distribution)] += 1; - } - - assert_eq!(counts[ZERO_WEIGHT_INDEX as usize], 0); - for (count, expected_count) in counts.into_iter().zip(expected_counts) { - let difference = (count as f64 - expected_count).abs(); - let max_allowed_difference = NUM_SAMPLES as f64 / NUM_WEIGHTS as f64 * 0.1; - assert!(difference <= max_allowed_difference); - } - - assert_eq!( - WeightedAliasIndex::::new(vec![]).unwrap_err(), - WeightError::InvalidInput - ); - assert_eq!( - WeightedAliasIndex::new(vec![W::ZERO]).unwrap_err(), - WeightError::InsufficientNonZero - ); - assert_eq!( - WeightedAliasIndex::new(vec![W::MAX, W::MAX]).unwrap_err(), - WeightError::InvalidWeight - ); - } - - #[test] - fn value_stability() { - fn test_samples(weights: Vec, buf: &mut [usize], expected: &[usize]) { - assert_eq!(buf.len(), expected.len()); - let distr = WeightedAliasIndex::new(weights).unwrap(); - let mut rng = crate::test::rng(0x9c9fa0b0580a7031); - for r in buf.iter_mut() { - *r = rng.sample(&distr); - } - assert_eq!(buf, expected); - } - - let mut buf = [0; 10]; - test_samples(vec![1i32, 1, 1, 1, 1, 1, 1, 1, 1], &mut buf, &[ - 6, 5, 7, 5, 8, 7, 6, 2, 3, 7, - ]); - test_samples(vec![0.7f32, 0.1, 0.1, 0.1], &mut buf, &[ - 2, 0, 0, 0, 0, 0, 0, 0, 1, 3, - ]); - test_samples(vec![1.0f64, 0.999, 0.998, 0.997], &mut buf, &[ - 2, 1, 2, 3, 2, 1, 3, 2, 1, 1, - ]); - } -} diff --git a/rand_distr/src/weighted_tree.rs b/rand_distr/src/weighted_tree.rs deleted file mode 100644 index d5b4ef467d8..00000000000 --- a/rand_distr/src/weighted_tree.rs +++ /dev/null @@ -1,393 +0,0 @@ -// Copyright 2024 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! This module contains an implementation of a tree structure for sampling random -//! indices with probabilities proportional to a collection of weights. - -use core::ops::SubAssign; - -use super::WeightError; -use crate::Distribution; -use alloc::vec::Vec; -use rand::distributions::uniform::{SampleBorrow, SampleUniform}; -use rand::distributions::Weight; -use rand::Rng; -#[cfg(feature = "serde1")] -use serde::{Deserialize, Serialize}; - -/// A distribution using weighted sampling to pick a discretely selected item. -/// -/// Sampling a [`WeightedTreeIndex`] distribution returns the index of a randomly -/// selected element from the vector used to create the [`WeightedTreeIndex`]. -/// The chance of a given element being picked is proportional to the value of -/// the element. The weights can have any type `W` for which an implementation of -/// [`Weight`] exists. -/// -/// # Key differences -/// -/// The main distinction between [`WeightedTreeIndex`] and [`rand::distributions::WeightedIndex`] -/// lies in the internal representation of weights. In [`WeightedTreeIndex`], -/// weights are structured as a tree, which is optimized for frequent updates of the weights. -/// -/// # Caution: Floating point types -/// -/// When utilizing [`WeightedTreeIndex`] with floating point types (such as f32 or f64), -/// exercise caution due to the inherent nature of floating point arithmetic. Floating point types -/// are susceptible to numerical rounding errors. Since operations on floating point weights are -/// repeated numerous times, rounding errors can accumulate, potentially leading to noticeable -/// deviations from the expected behavior. -/// -/// Ideally, use fixed point or integer types whenever possible. -/// -/// # Performance -/// -/// A [`WeightedTreeIndex`] with `n` elements requires `O(n)` memory. -/// -/// Time complexity for the operations of a [`WeightedTreeIndex`] are: -/// * Constructing: Building the initial tree from an iterator of weights takes `O(n)` time. -/// * Sampling: Choosing an index (traversing down the tree) requires `O(log n)` time. -/// * Weight Update: Modifying a weight (traversing up the tree), requires `O(log n)` time. -/// * Weight Addition (Pushing): Adding a new weight (traversing up the tree), requires `O(log n)` time. -/// * Weight Removal (Popping): Removing a weight (traversing up the tree), requires `O(log n)` time. -/// -/// # Example -/// -/// ``` -/// use rand_distr::WeightedTreeIndex; -/// use rand::prelude::*; -/// -/// let choices = vec!['a', 'b', 'c']; -/// let weights = vec![2, 0]; -/// let mut dist = WeightedTreeIndex::new(&weights).unwrap(); -/// dist.push(1).unwrap(); -/// dist.update(1, 1).unwrap(); -/// let mut rng = thread_rng(); -/// let mut samples = [0; 3]; -/// for _ in 0..100 { -/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' -/// let i = dist.sample(&mut rng); -/// samples[i] += 1; -/// } -/// println!("Results: {:?}", choices.iter().zip(samples.iter()).collect::>()); -/// ``` -/// -/// [`WeightedTreeIndex`]: WeightedTreeIndex -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -#[cfg_attr( - feature = "serde1", - serde(bound(serialize = "W: Serialize, W::Sampler: Serialize")) -)] -#[cfg_attr( - feature = "serde1 ", - serde(bound(deserialize = "W: Deserialize<'de>, W::Sampler: Deserialize<'de>")) -)] -#[derive(Clone, Default, Debug, PartialEq)] -pub struct WeightedTreeIndex< - W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign + Weight, -> { - subtotals: Vec, -} - -impl + Weight> - WeightedTreeIndex -{ - /// Creates a new [`WeightedTreeIndex`] from a slice of weights. - /// - /// Error cases: - /// - [`WeightError::InvalidWeight`] when a weight is not-a-number or negative. - /// - [`WeightError::Overflow`] when the sum of all weights overflows. - pub fn new(weights: I) -> Result - where - I: IntoIterator, - I::Item: SampleBorrow, - { - let mut subtotals: Vec = weights.into_iter().map(|x| x.borrow().clone()).collect(); - for weight in subtotals.iter() { - if !(*weight >= W::ZERO) { - return Err(WeightError::InvalidWeight); - } - } - let n = subtotals.len(); - for i in (1..n).rev() { - let w = subtotals[i].clone(); - let parent = (i - 1) / 2; - subtotals[parent] - .checked_add_assign(&w) - .map_err(|()| WeightError::Overflow)?; - } - Ok(Self { subtotals }) - } - - /// Returns `true` if the tree contains no weights. - pub fn is_empty(&self) -> bool { - self.subtotals.is_empty() - } - - /// Returns the number of weights. - pub fn len(&self) -> usize { - self.subtotals.len() - } - - /// Returns `true` if we can sample. - /// - /// This is the case if the total weight of the tree is greater than zero. - pub fn is_valid(&self) -> bool { - if let Some(weight) = self.subtotals.first() { - *weight > W::ZERO - } else { - false - } - } - - /// Gets the weight at an index. - pub fn get(&self, index: usize) -> W { - let left_index = 2 * index + 1; - let right_index = 2 * index + 2; - let mut w = self.subtotals[index].clone(); - w -= self.subtotal(left_index); - w -= self.subtotal(right_index); - w - } - - /// Removes the last weight and returns it, or [`None`] if it is empty. - pub fn pop(&mut self) -> Option { - self.subtotals.pop().map(|weight| { - let mut index = self.len(); - while index != 0 { - index = (index - 1) / 2; - self.subtotals[index] -= weight.clone(); - } - weight - }) - } - - /// Appends a new weight at the end. - /// - /// Error cases: - /// - [`WeightError::InvalidWeight`] when a weight is not-a-number or negative. - /// - [`WeightError::Overflow`] when the sum of all weights overflows. - pub fn push(&mut self, weight: W) -> Result<(), WeightError> { - if !(weight >= W::ZERO) { - return Err(WeightError::InvalidWeight); - } - if let Some(total) = self.subtotals.first() { - let mut total = total.clone(); - if total.checked_add_assign(&weight).is_err() { - return Err(WeightError::Overflow); - } - } - let mut index = self.len(); - self.subtotals.push(weight.clone()); - while index != 0 { - index = (index - 1) / 2; - self.subtotals[index].checked_add_assign(&weight).unwrap(); - } - Ok(()) - } - - /// Updates the weight at an index. - /// - /// Error cases: - /// - [`WeightError::InvalidWeight`] when a weight is not-a-number or negative. - /// - [`WeightError::Overflow`] when the sum of all weights overflows. - pub fn update(&mut self, mut index: usize, weight: W) -> Result<(), WeightError> { - if !(weight >= W::ZERO) { - return Err(WeightError::InvalidWeight); - } - let old_weight = self.get(index); - if weight > old_weight { - let mut difference = weight; - difference -= old_weight; - if let Some(total) = self.subtotals.first() { - let mut total = total.clone(); - if total.checked_add_assign(&difference).is_err() { - return Err(WeightError::Overflow); - } - } - self.subtotals[index] - .checked_add_assign(&difference) - .unwrap(); - while index != 0 { - index = (index - 1) / 2; - self.subtotals[index] - .checked_add_assign(&difference) - .unwrap(); - } - } else if weight < old_weight { - let mut difference = old_weight; - difference -= weight; - self.subtotals[index] -= difference.clone(); - while index != 0 { - index = (index - 1) / 2; - self.subtotals[index] -= difference.clone(); - } - } - Ok(()) - } - - fn subtotal(&self, index: usize) -> W { - if index < self.subtotals.len() { - self.subtotals[index].clone() - } else { - W::ZERO - } - } -} - -impl + Weight> - WeightedTreeIndex -{ - /// Samples a randomly selected index from the weighted distribution. - /// - /// Returns an error if there are no elements or all weights are zero. This - /// is unlike [`Distribution::sample`], which panics in those cases. - fn try_sample(&self, rng: &mut R) -> Result { - let total_weight = self.subtotals.first().cloned().unwrap_or(W::ZERO); - if total_weight == W::ZERO { - return Err(WeightError::InsufficientNonZero); - } - let mut target_weight = rng.gen_range(W::ZERO..total_weight); - let mut index = 0; - loop { - // Maybe descend into the left sub tree. - let left_index = 2 * index + 1; - let left_subtotal = self.subtotal(left_index); - if target_weight < left_subtotal { - index = left_index; - continue; - } - target_weight -= left_subtotal; - - // Maybe descend into the right sub tree. - let right_index = 2 * index + 2; - let right_subtotal = self.subtotal(right_index); - if target_weight < right_subtotal { - index = right_index; - continue; - } - target_weight -= right_subtotal; - - // Otherwise we found the index with the target weight. - break; - } - assert!(target_weight >= W::ZERO); - assert!(target_weight < self.get(index)); - Ok(index) - } -} - -/// Samples a randomly selected index from the weighted distribution. -/// -/// Caution: This method panics if there are no elements or all weights are zero. However, -/// it is guaranteed that this method will not panic if a call to [`WeightedTreeIndex::is_valid`] -/// returns `true`. -impl + Weight> Distribution - for WeightedTreeIndex -{ - fn sample(&self, rng: &mut R) -> usize { - self.try_sample(rng).unwrap() - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_no_item_error() { - let mut rng = crate::test::rng(0x9c9fa0b0580a7031); - let tree = WeightedTreeIndex::::new(&[]).unwrap(); - assert_eq!( - tree.try_sample(&mut rng).unwrap_err(), - WeightError::InsufficientNonZero - ); - } - - #[test] - fn test_overflow_error() { - assert_eq!( - WeightedTreeIndex::new(&[i32::MAX, 2]), - Err(WeightError::Overflow) - ); - let mut tree = WeightedTreeIndex::new(&[i32::MAX - 2, 1]).unwrap(); - assert_eq!(tree.push(3), Err(WeightError::Overflow)); - assert_eq!(tree.update(1, 4), Err(WeightError::Overflow)); - tree.update(1, 2).unwrap(); - } - - #[test] - fn test_all_weights_zero_error() { - let tree = WeightedTreeIndex::::new(&[0.0, 0.0]).unwrap(); - let mut rng = crate::test::rng(0x9c9fa0b0580a7031); - assert_eq!( - tree.try_sample(&mut rng).unwrap_err(), - WeightError::InsufficientNonZero - ); - } - - #[test] - fn test_invalid_weight_error() { - assert_eq!( - WeightedTreeIndex::::new(&[1, -1]).unwrap_err(), - WeightError::InvalidWeight - ); - let mut tree = WeightedTreeIndex::::new(&[]).unwrap(); - assert_eq!(tree.push(-1).unwrap_err(), WeightError::InvalidWeight); - tree.push(1).unwrap(); - assert_eq!( - tree.update(0, -1).unwrap_err(), - WeightError::InvalidWeight - ); - } - - #[test] - fn test_tree_modifications() { - let mut tree = WeightedTreeIndex::new(&[9, 1, 2]).unwrap(); - tree.push(3).unwrap(); - tree.push(5).unwrap(); - tree.update(0, 0).unwrap(); - assert_eq!(tree.pop(), Some(5)); - let expected = WeightedTreeIndex::new(&[0, 1, 2, 3]).unwrap(); - assert_eq!(tree, expected); - } - - #[test] - fn test_sample_counts_match_probabilities() { - let start = 1; - let end = 3; - let samples = 20; - let mut rng = crate::test::rng(0x9c9fa0b0580a7031); - let weights: Vec<_> = (0..end).map(|_| rng.gen()).collect(); - let mut tree = WeightedTreeIndex::new(&weights).unwrap(); - let mut total_weight = 0.0; - let mut weights = alloc::vec![0.0; end]; - for i in 0..end { - tree.update(i, i as f64).unwrap(); - weights[i] = i as f64; - total_weight += i as f64; - } - for i in 0..start { - tree.update(i, 0.0).unwrap(); - weights[i] = 0.0; - total_weight -= i as f64; - } - let mut counts = alloc::vec![0_usize; end]; - for _ in 0..samples { - let i = tree.sample(&mut rng); - counts[i] += 1; - } - for i in 0..start { - assert_eq!(counts[i], 0); - } - for i in start..end { - let diff = counts[i] as f64 / samples as f64 - weights[i] / total_weight; - assert!(diff.abs() < 0.05); - } - } -} diff --git a/rand_distr/src/ziggurat_tables.rs b/rand_distr/src/ziggurat_tables.rs deleted file mode 100644 index f830a601bdd..00000000000 --- a/rand_distr/src/ziggurat_tables.rs +++ /dev/null @@ -1,283 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// Copyright 2013 The Rust Project Developers. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -// Tables for distributions which are sampled using the ziggurat -// algorithm. Autogenerated by `ziggurat_tables.py`. - -pub type ZigTable = &'static [f64; 257]; -pub const ZIG_NORM_R: f64 = 3.654152885361008796; -#[rustfmt::skip] -pub static ZIG_NORM_X: [f64; 257] = - [3.910757959537090045, 3.654152885361008796, 3.449278298560964462, 3.320244733839166074, - 3.224575052047029100, 3.147889289517149969, 3.083526132001233044, 3.027837791768635434, - 2.978603279880844834, 2.934366867207854224, 2.894121053612348060, 2.857138730872132548, - 2.822877396825325125, 2.790921174000785765, 2.760944005278822555, 2.732685359042827056, - 2.705933656121858100, 2.680514643284522158, 2.656283037575502437, 2.633116393630324570, - 2.610910518487548515, 2.589575986706995181, 2.569035452680536569, 2.549221550323460761, - 2.530075232158516929, 2.511544441625342294, 2.493583041269680667, 2.476149939669143318, - 2.459208374333311298, 2.442725318198956774, 2.426670984935725972, 2.411018413899685520, - 2.395743119780480601, 2.380822795170626005, 2.366237056715818632, 2.351967227377659952, - 2.337996148795031370, 2.324308018869623016, 2.310888250599850036, 2.297723348901329565, - 2.284800802722946056, 2.272108990226823888, 2.259637095172217780, 2.247375032945807760, - 2.235313384928327984, 2.223443340090905718, 2.211756642882544366, 2.200245546609647995, - 2.188902771624720689, 2.177721467738641614, 2.166695180352645966, 2.155817819875063268, - 2.145083634046203613, 2.134487182844320152, 2.124023315687815661, 2.113687150684933957, - 2.103474055713146829, 2.093379631137050279, 2.083399693996551783, 2.073530263516978778, - 2.063767547809956415, 2.054107931648864849, 2.044547965215732788, 2.035084353727808715, - 2.025713947862032960, 2.016433734904371722, 2.007240830558684852, 1.998132471356564244, - 1.989106007615571325, 1.980158896898598364, 1.971288697931769640, 1.962493064942461896, - 1.953769742382734043, 1.945116560006753925, 1.936531428273758904, 1.928012334050718257, - 1.919557336591228847, 1.911164563769282232, 1.902832208548446369, 1.894558525668710081, - 1.886341828534776388, 1.878180486290977669, 1.870072921069236838, 1.862017605397632281, - 1.854013059758148119, 1.846057850283119750, 1.838150586580728607, 1.830289919680666566, - 1.822474540091783224, 1.814703175964167636, 1.806974591348693426, 1.799287584547580199, - 1.791640986550010028, 1.784033659547276329, 1.776464495522344977, 1.768932414909077933, - 1.761436365316706665, 1.753975320315455111, 1.746548278279492994, 1.739154261283669012, - 1.731792314050707216, 1.724461502945775715, 1.717160915015540690, 1.709889657069006086, - 1.702646854797613907, 1.695431651932238548, 1.688243209434858727, 1.681080704722823338, - 1.673943330923760353, 1.666830296159286684, 1.659740822855789499, 1.652674147080648526, - 1.645629517902360339, 1.638606196773111146, 1.631603456932422036, 1.624620582830568427, - 1.617656869570534228, 1.610711622367333673, 1.603784156023583041, 1.596873794420261339, - 1.589979870021648534, 1.583101723393471438, 1.576238702733332886, 1.569390163412534456, - 1.562555467528439657, 1.555733983466554893, 1.548925085471535512, 1.542128153226347553, - 1.535342571438843118, 1.528567729435024614, 1.521803020758293101, 1.515047842773992404, - 1.508301596278571965, 1.501563685112706548, 1.494833515777718391, 1.488110497054654369, - 1.481394039625375747, 1.474683555695025516, 1.467978458615230908, 1.461278162507407830, - 1.454582081885523293, 1.447889631277669675, 1.441200224845798017, 1.434513276002946425, - 1.427828197027290358, 1.421144398672323117, 1.414461289772464658, 1.407778276843371534, - 1.401094763676202559, 1.394410150925071257, 1.387723835686884621, 1.381035211072741964, - 1.374343665770030531, 1.367648583594317957, 1.360949343030101844, 1.354245316759430606, - 1.347535871177359290, 1.340820365893152122, 1.334098153216083604, 1.327368577624624679, - 1.320630975217730096, 1.313884673146868964, 1.307128989027353860, 1.300363230327433728, - 1.293586693733517645, 1.286798664489786415, 1.279998415710333237, 1.273185207661843732, - 1.266358287014688333, 1.259516886060144225, 1.252660221891297887, 1.245787495544997903, - 1.238897891102027415, 1.231990574742445110, 1.225064693752808020, 1.218119375481726552, - 1.211153726239911244, 1.204166830140560140, 1.197157747875585931, 1.190125515422801650, - 1.183069142678760732, 1.175987612011489825, 1.168879876726833800, 1.161744859441574240, - 1.154581450355851802, 1.147388505416733873, 1.140164844363995789, 1.132909248648336975, - 1.125620459211294389, 1.118297174115062909, 1.110938046009249502, 1.103541679420268151, - 1.096106627847603487, 1.088631390649514197, 1.081114409698889389, 1.073554065787871714, - 1.065948674757506653, 1.058296483326006454, 1.050595664586207123, 1.042844313139370538, - 1.035040439828605274, 1.027181966030751292, 1.019266717460529215, 1.011292417434978441, - 1.003256679539591412, 0.995156999629943084, 0.986990747093846266, 0.978755155288937750, - 0.970447311058864615, 0.962064143217605250, 0.953602409875572654, 0.945058684462571130, - 0.936429340280896860, 0.927710533396234771, 0.918898183643734989, 0.909987953490768997, - 0.900975224455174528, 0.891855070726792376, 0.882622229578910122, 0.873271068082494550, - 0.863795545546826915, 0.854189171001560554, 0.844444954902423661, 0.834555354079518752, - 0.824512208745288633, 0.814306670128064347, 0.803929116982664893, 0.793369058833152785, - 0.782615023299588763, 0.771654424216739354, 0.760473406422083165, 0.749056662009581653, - 0.737387211425838629, 0.725446140901303549, 0.713212285182022732, 0.700661841097584448, - 0.687767892786257717, 0.674499822827436479, 0.660822574234205984, 0.646695714884388928, - 0.632072236375024632, 0.616896989996235545, 0.601104617743940417, 0.584616766093722262, - 0.567338257040473026, 0.549151702313026790, 0.529909720646495108, 0.509423329585933393, - 0.487443966121754335, 0.463634336771763245, 0.437518402186662658, 0.408389134588000746, - 0.375121332850465727, 0.335737519180459465, 0.286174591747260509, 0.215241895913273806, - 0.000000000000000000]; -#[rustfmt::skip] -pub static ZIG_NORM_F: [f64; 257] = - [0.000477467764586655, 0.001260285930498598, 0.002609072746106363, 0.004037972593371872, - 0.005522403299264754, 0.007050875471392110, 0.008616582769422917, 0.010214971439731100, - 0.011842757857943104, 0.013497450601780807, 0.015177088307982072, 0.016880083152595839, - 0.018605121275783350, 0.020351096230109354, 0.022117062707379922, 0.023902203305873237, - 0.025705804008632656, 0.027527235669693315, 0.029365939758230111, 0.031221417192023690, - 0.033093219458688698, 0.034980941461833073, 0.036884215688691151, 0.038802707404656918, - 0.040736110656078753, 0.042684144916619378, 0.044646552251446536, 0.046623094902089664, - 0.048613553216035145, 0.050617723861121788, 0.052635418276973649, 0.054666461325077916, - 0.056710690106399467, 0.058767952921137984, 0.060838108349751806, 0.062921024437977854, - 0.065016577971470438, 0.067124653828023989, 0.069245144397250269, 0.071377949059141965, - 0.073522973714240991, 0.075680130359194964, 0.077849336702372207, 0.080030515814947509, - 0.082223595813495684, 0.084428509570654661, 0.086645194450867782, 0.088873592068594229, - 0.091113648066700734, 0.093365311913026619, 0.095628536713353335, 0.097903279039215627, - 0.100189498769172020, 0.102487158942306270, 0.104796225622867056, 0.107116667775072880, - 0.109448457147210021, 0.111791568164245583, 0.114145977828255210, 0.116511665626037014, - 0.118888613443345698, 0.121276805485235437, 0.123676228202051403, 0.126086870220650349, - 0.128508722280473636, 0.130941777174128166, 0.133386029692162844, 0.135841476571757352, - 0.138308116449064322, 0.140785949814968309, 0.143274978974047118, 0.145775208006537926, - 0.148286642733128721, 0.150809290682410169, 0.153343161060837674, 0.155888264725064563, - 0.158444614156520225, 0.161012223438117663, 0.163591108232982951, 0.166181285765110071, - 0.168782774801850333, 0.171395595638155623, 0.174019770082499359, 0.176655321444406654, - 0.179302274523530397, 0.181960655600216487, 0.184630492427504539, 0.187311814224516926, - 0.190004651671193070, 0.192709036904328807, 0.195425003514885592, 0.198152586546538112, - 0.200891822495431333, 0.203642749311121501, 0.206405406398679298, 0.209179834621935651, - 0.211966076307852941, 0.214764175252008499, 0.217574176725178370, 0.220396127481011589, - 0.223230075764789593, 0.226076071323264877, 0.228934165415577484, 0.231804410825248525, - 0.234686861873252689, 0.237581574432173676, 0.240488605941449107, 0.243408015423711988, - 0.246339863502238771, 0.249284212419516704, 0.252241126056943765, 0.255210669955677150, - 0.258192911338648023, 0.261187919133763713, 0.264195763998317568, 0.267216518344631837, - 0.270250256366959984, 0.273297054069675804, 0.276356989296781264, 0.279430141762765316, - 0.282516593084849388, 0.285616426816658109, 0.288729728483353931, 0.291856585618280984, - 0.294997087801162572, 0.298151326697901342, 0.301319396102034120, 0.304501391977896274, - 0.307697412505553769, 0.310907558127563710, 0.314131931597630143, 0.317370638031222396, - 0.320623784958230129, 0.323891482377732021, 0.327173842814958593, 0.330470981380537099, - 0.333783015832108509, 0.337110066638412809, 0.340452257045945450, 0.343809713148291340, - 0.347182563958251478, 0.350570941482881204, 0.353974980801569250, 0.357394820147290515, - 0.360830600991175754, 0.364282468130549597, 0.367750569780596226, 0.371235057669821344, - 0.374736087139491414, 0.378253817247238111, 0.381788410875031348, 0.385340034841733958, - 0.388908860020464597, 0.392495061461010764, 0.396098818517547080, 0.399720314981931668, - 0.403359739222868885, 0.407017284331247953, 0.410693148271983222, 0.414387534042706784, - 0.418100649839684591, 0.421832709231353298, 0.425583931339900579, 0.429354541031341519, - 0.433144769114574058, 0.436954852549929273, 0.440785034667769915, 0.444635565397727750, - 0.448506701509214067, 0.452398706863882505, 0.456311852680773566, 0.460246417814923481, - 0.464202689050278838, 0.468180961407822172, 0.472181538469883255, 0.476204732721683788, - 0.480250865911249714, 0.484320269428911598, 0.488413284707712059, 0.492530263646148658, - 0.496671569054796314, 0.500837575128482149, 0.505028667945828791, 0.509245245998136142, - 0.513487720749743026, 0.517756517232200619, 0.522052074674794864, 0.526374847174186700, - 0.530725304406193921, 0.535103932383019565, 0.539511234259544614, 0.543947731192649941, - 0.548413963257921133, 0.552910490428519918, 0.557437893621486324, 0.561996775817277916, - 0.566587763258951771, 0.571211506738074970, 0.575868682975210544, 0.580559996103683473, - 0.585286179266300333, 0.590047996335791969, 0.594846243770991268, 0.599681752622167719, - 0.604555390700549533, 0.609468064928895381, 0.614420723892076803, 0.619414360609039205, - 0.624450015550274240, 0.629528779928128279, 0.634651799290960050, 0.639820277456438991, - 0.645035480824251883, 0.650298743114294586, 0.655611470583224665, 0.660975147780241357, - 0.666391343912380640, 0.671861719900766374, 0.677388036222513090, 0.682972161648791376, - 0.688616083008527058, 0.694321916130032579, 0.700091918140490099, 0.705928501336797409, - 0.711834248882358467, 0.717811932634901395, 0.723864533472881599, 0.729995264565802437, - 0.736207598131266683, 0.742505296344636245, 0.748892447223726720, 0.755373506511754500, - 0.761953346841546475, 0.768637315803334831, 0.775431304986138326, 0.782341832659861902, - 0.789376143571198563, 0.796542330428254619, 0.803849483176389490, 0.811307874318219935, - 0.818929191609414797, 0.826726833952094231, 0.834716292992930375, 0.842915653118441077, - 0.851346258465123684, 0.860033621203008636, 0.869008688043793165, 0.878309655816146839, - 0.887984660763399880, 0.898095921906304051, 0.908726440060562912, 0.919991505048360247, - 0.932060075968990209, 0.945198953453078028, 0.959879091812415930, 0.977101701282731328, - 1.000000000000000000]; -pub const ZIG_EXP_R: f64 = 7.697117470131050077; -#[rustfmt::skip] -pub static ZIG_EXP_X: [f64; 257] = - [8.697117470131052741, 7.697117470131050077, 6.941033629377212577, 6.478378493832569696, - 6.144164665772472667, 5.882144315795399869, 5.666410167454033697, 5.482890627526062488, - 5.323090505754398016, 5.181487281301500047, 5.054288489981304089, 4.938777085901250530, - 4.832939741025112035, 4.735242996601741083, 4.644491885420085175, 4.559737061707351380, - 4.480211746528421912, 4.405287693473573185, 4.334443680317273007, 4.267242480277365857, - 4.203313713735184365, 4.142340865664051464, 4.084051310408297830, 4.028208544647936762, - 3.974606066673788796, 3.923062500135489739, 3.873417670399509127, 3.825529418522336744, - 3.779270992411667862, 3.734528894039797375, 3.691201090237418825, 3.649195515760853770, - 3.608428813128909507, 3.568825265648337020, 3.530315889129343354, 3.492837654774059608, - 3.456332821132760191, 3.420748357251119920, 3.386035442460300970, 3.352149030900109405, - 3.319047470970748037, 3.286692171599068679, 3.255047308570449882, 3.224079565286264160, - 3.193757903212240290, 3.164053358025972873, 3.134938858084440394, 3.106389062339824481, - 3.078380215254090224, 3.050890016615455114, 3.023897504455676621, 2.997382949516130601, - 2.971327759921089662, 2.945714394895045718, 2.920526286512740821, 2.895747768600141825, - 2.871364012015536371, 2.847360965635188812, 2.823725302450035279, 2.800444370250737780, - 2.777506146439756574, 2.754899196562344610, 2.732612636194700073, 2.710636095867928752, - 2.688959688741803689, 2.667573980773266573, 2.646469963151809157, 2.625639026797788489, - 2.605072938740835564, 2.584763820214140750, 2.564704126316905253, 2.544886627111869970, - 2.525304390037828028, 2.505950763528594027, 2.486819361740209455, 2.467904050297364815, - 2.449198932978249754, 2.430698339264419694, 2.412396812688870629, 2.394289099921457886, - 2.376370140536140596, 2.358635057409337321, 2.341079147703034380, 2.323697874390196372, - 2.306486858283579799, 2.289441870532269441, 2.272558825553154804, 2.255833774367219213, - 2.239262898312909034, 2.222842503111036816, 2.206569013257663858, 2.190438966723220027, - 2.174449009937774679, 2.158595893043885994, 2.142876465399842001, 2.127287671317368289, - 2.111826546019042183, 2.096490211801715020, 2.081275874393225145, 2.066180819490575526, - 2.051202409468584786, 2.036338080248769611, 2.021585338318926173, 2.006941757894518563, - 1.992404978213576650, 1.977972700957360441, 1.963642687789548313, 1.949412758007184943, - 1.935280786297051359, 1.921244700591528076, 1.907302480018387536, 1.893452152939308242, - 1.879691795072211180, 1.866019527692827973, 1.852433515911175554, 1.838931967018879954, - 1.825513128903519799, 1.812175288526390649, 1.798916770460290859, 1.785735935484126014, - 1.772631179231305643, 1.759600930889074766, 1.746643651946074405, 1.733757834985571566, - 1.720942002521935299, 1.708194705878057773, 1.695514524101537912, 1.682900062917553896, - 1.670349953716452118, 1.657862852574172763, 1.645437439303723659, 1.633072416535991334, - 1.620766508828257901, 1.608518461798858379, 1.596327041286483395, 1.584191032532688892, - 1.572109239386229707, 1.560080483527888084, 1.548103603714513499, 1.536177455041032092, - 1.524300908219226258, 1.512472848872117082, 1.500692176842816750, 1.488957805516746058, - 1.477268661156133867, 1.465623682245745352, 1.454021818848793446, 1.442462031972012504, - 1.430943292938879674, 1.419464582769983219, 1.408024891569535697, 1.396623217917042137, - 1.385258568263121992, 1.373929956328490576, 1.362636402505086775, 1.351376933258335189, - 1.340150580529504643, 1.328956381137116560, 1.317793376176324749, 1.306660610415174117, - 1.295557131686601027, 1.284481990275012642, 1.273434238296241139, 1.262412929069615330, - 1.251417116480852521, 1.240445854334406572, 1.229498195693849105, 1.218573192208790124, - 1.207669893426761121, 1.196787346088403092, 1.185924593404202199, 1.175080674310911677, - 1.164254622705678921, 1.153445466655774743, 1.142652227581672841, 1.131873919411078511, - 1.121109547701330200, 1.110358108727411031, 1.099618588532597308, 1.088889961938546813, - 1.078171191511372307, 1.067461226479967662, 1.056759001602551429, 1.046063435977044209, - 1.035373431790528542, 1.024687873002617211, 1.014005623957096480, 1.003325527915696735, - 0.992646405507275897, 0.981967053085062602, 0.971286240983903260, 0.960602711668666509, - 0.949915177764075969, 0.939222319955262286, 0.928522784747210395, 0.917815182070044311, - 0.907098082715690257, 0.896370015589889935, 0.885629464761751528, 0.874874866291025066, - 0.864104604811004484, 0.853317009842373353, 0.842510351810368485, 0.831682837734273206, - 0.820832606554411814, 0.809957724057418282, 0.799056177355487174, 0.788125868869492430, - 0.777164609759129710, 0.766170112735434672, 0.755139984181982249, 0.744071715500508102, - 0.732962673584365398, 0.721810090308756203, 0.710611050909655040, 0.699362481103231959, - 0.688061132773747808, 0.676703568029522584, 0.665286141392677943, 0.653804979847664947, - 0.642255960424536365, 0.630634684933490286, 0.618936451394876075, 0.607156221620300030, - 0.595288584291502887, 0.583327712748769489, 0.571267316532588332, 0.559100585511540626, - 0.546820125163310577, 0.534417881237165604, 0.521885051592135052, 0.509211982443654398, - 0.496388045518671162, 0.483401491653461857, 0.470239275082169006, 0.456886840931420235, - 0.443327866073552401, 0.429543940225410703, 0.415514169600356364, 0.401214678896277765, - 0.386617977941119573, 0.371692145329917234, 0.356399760258393816, 0.340696481064849122, - 0.324529117016909452, 0.307832954674932158, 0.290527955491230394, 0.272513185478464703, - 0.253658363385912022, 0.233790483059674731, 0.212671510630966620, 0.189958689622431842, - 0.165127622564187282, 0.137304980940012589, 0.104838507565818778, 0.063852163815001570, - 0.000000000000000000]; -#[rustfmt::skip] -pub static ZIG_EXP_F: [f64; 257] = - [0.000167066692307963, 0.000454134353841497, 0.000967269282327174, 0.001536299780301573, - 0.002145967743718907, 0.002788798793574076, 0.003460264777836904, 0.004157295120833797, - 0.004877655983542396, 0.005619642207205489, 0.006381905937319183, 0.007163353183634991, - 0.007963077438017043, 0.008780314985808977, 0.009614413642502212, 0.010464810181029981, - 0.011331013597834600, 0.012212592426255378, 0.013109164931254991, 0.014020391403181943, - 0.014945968011691148, 0.015885621839973156, 0.016839106826039941, 0.017806200410911355, - 0.018786700744696024, 0.019780424338009740, 0.020787204072578114, 0.021806887504283581, - 0.022839335406385240, 0.023884420511558174, 0.024942026419731787, 0.026012046645134221, - 0.027094383780955803, 0.028188948763978646, 0.029295660224637411, 0.030414443910466622, - 0.031545232172893622, 0.032687963508959555, 0.033842582150874358, 0.035009037697397431, - 0.036187284781931443, 0.037377282772959382, 0.038578995503074871, 0.039792391023374139, - 0.041017441380414840, 0.042254122413316254, 0.043502413568888197, 0.044762297732943289, - 0.046033761076175184, 0.047316792913181561, 0.048611385573379504, 0.049917534282706379, - 0.051235237055126281, 0.052564494593071685, 0.053905310196046080, 0.055257689676697030, - 0.056621641283742870, 0.057997175631200659, 0.059384305633420280, 0.060783046445479660, - 0.062193415408541036, 0.063615431999807376, 0.065049117786753805, 0.066494496385339816, - 0.067951593421936643, 0.069420436498728783, 0.070901055162371843, 0.072393480875708752, - 0.073897746992364746, 0.075413888734058410, 0.076941943170480517, 0.078481949201606435, - 0.080033947542319905, 0.081597980709237419, 0.083174093009632397, 0.084762330532368146, - 0.086362741140756927, 0.087975374467270231, 0.089600281910032886, 0.091237516631040197, - 0.092887133556043569, 0.094549189376055873, 0.096223742550432825, 0.097910853311492213, - 0.099610583670637132, 0.101322997425953631, 0.103048160171257702, 0.104786139306570145, - 0.106537004050001632, 0.108300825451033755, 0.110077676405185357, 0.111867631670056283, - 0.113670767882744286, 0.115487163578633506, 0.117316899211555525, 0.119160057175327641, - 0.121016721826674792, 0.122886979509545108, 0.124770918580830933, 0.126668629437510671, - 0.128580204545228199, 0.130505738468330773, 0.132445327901387494, 0.134399071702213602, - 0.136367070926428829, 0.138349428863580176, 0.140346251074862399, 0.142357645432472146, - 0.144383722160634720, 0.146424593878344889, 0.148480375643866735, 0.150551185001039839, - 0.152637142027442801, 0.154738369384468027, 0.156854992369365148, 0.158987138969314129, - 0.161134939917591952, 0.163298528751901734, 0.165478041874935922, 0.167673618617250081, - 0.169885401302527550, 0.172113535315319977, 0.174358169171353411, 0.176619454590494829, - 0.178897546572478278, 0.181192603475496261, 0.183504787097767436, 0.185834262762197083, - 0.188181199404254262, 0.190545769663195363, 0.192928149976771296, 0.195328520679563189, - 0.197747066105098818, 0.200183974691911210, 0.202639439093708962, 0.205113656293837654, - 0.207606827724221982, 0.210119159388988230, 0.212650861992978224, 0.215202151075378628, - 0.217773247148700472, 0.220364375843359439, 0.222975768058120111, 0.225607660116683956, - 0.228260293930716618, 0.230933917169627356, 0.233628783437433291, 0.236345152457059560, - 0.239083290262449094, 0.241843469398877131, 0.244625969131892024, 0.247431075665327543, - 0.250259082368862240, 0.253110290015629402, 0.255985007030415324, 0.258883549749016173, - 0.261806242689362922, 0.264753418835062149, 0.267725419932044739, 0.270722596799059967, - 0.273745309652802915, 0.276793928448517301, 0.279868833236972869, 0.282970414538780746, - 0.286099073737076826, 0.289255223489677693, 0.292439288161892630, 0.295651704281261252, - 0.298892921015581847, 0.302163400675693528, 0.305463619244590256, 0.308794066934560185, - 0.312155248774179606, 0.315547685227128949, 0.318971912844957239, 0.322428484956089223, - 0.325917972393556354, 0.329440964264136438, 0.332998068761809096, 0.336589914028677717, - 0.340217149066780189, 0.343880444704502575, 0.347580494621637148, 0.351318016437483449, - 0.355093752866787626, 0.358908472948750001, 0.362762973354817997, 0.366658079781514379, - 0.370594648435146223, 0.374573567615902381, 0.378595759409581067, 0.382662181496010056, - 0.386773829084137932, 0.390931736984797384, 0.395136981833290435, 0.399390684475231350, - 0.403694012530530555, 0.408048183152032673, 0.412454465997161457, 0.416914186433003209, - 0.421428728997616908, 0.425999541143034677, 0.430628137288459167, 0.435316103215636907, - 0.440065100842354173, 0.444876873414548846, 0.449753251162755330, 0.454696157474615836, - 0.459707615642138023, 0.464789756250426511, 0.469944825283960310, 0.475175193037377708, - 0.480483363930454543, 0.485871987341885248, 0.491343869594032867, 0.496901987241549881, - 0.502549501841348056, 0.508289776410643213, 0.514126393814748894, 0.520063177368233931, - 0.526104213983620062, 0.532253880263043655, 0.538516872002862246, 0.544898237672440056, - 0.551403416540641733, 0.558038282262587892, 0.564809192912400615, 0.571723048664826150, - 0.578787358602845359, 0.586010318477268366, 0.593400901691733762, 0.600968966365232560, - 0.608725382079622346, 0.616682180915207878, 0.624852738703666200, 0.633251994214366398, - 0.641896716427266423, 0.650805833414571433, 0.660000841079000145, 0.669506316731925177, - 0.679350572264765806, 0.689566496117078431, 0.700192655082788606, 0.711274760805076456, - 0.722867659593572465, 0.735038092431424039, 0.747868621985195658, 0.761463388849896838, - 0.775956852040116218, 0.791527636972496285, 0.808421651523009044, 0.826993296643051101, - 0.847785500623990496, 0.871704332381204705, 0.900469929925747703, 0.938143680862176477, - 1.000000000000000000]; diff --git a/rand_distr/src/zipf.rs b/rand_distr/src/zipf.rs deleted file mode 100644 index e15b6cdd197..00000000000 --- a/rand_distr/src/zipf.rs +++ /dev/null @@ -1,383 +0,0 @@ -// Copyright 2021 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! The Zeta and related distributions. - -use num_traits::Float; -use crate::{Distribution, Standard}; -use rand::{Rng, distributions::OpenClosed01}; -use core::fmt; - -/// Samples integers according to the [zeta distribution]. -/// -/// The zeta distribution is a limit of the [`Zipf`] distribution. Sometimes it -/// is called one of the following: discrete Pareto, Riemann-Zeta, Zipf, or -/// Zipf–Estoup distribution. -/// -/// It has the density function `f(k) = k^(-a) / C(a)` for `k >= 1`, where `a` -/// is the parameter and `C(a)` is the Riemann zeta function. -/// -/// # Example -/// ``` -/// use rand::prelude::*; -/// use rand_distr::Zeta; -/// -/// let val: f64 = thread_rng().sample(Zeta::new(1.5).unwrap()); -/// println!("{}", val); -/// ``` -/// -/// # Remarks -/// -/// The zeta distribution has no upper limit. Sampled values may be infinite. -/// In particular, a value of infinity might be returned for the following -/// reasons: -/// 1. it is the best representation in the type `F` of the actual sample. -/// 2. to prevent infinite loops for very small `a`. -/// -/// # Implementation details -/// -/// We are using the algorithm from [Non-Uniform Random Variate Generation], -/// Section 6.1, page 551. -/// -/// [zeta distribution]: https://en.wikipedia.org/wiki/Zeta_distribution -/// [Non-Uniform Random Variate Generation]: https://doi.org/10.1007/978-1-4613-8643-8 -#[derive(Clone, Copy, Debug, PartialEq)] -pub struct Zeta -where F: Float, Standard: Distribution, OpenClosed01: Distribution -{ - a_minus_1: F, - b: F, -} - -/// Error type returned from `Zeta::new`. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum ZetaError { - /// `a <= 1` or `nan`. - ATooSmall, -} - -impl fmt::Display for ZetaError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - ZetaError::ATooSmall => "a <= 1 or is NaN in Zeta distribution", - }) - } -} - -#[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] -impl std::error::Error for ZetaError {} - -impl Zeta -where F: Float, Standard: Distribution, OpenClosed01: Distribution -{ - /// Construct a new `Zeta` distribution with given `a` parameter. - #[inline] - pub fn new(a: F) -> Result, ZetaError> { - if !(a > F::one()) { - return Err(ZetaError::ATooSmall); - } - let a_minus_1 = a - F::one(); - let two = F::one() + F::one(); - Ok(Zeta { - a_minus_1, - b: two.powf(a_minus_1), - }) - } -} - -impl Distribution for Zeta -where F: Float, Standard: Distribution, OpenClosed01: Distribution -{ - #[inline] - fn sample(&self, rng: &mut R) -> F { - loop { - let u = rng.sample(OpenClosed01); - let x = u.powf(-F::one() / self.a_minus_1).floor(); - debug_assert!(x >= F::one()); - if x.is_infinite() { - // For sufficiently small `a`, `x` will always be infinite, - // which is rejected, resulting in an infinite loop. We avoid - // this by always returning infinity instead. - return x; - } - - let t = (F::one() + F::one() / x).powf(self.a_minus_1); - - let v = rng.sample(Standard); - if v * x * (t - F::one()) * self.b <= t * (self.b - F::one()) { - return x; - } - } - } -} - -/// Samples integers according to the Zipf distribution. -/// -/// The samples follow Zipf's law: The frequency of each sample from a finite -/// set of size `n` is inversely proportional to a power of its frequency rank -/// (with exponent `s`). -/// -/// For large `n`, this converges to the [`Zeta`] distribution. -/// -/// For `s = 0`, this becomes a uniform distribution. -/// -/// # Example -/// ``` -/// use rand::prelude::*; -/// use rand_distr::Zipf; -/// -/// let val: f64 = thread_rng().sample(Zipf::new(10, 1.5).unwrap()); -/// println!("{}", val); -/// ``` -/// -/// # Implementation details -/// -/// Implemented via [rejection sampling](https://en.wikipedia.org/wiki/Rejection_sampling), -/// due to Jason Crease[1]. -/// -/// [1]: https://jasoncrease.medium.com/rejection-sampling-the-zipf-distribution-6b359792cffa -#[derive(Clone, Copy, Debug, PartialEq)] -pub struct Zipf -where F: Float, Standard: Distribution { - s: F, - t: F, - q: F, -} - -/// Error type returned from `Zipf::new`. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum ZipfError { - /// `s < 0` or `nan`. - STooSmall, - /// `n < 1`. - NTooSmall, -} - -impl fmt::Display for ZipfError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - ZipfError::STooSmall => "s < 0 or is NaN in Zipf distribution", - ZipfError::NTooSmall => "n < 1 in Zipf distribution", - }) - } -} - -#[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] -impl std::error::Error for ZipfError {} - -impl Zipf -where F: Float, Standard: Distribution { - /// Construct a new `Zipf` distribution for a set with `n` elements and a - /// frequency rank exponent `s`. - /// - /// For large `n`, rounding may occur to fit the number into the float type. - #[inline] - pub fn new(n: u64, s: F) -> Result, ZipfError> { - if !(s >= F::zero()) { - return Err(ZipfError::STooSmall); - } - if n < 1 { - return Err(ZipfError::NTooSmall); - } - let n = F::from(n).unwrap(); // This does not fail. - let q = if s != F::one() { - // Make sure to calculate the division only once. - F::one() / (F::one() - s) - } else { - // This value is never used. - F::zero() - }; - let t = if s != F::one() { - (n.powf(F::one() - s) - s) * q - } else { - F::one() + n.ln() - }; - debug_assert!(t > F::zero()); - Ok(Zipf { - s, t, q - }) - } - - /// Inverse cumulative density function - #[inline] - fn inv_cdf(&self, p: F) -> F { - let one = F::one(); - let pt = p * self.t; - if pt <= one { - pt - } else if self.s != one { - (pt * (one - self.s) + self.s).powf(self.q) - } else { - (pt - one).exp() - } - } -} - -impl Distribution for Zipf -where F: Float, Standard: Distribution -{ - #[inline] - fn sample(&self, rng: &mut R) -> F { - let one = F::one(); - loop { - let inv_b = self.inv_cdf(rng.sample(Standard)); - let x = (inv_b + one).floor(); - let mut ratio = x.powf(-self.s); - if x > one { - ratio = ratio * inv_b.powf(self.s) - }; - - let y = rng.sample(Standard); - if y < ratio { - return x; - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - fn test_samples>( - distr: D, zero: F, expected: &[F], - ) { - let mut rng = crate::test::rng(213); - let mut buf = [zero; 4]; - for x in &mut buf { - *x = rng.sample(&distr); - } - assert_eq!(buf, expected); - } - - #[test] - #[should_panic] - fn zeta_invalid() { - Zeta::new(1.).unwrap(); - } - - #[test] - #[should_panic] - fn zeta_nan() { - Zeta::new(core::f64::NAN).unwrap(); - } - - #[test] - fn zeta_sample() { - let a = 2.0; - let d = Zeta::new(a).unwrap(); - let mut rng = crate::test::rng(1); - for _ in 0..1000 { - let r = d.sample(&mut rng); - assert!(r >= 1.); - } - } - - #[test] - fn zeta_small_a() { - let a = 1. + 1e-15; - let d = Zeta::new(a).unwrap(); - let mut rng = crate::test::rng(2); - for _ in 0..1000 { - let r = d.sample(&mut rng); - assert!(r >= 1.); - } - } - - #[test] - fn zeta_value_stability() { - test_samples(Zeta::new(1.5).unwrap(), 0f32, &[ - 1.0, 2.0, 1.0, 1.0, - ]); - test_samples(Zeta::new(2.0).unwrap(), 0f64, &[ - 2.0, 1.0, 1.0, 1.0, - ]); - } - - #[test] - #[should_panic] - fn zipf_s_too_small() { - Zipf::new(10, -1.).unwrap(); - } - - #[test] - #[should_panic] - fn zipf_n_too_small() { - Zipf::new(0, 1.).unwrap(); - } - - #[test] - #[should_panic] - fn zipf_nan() { - Zipf::new(10, core::f64::NAN).unwrap(); - } - - #[test] - fn zipf_sample() { - let d = Zipf::new(10, 0.5).unwrap(); - let mut rng = crate::test::rng(2); - for _ in 0..1000 { - let r = d.sample(&mut rng); - assert!(r >= 1.); - } - } - - #[test] - fn zipf_sample_s_1() { - let d = Zipf::new(10, 1.).unwrap(); - let mut rng = crate::test::rng(2); - for _ in 0..1000 { - let r = d.sample(&mut rng); - assert!(r >= 1.); - } - } - - #[test] - fn zipf_sample_s_0() { - let d = Zipf::new(10, 0.).unwrap(); - let mut rng = crate::test::rng(2); - for _ in 0..1000 { - let r = d.sample(&mut rng); - assert!(r >= 1.); - } - // TODO: verify that this is a uniform distribution - } - - #[test] - fn zipf_sample_large_n() { - let d = Zipf::new(core::u64::MAX, 1.5).unwrap(); - let mut rng = crate::test::rng(2); - for _ in 0..1000 { - let r = d.sample(&mut rng); - assert!(r >= 1.); - } - // TODO: verify that this is a zeta distribution - } - - #[test] - fn zipf_value_stability() { - test_samples(Zipf::new(10, 0.5).unwrap(), 0f32, &[ - 10.0, 2.0, 6.0, 7.0 - ]); - test_samples(Zipf::new(10, 2.0).unwrap(), 0f64, &[ - 1.0, 2.0, 3.0, 2.0 - ]); - } - - #[test] - fn zipf_distributions_can_be_compared() { - assert_eq!(Zipf::new(1, 2.0), Zipf::new(1, 2.0)); - } - - #[test] - fn zeta_distributions_can_be_compared() { - assert_eq!(Zeta::new(1.0), Zeta::new(1.0)); - } -} diff --git a/rand_distr/tests/pdf.rs b/rand_distr/tests/pdf.rs deleted file mode 100644 index b4fd7810926..00000000000 --- a/rand_distr/tests/pdf.rs +++ /dev/null @@ -1,179 +0,0 @@ -// Copyright 2021 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -#![allow(clippy::float_cmp)] - -use average::Histogram; -use rand::Rng; -use rand_distr::{Normal, SkewNormal}; - -const HIST_LEN: usize = 100; -average::define_histogram!(hist, crate::HIST_LEN); -use hist::Histogram as Histogram100; - -mod sparkline; - -#[test] -fn normal() { - const N_SAMPLES: u64 = 1_000_000; - const MEAN: f64 = 2.; - const STD_DEV: f64 = 0.5; - const MIN_X: f64 = -1.; - const MAX_X: f64 = 5.; - - let dist = Normal::new(MEAN, STD_DEV).unwrap(); - let mut hist = Histogram100::with_const_width(MIN_X, MAX_X); - let mut rng = rand::rngs::SmallRng::seed_from_u64(1); - - for _ in 0..N_SAMPLES { - let _ = hist.add(rng.sample(dist)); // Ignore out-of-range values - } - - println!( - "Sampled normal distribution:\n{}", - sparkline::render_u64_as_string(hist.bins()) - ); - - fn pdf(x: f64) -> f64 { - (-0.5 * ((x - MEAN) / STD_DEV).powi(2)).exp() - / (STD_DEV * (2. * core::f64::consts::PI).sqrt()) - } - - let mut bin_centers = hist.centers(); - let mut expected = [0.; HIST_LEN]; - for e in &mut expected[..] { - *e = pdf(bin_centers.next().unwrap()); - } - - println!( - "Expected normal distribution:\n{}", - sparkline::render_u64_as_string(hist.bins()) - ); - - let mut diff = [0.; HIST_LEN]; - for (i, n) in hist.normalized_bins().enumerate() { - let bin = (n as f64) / (N_SAMPLES as f64); - diff[i] = (bin - expected[i]).abs(); - } - - println!( - "Difference:\n{}", - sparkline::render_f64_as_string(&diff[..]) - ); - println!( - "max diff: {:?}", - diff.iter().fold(core::f64::NEG_INFINITY, |a, &b| a.max(b)) - ); - - // Check that the differences are significantly smaller than the expected error. - let mut expected_error = [0.; HIST_LEN]; - // Calculate error from histogram - for (err, var) in expected_error.iter_mut().zip(hist.variances()) { - *err = var.sqrt() / (N_SAMPLES as f64); - } - // Normalize error by bin width - for (err, width) in expected_error.iter_mut().zip(hist.widths()) { - *err /= width; - } - // TODO: Calculate error from distribution cutoff / normalization - - println!( - "max expected_error: {:?}", - expected_error - .iter() - .fold(core::f64::NEG_INFINITY, |a, &b| a.max(b)) - ); - for (&d, &e) in diff.iter().zip(expected_error.iter()) { - // Difference larger than 4 standard deviations or cutoff - let tol = (4. * e).max(1e-4); - assert!(d <= tol, "Difference = {} * tol", d / tol); - } -} - -#[test] -fn skew_normal() { - const N_SAMPLES: u64 = 1_000_000; - const LOCATION: f64 = 2.; - const SCALE: f64 = 0.5; - const SHAPE: f64 = -3.0; - const MIN_X: f64 = -1.; - const MAX_X: f64 = 4.; - - let dist = SkewNormal::new(LOCATION, SCALE, SHAPE).unwrap(); - let mut hist = Histogram100::with_const_width(MIN_X, MAX_X); - let mut rng = rand::rngs::SmallRng::seed_from_u64(1); - - for _ in 0..N_SAMPLES { - let _ = hist.add(rng.sample(dist)); // Ignore out-of-range values - } - - println!( - "Sampled skew normal distribution:\n{}", - sparkline::render_u64_as_string(hist.bins()) - ); - - use special::Error; - fn pdf(x: f64) -> f64 { - let x_normalized = (x - LOCATION) / SCALE; - let normal_density_x = - (-0.5 * (x_normalized).powi(2)).exp() / (2. * core::f64::consts::PI).sqrt(); - let normal_distribution_x = - 0.5 * (1.0 + (SHAPE * x_normalized / core::f64::consts::SQRT_2).error()); - 2.0 / SCALE * normal_density_x * normal_distribution_x - } - - let mut bin_centers = hist.centers(); - let mut expected = [0.; HIST_LEN]; - for e in &mut expected[..] { - *e = pdf(bin_centers.next().unwrap()); - } - - println!( - "Expected skew normal distribution:\n{}", - sparkline::render_u64_as_string(hist.bins()) - ); - - let mut diff = [0.; HIST_LEN]; - for (i, n) in hist.normalized_bins().enumerate() { - let bin = (n as f64) / (N_SAMPLES as f64); - diff[i] = (bin - expected[i]).abs(); - } - - println!( - "Difference:\n{}", - sparkline::render_f64_as_string(&diff[..]) - ); - println!( - "max diff: {:?}", - diff.iter().fold(core::f64::NEG_INFINITY, |a, &b| a.max(b)) - ); - - // Check that the differences are significantly smaller than the expected error. - let mut expected_error = [0.; HIST_LEN]; - // Calculate error from histogram - for (err, var) in expected_error.iter_mut().zip(hist.variances()) { - *err = var.sqrt() / (N_SAMPLES as f64); - } - // Normalize error by bin width - for (err, width) in expected_error.iter_mut().zip(hist.widths()) { - *err /= width; - } - // TODO: Calculate error from distribution cutoff / normalization - - println!( - "max expected_error: {:?}", - expected_error - .iter() - .fold(core::f64::NEG_INFINITY, |a, &b| a.max(b)) - ); - for (&d, &e) in diff.iter().zip(expected_error.iter()) { - // Difference larger than 4 standard deviations or cutoff - let tol = (4. * e).max(1e-4); - assert!(d <= tol, "Difference = {} * tol", d / tol); - } -} diff --git a/rand_distr/tests/sparkline.rs b/rand_distr/tests/sparkline.rs deleted file mode 100644 index 6ba48ba886e..00000000000 --- a/rand_distr/tests/sparkline.rs +++ /dev/null @@ -1,128 +0,0 @@ -// Copyright 2021 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -/// Number of ticks. -const N: usize = 8; -/// Ticks used for the sparkline. -static TICKS: [char; N] = ['▁', '▂', '▃', '▄', '▅', '▆', '▇', '█']; - -/// Render a sparkline of `data` into `buffer`. -pub fn render_u64(data: &[u64], buffer: &mut String) { - match data.len() { - 0 => { - return; - }, - 1 => { - if data[0] == 0 { - buffer.push(TICKS[0]); - } else { - buffer.push(TICKS[N - 1]); - } - return; - }, - _ => {}, - } - let max = data.iter().max().unwrap(); - let min = data.iter().min().unwrap(); - let scale = ((N - 1) as f64) / ((max - min) as f64); - for i in data { - let tick = (((i - min) as f64) * scale) as usize; - buffer.push(TICKS[tick]); - } -} - -/// Calculate the required capacity for the sparkline, given the length of the -/// input data. -pub fn required_capacity(len: usize) -> usize { - len * TICKS[0].len_utf8() -} - -/// Render a sparkline of `data` into a newly allocated string. -pub fn render_u64_as_string(data: &[u64]) -> String { - let cap = required_capacity(data.len()); - let mut s = String::with_capacity(cap); - render_u64(data, &mut s); - debug_assert_eq!(s.capacity(), cap); - s -} - -/// Render a sparkline of `data` into `buffer`. -pub fn render_f64(data: &[f64], buffer: &mut String) { - match data.len() { - 0 => { - return; - }, - 1 => { - if data[0] == 0. { - buffer.push(TICKS[0]); - } else { - buffer.push(TICKS[N - 1]); - } - return; - }, - _ => {}, - } - for x in data { - assert!(x.is_finite(), "can only render finite values"); - } - let max = data.iter().fold( - core::f64::NEG_INFINITY, |a, &b| a.max(b)); - let min = data.iter().fold( - core::f64::INFINITY, |a, &b| a.min(b)); - let scale = ((N - 1) as f64) / (max - min); - for x in data { - let tick = ((x - min) * scale) as usize; - buffer.push(TICKS[tick]); - } -} - -/// Render a sparkline of `data` into a newly allocated string. -pub fn render_f64_as_string(data: &[f64]) -> String { - let cap = required_capacity(data.len()); - let mut s = String::with_capacity(cap); - render_f64(data, &mut s); - debug_assert_eq!(s.capacity(), cap); - s -} - -#[cfg(test)] -mod tests { - #[test] - fn render_u64() { - let data = [2, 250, 670, 890, 2, 430, 11, 908, 123, 57]; - let mut s = String::with_capacity(super::required_capacity(data.len())); - super::render_u64(&data, &mut s); - println!("{}", s); - assert_eq!("▁▂▆▇▁▄▁█▁▁", &s); - } - - #[test] - fn render_u64_as_string() { - let data = [2, 250, 670, 890, 2, 430, 11, 908, 123, 57]; - let s = super::render_u64_as_string(&data); - println!("{}", s); - assert_eq!("▁▂▆▇▁▄▁█▁▁", &s); - } - - #[test] - fn render_f64() { - let data = [2., 250., 670., 890., 2., 430., 11., 908., 123., 57.]; - let mut s = String::with_capacity(super::required_capacity(data.len())); - super::render_f64(&data, &mut s); - println!("{}", s); - assert_eq!("▁▂▆▇▁▄▁█▁▁", &s); - } - - #[test] - fn render_f64_as_string() { - let data = [2., 250., 670., 890., 2., 430., 11., 908., 123., 57.]; - let s = super::render_f64_as_string(&data); - println!("{}", s); - assert_eq!("▁▂▆▇▁▄▁█▁▁", &s); - } -} diff --git a/rand_distr/tests/uniformity.rs b/rand_distr/tests/uniformity.rs deleted file mode 100644 index d37ef0a9d06..00000000000 --- a/rand_distr/tests/uniformity.rs +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -#![allow(clippy::float_cmp)] - -use average::Histogram; -use rand::prelude::*; - -const N_BINS: usize = 100; -const N_SAMPLES: u32 = 1_000_000; -const TOL: f64 = 1e-3; -average::define_histogram!(hist, 100); -use hist::Histogram as Histogram100; - -#[test] -fn unit_sphere() { - const N_DIM: usize = 3; - let h = Histogram100::with_const_width(-1., 1.); - let mut histograms = [h.clone(), h.clone(), h]; - let dist = rand_distr::UnitSphere; - let mut rng = rand_pcg::Pcg32::from_entropy(); - for _ in 0..N_SAMPLES { - let v: [f64; 3] = dist.sample(&mut rng); - for i in 0..N_DIM { - histograms[i] - .add(v[i]) - .map_err(|e| { - println!("v: {}", v[i]); - e - }) - .unwrap(); - } - } - for h in &histograms { - let sum: u64 = h.bins().iter().sum(); - println!("{:?}", h); - for &b in h.bins() { - let p = (b as f64) / (sum as f64); - assert!((p - 1.0 / (N_BINS as f64)).abs() < TOL, "{}", p); - } - } -} - -#[test] -fn unit_circle() { - use core::f64::consts::PI; - let mut h = Histogram100::with_const_width(-PI, PI); - let dist = rand_distr::UnitCircle; - let mut rng = rand_pcg::Pcg32::from_entropy(); - for _ in 0..N_SAMPLES { - let v: [f64; 2] = dist.sample(&mut rng); - h.add(v[0].atan2(v[1])).unwrap(); - } - let sum: u64 = h.bins().iter().sum(); - println!("{:?}", h); - for &b in h.bins() { - let p = (b as f64) / (sum as f64); - assert!((p - 1.0 / (N_BINS as f64)).abs() < TOL, "{}", p); - } -} diff --git a/rand_distr/tests/value_stability.rs b/rand_distr/tests/value_stability.rs deleted file mode 100644 index 88fe7d9ecab..00000000000 --- a/rand_distr/tests/value_stability.rs +++ /dev/null @@ -1,392 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -use average::assert_almost_eq; -use core::fmt::Debug; -use rand::Rng; -use rand_distr::*; - -fn get_rng(seed: u64) -> impl rand::Rng { - // For tests, we want a statistically good, fast, reproducible RNG. - // PCG32 will do fine, and will be easy to embed if we ever need to. - const INC: u64 = 11634580027462260723; - rand_pcg::Pcg32::new(seed, INC) -} - -/// We only assert approximate equality since some platforms do not perform -/// identically (i686-unknown-linux-gnu and most notably x86_64-pc-windows-gnu). -trait ApproxEq { - fn assert_almost_eq(&self, rhs: &Self); -} - -impl ApproxEq for f32 { - fn assert_almost_eq(&self, rhs: &Self) { - assert_almost_eq!(self, rhs, 1e-6); - } -} -impl ApproxEq for f64 { - fn assert_almost_eq(&self, rhs: &Self) { - assert_almost_eq!(self, rhs, 1e-14); - } -} -impl ApproxEq for u64 { - fn assert_almost_eq(&self, rhs: &Self) { - assert_eq!(self, rhs); - } -} -impl ApproxEq for [T; 2] { - fn assert_almost_eq(&self, rhs: &Self) { - self[0].assert_almost_eq(&rhs[0]); - self[1].assert_almost_eq(&rhs[1]); - } -} -impl ApproxEq for [T; 3] { - fn assert_almost_eq(&self, rhs: &Self) { - self[0].assert_almost_eq(&rhs[0]); - self[1].assert_almost_eq(&rhs[1]); - self[2].assert_almost_eq(&rhs[2]); - } -} - -fn test_samples>( - seed: u64, distr: D, expected: &[F], -) { - let mut rng = get_rng(seed); - for val in expected { - let x = rng.sample(&distr); - x.assert_almost_eq(val); - } -} - -#[test] -fn binomial_stability() { - // We have multiple code paths: np < 10, p > 0.5 - test_samples(353, Binomial::new(2, 0.7).unwrap(), &[1, 1, 2, 1]); - test_samples(353, Binomial::new(20, 0.3).unwrap(), &[7, 7, 5, 7]); - test_samples(353, Binomial::new(2000, 0.6).unwrap(), &[1194, 1208, 1192, 1210]); -} - -#[test] -fn geometric_stability() { - test_samples(464, StandardGeometric, &[3, 0, 1, 0, 0, 3, 2, 1, 2, 0]); - - test_samples(464, Geometric::new(0.5).unwrap(), &[2, 1, 1, 0, 0, 1, 0, 1]); - test_samples(464, Geometric::new(0.05).unwrap(), &[24, 51, 81, 67, 27, 11, 7, 6]); - test_samples(464, Geometric::new(0.95).unwrap(), &[0, 0, 0, 0, 1, 0, 0, 0]); - - // expect non-random behaviour for series of pre-determined trials - test_samples(464, Geometric::new(0.0).unwrap(), &[u64::max_value(); 100][..]); - test_samples(464, Geometric::new(1.0).unwrap(), &[0; 100][..]); -} - -#[test] -fn hypergeometric_stability() { - // We have multiple code paths based on the distribution's mode and sample_size - test_samples(7221, Hypergeometric::new(99, 33, 8).unwrap(), &[4, 3, 2, 2, 3, 2, 3, 1]); // Algorithm HIN - test_samples(7221, Hypergeometric::new(100, 50, 50).unwrap(), &[23, 27, 26, 27, 22, 24, 31, 22]); // Algorithm H2PE -} - -#[test] -fn unit_ball_stability() { - test_samples(2, UnitBall, &[ - [0.018035709265959987f64, -0.4348771383120438, -0.07982762085055706], - [0.10588569388223945, -0.4734350111375454, -0.7392104908825501], - [0.11060237642041049, -0.16065642822852677, -0.8444043930440075] - ]); -} - -#[test] -fn unit_circle_stability() { - test_samples(2, UnitCircle, &[ - [-0.9965658683520504f64, -0.08280380447614634], - [-0.9790853270389644, -0.20345004884984505], - [-0.8449189758898707, 0.5348943112253227], - ]); -} - -#[test] -fn unit_sphere_stability() { - test_samples(2, UnitSphere, &[ - [0.03247542860231647f64, -0.7830477442152738, 0.6211131755296027], - [-0.09978440840914075, 0.9706650829833128, -0.21875184231323952], - [0.2735582468624679, 0.9435374242279655, -0.1868234852870203], - ]); -} - -#[test] -fn unit_disc_stability() { - test_samples(2, UnitDisc, &[ - [0.018035709265959987f64, -0.4348771383120438], - [-0.07982762085055706, 0.7765329819820659], - [0.21450745997299503, 0.7398636984333291], - ]); -} - -#[test] -fn pareto_stability() { - test_samples(213, Pareto::new(1.0, 1.0).unwrap(), &[ - 1.0423688f32, 2.1235929, 4.132709, 1.4679428, - ]); - test_samples(213, Pareto::new(2.0, 0.5).unwrap(), &[ - 9.019295276219136f64, - 4.3097126018270595, - 6.837815045397157, - 105.8826669383772, - ]); -} - -#[test] -fn poisson_stability() { - test_samples(223, Poisson::new(7.0).unwrap(), &[5.0f32, 11.0, 6.0, 5.0]); - test_samples(223, Poisson::new(7.0).unwrap(), &[9.0f64, 5.0, 7.0, 6.0]); - test_samples(223, Poisson::new(27.0).unwrap(), &[28.0f32, 32.0, 36.0, 36.0]); -} - - -#[test] -fn triangular_stability() { - test_samples(860, Triangular::new(2., 10., 3.).unwrap(), &[ - 5.74373257511361f64, - 7.890059162791258f64, - 4.7256280652553455f64, - 2.9474808121184077f64, - 3.058301946314053f64, - ]); -} - - -#[test] -fn normal_inverse_gaussian_stability() { - test_samples(213, NormalInverseGaussian::new(2.0, 1.0).unwrap(), &[ - 0.6568966f32, 1.3744819, 2.216063, 0.11488572, - ]); - test_samples(213, NormalInverseGaussian::new(2.0, 1.0).unwrap(), &[ - 0.6838707059642927f64, - 2.4447306460569784, - 0.2361045023235968, - 1.7774534624785319, - ]); -} - -#[test] -fn pert_stability() { - // mean = 4, var = 12/7 - test_samples(860, Pert::new(2., 10., 3.).unwrap(), &[ - 4.908681667460367, - 4.014196196158352, - 2.6489397149197234, - 3.4569780580044727, - 4.242864311947118, - ]); -} - -#[test] -fn inverse_gaussian_stability() { - test_samples(213, InverseGaussian::new(1.0, 3.0).unwrap(),&[ - 0.9339157f32, 1.108113, 0.50864697, 0.39849377, - ]); - test_samples(213, InverseGaussian::new(1.0, 3.0).unwrap(), &[ - 1.0707604954722476f64, - 0.9628140605340697, - 0.4069687656468226, - 0.660283852985818, - ]); -} - -#[test] -fn gamma_stability() { - // Gamma has 3 cases: shape == 1, shape < 1, shape > 1 - test_samples(223, Gamma::new(1.0, 5.0).unwrap(), &[ - 5.398085f32, 9.162783, 0.2300583, 1.7235851, - ]); - test_samples(223, Gamma::new(0.8, 5.0).unwrap(), &[ - 0.5051203f32, 0.9048302, 3.095812, 1.8566116, - ]); - test_samples(223, Gamma::new(1.1, 5.0).unwrap(), &[ - 7.783878094584059f64, - 1.4939528171618057, - 8.638017638857592, - 3.0949337228829004, - ]); - - // ChiSquared has 2 cases: k == 1, k != 1 - test_samples(223, ChiSquared::new(1.0).unwrap(), &[ - 0.4893526200348249f64, - 1.635249736808788, - 0.5013580219361969, - 0.1457735613733489, - ]); - test_samples(223, ChiSquared::new(0.1).unwrap(), &[ - 0.014824404726978617f64, - 0.021602123937134326, - 0.0000003431429746851693, - 0.00000002291755769542258, - ]); - test_samples(223, ChiSquared::new(10.0).unwrap(), &[ - 12.693656f32, 6.812016, 11.082001, 12.436167, - ]); - - // FisherF has same special cases as ChiSquared on each param - test_samples(223, FisherF::new(1.0, 13.5).unwrap(), &[ - 0.32283646f32, 0.048049655, 0.0788893, 1.817178, - ]); - test_samples(223, FisherF::new(1.0, 1.0).unwrap(), &[ - 0.29925257f32, 3.4392934, 9.567652, 0.020074, - ]); - test_samples(223, FisherF::new(0.7, 13.5).unwrap(), &[ - 3.3196593155045124f64, - 0.3409169916262829, - 0.03377989856426519, - 0.00004041672861036937, - ]); - - // StudentT has same special cases as ChiSquared - test_samples(223, StudentT::new(1.0).unwrap(), &[ - 0.54703987f32, -1.8545331, 3.093162, -0.14168274, - ]); - test_samples(223, StudentT::new(1.1).unwrap(), &[ - 0.7729195887949754f64, - 1.2606210611616204, - -1.7553606501113175, - -2.377641221169782, - ]); - - // Beta has two special cases: - // - // 1. min(alpha, beta) <= 1 - // 2. min(alpha, beta) > 1 - test_samples(223, Beta::new(1.0, 0.8).unwrap(), &[ - 0.8300703726659456, - 0.8134131062097899, - 0.47912589330631555, - 0.25323238071138526, - ]); - test_samples(223, Beta::new(3.0, 1.2).unwrap(), &[ - 0.49563509121756827, - 0.9551305482256759, - 0.5151181353461637, - 0.7551732971235077, - ]); -} - -#[test] -fn exponential_stability() { - test_samples(223, Exp1, &[ - 1.079617f32, 1.8325565, 0.04601166, 0.34471703, - ]); - test_samples(223, Exp1, &[ - 1.0796170642388276f64, - 1.8325565304274, - 0.04601166186842716, - 0.3447170217100157, - ]); - - test_samples(223, Exp::new(2.0).unwrap(), &[ - 0.5398085f32, 0.91627824, 0.02300583, 0.17235851, - ]); - test_samples(223, Exp::new(1.0).unwrap(), &[ - 1.0796170642388276f64, - 1.8325565304274, - 0.04601166186842716, - 0.3447170217100157, - ]); -} - -#[test] -fn normal_stability() { - test_samples(213, StandardNormal, &[ - -0.11844189f32, 0.781378, 0.06563994, -1.1932899, - ]); - test_samples(213, StandardNormal, &[ - -0.11844188827977231f64, - 0.7813779637772346, - 0.06563993969580051, - -1.1932899004186373, - ]); - - test_samples(213, Normal::new(0.0, 1.0).unwrap(), &[ - -0.11844189f32, 0.781378, 0.06563994, -1.1932899, - ]); - test_samples(213, Normal::new(2.0, 0.5).unwrap(), &[ - 1.940779055860114f64, - 2.3906889818886174, - 2.0328199698479, - 1.4033550497906813, - ]); - - test_samples(213, LogNormal::new(0.0, 1.0).unwrap(), &[ - 0.88830346f32, 2.1844804, 1.0678421, 0.30322206, - ]); - test_samples(213, LogNormal::new(2.0, 0.5).unwrap(), &[ - 6.964174338639032f64, - 10.921015733601452, - 7.6355881556915906, - 4.068828213584092, - ]); -} - -#[test] -fn weibull_stability() { - test_samples(213, Weibull::new(1.0, 1.0).unwrap(), &[ - 0.041495778f32, 0.7531094, 1.4189332, 0.38386202, - ]); - test_samples(213, Weibull::new(2.0, 0.5).unwrap(), &[ - 1.1343478702739669f64, - 0.29470010050655226, - 0.7556151370284702, - 7.877212340241561, - ]); -} - -#[cfg(feature = "alloc")] -#[test] -fn dirichlet_stability() { - let mut rng = get_rng(223); - assert_eq!( - rng.sample(Dirichlet::new([1.0, 2.0, 3.0]).unwrap()), - [0.12941567177708177, 0.4702121891675036, 0.4003721390554146] - ); - assert_eq!(rng.sample(Dirichlet::new([8.0; 5]).unwrap()), [ - 0.17684200044809556, - 0.29915953935953055, - 0.1832858056608014, - 0.1425623503573967, - 0.19815030417417595 - ]); - // Test stability for the case where all alphas are less than 0.1. - assert_eq!( - rng.sample(Dirichlet::new([0.05, 0.025, 0.075, 0.05]).unwrap()), - [ - 0.00027580456855692104, - 2.296135759821706e-20, - 3.004118281150937e-9, - 0.9997241924273248 - ] - ); -} - -#[test] -fn cauchy_stability() { - test_samples(353, Cauchy::new(100f64, 10.0).unwrap(), &[ - 77.93369152808678f64, - 90.1606912098641, - 125.31516221323625, - 86.10217834773925, - ]); - - // Unfortunately this test is not fully portable due to reliance on the - // system's implementation of tanf (see doc on Cauchy struct). - // We use a lower threshold of 1e-5 here. - let distr = Cauchy::new(10f32, 7.0).unwrap(); - let mut rng = get_rng(353); - let expected = [15.023088, -5.446413, 3.7092876, 3.112482]; - for &a in expected.iter() { - let b = rng.sample(&distr); - assert_almost_eq!(a, b, 1e-5); - } -} diff --git a/rand_pcg/CHANGELOG.md b/rand_pcg/CHANGELOG.md index c60386cf0bc..bab1cd0e8c8 100644 --- a/rand_pcg/CHANGELOG.md +++ b/rand_pcg/CHANGELOG.md @@ -4,11 +4,16 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [0.9.0-alpha.0] - 2024-02-18 -This is a pre-release. To depend on this version, use `rand_pcg = "=0.9.0-alpha.0"` to prevent automatic updates (which can be expected to include breaking changes). +## [0.9.0] - 2025-01-27 +### Dependencies and features +- Update to `rand_core` v0.9.0 (#1558) +- Rename feature `serde1` to `serde` (#1477) +- Rename feature `getrandom` to `os_rng` (#1537) +### Other changes - Add `Lcg128CmDxsm64` generator compatible with NumPy's `PCG64DXSM` (#1202) -- Add examples for initializing the RNGs +- Add examples for initializing the RNGs (#1352) +- Revise crate docs (#1454) ## [0.3.1] - 2021-06-15 - Add `advance` methods to RNGs (#1111) diff --git a/rand_pcg/Cargo.toml b/rand_pcg/Cargo.toml index 1fa514344bf..74740950712 100644 --- a/rand_pcg/Cargo.toml +++ b/rand_pcg/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rand_pcg" -version = "0.9.0-alpha.0" +version = "0.9.0" authors = ["The Rand Project Developers"] license = "MIT OR Apache-2.0" readme = "README.md" @@ -13,16 +13,18 @@ Selected PCG random number generators keywords = ["random", "rng", "pcg"] categories = ["algorithms", "no-std"] edition = "2021" -rust-version = "1.60" +rust-version = "1.63" [package.metadata.docs.rs] +all-features = true rustdoc-args = ["--generate-link-to-definition"] [features] -serde1 = ["serde"] +serde = ["dep:serde"] +os_rng = ["rand_core/os_rng"] [dependencies] -rand_core = { path = "../rand_core", version = "=0.9.0-alpha.0" } +rand_core = { path = "../rand_core", version = "0.9.0" } serde = { version = "1", features = ["derive"], optional = true } [dev-dependencies] @@ -30,3 +32,4 @@ serde = { version = "1", features = ["derive"], optional = true } # deps yet, see: https://github.com/rust-lang/cargo/issues/1596 # Versions prior to 1.1.4 had incorrect minimal dependencies. bincode = { version = "1.1.4" } +rand_core = { path = "../rand_core", version = "0.9.0", features = ["os_rng"] } diff --git a/rand_pcg/README.md b/rand_pcg/README.md index da1a1beeffb..50e91e59795 100644 --- a/rand_pcg/README.md +++ b/rand_pcg/README.md @@ -1,11 +1,10 @@ # rand_pcg -[![Test Status](https://github.com/rust-random/rand/workflows/Tests/badge.svg?event=push)](https://github.com/rust-random/rand/actions) +[![Test Status](https://github.com/rust-random/rand/actions/workflows/test.yml/badge.svg?event=push)](https://github.com/rust-random/rand/actions) [![Latest version](https://img.shields.io/crates/v/rand_pcg.svg)](https://crates.io/crates/rand_pcg) [![Book](https://img.shields.io/badge/book-master-yellow.svg)](https://rust-random.github.io/book/) [![API](https://img.shields.io/badge/api-master-yellow.svg)](https://rust-random.github.io/rand/rand_pcg) [![API](https://docs.rs/rand_pcg/badge.svg)](https://docs.rs/rand_pcg) -[![Minimum rustc version](https://img.shields.io/badge/rustc-1.60+-lightgray.svg)](https://github.com/rust-random/rand#rust-version-requirements) Implements a selection of PCG random number generators. @@ -30,7 +29,7 @@ Links: `rand_pcg` is `no_std` compatible by default. -The `serde1` feature includes implementations of `Serialize` and `Deserialize` +The `serde` feature includes implementations of `Serialize` and `Deserialize` for the included RNGs. ## License diff --git a/rand_pcg/src/lib.rs b/rand_pcg/src/lib.rs index e67728a90e6..6b9d9d833f0 100644 --- a/rand_pcg/src/lib.rs +++ b/rand_pcg/src/lib.rs @@ -8,48 +8,76 @@ //! The PCG random number generators. //! -//! This is a native Rust implementation of a small selection of PCG generators. +//! This is a native Rust implementation of a small selection of [PCG generators]. //! The primary goal of this crate is simple, minimal, well-tested code; in //! other words it is explicitly not a goal to re-implement all of PCG. //! +//! ## Generators +//! //! This crate provides: //! -//! - `Pcg32` aka `Lcg64Xsh32`, officially known as `pcg32`, a general +//! - [`Pcg32`] aka [`Lcg64Xsh32`], officially known as `pcg32`, a general //! purpose RNG. This is a good choice on both 32-bit and 64-bit CPUs //! (for 32-bit output). -//! - `Pcg64` aka `Lcg128Xsl64`, officially known as `pcg64`, a general +//! - [`Pcg64`] aka [`Lcg128Xsl64`], officially known as `pcg64`, a general //! purpose RNG. This is a good choice on 64-bit CPUs. -//! - `Pcg64Mcg` aka `Mcg128Xsl64`, officially known as `pcg64_fast`, +//! - [`Pcg64Mcg`] aka [`Mcg128Xsl64`], officially known as `pcg64_fast`, //! a general purpose RNG using 128-bit multiplications. This has poor //! performance on 32-bit CPUs but is a good choice on 64-bit CPUs for //! both 32-bit and 64-bit output. //! -//! Both of these use 16 bytes of state and 128-bit seeds, and are considered -//! value-stable (i.e. any change affecting the output given a fixed seed would -//! be considered a breaking change to the crate). +//! These generators are all deterministic and portable (see [Reproducibility] +//! in the book), with testing against reference vectors. +//! +//! ## Seeding (construction) +//! +//! Generators implement the [`SeedableRng`] trait. All methods are suitable for +//! seeding. Some suggestions: +//! +//! 1. To automatically seed with a unique seed, use [`SeedableRng::from_rng`] +//! with a master generator (here [`rand::rng()`](https://docs.rs/rand/latest/rand/fn.rng.html)): +//! ```ignore +//! use rand_core::SeedableRng; +//! use rand_pcg::Pcg64Mcg; +//! let rng = Pcg64Mcg::from_rng(&mut rand::rng()); +//! # let _: Pcg64Mcg = rng; +//! ``` +//! 2. Seed **from an integer** via `seed_from_u64`. This uses a hash function +//! internally to yield a (typically) good seed from any input. +//! ``` +//! # use {rand_core::SeedableRng, rand_pcg::Pcg64Mcg}; +//! let rng = Pcg64Mcg::seed_from_u64(1); +//! # let _: Pcg64Mcg = rng; +//! ``` //! -//! # Example +//! See also [Seeding RNGs] in the book. //! -//! To initialize a generator, use the [`SeedableRng`][rand_core::SeedableRng] trait: +//! ## Generation //! +//! Generators implement [`RngCore`], whose methods may be used directly to +//! generate unbounded integer or byte values. //! ``` //! use rand_core::{SeedableRng, RngCore}; //! use rand_pcg::Pcg64Mcg; //! //! let mut rng = Pcg64Mcg::seed_from_u64(0); -//! let x: u32 = rng.next_u32(); +//! let x = rng.next_u64(); +//! assert_eq!(x, 0x5603f242407deca2); //! ``` //! -//! The functionality of this crate is implemented using traits from the `rand_core` crate, but you may use the `rand` -//! crate for further functionality to initialize the generator from various sources and to generate random values: +//! It is often more convenient to use the [`rand::Rng`] trait, which provides +//! further functionality. See also the [Random Values] chapter in the book. //! -//! ```ignore -//! use rand::{Rng, SeedableRng}; -//! use rand_pcg::Pcg64Mcg; -//! -//! let mut rng = Pcg64Mcg::from_entropy(); -//! let x: f64 = rng.gen(); -//! ``` +//! [PCG generators]: https://www.pcg-random.org/ +//! [Reproducibility]: https://rust-random.github.io/book/crate-reprod.html +//! [Seeding RNGs]: https://rust-random.github.io/book/guide-seeding.html +//! [Random Values]: https://rust-random.github.io/book/guide-values.html +//! [`RngCore`]: rand_core::RngCore +//! [`SeedableRng`]: rand_core::SeedableRng +//! [`SeedableRng::from_rng`]: rand_core::SeedableRng#method.from_rng +//! [`rand::rng`]: https://docs.rs/rand/latest/rand/fn.rng.html +//! [`rand::Rng`]: https://docs.rs/rand/latest/rand/trait.Rng.html +//! [`rand_chacha::ChaCha8Rng`]: https://docs.rs/rand_chacha/latest/rand_chacha/struct.ChaCha8Rng.html #![doc( html_logo_url = "https://www.rust-lang.org/logos/rust-logo-128x128-blk.png", @@ -65,6 +93,8 @@ mod pcg128; mod pcg128cm; mod pcg64; +pub use rand_core; + pub use self::pcg128::{Lcg128Xsl64, Mcg128Xsl64, Pcg64, Pcg64Mcg}; pub use self::pcg128cm::{Lcg128CmDxsm64, Pcg64Dxsm}; pub use self::pcg64::{Lcg64Xsh32, Pcg32}; diff --git a/rand_pcg/src/pcg128.rs b/rand_pcg/src/pcg128.rs index df2025dc444..d2341425673 100644 --- a/rand_pcg/src/pcg128.rs +++ b/rand_pcg/src/pcg128.rs @@ -14,8 +14,9 @@ const MULTIPLIER: u128 = 0x2360_ED05_1FC6_5DA4_4385_DF64_9FCC_F645; use core::fmt; -use rand_core::{impls, le, Error, RngCore, SeedableRng}; -#[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; +use rand_core::{impls, le, RngCore, SeedableRng}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; /// A PCG random number generator (XSL RR 128/64 (LCG) variant). /// @@ -33,7 +34,7 @@ use rand_core::{impls, le, Error, RngCore, SeedableRng}; /// Note that two generators with different stream parameters may be closely /// correlated. #[derive(Clone, PartialEq, Eq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Lcg128Xsl64 { state: u128, increment: u128, @@ -151,15 +152,8 @@ impl RngCore for Lcg128Xsl64 { fn fill_bytes(&mut self, dest: &mut [u8]) { impls::fill_bytes_via_next(self, dest) } - - #[inline] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.fill_bytes(dest); - Ok(()) - } } - /// A PCG random number generator (XSL 128/64 (MCG) variant). /// /// Permuted Congruential Generator with 128-bit state, internal Multiplicative @@ -172,7 +166,7 @@ impl RngCore for Lcg128Xsl64 { /// output function), this RNG is faster, also has a long cycle, and still has /// good performance on statistical tests. #[derive(Clone, PartialEq, Eq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Mcg128Xsl64 { state: u128, } @@ -240,7 +234,7 @@ impl SeedableRng for Mcg128Xsl64 { // Read as if a little-endian u128 value: let mut seed_u64 = [0u64; 2]; le::read_u64_into(&seed, &mut seed_u64); - let state = u128::from(seed_u64[0]) | u128::from(seed_u64[1]) << 64; + let state = u128::from(seed_u64[0]) | (u128::from(seed_u64[1]) << 64); Mcg128Xsl64::new(state) } } @@ -261,12 +255,6 @@ impl RngCore for Mcg128Xsl64 { fn fill_bytes(&mut self, dest: &mut [u8]) { impls::fill_bytes_via_next(self, dest) } - - #[inline] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.fill_bytes(dest); - Ok(()) - } } #[inline(always)] diff --git a/rand_pcg/src/pcg128cm.rs b/rand_pcg/src/pcg128cm.rs index 7ac5187e4e0..a5a2b178795 100644 --- a/rand_pcg/src/pcg128cm.rs +++ b/rand_pcg/src/pcg128cm.rs @@ -14,8 +14,9 @@ const MULTIPLIER: u64 = 15750249268501108917; use core::fmt; -use rand_core::{impls, le, Error, RngCore, SeedableRng}; -#[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; +use rand_core::{impls, le, RngCore, SeedableRng}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; /// A PCG random number generator (CM DXSM 128/64 (LCG) variant). /// @@ -36,7 +37,7 @@ use rand_core::{impls, le, Error, RngCore, SeedableRng}; /// /// [upgrading-pcg64]: https://numpy.org/doc/stable/reference/random/upgrading-pcg64.html #[derive(Clone, PartialEq, Eq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Lcg128CmDxsm64 { state: u128, increment: u128, @@ -148,21 +149,15 @@ impl RngCore for Lcg128CmDxsm64 { #[inline] fn next_u64(&mut self) -> u64 { - let val = output_dxsm(self.state); + let res = output_dxsm(self.state); self.step(); - val + res } #[inline] fn fill_bytes(&mut self, dest: &mut [u8]) { impls::fill_bytes_via_next(self, dest) } - - #[inline] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.fill_bytes(dest); - Ok(()) - } } #[inline(always)] diff --git a/rand_pcg/src/pcg64.rs b/rand_pcg/src/pcg64.rs index 365f1c0b117..771a996d28f 100644 --- a/rand_pcg/src/pcg64.rs +++ b/rand_pcg/src/pcg64.rs @@ -11,8 +11,9 @@ //! PCG random number generators use core::fmt; -use rand_core::{impls, le, Error, RngCore, SeedableRng}; -#[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; +use rand_core::{impls, le, RngCore, SeedableRng}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; // This is the default multiplier used by PCG for 64-bit state. const MULTIPLIER: u64 = 6364136223846793005; @@ -33,7 +34,7 @@ const MULTIPLIER: u64 = 6364136223846793005; /// Note that two generators with different stream parameter may be closely /// correlated. #[derive(Clone, PartialEq, Eq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Lcg64Xsh32 { state: u64, increment: u64, @@ -160,10 +161,4 @@ impl RngCore for Lcg64Xsh32 { fn fill_bytes(&mut self, dest: &mut [u8]) { impls::fill_bytes_via_next(self, dest) } - - #[inline] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.fill_bytes(dest); - Ok(()) - } } diff --git a/rand_pcg/tests/lcg128cmdxsm64.rs b/rand_pcg/tests/lcg128cmdxsm64.rs index b254b4fac66..b5b37f582e0 100644 --- a/rand_pcg/tests/lcg128cmdxsm64.rs +++ b/rand_pcg/tests/lcg128cmdxsm64.rs @@ -23,7 +23,7 @@ fn test_lcg128cmdxsm64_construction() { let mut rng1 = Lcg128CmDxsm64::from_seed(seed); assert_eq!(rng1.next_u64(), 12201417210360370199); - let mut rng2 = Lcg128CmDxsm64::from_rng(&mut rng1).unwrap(); + let mut rng2 = Lcg128CmDxsm64::from_rng(&mut rng1); assert_eq!(rng2.next_u64(), 11487972556150888383); let mut rng3 = Lcg128CmDxsm64::seed_from_u64(0); @@ -54,7 +54,7 @@ fn test_lcg128cmdxsm64_reference() { assert_eq!(results, expected); } -#[cfg(feature = "serde1")] +#[cfg(feature = "serde")] #[test] fn test_lcg128cmdxsm64_serde() { use bincode; diff --git a/rand_pcg/tests/lcg128xsl64.rs b/rand_pcg/tests/lcg128xsl64.rs index 31eada442eb..07bd6137da9 100644 --- a/rand_pcg/tests/lcg128xsl64.rs +++ b/rand_pcg/tests/lcg128xsl64.rs @@ -23,7 +23,7 @@ fn test_lcg128xsl64_construction() { let mut rng1 = Lcg128Xsl64::from_seed(seed); assert_eq!(rng1.next_u64(), 8740028313290271629); - let mut rng2 = Lcg128Xsl64::from_rng(&mut rng1).unwrap(); + let mut rng2 = Lcg128Xsl64::from_rng(&mut rng1); assert_eq!(rng2.next_u64(), 1922280315005786345); let mut rng3 = Lcg128Xsl64::seed_from_u64(0); @@ -54,7 +54,7 @@ fn test_lcg128xsl64_reference() { assert_eq!(results, expected); } -#[cfg(feature = "serde1")] +#[cfg(feature = "serde")] #[test] fn test_lcg128xsl64_serde() { use bincode; diff --git a/rand_pcg/tests/lcg64xsh32.rs b/rand_pcg/tests/lcg64xsh32.rs index 9c181ee3a45..ea704a50f6f 100644 --- a/rand_pcg/tests/lcg64xsh32.rs +++ b/rand_pcg/tests/lcg64xsh32.rs @@ -21,7 +21,7 @@ fn test_lcg64xsh32_construction() { let mut rng1 = Lcg64Xsh32::from_seed(seed); assert_eq!(rng1.next_u64(), 1204678643940597513); - let mut rng2 = Lcg64Xsh32::from_rng(&mut rng1).unwrap(); + let mut rng2 = Lcg64Xsh32::from_rng(&mut rng1); assert_eq!(rng2.next_u64(), 12384929573776311845); let mut rng3 = Lcg64Xsh32::seed_from_u64(0); @@ -47,7 +47,7 @@ fn test_lcg64xsh32_reference() { assert_eq!(results, expected); } -#[cfg(feature = "serde1")] +#[cfg(feature = "serde")] #[test] fn test_lcg64xsh32_serde() { use bincode; diff --git a/rand_pcg/tests/mcg128xsl64.rs b/rand_pcg/tests/mcg128xsl64.rs index 1f352b6e879..6125f1998c2 100644 --- a/rand_pcg/tests/mcg128xsl64.rs +++ b/rand_pcg/tests/mcg128xsl64.rs @@ -21,7 +21,7 @@ fn test_mcg128xsl64_construction() { let mut rng1 = Mcg128Xsl64::from_seed(seed); assert_eq!(rng1.next_u64(), 7071994460355047496); - let mut rng2 = Mcg128Xsl64::from_rng(&mut rng1).unwrap(); + let mut rng2 = Mcg128Xsl64::from_rng(&mut rng1); assert_eq!(rng2.next_u64(), 12300796107712034932); let mut rng3 = Mcg128Xsl64::seed_from_u64(0); @@ -52,7 +52,7 @@ fn test_mcg128xsl64_reference() { assert_eq!(results, expected); } -#[cfg(feature = "serde1")] +#[cfg(feature = "serde")] #[test] fn test_mcg128xsl64_serde() { use bincode; diff --git a/rustfmt.toml b/rustfmt.toml deleted file mode 100644 index ded1e7812fb..00000000000 --- a/rustfmt.toml +++ /dev/null @@ -1,32 +0,0 @@ -# This rustfmt file is added for configuration, but in practice much of our -# code is hand-formatted, frequently with more readable results. - -# Comments: -normalize_comments = true -wrap_comments = false -comment_width = 90 # small excess is okay but prefer 80 - -# Arguments: -use_small_heuristics = "Default" -# TODO: single line functions only where short, please? -# https://github.com/rust-lang/rustfmt/issues/3358 -fn_single_line = false -fn_args_layout = "Compressed" -overflow_delimited_expr = true -where_single_line = true - -# enum_discrim_align_threshold = 20 -# struct_field_align_threshold = 20 - -# Compatibility: -edition = "2021" - -# Misc: -inline_attribute_width = 80 -blank_lines_upper_bound = 2 -reorder_impl_items = true -# report_todo = "Unnumbered" -# report_fixme = "Unnumbered" - -# Ignored files: -ignore = [] diff --git a/src/distributions/bernoulli.rs b/src/distr/bernoulli.rs similarity index 80% rename from src/distributions/bernoulli.rs rename to src/distr/bernoulli.rs index 78bd724d789..6803518e376 100644 --- a/src/distributions/bernoulli.rs +++ b/src/distr/bernoulli.rs @@ -6,26 +6,35 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! The Bernoulli distribution. +//! The Bernoulli distribution `Bernoulli(p)`. -use crate::distributions::Distribution; +use crate::distr::Distribution; use crate::Rng; -use core::{fmt, u64}; +use core::fmt; -#[cfg(feature = "serde1")] -use serde::{Serialize, Deserialize}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; -/// The Bernoulli distribution. +/// The [Bernoulli distribution](https://en.wikipedia.org/wiki/Bernoulli_distribution) `Bernoulli(p)`. /// -/// This is a special case of the Binomial distribution where `n = 1`. +/// This distribution describes a single boolean random variable, which is true +/// with probability `p` and false with probability `1 - p`. +/// It is a special case of the Binomial distribution with `n = 1`. +/// +/// # Plot +/// +/// The following plot shows the Bernoulli distribution with `p = 0.1`, +/// `p = 0.5`, and `p = 0.9`. +/// +/// ![Bernoulli distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/bernoulli.svg) /// /// # Example /// /// ```rust -/// use rand::distributions::{Bernoulli, Distribution}; +/// use rand::distr::{Bernoulli, Distribution}; /// /// let d = Bernoulli::new(0.3).unwrap(); -/// let v = d.sample(&mut rand::thread_rng()); +/// let v = d.sample(&mut rand::rng()); /// println!("{} is from a Bernoulli distribution", v); /// ``` /// @@ -35,7 +44,7 @@ use serde::{Serialize, Deserialize}; /// so only probabilities that are multiples of 2-64 can be /// represented. #[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Bernoulli { /// Probability of success, relative to the maximal integer. p_int: u64, @@ -66,7 +75,7 @@ const ALWAYS_TRUE: u64 = u64::MAX; // in `no_std` mode. const SCALE: f64 = 2.0 * (1u64 << 63) as f64; -/// Error type returned from `Bernoulli::new`. +/// Error type returned from [`Bernoulli::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum BernoulliError { /// `p < 0` or `p > 1`. @@ -82,7 +91,7 @@ impl fmt::Display for BernoulliError { } #[cfg(feature = "std")] -impl ::std::error::Error for BernoulliError {} +impl std::error::Error for BernoulliError {} impl Bernoulli { /// Construct a new `Bernoulli` with the given probability of success `p`. @@ -127,6 +136,18 @@ impl Bernoulli { let p_int = ((f64::from(numerator) / f64::from(denominator)) * SCALE) as u64; Ok(Bernoulli { p_int }) } + + #[inline] + /// Returns the probability (`p`) of the distribution. + /// + /// This value may differ slightly from the input due to loss of precision. + pub fn p(&self) -> f64 { + if self.p_int == ALWAYS_TRUE { + 1.0 + } else { + (self.p_int as f64) / SCALE + } + } } impl Distribution for Bernoulli { @@ -136,7 +157,7 @@ impl Distribution for Bernoulli { if self.p_int == ALWAYS_TRUE { return true; } - let v: u64 = rng.gen(); + let v: u64 = rng.random(); v < self.p_int } } @@ -144,14 +165,15 @@ impl Distribution for Bernoulli { #[cfg(test)] mod test { use super::Bernoulli; - use crate::distributions::Distribution; + use crate::distr::Distribution; use crate::Rng; #[test] - #[cfg(feature = "serde1")] + #[cfg(feature = "serde")] fn test_serializing_deserializing_bernoulli() { let coin_flip = Bernoulli::new(0.5).unwrap(); - let de_coin_flip: Bernoulli = bincode::deserialize(&bincode::serialize(&coin_flip).unwrap()).unwrap(); + let de_coin_flip: Bernoulli = + bincode::deserialize(&bincode::serialize(&coin_flip).unwrap()).unwrap(); assert_eq!(coin_flip.p_int, de_coin_flip.p_int); } @@ -208,9 +230,10 @@ mod test { for x in &mut buf { *x = rng.sample(distr); } - assert_eq!(buf, [ - true, false, false, true, false, false, true, true, true, true - ]); + assert_eq!( + buf, + [true, false, false, true, false, false, true, true, true, true] + ); } #[test] diff --git a/src/distributions/distribution.rs b/src/distr/distribution.rs similarity index 72% rename from src/distributions/distribution.rs rename to src/distr/distribution.rs index 18ab30b8860..48598ec0fde 100644 --- a/src/distributions/distribution.rs +++ b/src/distr/distribution.rs @@ -10,9 +10,9 @@ //! Distribution trait and associates use crate::Rng; -use core::iter; #[cfg(feature = "alloc")] use alloc::string::String; +use core::iter; /// Types (distributions) that can be used to create a random instance of `T`. /// @@ -48,13 +48,12 @@ pub trait Distribution { /// # Example /// /// ``` - /// use rand::thread_rng; - /// use rand::distributions::{Distribution, Alphanumeric, Uniform, Standard}; + /// use rand::distr::{Distribution, Alphanumeric, Uniform, StandardUniform}; /// - /// let mut rng = thread_rng(); + /// let mut rng = rand::rng(); /// /// // Vec of 16 x f32: - /// let v: Vec = Standard.sample_iter(&mut rng).take(16).collect(); + /// let v: Vec = StandardUniform.sample_iter(&mut rng).take(16).collect(); /// /// // String: /// let s: String = Alphanumeric @@ -70,69 +69,66 @@ pub trait Distribution { /// println!("Not a 6; rolling again!"); /// } /// ``` - fn sample_iter(self, rng: R) -> DistIter + fn sample_iter(self, rng: R) -> Iter where R: Rng, Self: Sized, { - DistIter { + Iter { distr: self, rng, - phantom: ::core::marker::PhantomData, + phantom: core::marker::PhantomData, } } - /// Create a distribution of values of 'S' by mapping the output of `Self` - /// through the closure `F` + /// Map sampled values to type `S` /// /// # Example /// /// ``` - /// use rand::thread_rng; - /// use rand::distributions::{Distribution, Uniform}; - /// - /// let mut rng = thread_rng(); + /// use rand::distr::{Distribution, Uniform}; /// /// let die = Uniform::new_inclusive(1, 6).unwrap(); /// let even_number = die.map(|num| num % 2 == 0); - /// while !even_number.sample(&mut rng) { + /// while !even_number.sample(&mut rand::rng()) { /// println!("Still odd; rolling again!"); /// } /// ``` - fn map(self, func: F) -> DistMap + fn map(self, func: F) -> Map where F: Fn(T) -> S, Self: Sized, { - DistMap { + Map { distr: self, func, - phantom: ::core::marker::PhantomData, + phantom: core::marker::PhantomData, } } } -impl<'a, T, D: Distribution + ?Sized> Distribution for &'a D { +impl + ?Sized> Distribution for &D { fn sample(&self, rng: &mut R) -> T { (*self).sample(rng) } } -/// An iterator that generates random values of `T` with distribution `D`, -/// using `R` as the source of randomness. +/// An iterator over a [`Distribution`] /// -/// This `struct` is created by the [`sample_iter`] method on [`Distribution`]. -/// See its documentation for more. +/// This iterator yields random values of type `T` with distribution `D` +/// from a random generator of type `R`. /// -/// [`sample_iter`]: Distribution::sample_iter +/// Construct this `struct` using [`Distribution::sample_iter`] or +/// [`Rng::sample_iter`]. It is also used by [`Rng::random_iter`] and +/// [`crate::random_iter`]. #[derive(Debug)] -pub struct DistIter { +pub struct Iter { distr: D, rng: R, - phantom: ::core::marker::PhantomData, + phantom: core::marker::PhantomData, } -impl Iterator for DistIter +impl Iterator for Iter where D: Distribution, R: Rng, @@ -148,30 +144,29 @@ where } fn size_hint(&self) -> (usize, Option) { - (usize::max_value(), None) + (usize::MAX, None) } } -impl iter::FusedIterator for DistIter +impl iter::FusedIterator for Iter where D: Distribution, R: Rng, { } -/// A distribution of values of type `S` derived from the distribution `D` -/// by mapping its output of type `T` through the closure `F`. +/// A [`Distribution`] which maps sampled values to type `S` /// /// This `struct` is created by the [`Distribution::map`] method. /// See its documentation for more. #[derive(Debug)] -pub struct DistMap { +pub struct Map { distr: D, func: F, - phantom: ::core::marker::PhantomData S>, + phantom: core::marker::PhantomData S>, } -impl Distribution for DistMap +impl Distribution for Map where D: Distribution, F: Fn(T) -> S, @@ -181,16 +176,23 @@ where } } -/// `String` sampler +/// Sample or extend a [`String`] /// -/// Sampling a `String` of random characters is not quite the same as collecting -/// a sequence of chars. This trait contains some helpers. +/// Helper methods to extend a [`String`] or sample a new [`String`]. #[cfg(feature = "alloc")] -pub trait DistString { +pub trait SampleString { /// Append `len` random chars to `string` + /// + /// Note: implementations may leave `string` with excess capacity. If this + /// is undesirable, consider calling [`String::shrink_to_fit`] after this + /// method. fn append_string(&self, rng: &mut R, string: &mut String, len: usize); - /// Generate a `String` of `len` random chars + /// Generate a [`String`] of `len` random chars + /// + /// Note: implementations may leave the string with excess capacity. If this + /// is undesirable, consider calling [`String::shrink_to_fit`] after this + /// method. #[inline] fn sample_string(&self, rng: &mut R, len: usize) -> String { let mut s = String::new(); @@ -201,12 +203,12 @@ pub trait DistString { #[cfg(test)] mod tests { - use crate::distributions::{Distribution, Uniform}; + use crate::distr::{Distribution, Uniform}; use crate::Rng; #[test] fn test_distributions_iter() { - use crate::distributions::Open01; + use crate::distr::Open01; let mut rng = crate::test::rng(210); let distr = Open01; let mut iter = Distribution::::sample_iter(distr, &mut rng); @@ -228,9 +230,7 @@ mod tests { #[test] fn test_make_an_iter() { - fn ten_dice_rolls_other_than_five( - rng: &mut R, - ) -> impl Iterator + '_ { + fn ten_dice_rolls_other_than_five(rng: &mut R) -> impl Iterator + '_ { Uniform::new_inclusive(1, 6) .unwrap() .sample_iter(rng) @@ -250,16 +250,20 @@ mod tests { #[test] #[cfg(feature = "alloc")] fn test_dist_string() { + use crate::distr::{Alphabetic, Alphanumeric, SampleString, StandardUniform}; use core::str; - use crate::distributions::{Alphanumeric, DistString, Standard}; let mut rng = crate::test::rng(213); let s1 = Alphanumeric.sample_string(&mut rng, 20); assert_eq!(s1.len(), 20); assert_eq!(str::from_utf8(s1.as_bytes()), Ok(s1.as_str())); - let s2 = Standard.sample_string(&mut rng, 20); + let s2 = StandardUniform.sample_string(&mut rng, 20); assert_eq!(s2.chars().count(), 20); assert_eq!(str::from_utf8(s2.as_bytes()), Ok(s2.as_str())); + + let s3 = Alphabetic.sample_string(&mut rng, 20); + assert_eq!(s3.len(), 20); + assert_eq!(str::from_utf8(s3.as_bytes()), Ok(s3.as_str())); } } diff --git a/src/distributions/float.rs b/src/distr/float.rs similarity index 58% rename from src/distributions/float.rs rename to src/distr/float.rs index 37b71612a1b..44c33e99268 100644 --- a/src/distributions/float.rs +++ b/src/distr/float.rs @@ -8,14 +8,15 @@ //! Basic floating-point number distributions -use crate::distributions::utils::{IntAsSIMD, FloatAsSIMD, FloatSIMDUtils}; -use crate::distributions::{Distribution, Standard}; +use crate::distr::utils::{FloatAsSIMD, FloatSIMDUtils, IntAsSIMD}; +use crate::distr::{Distribution, StandardUniform}; use crate::Rng; use core::mem; -#[cfg(feature = "simd_support")] use core::simd::prelude::*; +#[cfg(feature = "simd_support")] +use core::simd::prelude::*; -#[cfg(feature = "serde1")] -use serde::{Serialize, Deserialize}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; /// A distribution to sample floating point numbers uniformly in the half-open /// interval `(0, 1]`, i.e. including 1 but not 0. @@ -25,24 +26,24 @@ use serde::{Serialize, Deserialize}; /// 53 most significant bits of a `u64` are used. The conversion uses the /// multiplicative method. /// -/// See also: [`Standard`] which samples from `[0, 1)`, [`Open01`] +/// See also: [`StandardUniform`] which samples from `[0, 1)`, [`Open01`] /// which samples from `(0, 1)` and [`Uniform`] which samples from arbitrary /// ranges. /// /// # Example /// ``` -/// use rand::{thread_rng, Rng}; -/// use rand::distributions::OpenClosed01; +/// use rand::Rng; +/// use rand::distr::OpenClosed01; /// -/// let val: f32 = thread_rng().sample(OpenClosed01); +/// let val: f32 = rand::rng().sample(OpenClosed01); /// println!("f32 from (0, 1): {}", val); /// ``` /// -/// [`Standard`]: crate::distributions::Standard -/// [`Open01`]: crate::distributions::Open01 -/// [`Uniform`]: crate::distributions::uniform::Uniform -#[derive(Clone, Copy, Debug)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +/// [`StandardUniform`]: crate::distr::StandardUniform +/// [`Open01`]: crate::distr::Open01 +/// [`Uniform`]: crate::distr::uniform::Uniform +#[derive(Clone, Copy, Debug, Default)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct OpenClosed01; /// A distribution to sample floating point numbers uniformly in the open @@ -52,27 +53,26 @@ pub struct OpenClosed01; /// the 23 most significant random bits of an `u32` are used, for `f64` 52 from /// an `u64`. The conversion uses a transmute-based method. /// -/// See also: [`Standard`] which samples from `[0, 1)`, [`OpenClosed01`] +/// See also: [`StandardUniform`] which samples from `[0, 1)`, [`OpenClosed01`] /// which samples from `(0, 1]` and [`Uniform`] which samples from arbitrary /// ranges. /// /// # Example /// ``` -/// use rand::{thread_rng, Rng}; -/// use rand::distributions::Open01; +/// use rand::Rng; +/// use rand::distr::Open01; /// -/// let val: f32 = thread_rng().sample(Open01); +/// let val: f32 = rand::rng().sample(Open01); /// println!("f32 from (0, 1): {}", val); /// ``` /// -/// [`Standard`]: crate::distributions::Standard -/// [`OpenClosed01`]: crate::distributions::OpenClosed01 -/// [`Uniform`]: crate::distributions::uniform::Uniform -#[derive(Clone, Copy, Debug)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +/// [`StandardUniform`]: crate::distr::StandardUniform +/// [`OpenClosed01`]: crate::distr::OpenClosed01 +/// [`Uniform`]: crate::distr::uniform::Uniform +#[derive(Clone, Copy, Debug, Default)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Open01; - // This trait is needed by both this lib and rand_distr hence is a hidden export #[doc(hidden)] pub trait IntoFloat { @@ -90,8 +90,9 @@ pub trait IntoFloat { } macro_rules! float_impls { - ($ty:ident, $uty:ident, $f_scalar:ident, $u_scalar:ty, + ($($meta:meta)?, $ty:ident, $uty:ident, $f_scalar:ident, $u_scalar:ty, $fraction_bits:expr, $exponent_bias:expr) => { + $(#[cfg($meta)])? impl IntoFloat for $uty { type F = $ty; #[inline(always)] @@ -103,7 +104,8 @@ macro_rules! float_impls { } } - impl Distribution<$ty> for Standard { + $(#[cfg($meta)])? + impl Distribution<$ty> for StandardUniform { fn sample(&self, rng: &mut R) -> $ty { // Multiply-based method; 24/53 random bits; [0, 1) interval. // We use the most significant bits because for simple RNGs @@ -112,12 +114,13 @@ macro_rules! float_impls { let precision = $fraction_bits + 1; let scale = 1.0 / ((1 as $u_scalar << precision) as $f_scalar); - let value: $uty = rng.gen(); + let value: $uty = rng.random(); let value = value >> $uty::splat(float_size - precision); $ty::splat(scale) * $ty::cast_from_int(value) } } + $(#[cfg($meta)])? impl Distribution<$ty> for OpenClosed01 { fn sample(&self, rng: &mut R) -> $ty { // Multiply-based method; 24/53 random bits; (0, 1] interval. @@ -127,56 +130,55 @@ macro_rules! float_impls { let precision = $fraction_bits + 1; let scale = 1.0 / ((1 as $u_scalar << precision) as $f_scalar); - let value: $uty = rng.gen(); + let value: $uty = rng.random(); let value = value >> $uty::splat(float_size - precision); // Add 1 to shift up; will not overflow because of right-shift: $ty::splat(scale) * $ty::cast_from_int(value + $uty::splat(1)) } } + $(#[cfg($meta)])? impl Distribution<$ty> for Open01 { fn sample(&self, rng: &mut R) -> $ty { // Transmute-based method; 23/52 random bits; (0, 1) interval. // We use the most significant bits because for simple RNGs // those are usually more random. - use core::$f_scalar::EPSILON; let float_size = mem::size_of::<$f_scalar>() as $u_scalar * 8; - let value: $uty = rng.gen(); + let value: $uty = rng.random(); let fraction = value >> $uty::splat(float_size - $fraction_bits); - fraction.into_float_with_exponent(0) - $ty::splat(1.0 - EPSILON / 2.0) + fraction.into_float_with_exponent(0) - $ty::splat(1.0 - $f_scalar::EPSILON / 2.0) } } } } -float_impls! { f32, u32, f32, u32, 23, 127 } -float_impls! { f64, u64, f64, u64, 52, 1023 } +float_impls! { , f32, u32, f32, u32, 23, 127 } +float_impls! { , f64, u64, f64, u64, 52, 1023 } #[cfg(feature = "simd_support")] -float_impls! { f32x2, u32x2, f32, u32, 23, 127 } +float_impls! { feature = "simd_support", f32x2, u32x2, f32, u32, 23, 127 } #[cfg(feature = "simd_support")] -float_impls! { f32x4, u32x4, f32, u32, 23, 127 } +float_impls! { feature = "simd_support", f32x4, u32x4, f32, u32, 23, 127 } #[cfg(feature = "simd_support")] -float_impls! { f32x8, u32x8, f32, u32, 23, 127 } +float_impls! { feature = "simd_support", f32x8, u32x8, f32, u32, 23, 127 } #[cfg(feature = "simd_support")] -float_impls! { f32x16, u32x16, f32, u32, 23, 127 } +float_impls! { feature = "simd_support", f32x16, u32x16, f32, u32, 23, 127 } #[cfg(feature = "simd_support")] -float_impls! { f64x2, u64x2, f64, u64, 52, 1023 } +float_impls! { feature = "simd_support", f64x2, u64x2, f64, u64, 52, 1023 } #[cfg(feature = "simd_support")] -float_impls! { f64x4, u64x4, f64, u64, 52, 1023 } +float_impls! { feature = "simd_support", f64x4, u64x4, f64, u64, 52, 1023 } #[cfg(feature = "simd_support")] -float_impls! { f64x8, u64x8, f64, u64, 52, 1023 } +float_impls! { feature = "simd_support", f64x8, u64x8, f64, u64, 52, 1023 } #[cfg(test)] mod tests { use super::*; - use crate::distributions::utils::FloatAsSIMD; - use crate::rngs::mock::StepRng; + use crate::test::const_rng; - const EPSILON32: f32 = ::core::f32::EPSILON; - const EPSILON64: f64 = ::core::f64::EPSILON; + const EPSILON32: f32 = f32::EPSILON; + const EPSILON64: f64 = f64::EPSILON; macro_rules! test_f32 { ($fnn:ident, $ty:ident, $ZERO:expr, $EPSILON:expr) => { @@ -184,29 +186,35 @@ mod tests { fn $fnn() { let two = $ty::splat(2.0); - // Standard - let mut zeros = StepRng::new(0, 0); - assert_eq!(zeros.gen::<$ty>(), $ZERO); - let mut one = StepRng::new(1 << 8 | 1 << (8 + 32), 0); - assert_eq!(one.gen::<$ty>(), $EPSILON / two); - let mut max = StepRng::new(!0, 0); - assert_eq!(max.gen::<$ty>(), $ty::splat(1.0) - $EPSILON / two); + // StandardUniform + let mut zeros = const_rng(0); + assert_eq!(zeros.random::<$ty>(), $ZERO); + let mut one = const_rng(1 << 8 | 1 << (8 + 32)); + assert_eq!(one.random::<$ty>(), $EPSILON / two); + let mut max = const_rng(!0); + assert_eq!(max.random::<$ty>(), $ty::splat(1.0) - $EPSILON / two); // OpenClosed01 - let mut zeros = StepRng::new(0, 0); + let mut zeros = const_rng(0); assert_eq!(zeros.sample::<$ty, _>(OpenClosed01), $ZERO + $EPSILON / two); - let mut one = StepRng::new(1 << 8 | 1 << (8 + 32), 0); + let mut one = const_rng(1 << 8 | 1 << (8 + 32)); assert_eq!(one.sample::<$ty, _>(OpenClosed01), $EPSILON); - let mut max = StepRng::new(!0, 0); + let mut max = const_rng(!0); assert_eq!(max.sample::<$ty, _>(OpenClosed01), $ZERO + $ty::splat(1.0)); // Open01 - let mut zeros = StepRng::new(0, 0); + let mut zeros = const_rng(0); assert_eq!(zeros.sample::<$ty, _>(Open01), $ZERO + $EPSILON / two); - let mut one = StepRng::new(1 << 9 | 1 << (9 + 32), 0); - assert_eq!(one.sample::<$ty, _>(Open01), $EPSILON / two * $ty::splat(3.0)); - let mut max = StepRng::new(!0, 0); - assert_eq!(max.sample::<$ty, _>(Open01), $ty::splat(1.0) - $EPSILON / two); + let mut one = const_rng(1 << 9 | 1 << (9 + 32)); + assert_eq!( + one.sample::<$ty, _>(Open01), + $EPSILON / two * $ty::splat(3.0) + ); + let mut max = const_rng(!0); + assert_eq!( + max.sample::<$ty, _>(Open01), + $ty::splat(1.0) - $EPSILON / two + ); } }; } @@ -226,29 +234,35 @@ mod tests { fn $fnn() { let two = $ty::splat(2.0); - // Standard - let mut zeros = StepRng::new(0, 0); - assert_eq!(zeros.gen::<$ty>(), $ZERO); - let mut one = StepRng::new(1 << 11, 0); - assert_eq!(one.gen::<$ty>(), $EPSILON / two); - let mut max = StepRng::new(!0, 0); - assert_eq!(max.gen::<$ty>(), $ty::splat(1.0) - $EPSILON / two); + // StandardUniform + let mut zeros = const_rng(0); + assert_eq!(zeros.random::<$ty>(), $ZERO); + let mut one = const_rng(1 << 11); + assert_eq!(one.random::<$ty>(), $EPSILON / two); + let mut max = const_rng(!0); + assert_eq!(max.random::<$ty>(), $ty::splat(1.0) - $EPSILON / two); // OpenClosed01 - let mut zeros = StepRng::new(0, 0); + let mut zeros = const_rng(0); assert_eq!(zeros.sample::<$ty, _>(OpenClosed01), $ZERO + $EPSILON / two); - let mut one = StepRng::new(1 << 11, 0); + let mut one = const_rng(1 << 11); assert_eq!(one.sample::<$ty, _>(OpenClosed01), $EPSILON); - let mut max = StepRng::new(!0, 0); + let mut max = const_rng(!0); assert_eq!(max.sample::<$ty, _>(OpenClosed01), $ZERO + $ty::splat(1.0)); // Open01 - let mut zeros = StepRng::new(0, 0); + let mut zeros = const_rng(0); assert_eq!(zeros.sample::<$ty, _>(Open01), $ZERO + $EPSILON / two); - let mut one = StepRng::new(1 << 12, 0); - assert_eq!(one.sample::<$ty, _>(Open01), $EPSILON / two * $ty::splat(3.0)); - let mut max = StepRng::new(!0, 0); - assert_eq!(max.sample::<$ty, _>(Open01), $ty::splat(1.0) - $EPSILON / two); + let mut one = const_rng(1 << 12); + assert_eq!( + one.sample::<$ty, _>(Open01), + $EPSILON / two * $ty::splat(3.0) + ); + let mut max = const_rng(!0); + assert_eq!( + max.sample::<$ty, _>(Open01), + $ty::splat(1.0) - $EPSILON / two + ); } }; } @@ -263,7 +277,9 @@ mod tests { #[test] fn value_stability() { fn test_samples>( - distr: &D, zero: T, expected: &[T], + distr: &D, + zero: T, + expected: &[T], ) { let mut rng = crate::test::rng(0x6f44f5646c2a7334); let mut buf = [zero; 3]; @@ -273,26 +289,30 @@ mod tests { assert_eq!(&buf, expected); } - test_samples(&Standard, 0f32, &[0.0035963655, 0.7346052, 0.09778172]); - test_samples(&Standard, 0f64, &[ - 0.7346051961657583, - 0.20298547462974248, - 0.8166436635290655, - ]); + test_samples( + &StandardUniform, + 0f32, + &[0.0035963655, 0.7346052, 0.09778172], + ); + test_samples( + &StandardUniform, + 0f64, + &[0.7346051961657583, 0.20298547462974248, 0.8166436635290655], + ); test_samples(&OpenClosed01, 0f32, &[0.003596425, 0.73460525, 0.09778178]); - test_samples(&OpenClosed01, 0f64, &[ - 0.7346051961657584, - 0.2029854746297426, - 0.8166436635290656, - ]); + test_samples( + &OpenClosed01, + 0f64, + &[0.7346051961657584, 0.2029854746297426, 0.8166436635290656], + ); test_samples(&Open01, 0f32, &[0.0035963655, 0.73460525, 0.09778172]); - test_samples(&Open01, 0f64, &[ - 0.7346051961657584, - 0.20298547462974248, - 0.8166436635290656, - ]); + test_samples( + &Open01, + 0f64, + &[0.7346051961657584, 0.20298547462974248, 0.8166436635290656], + ); #[cfg(feature = "simd_support")] { @@ -300,17 +320,25 @@ mod tests { // non-SIMD types; we assume this pattern continues across all // SIMD types. - test_samples(&Standard, f32x2::from([0.0, 0.0]), &[ - f32x2::from([0.0035963655, 0.7346052]), - f32x2::from([0.09778172, 0.20298547]), - f32x2::from([0.34296435, 0.81664366]), - ]); - - test_samples(&Standard, f64x2::from([0.0, 0.0]), &[ - f64x2::from([0.7346051961657583, 0.20298547462974248]), - f64x2::from([0.8166436635290655, 0.7423708925400552]), - f64x2::from([0.16387782224016323, 0.9087068770169618]), - ]); + test_samples( + &StandardUniform, + f32x2::from([0.0, 0.0]), + &[ + f32x2::from([0.0035963655, 0.7346052]), + f32x2::from([0.09778172, 0.20298547]), + f32x2::from([0.34296435, 0.81664366]), + ], + ); + + test_samples( + &StandardUniform, + f64x2::from([0.0, 0.0]), + &[ + f64x2::from([0.7346051961657583, 0.20298547462974248]), + f64x2::from([0.8166436635290655, 0.7423708925400552]), + f64x2::from([0.16387782224016323, 0.9087068770169618]), + ], + ); } } } diff --git a/src/distr/integer.rs b/src/distr/integer.rs new file mode 100644 index 00000000000..37b2081c471 --- /dev/null +++ b/src/distr/integer.rs @@ -0,0 +1,307 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! The implementations of the `StandardUniform` distribution for integer types. + +use crate::distr::{Distribution, StandardUniform}; +use crate::Rng; +#[cfg(all(target_arch = "x86", feature = "simd_support"))] +use core::arch::x86::__m512i; +#[cfg(target_arch = "x86")] +use core::arch::x86::{__m128i, __m256i}; +#[cfg(all(target_arch = "x86_64", feature = "simd_support"))] +use core::arch::x86_64::__m512i; +#[cfg(target_arch = "x86_64")] +use core::arch::x86_64::{__m128i, __m256i}; +use core::num::{ + NonZeroI128, NonZeroI16, NonZeroI32, NonZeroI64, NonZeroI8, NonZeroU128, NonZeroU16, + NonZeroU32, NonZeroU64, NonZeroU8, +}; +#[cfg(feature = "simd_support")] +use core::simd::*; + +impl Distribution for StandardUniform { + #[inline] + fn sample(&self, rng: &mut R) -> u8 { + rng.next_u32() as u8 + } +} + +impl Distribution for StandardUniform { + #[inline] + fn sample(&self, rng: &mut R) -> u16 { + rng.next_u32() as u16 + } +} + +impl Distribution for StandardUniform { + #[inline] + fn sample(&self, rng: &mut R) -> u32 { + rng.next_u32() + } +} + +impl Distribution for StandardUniform { + #[inline] + fn sample(&self, rng: &mut R) -> u64 { + rng.next_u64() + } +} + +impl Distribution for StandardUniform { + #[inline] + fn sample(&self, rng: &mut R) -> u128 { + // Use LE; we explicitly generate one value before the next. + let x = u128::from(rng.next_u64()); + let y = u128::from(rng.next_u64()); + (y << 64) | x + } +} + +macro_rules! impl_int_from_uint { + ($ty:ty, $uty:ty) => { + impl Distribution<$ty> for StandardUniform { + #[inline] + fn sample(&self, rng: &mut R) -> $ty { + rng.random::<$uty>() as $ty + } + } + }; +} + +impl_int_from_uint! { i8, u8 } +impl_int_from_uint! { i16, u16 } +impl_int_from_uint! { i32, u32 } +impl_int_from_uint! { i64, u64 } +impl_int_from_uint! { i128, u128 } + +macro_rules! impl_nzint { + ($ty:ty, $new:path) => { + impl Distribution<$ty> for StandardUniform { + fn sample(&self, rng: &mut R) -> $ty { + loop { + if let Some(nz) = $new(rng.random()) { + break nz; + } + } + } + } + }; +} + +impl_nzint!(NonZeroU8, NonZeroU8::new); +impl_nzint!(NonZeroU16, NonZeroU16::new); +impl_nzint!(NonZeroU32, NonZeroU32::new); +impl_nzint!(NonZeroU64, NonZeroU64::new); +impl_nzint!(NonZeroU128, NonZeroU128::new); + +impl_nzint!(NonZeroI8, NonZeroI8::new); +impl_nzint!(NonZeroI16, NonZeroI16::new); +impl_nzint!(NonZeroI32, NonZeroI32::new); +impl_nzint!(NonZeroI64, NonZeroI64::new); +impl_nzint!(NonZeroI128, NonZeroI128::new); + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +impl Distribution<__m128i> for StandardUniform { + #[inline] + fn sample(&self, rng: &mut R) -> __m128i { + // NOTE: It's tempting to use the u128 impl here, but confusingly this + // results in different code (return via rdx, r10 instead of rax, rdx + // with u128 impl) and is much slower (+130 time). This version calls + // impls::fill_bytes_via_next but performs well. + + let mut buf = [0_u8; core::mem::size_of::<__m128i>()]; + rng.fill_bytes(&mut buf); + // x86 is little endian so no need for conversion + + // SAFETY: All byte sequences of `buf` represent values of the output type. + unsafe { core::mem::transmute(buf) } + } +} + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +impl Distribution<__m256i> for StandardUniform { + #[inline] + fn sample(&self, rng: &mut R) -> __m256i { + let mut buf = [0_u8; core::mem::size_of::<__m256i>()]; + rng.fill_bytes(&mut buf); + // x86 is little endian so no need for conversion + + // SAFETY: All byte sequences of `buf` represent values of the output type. + unsafe { core::mem::transmute(buf) } + } +} + +#[cfg(all( + any(target_arch = "x86", target_arch = "x86_64"), + feature = "simd_support" +))] +impl Distribution<__m512i> for StandardUniform { + #[inline] + fn sample(&self, rng: &mut R) -> __m512i { + let mut buf = [0_u8; core::mem::size_of::<__m512i>()]; + rng.fill_bytes(&mut buf); + // x86 is little endian so no need for conversion + + // SAFETY: All byte sequences of `buf` represent values of the output type. + unsafe { core::mem::transmute(buf) } + } +} + +#[cfg(feature = "simd_support")] +macro_rules! simd_impl { + ($($ty:ty),+) => {$( + /// Requires nightly Rust and the [`simd_support`] feature + /// + /// [`simd_support`]: https://github.com/rust-random/rand#crate-features + #[cfg(feature = "simd_support")] + impl Distribution> for StandardUniform + where + LaneCount: SupportedLaneCount, + { + #[inline] + fn sample(&self, rng: &mut R) -> Simd<$ty, LANES> { + let mut vec = Simd::default(); + rng.fill(vec.as_mut_array().as_mut_slice()); + vec + } + } + )+}; +} + +#[cfg(feature = "simd_support")] +simd_impl!(u8, i8, u16, i16, u32, i32, u64, i64); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_integers() { + let mut rng = crate::test::rng(806); + + rng.sample::(StandardUniform); + rng.sample::(StandardUniform); + rng.sample::(StandardUniform); + rng.sample::(StandardUniform); + rng.sample::(StandardUniform); + + rng.sample::(StandardUniform); + rng.sample::(StandardUniform); + rng.sample::(StandardUniform); + rng.sample::(StandardUniform); + rng.sample::(StandardUniform); + } + + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + #[test] + fn x86_integers() { + let mut rng = crate::test::rng(807); + + rng.sample::<__m128i, _>(StandardUniform); + rng.sample::<__m256i, _>(StandardUniform); + #[cfg(feature = "simd_support")] + rng.sample::<__m512i, _>(StandardUniform); + } + + #[test] + fn value_stability() { + fn test_samples(zero: T, expected: &[T]) + where + StandardUniform: Distribution, + { + let mut rng = crate::test::rng(807); + let mut buf = [zero; 3]; + for x in &mut buf { + *x = rng.sample(StandardUniform); + } + assert_eq!(&buf, expected); + } + + test_samples(0u8, &[9, 247, 111]); + test_samples(0u16, &[32265, 42999, 38255]); + test_samples(0u32, &[2220326409, 2575017975, 2018088303]); + test_samples( + 0u64, + &[ + 11059617991457472009, + 16096616328739788143, + 1487364411147516184, + ], + ); + test_samples( + 0u128, + &[ + 296930161868957086625409848350820761097, + 145644820879247630242265036535529306392, + 111087889832015897993126088499035356354, + ], + ); + + test_samples(0i8, &[9, -9, 111]); + // Skip further i* types: they are simple reinterpretation of u* samples + + #[cfg(feature = "simd_support")] + { + // We only test a sub-set of types here and make assumptions about the rest. + + test_samples( + u8x4::default(), + &[ + u8x4::from([9, 126, 87, 132]), + u8x4::from([247, 167, 123, 153]), + u8x4::from([111, 149, 73, 120]), + ], + ); + test_samples( + u8x8::default(), + &[ + u8x8::from([9, 126, 87, 132, 247, 167, 123, 153]), + u8x8::from([111, 149, 73, 120, 68, 171, 98, 223]), + u8x8::from([24, 121, 1, 50, 13, 46, 164, 20]), + ], + ); + + test_samples( + i64x8::default(), + &[ + i64x8::from([ + -7387126082252079607, + -2350127744969763473, + 1487364411147516184, + 7895421560427121838, + 602190064936008898, + 6022086574635100741, + -5080089175222015595, + -4066367846667249123, + ]), + i64x8::from([ + 9180885022207963908, + 3095981199532211089, + 6586075293021332726, + 419343203796414657, + 3186951873057035255, + 5287129228749947252, + 444726432079249540, + -1587028029513790706, + ]), + i64x8::from([ + 6075236523189346388, + 1351763722368165432, + -6192309979959753740, + -7697775502176768592, + -4482022114172078123, + 7522501477800909500, + -1837258847956201231, + -586926753024886735, + ]), + ], + ); + } + } +} diff --git a/src/distributions/mod.rs b/src/distr/mod.rs similarity index 50% rename from src/distributions/mod.rs rename to src/distr/mod.rs index 39d967d4f60..a66504624bb 100644 --- a/src/distributions/mod.rs +++ b/src/distr/mod.rs @@ -11,7 +11,7 @@ //! //! This module is the home of the [`Distribution`] trait and several of its //! implementations. It is the workhorse behind some of the convenient -//! functionality of the [`Rng`] trait, e.g. [`Rng::gen`] and of course +//! functionality of the [`Rng`] trait, e.g. [`Rng::random`] and of course //! [`Rng::sample`]. //! //! Abstractly, a [probability distribution] describes the probability of @@ -28,58 +28,51 @@ //! [`Uniform`] allows specification of its sample space as a range within `T`). //! //! -//! # The `Standard` distribution +//! # The Standard Uniform distribution //! -//! The [`Standard`] distribution is important to mention. This is the -//! distribution used by [`Rng::gen`] and represents the "default" way to +//! The [`StandardUniform`] distribution is important to mention. This is the +//! distribution used by [`Rng::random`] and represents the "default" way to //! produce a random value for many different types, including most primitive //! types, tuples, arrays, and a few derived types. See the documentation of -//! [`Standard`] for more details. +//! [`StandardUniform`] for more details. //! -//! Implementing `Distribution` for [`Standard`] for user types `T` makes it -//! possible to generate type `T` with [`Rng::gen`], and by extension also +//! Implementing [`Distribution`] for [`StandardUniform`] for user types `T` makes it +//! possible to generate type `T` with [`Rng::random`], and by extension also //! with the [`random`] function. //! -//! ## Random characters +//! ## Other standard uniform distributions //! //! [`Alphanumeric`] is a simple distribution to sample random letters and -//! numbers of the `char` type; in contrast [`Standard`] may sample any valid +//! numbers of the `char` type; in contrast [`StandardUniform`] may sample any valid //! `char`. //! +//! There's also an [`Alphabetic`] distribution which acts similarly to [`Alphanumeric`] but +//! doesn't include digits. //! -//! # Uniform numeric ranges +//! For floats (`f32`, `f64`), [`StandardUniform`] samples from `[0, 1)`. Also +//! provided are [`Open01`] (samples from `(0, 1)`) and [`OpenClosed01`] +//! (samples from `(0, 1]`). No option is provided to sample from `[0, 1]`; it +//! is suggested to use one of the above half-open ranges since the failure to +//! sample a value which would have a low chance of being sampled anyway is +//! rarely an issue in practice. //! -//! The [`Uniform`] distribution is more flexible than [`Standard`], but also -//! more specialised: it supports fewer target types, but allows the sample -//! space to be specified as an arbitrary range within its target type `T`. -//! Both [`Standard`] and [`Uniform`] are in some sense uniform distributions. +//! # Parameterized Uniform distributions //! -//! Values may be sampled from this distribution using [`Rng::sample(Range)`] or -//! by creating a distribution object with [`Uniform::new`], -//! [`Uniform::new_inclusive`] or `From`. When the range limits are not -//! known at compile time it is typically faster to reuse an existing -//! `Uniform` object than to call [`Rng::sample(Range)`]. +//! The [`Uniform`] distribution provides uniform sampling over a specified +//! range on a subset of the types supported by the above distributions. //! -//! User types `T` may also implement `Distribution` for [`Uniform`], -//! although this is less straightforward than for [`Standard`] (see the -//! documentation in the [`uniform`] module). Doing so enables generation of -//! values of type `T` with [`Rng::sample(Range)`]. -//! -//! ## Open and half-open ranges -//! -//! There are surprisingly many ways to uniformly generate random floats. A -//! range between 0 and 1 is standard, but the exact bounds (open vs closed) -//! and accuracy differ. In addition to the [`Standard`] distribution Rand offers -//! [`Open01`] and [`OpenClosed01`]. See "Floating point implementation" section of -//! [`Standard`] documentation for more details. +//! Implementations support single-value-sampling via +//! [`Rng::random_range(Range)`](Rng::random_range). +//! Where a fixed (non-`const`) range will be sampled many times, it is likely +//! faster to pre-construct a [`Distribution`] object using +//! [`Uniform::new`], [`Uniform::new_inclusive`] or `From`. //! //! # Non-uniform sampling //! //! Sampling a simple true/false outcome with a given probability has a name: -//! the [`Bernoulli`] distribution (this is used by [`Rng::gen_bool`]). +//! the [`Bernoulli`] distribution (this is used by [`Rng::random_bool`]). //! -//! For weighted sampling from a sequence of discrete values, use the -//! [`WeightedIndex`] distribution. +//! For weighted sampling of discrete values see the [`weighted`] module. //! //! This crate no longer includes other non-uniform distributions; instead //! it is recommended that you use either [`rand_distr`] or [`statrs`]. @@ -98,89 +91,88 @@ mod distribution; mod float; mod integer; mod other; -mod slice; mod utils; -#[cfg(feature = "alloc")] -mod weighted_index; #[doc(hidden)] pub mod hidden_export { pub use super::float::IntoFloat; // used by rand_distr } +pub mod slice; pub mod uniform; +#[cfg(feature = "alloc")] +pub mod weighted; pub use self::bernoulli::{Bernoulli, BernoulliError}; -pub use self::distribution::{Distribution, DistIter, DistMap}; #[cfg(feature = "alloc")] -pub use self::distribution::DistString; +pub use self::distribution::SampleString; +pub use self::distribution::{Distribution, Iter, Map}; pub use self::float::{Open01, OpenClosed01}; -pub use self::other::Alphanumeric; -pub use self::slice::Slice; +pub use self::other::{Alphabetic, Alphanumeric}; #[doc(inline)] pub use self::uniform::Uniform; -#[cfg(feature = "alloc")] -pub use self::weighted_index::{Weight, WeightError, WeightedIndex}; #[allow(unused)] use crate::Rng; -/// A generic random value distribution, implemented for many primitive types. -/// Usually generates values with a numerically uniform distribution, and with a -/// range appropriate to the type. +/// The Standard Uniform distribution /// -/// ## Provided implementations +/// This [`Distribution`] is the *standard* parameterization of [`Uniform`]. Bounds +/// are selected according to the output type. /// /// Assuming the provided `Rng` is well-behaved, these implementations /// generate values with the following ranges and distributions: /// -/// * Integers (`i32`, `u32`, `isize`, `usize`, etc.): Uniformly distributed -/// over all values of the type. -/// * `char`: Uniformly distributed over all Unicode scalar values, i.e. all +/// * Integers (`i8`, `i32`, `u64`, etc.) are uniformly distributed +/// over the whole range of the type (thus each possible value may be sampled +/// with equal probability). +/// * `char` is uniformly distributed over all Unicode scalar values, i.e. all /// code points in the range `0...0x10_FFFF`, except for the range /// `0xD800...0xDFFF` (the surrogate code points). This includes /// unassigned/reserved code points. -/// * `bool`: Generates `false` or `true`, each with probability 0.5. -/// * Floating point types (`f32` and `f64`): Uniformly distributed in the -/// half-open range `[0, 1)`. See notes below. +/// For some uses, the [`Alphanumeric`] or [`Alphabetic`] distribution will be more +/// appropriate. +/// * `bool` samples `false` or `true`, each with probability 0.5. +/// * Floating point types (`f32` and `f64`) are uniformly distributed in the +/// half-open range `[0, 1)`. See also the [notes below](#floating-point-implementation). /// * Wrapping integers ([`Wrapping`]), besides the type identical to their /// normal integer variants. /// * Non-zero integers ([`NonZeroU8`]), which are like their normal integer -/// variants but cannot produce zero. -/// * SIMD types like x86's [`__m128i`], `std::simd`'s [`u32x4`]/[`f32x4`]/ -/// [`mask32x4`] (requires [`simd_support`]), where each lane is distributed -/// like their scalar `Standard` variants. See the list of `Standard` -/// implementations for more. +/// variants but cannot sample zero. /// -/// The `Standard` distribution also supports generation of the following +/// The `StandardUniform` distribution also supports generation of the following /// compound types where all component types are supported: /// -/// * Tuples (up to 12 elements): each element is generated sequentially. -/// * Arrays: each element is generated sequentially; -/// see also [`Rng::fill`] which supports arbitrary array length for integer -/// and float types and tends to be faster for `u32` and smaller types. -/// Note that [`Rng::fill`] and `Standard`'s array support are *not* equivalent: -/// the former is optimised for integer types (using fewer RNG calls for -/// element types smaller than the RNG word size), while the latter supports -/// any element type supported by `Standard`. -/// * `Option` first generates a `bool`, and if true generates and returns -/// `Some(value)` where `value: T`, otherwise returning `None`. +/// * Tuples (up to 12 elements): each element is sampled sequentially and +/// independently (thus, assuming a well-behaved RNG, there is no correlation +/// between elements). +/// * Arrays `[T; n]` where `T` is supported. Each element is sampled +/// sequentially and independently. Note that for small `T` this usually +/// results in the RNG discarding random bits; see also [`Rng::fill`] which +/// offers a more efficient approach to filling an array of integer types +/// with random data. +/// * SIMD types (requires [`simd_support`] feature) like x86's [`__m128i`] +/// and `std::simd`'s [`u32x4`], [`f32x4`] and [`mask32x4`] types are +/// effectively arrays of integer or floating-point types. Each lane is +/// sampled independently, potentially with more efficient random-bit-usage +/// (and a different resulting value) than would be achieved with sequential +/// sampling (as with the array types above). /// /// ## Custom implementations /// -/// The [`Standard`] distribution may be implemented for user types as follows: +/// The [`StandardUniform`] distribution may be implemented for user types as follows: /// /// ``` /// # #![allow(dead_code)] /// use rand::Rng; -/// use rand::distributions::{Distribution, Standard}; +/// use rand::distr::{Distribution, StandardUniform}; /// /// struct MyF32 { /// x: f32, /// } /// -/// impl Distribution for Standard { +/// impl Distribution for StandardUniform { /// fn sample(&self, rng: &mut R) -> MyF32 { -/// MyF32 { x: rng.gen() } +/// MyF32 { x: rng.random() } /// } /// } /// ``` @@ -188,14 +180,14 @@ use crate::Rng; /// ## Example usage /// ``` /// use rand::prelude::*; -/// use rand::distributions::Standard; +/// use rand::distr::StandardUniform; /// -/// let val: f32 = StdRng::from_entropy().sample(Standard); +/// let val: f32 = rand::rng().sample(StandardUniform); /// println!("f32 from [0, 1): {}", val); /// ``` /// /// # Floating point implementation -/// The floating point implementations for `Standard` generate a random value in +/// The floating point implementations for `StandardUniform` generate a random value in /// the half-open interval `[0, 1)`, i.e. including 0 but not 1. /// /// All values that can be generated are of the form `n * ε/2`. For `f32` @@ -204,7 +196,7 @@ use crate::Rng; /// multiplicative method: `(rng.gen::<$uty>() >> N) as $ty * (ε/2)`. /// /// See also: [`Open01`] which samples from `(0, 1)`, [`OpenClosed01`] which -/// samples from `(0, 1]` and `Rng::gen_range(0..1)` which also samples from +/// samples from `(0, 1]` and `Rng::random_range(0..1)` which also samples from /// `[0, 1)`. Note that `Open01` uses transmute-based methods which yield 1 bit /// less precision but may perform faster on some architectures (on modern Intel /// CPUs all methods have approximately equal performance). @@ -217,6 +209,6 @@ use crate::Rng; /// [`f32x4`]: std::simd::f32x4 /// [`mask32x4`]: std::simd::mask32x4 /// [`simd_support`]: https://github.com/rust-random/rand#crate-features -#[derive(Clone, Copy, Debug)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct Standard; +#[derive(Clone, Copy, Debug, Default)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct StandardUniform; diff --git a/src/distributions/other.rs b/src/distr/other.rs similarity index 59% rename from src/distributions/other.rs rename to src/distr/other.rs index ebe3d57ed3f..47b99323d6b 100644 --- a/src/distributions/other.rs +++ b/src/distr/other.rs @@ -6,26 +6,25 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! The implementations of the `Standard` distribution for other built-in types. +//! The implementations of the `StandardUniform` distribution for other built-in types. -use core::char; -use core::num::Wrapping; #[cfg(feature = "alloc")] use alloc::string::String; +use core::array; +use core::char; +use core::num::Wrapping; -use crate::distributions::{Distribution, Standard, Uniform}; #[cfg(feature = "alloc")] -use crate::distributions::DistString; +use crate::distr::SampleString; +use crate::distr::{Distribution, StandardUniform, Uniform}; use crate::Rng; -#[cfg(feature = "serde1")] -use serde::{Serialize, Deserialize}; -use core::mem::{self, MaybeUninit}; #[cfg(feature = "simd_support")] use core::simd::prelude::*; #[cfg(feature = "simd_support")] use core::simd::{LaneCount, MaskElement, SupportedLaneCount}; - +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; // ----- Sampling distributions ----- @@ -35,19 +34,19 @@ use core::simd::{LaneCount, MaskElement, SupportedLaneCount}; /// # Example /// /// ``` -/// use rand::{Rng, thread_rng}; -/// use rand::distributions::Alphanumeric; +/// use rand::Rng; +/// use rand::distr::Alphanumeric; /// -/// let mut rng = thread_rng(); +/// let mut rng = rand::rng(); /// let chars: String = (0..7).map(|_| rng.sample(Alphanumeric) as char).collect(); /// println!("Random chars: {}", chars); /// ``` /// -/// The [`DistString`] trait provides an easier method of generating -/// a random `String`, and offers more efficient allocation: +/// The [`SampleString`] trait provides an easier method of generating +/// a random [`String`], and offers more efficient allocation: /// ``` -/// use rand::distributions::{Alphanumeric, DistString}; -/// let string = Alphanumeric.sample_string(&mut rand::thread_rng(), 16); +/// use rand::distr::{Alphanumeric, SampleString}; +/// let string = Alphanumeric.sample_string(&mut rand::rng(), 16); /// println!("Random string: {}", string); /// ``` /// @@ -67,14 +66,42 @@ use core::simd::{LaneCount, MaskElement, SupportedLaneCount}; /// /// - [Wikipedia article on Password Strength](https://en.wikipedia.org/wiki/Password_strength) /// - [Diceware for generating memorable passwords](https://en.wikipedia.org/wiki/Diceware) -#[derive(Debug, Clone, Copy)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[derive(Debug, Clone, Copy, Default)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Alphanumeric; +/// Sample a [`u8`], uniformly distributed over letters: +/// a-z and A-Z. +/// +/// # Example +/// +/// You're able to generate random Alphabetic characters via mapping or via the +/// [`SampleString::sample_string`] method like so: +/// +/// ``` +/// use rand::Rng; +/// use rand::distr::{Alphabetic, SampleString}; +/// +/// // Manual mapping +/// let mut rng = rand::rng(); +/// let chars: String = (0..7).map(|_| rng.sample(Alphabetic) as char).collect(); +/// println!("Random chars: {}", chars); +/// +/// // Using [`SampleString::sample_string`] +/// let string = Alphabetic.sample_string(&mut rand::rng(), 16); +/// println!("Random string: {}", string); +/// ``` +/// +/// # Passwords +/// +/// Refer to [`Alphanumeric#Passwords`]. +#[derive(Debug, Clone, Copy, Default)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct Alphabetic; // ----- Implementations of distributions ----- -impl Distribution for Standard { +impl Distribution for StandardUniform { #[inline] fn sample(&self, rng: &mut R) -> char { // A valid `char` is either in the interval `[0, 0xD800)` or @@ -91,14 +118,13 @@ impl Distribution for Standard { if n <= 0xDFFF { n -= GAP_SIZE; } + // SAFETY: We ensure above that `n` represents a `char`. unsafe { char::from_u32_unchecked(n) } } } -/// Note: the `String` is potentially left with excess capacity; optionally the -/// user may call `string.shrink_to_fit()` afterwards. #[cfg(feature = "alloc")] -impl DistString for Standard { +impl SampleString for StandardUniform { fn append_string(&self, rng: &mut R, s: &mut String, len: usize) { // A char is encoded with at most four bytes, thus this reservation is // guaranteed to be sufficient. We do not shrink_to_fit afterwards so @@ -127,17 +153,47 @@ impl Distribution for Alphanumeric { } } +impl Distribution for Alphabetic { + fn sample(&self, rng: &mut R) -> u8 { + const RANGE: u8 = 26 + 26; + + let offset = rng.random_range(0..RANGE) + b'A'; + + // Account for upper-cases + offset + (offset > b'Z') as u8 * (b'a' - b'Z' - 1) + } +} + +#[cfg(feature = "alloc")] +impl SampleString for Alphanumeric { + fn append_string(&self, rng: &mut R, string: &mut String, len: usize) { + // SAFETY: `self` only samples alphanumeric characters, which are valid UTF-8. + unsafe { + let v = string.as_mut_vec(); + v.extend( + self.sample_iter(rng) + .take(len) + .inspect(|b| debug_assert!(b.is_ascii_alphanumeric())), + ); + } + } +} + #[cfg(feature = "alloc")] -impl DistString for Alphanumeric { +impl SampleString for Alphabetic { fn append_string(&self, rng: &mut R, string: &mut String, len: usize) { + // SAFETY: With this distribution we guarantee that we're working with valid ASCII + // characters. + // See [#1590](https://github.com/rust-random/rand/issues/1590). unsafe { let v = string.as_mut_vec(); + v.reserve_exact(len); v.extend(self.sample_iter(rng).take(len)); } } } -impl Distribution for Standard { +impl Distribution for StandardUniform { #[inline] fn sample(&self, rng: &mut R) -> bool { // We can compare against an arbitrary bit of an u32 to get a bool. @@ -148,18 +204,16 @@ impl Distribution for Standard { } } -/// Requires nightly Rust and the [`simd_support`] feature -/// /// Note that on some hardware like x86/64 mask operations like [`_mm_blendv_epi8`] /// only care about a single bit. This means that you could use uniform random bits /// directly: /// /// ```ignore /// // this may be faster... -/// let x = unsafe { _mm_blendv_epi8(a.into(), b.into(), rng.gen::<__m128i>()) }; +/// let x = unsafe { _mm_blendv_epi8(a.into(), b.into(), rng.random::<__m128i>()) }; /// /// // ...than this -/// let x = rng.gen::().select(b, a); +/// let x = rng.random::().select(b, a); /// ``` /// /// Since most bits are unused you could also generate only as many bits as you need, i.e.: @@ -167,9 +221,9 @@ impl Distribution for Standard { /// #![feature(portable_simd)] /// use std::simd::prelude::*; /// use rand::prelude::*; -/// let mut rng = thread_rng(); +/// let mut rng = rand::rng(); /// -/// let x = u16x8::splat(rng.gen::() as u16); +/// let x = u16x8::splat(rng.random::() as u16); /// let mask = u16x8::splat(1) << u16x8::from([0, 1, 2, 3, 4, 5, 6, 7]); /// let rand_mask = (x & mask).simd_eq(mask); /// ``` @@ -177,29 +231,29 @@ impl Distribution for Standard { /// [`_mm_blendv_epi8`]: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_blendv_epi8&ig_expand=514/ /// [`simd_support`]: https://github.com/rust-random/rand#crate-features #[cfg(feature = "simd_support")] -impl Distribution> for Standard +impl Distribution> for StandardUniform where T: MaskElement + Default, LaneCount: SupportedLaneCount, - Standard: Distribution>, + StandardUniform: Distribution>, Simd: SimdPartialOrd>, { #[inline] fn sample(&self, rng: &mut R) -> Mask { // `MaskElement` must be a signed integer, so this is equivalent // to the scalar `i32 < 0` method - let var = rng.gen::>(); + let var = rng.random::>(); var.simd_lt(Simd::default()) } } -/// Implement `Distribution<(A, B, C, ...)> for Standard, using the list of +/// Implement `Distribution<(A, B, C, ...)> for StandardUniform`, using the list of /// identifiers macro_rules! tuple_impl { ($($tyvar:ident)*) => { - impl< $($tyvar,)* > Distribution<($($tyvar,)*)> for Standard + impl< $($tyvar,)* > Distribution<($($tyvar,)*)> for StandardUniform where $( - Standard: Distribution< $tyvar >, + StandardUniform: Distribution< $tyvar >, )* { #[inline] @@ -207,7 +261,7 @@ macro_rules! tuple_impl { let out = ($( // use the $tyvar's to get the appropriate number of // repeats (they're not actually needed) - rng.gen::<$tyvar>() + rng.random::<$tyvar>() ,)*); // Suppress the unused variable warning for empty tuple @@ -238,57 +292,37 @@ macro_rules! tuple_impls { tuple_impls! {A B C D E F G H I J K L} -impl Distribution<[T; N]> for Standard -where Standard: Distribution -{ - #[inline] - fn sample(&self, _rng: &mut R) -> [T; N] { - let mut buff: [MaybeUninit; N] = unsafe { MaybeUninit::uninit().assume_init() }; - - for elem in &mut buff { - *elem = MaybeUninit::new(_rng.gen()); - } - - unsafe { mem::transmute_copy::<_, _>(&buff) } - } -} - -impl Distribution> for Standard -where Standard: Distribution +impl Distribution<[T; N]> for StandardUniform +where + StandardUniform: Distribution, { #[inline] - fn sample(&self, rng: &mut R) -> Option { - // UFCS is needed here: https://github.com/rust-lang/rust/issues/24066 - if rng.gen::() { - Some(rng.gen()) - } else { - None - } + fn sample(&self, rng: &mut R) -> [T; N] { + array::from_fn(|_| rng.random()) } } -impl Distribution> for Standard -where Standard: Distribution +impl Distribution> for StandardUniform +where + StandardUniform: Distribution, { #[inline] fn sample(&self, rng: &mut R) -> Wrapping { - Wrapping(rng.gen()) + Wrapping(rng.random()) } } - #[cfg(test)] mod tests { use super::*; use crate::RngCore; - #[cfg(feature = "alloc")] use alloc::string::String; #[test] fn test_misc() { let rng: &mut dyn RngCore = &mut crate::test::rng(820); - rng.sample::(Standard); - rng.sample::(Standard); + rng.sample::(StandardUniform); + rng.sample::(StandardUniform); } #[cfg(feature = "alloc")] @@ -300,7 +334,7 @@ mod tests { // Test by generating a relatively large number of chars, so we also // take the rejection sampling path. let word: String = iter::repeat(()) - .map(|()| rng.gen::()) + .map(|()| rng.random::()) .take(1000) .collect(); assert!(!word.is_empty()); @@ -315,9 +349,21 @@ mod tests { let mut incorrect = false; for _ in 0..100 { let c: char = rng.sample(Alphanumeric).into(); - incorrect |= !(('0'..='9').contains(&c) || - ('A'..='Z').contains(&c) || - ('a'..='z').contains(&c) ); + incorrect |= !c.is_ascii_alphanumeric(); + } + assert!(!incorrect); + } + + #[test] + fn test_alphabetic() { + let mut rng = crate::test::rng(806); + + // Test by generating a relatively large number of chars, so we also + // take the rejection sampling path. + let mut incorrect = false; + for _ in 0..100 { + let c: char = rng.sample(Alphabetic).into(); + incorrect |= !c.is_ascii_alphabetic(); } assert!(!incorrect); } @@ -325,7 +371,9 @@ mod tests { #[test] fn value_stability() { fn test_samples>( - distr: &D, zero: T, expected: &[T], + distr: &D, + zero: T, + expected: &[T], ) { let mut rng = crate::test::rng(807); let mut buf = [zero; 5]; @@ -335,54 +383,62 @@ mod tests { assert_eq!(&buf, expected); } - test_samples(&Standard, 'a', &[ - '\u{8cdac}', - '\u{a346a}', - '\u{80120}', - '\u{ed692}', - '\u{35888}', - ]); + test_samples( + &StandardUniform, + 'a', + &[ + '\u{8cdac}', + '\u{a346a}', + '\u{80120}', + '\u{ed692}', + '\u{35888}', + ], + ); test_samples(&Alphanumeric, 0, &[104, 109, 101, 51, 77]); - test_samples(&Standard, false, &[true, true, false, true, false]); - test_samples(&Standard, None as Option, &[ - Some(true), - None, - Some(false), - None, - Some(false), - ]); - test_samples(&Standard, Wrapping(0i32), &[ - Wrapping(-2074640887), - Wrapping(-1719949321), - Wrapping(2018088303), - Wrapping(-547181756), - Wrapping(838957336), - ]); + test_samples(&Alphabetic, 0, &[97, 102, 89, 116, 75]); + test_samples(&StandardUniform, false, &[true, true, false, true, false]); + test_samples( + &StandardUniform, + Wrapping(0i32), + &[ + Wrapping(-2074640887), + Wrapping(-1719949321), + Wrapping(2018088303), + Wrapping(-547181756), + Wrapping(838957336), + ], + ); // We test only sub-sets of tuple and array impls - test_samples(&Standard, (), &[(), (), (), (), ()]); - test_samples(&Standard, (false,), &[ - (true,), - (true,), - (false,), - (true,), + test_samples(&StandardUniform, (), &[(), (), (), (), ()]); + test_samples( + &StandardUniform, (false,), - ]); - test_samples(&Standard, (false, false), &[ - (true, true), - (false, true), - (false, false), - (true, false), + &[(true,), (true,), (false,), (true,), (false,)], + ); + test_samples( + &StandardUniform, (false, false), - ]); - - test_samples(&Standard, [0u8; 0], &[[], [], [], [], []]); - test_samples(&Standard, [0u8; 3], &[ - [9, 247, 111], - [68, 24, 13], - [174, 19, 194], - [172, 69, 213], - [149, 207, 29], - ]); + &[ + (true, true), + (false, true), + (false, false), + (true, false), + (false, false), + ], + ); + + test_samples(&StandardUniform, [0u8; 0], &[[], [], [], [], []]); + test_samples( + &StandardUniform, + [0u8; 3], + &[ + [9, 247, 111], + [68, 24, 13], + [174, 19, 194], + [172, 69, 213], + [149, 207, 29], + ], + ); } } diff --git a/src/distributions/slice.rs b/src/distr/slice.rs similarity index 52% rename from src/distributions/slice.rs rename to src/distr/slice.rs index 5fc08751f6c..07e243fec5d 100644 --- a/src/distributions/slice.rs +++ b/src/distr/slice.rs @@ -6,40 +6,35 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use crate::distributions::{Distribution, Uniform}; +//! Distributions over slices + +use core::num::NonZeroUsize; + +use crate::distr::uniform::{UniformSampler, UniformUsize}; +use crate::distr::Distribution; #[cfg(feature = "alloc")] use alloc::string::String; -/// A distribution to sample items uniformly from a slice. -/// -/// [`Slice::new`] constructs a distribution referencing a slice and uniformly -/// samples references from the items in the slice. It may do extra work up -/// front to make sampling of multiple values faster; if only one sample from -/// the slice is required, [`IndexedRandom::choose`] can be more efficient. -/// -/// Steps are taken to avoid bias which might be present in naive -/// implementations; for example `slice[rng.gen() % slice.len()]` samples from -/// the slice, but may be more likely to select numbers in the low range than -/// other values. +/// A distribution to uniformly sample elements of a slice /// -/// This distribution samples with replacement; each sample is independent. -/// Sampling without replacement requires state to be retained, and therefore -/// cannot be handled by a distribution; you should instead consider methods -/// on [`IndexedRandom`], such as [`IndexedRandom::choose_multiple`]. +/// Like [`IndexedRandom::choose`], this uniformly samples elements of a slice +/// without modification of the slice (so called "sampling with replacement"). +/// This distribution object may be a little faster for repeated sampling (but +/// slower for small numbers of samples). /// -/// # Example +/// ## Examples /// +/// Since this is a distribution, [`Rng::sample_iter`] and +/// [`Distribution::sample_iter`] may be used, for example: /// ``` -/// use rand::Rng; -/// use rand::distributions::Slice; +/// use rand::distr::{Distribution, slice::Choose}; /// /// let vowels = ['a', 'e', 'i', 'o', 'u']; -/// let vowels_dist = Slice::new(&vowels).unwrap(); -/// let rng = rand::thread_rng(); +/// let vowels_dist = Choose::new(&vowels).unwrap(); /// /// // build a string of 10 vowels -/// let vowel_string: String = rng -/// .sample_iter(&vowels_dist) +/// let vowel_string: String = vowels_dist +/// .sample_iter(&mut rand::rng()) /// .take(10) /// .collect(); /// @@ -48,42 +43,46 @@ use alloc::string::String; /// assert!(vowel_string.chars().all(|c| vowels.contains(&c))); /// ``` /// -/// For a single sample, [`IndexedRandom::choose`][crate::seq::IndexedRandom::choose] -/// may be preferred: -/// +/// For a single sample, [`IndexedRandom::choose`] may be preferred: /// ``` /// use rand::seq::IndexedRandom; /// /// let vowels = ['a', 'e', 'i', 'o', 'u']; -/// let mut rng = rand::thread_rng(); +/// let mut rng = rand::rng(); /// -/// println!("{}", vowels.choose(&mut rng).unwrap()) +/// println!("{}", vowels.choose(&mut rng).unwrap()); /// ``` /// -/// [`IndexedRandom`]: crate::seq::IndexedRandom /// [`IndexedRandom::choose`]: crate::seq::IndexedRandom::choose -/// [`IndexedRandom::choose_multiple`]: crate::seq::IndexedRandom::choose_multiple +/// [`Rng::sample_iter`]: crate::Rng::sample_iter #[derive(Debug, Clone, Copy)] -pub struct Slice<'a, T> { +pub struct Choose<'a, T> { slice: &'a [T], - range: Uniform, + range: UniformUsize, + num_choices: NonZeroUsize, } -impl<'a, T> Slice<'a, T> { - /// Create a new `Slice` instance which samples uniformly from the slice. - /// Returns `Err` if the slice is empty. - pub fn new(slice: &'a [T]) -> Result { - match slice.len() { - 0 => Err(EmptySlice), - len => Ok(Self { - slice, - range: Uniform::new(0, len).unwrap(), - }), - } +impl<'a, T> Choose<'a, T> { + /// Create a new `Choose` instance which samples uniformly from the slice. + /// + /// Returns error [`Empty`] if the slice is empty. + pub fn new(slice: &'a [T]) -> Result { + let num_choices = NonZeroUsize::new(slice.len()).ok_or(Empty)?; + + Ok(Self { + slice, + range: UniformUsize::new(0, num_choices.get()).unwrap(), + num_choices, + }) + } + + /// Returns the count of choices in this distribution + pub fn num_choices(&self) -> NonZeroUsize { + self.num_choices } } -impl<'a, T> Distribution<&'a T> for Slice<'a, T> { +impl<'a, T> Distribution<&'a T> for Choose<'a, T> { fn sample(&self, rng: &mut R) -> &'a T { let idx = self.range.sample(rng); @@ -101,27 +100,26 @@ impl<'a, T> Distribution<&'a T> for Slice<'a, T> { } } -/// Error type indicating that a [`Slice`] distribution was improperly -/// constructed with an empty slice. +/// Error: empty slice +/// +/// This error is returned when [`Choose::new`] is given an empty slice. #[derive(Debug, Clone, Copy)] -pub struct EmptySlice; +pub struct Empty; -impl core::fmt::Display for EmptySlice { +impl core::fmt::Display for Empty { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { write!( f, - "Tried to create a `distributions::Slice` with an empty slice" + "Tried to create a `rand::distr::slice::Choose` with an empty slice" ) } } #[cfg(feature = "std")] -impl std::error::Error for EmptySlice {} +impl std::error::Error for Empty {} -/// Note: the `String` is potentially left with excess capacity; optionally the -/// user may call `string.shrink_to_fit()` afterwards. #[cfg(feature = "alloc")] -impl<'a> super::DistString for Slice<'a, char> { +impl super::SampleString for Choose<'_, char> { fn append_string(&self, rng: &mut R, string: &mut String, len: usize) { // Get the max char length to minimize extra space. // Limit this check to avoid searching for long slice. @@ -139,7 +137,11 @@ impl<'a> super::DistString for Slice<'a, char> { // Split the extension of string to reuse the unused capacities. // Skip the split for small length or only ascii slice. - let mut extend_len = if max_char_len == 1 || len < 100 { len } else { len / 4 }; + let mut extend_len = if max_char_len == 1 || len < 100 { + len + } else { + len / 4 + }; let mut remain_len = len; while extend_len > 0 { string.reserve(max_char_len * extend_len); @@ -149,3 +151,17 @@ impl<'a> super::DistString for Slice<'a, char> { } } } + +#[cfg(test)] +mod test { + use super::*; + use core::iter; + + #[test] + fn value_stability() { + let rng = crate::test::rng(651); + let slice = Choose::new(b"escaped emus explore extensively").unwrap(); + let expected = b"eaxee"; + assert!(iter::zip(slice.sample_iter(rng), expected).all(|(a, b)| a == b)); + } +} diff --git a/src/distr/uniform.rs b/src/distr/uniform.rs new file mode 100644 index 00000000000..b59fdbf790b --- /dev/null +++ b/src/distr/uniform.rs @@ -0,0 +1,622 @@ +// Copyright 2018-2020 Developers of the Rand project. +// Copyright 2017 The Rust Project Developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! A distribution uniformly sampling numbers within a given range. +//! +//! [`Uniform`] is the standard distribution to sample uniformly from a range; +//! e.g. `Uniform::new_inclusive(1, 6).unwrap()` can sample integers from 1 to 6, like a +//! standard die. [`Rng::random_range`] is implemented over [`Uniform`]. +//! +//! # Example usage +//! +//! ``` +//! use rand::Rng; +//! use rand::distr::Uniform; +//! +//! let mut rng = rand::rng(); +//! let side = Uniform::new(-10.0, 10.0).unwrap(); +//! +//! // sample between 1 and 10 points +//! for _ in 0..rng.random_range(1..=10) { +//! // sample a point from the square with sides -10 - 10 in two dimensions +//! let (x, y) = (rng.sample(side), rng.sample(side)); +//! println!("Point: {}, {}", x, y); +//! } +//! ``` +//! +//! # Extending `Uniform` to support a custom type +//! +//! To extend [`Uniform`] to support your own types, write a back-end which +//! implements the [`UniformSampler`] trait, then implement the [`SampleUniform`] +//! helper trait to "register" your back-end. See the `MyF32` example below. +//! +//! At a minimum, the back-end needs to store any parameters needed for sampling +//! (e.g. the target range) and implement `new`, `new_inclusive` and `sample`. +//! Those methods should include an assertion to check the range is valid (i.e. +//! `low < high`). The example below merely wraps another back-end. +//! +//! The `new`, `new_inclusive`, `sample_single` and `sample_single_inclusive` +//! functions use arguments of +//! type `SampleBorrow` to support passing in values by reference or +//! by value. In the implementation of these functions, you can choose to +//! simply use the reference returned by [`SampleBorrow::borrow`], or you can choose +//! to copy or clone the value, whatever is appropriate for your type. +//! +//! ``` +//! use rand::prelude::*; +//! use rand::distr::uniform::{Uniform, SampleUniform, +//! UniformSampler, UniformFloat, SampleBorrow, Error}; +//! +//! struct MyF32(f32); +//! +//! #[derive(Clone, Copy, Debug)] +//! struct UniformMyF32(UniformFloat); +//! +//! impl UniformSampler for UniformMyF32 { +//! type X = MyF32; +//! +//! fn new(low: B1, high: B2) -> Result +//! where B1: SampleBorrow + Sized, +//! B2: SampleBorrow + Sized +//! { +//! UniformFloat::::new(low.borrow().0, high.borrow().0).map(UniformMyF32) +//! } +//! fn new_inclusive(low: B1, high: B2) -> Result +//! where B1: SampleBorrow + Sized, +//! B2: SampleBorrow + Sized +//! { +//! UniformFloat::::new_inclusive(low.borrow().0, high.borrow().0).map(UniformMyF32) +//! } +//! fn sample(&self, rng: &mut R) -> Self::X { +//! MyF32(self.0.sample(rng)) +//! } +//! } +//! +//! impl SampleUniform for MyF32 { +//! type Sampler = UniformMyF32; +//! } +//! +//! let (low, high) = (MyF32(17.0f32), MyF32(22.0f32)); +//! let uniform = Uniform::new(low, high).unwrap(); +//! let x = uniform.sample(&mut rand::rng()); +//! ``` +//! +//! [`SampleUniform`]: crate::distr::uniform::SampleUniform +//! [`UniformSampler`]: crate::distr::uniform::UniformSampler +//! [`UniformInt`]: crate::distr::uniform::UniformInt +//! [`UniformFloat`]: crate::distr::uniform::UniformFloat +//! [`UniformDuration`]: crate::distr::uniform::UniformDuration +//! [`SampleBorrow::borrow`]: crate::distr::uniform::SampleBorrow::borrow + +#[path = "uniform_float.rs"] +mod float; +#[doc(inline)] +pub use float::UniformFloat; + +#[path = "uniform_int.rs"] +mod int; +#[doc(inline)] +pub use int::{UniformInt, UniformUsize}; + +#[path = "uniform_other.rs"] +mod other; +#[doc(inline)] +pub use other::{UniformChar, UniformDuration}; + +use core::fmt; +use core::ops::{Range, RangeInclusive, RangeTo, RangeToInclusive}; + +use crate::distr::Distribution; +use crate::{Rng, RngCore}; + +/// Error type returned from [`Uniform::new`] and `new_inclusive`. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Error { + /// `low > high`, or equal in case of exclusive range. + EmptyRange, + /// Input or range `high - low` is non-finite. Not relevant to integer types. + NonFinite, +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + Error::EmptyRange => "low > high (or equal if exclusive) in uniform distribution", + Error::NonFinite => "Non-finite range in uniform distribution", + }) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for Error {} + +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +/// Sample values uniformly between two bounds. +/// +/// # Construction +/// +/// [`Uniform::new`] and [`Uniform::new_inclusive`] construct a uniform +/// distribution sampling from the given `low` and `high` limits. `Uniform` may +/// also be constructed via [`TryFrom`] as in `Uniform::try_from(1..=6).unwrap()`. +/// +/// Constructors may do extra work up front to allow faster sampling of multiple +/// values. Where only a single sample is required it is suggested to use +/// [`Rng::random_range`] or one of the `sample_single` methods instead. +/// +/// When sampling from a constant range, many calculations can happen at +/// compile-time and all methods should be fast; for floating-point ranges and +/// the full range of integer types, this should have comparable performance to +/// the [`StandardUniform`](super::StandardUniform) distribution. +/// +/// # Provided implementations +/// +/// - `char` ([`UniformChar`]): samples a range over the implementation for `u32` +/// - `f32`, `f64` ([`UniformFloat`]): samples approximately uniformly within a +/// range; bias may be present in the least-significant bit of the significand +/// and the limits of the input range may be sampled even when an open +/// (exclusive) range is used +/// - Integer types ([`UniformInt`]) may show a small bias relative to the +/// expected uniform distribution of output. In the worst case, bias affects +/// 1 in `2^n` samples where n is 56 (`i8` and `u8`), 48 (`i16` and `u16`), 96 +/// (`i32` and `u32`), 64 (`i64` and `u64`), 128 (`i128` and `u128`). +/// The `unbiased` feature flag fixes this bias. +/// - `usize` ([`UniformUsize`]) is handled specially, using the `u32` +/// implementation where possible to enable portable results across 32-bit and +/// 64-bit CPU architectures. +/// - `Duration` ([`UniformDuration`]): samples a range over the implementation +/// for `u32` or `u64` +/// - SIMD types (requires [`simd_support`] feature) like x86's [`__m128i`] +/// and `std::simd`'s [`u32x4`], [`f32x4`] and [`mask32x4`] types are +/// effectively arrays of integer or floating-point types. Each lane is +/// sampled independently from its own range, potentially with more efficient +/// random-bit-usage than would be achieved with sequential sampling. +/// +/// # Example +/// +/// ``` +/// use rand::distr::{Distribution, Uniform}; +/// +/// let between = Uniform::try_from(10..10000).unwrap(); +/// let mut rng = rand::rng(); +/// let mut sum = 0; +/// for _ in 0..1000 { +/// sum += between.sample(&mut rng); +/// } +/// println!("{}", sum); +/// ``` +/// +/// For a single sample, [`Rng::random_range`] may be preferred: +/// +/// ``` +/// use rand::Rng; +/// +/// let mut rng = rand::rng(); +/// println!("{}", rng.random_range(0..10)); +/// ``` +/// +/// [`new`]: Uniform::new +/// [`new_inclusive`]: Uniform::new_inclusive +/// [`Rng::random_range`]: Rng::random_range +/// [`__m128i`]: https://doc.rust-lang.org/core/arch/x86/struct.__m128i.html +/// [`u32x4`]: std::simd::u32x4 +/// [`f32x4`]: std::simd::f32x4 +/// [`mask32x4`]: std::simd::mask32x4 +/// [`simd_support`]: https://github.com/rust-random/rand#crate-features +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", serde(bound(serialize = "X::Sampler: Serialize")))] +#[cfg_attr( + feature = "serde", + serde(bound(deserialize = "X::Sampler: Deserialize<'de>")) +)] +pub struct Uniform(X::Sampler); + +impl Uniform { + /// Create a new `Uniform` instance, which samples uniformly from the half + /// open range `[low, high)` (excluding `high`). + /// + /// For discrete types (e.g. integers), samples will always be strictly less + /// than `high`. For (approximations of) continuous types (e.g. `f32`, `f64`), + /// samples may equal `high` due to loss of precision but may not be + /// greater than `high`. + /// + /// Fails if `low >= high`, or if `low`, `high` or the range `high - low` is + /// non-finite. In release mode, only the range is checked. + pub fn new(low: B1, high: B2) -> Result, Error> + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + X::Sampler::new(low, high).map(Uniform) + } + + /// Create a new `Uniform` instance, which samples uniformly from the closed + /// range `[low, high]` (inclusive). + /// + /// Fails if `low > high`, or if `low`, `high` or the range `high - low` is + /// non-finite. In release mode, only the range is checked. + pub fn new_inclusive(low: B1, high: B2) -> Result, Error> + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + X::Sampler::new_inclusive(low, high).map(Uniform) + } +} + +impl Distribution for Uniform { + fn sample(&self, rng: &mut R) -> X { + self.0.sample(rng) + } +} + +/// Helper trait for creating objects using the correct implementation of +/// [`UniformSampler`] for the sampling type. +/// +/// See the [module documentation] on how to implement [`Uniform`] range +/// sampling for a custom type. +/// +/// [module documentation]: crate::distr::uniform +pub trait SampleUniform: Sized { + /// The `UniformSampler` implementation supporting type `X`. + type Sampler: UniformSampler; +} + +/// Helper trait handling actual uniform sampling. +/// +/// See the [module documentation] on how to implement [`Uniform`] range +/// sampling for a custom type. +/// +/// Implementation of [`sample_single`] is optional, and is only useful when +/// the implementation can be faster than `Self::new(low, high).sample(rng)`. +/// +/// [module documentation]: crate::distr::uniform +/// [`sample_single`]: UniformSampler::sample_single +pub trait UniformSampler: Sized { + /// The type sampled by this implementation. + type X; + + /// Construct self, with inclusive lower bound and exclusive upper bound `[low, high)`. + /// + /// For discrete types (e.g. integers), samples will always be strictly less + /// than `high`. For (approximations of) continuous types (e.g. `f32`, `f64`), + /// samples may equal `high` due to loss of precision but may not be + /// greater than `high`. + /// + /// Usually users should not call this directly but prefer to use + /// [`Uniform::new`]. + fn new(low: B1, high: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized; + + /// Construct self, with inclusive bounds `[low, high]`. + /// + /// Usually users should not call this directly but prefer to use + /// [`Uniform::new_inclusive`]. + fn new_inclusive(low: B1, high: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized; + + /// Sample a value. + fn sample(&self, rng: &mut R) -> Self::X; + + /// Sample a single value uniformly from a range with inclusive lower bound + /// and exclusive upper bound `[low, high)`. + /// + /// For discrete types (e.g. integers), samples will always be strictly less + /// than `high`. For (approximations of) continuous types (e.g. `f32`, `f64`), + /// samples may equal `high` due to loss of precision but may not be + /// greater than `high`. + /// + /// By default this is implemented using + /// `UniformSampler::new(low, high).sample(rng)`. However, for some types + /// more optimal implementations for single usage may be provided via this + /// method (which is the case for integers and floats). + /// Results may not be identical. + /// + /// Note that to use this method in a generic context, the type needs to be + /// retrieved via `SampleUniform::Sampler` as follows: + /// ``` + /// use rand::distr::uniform::{SampleUniform, UniformSampler}; + /// # #[allow(unused)] + /// fn sample_from_range(lb: T, ub: T) -> T { + /// let mut rng = rand::rng(); + /// ::Sampler::sample_single(lb, ub, &mut rng).unwrap() + /// } + /// ``` + fn sample_single( + low: B1, + high: B2, + rng: &mut R, + ) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let uniform: Self = UniformSampler::new(low, high)?; + Ok(uniform.sample(rng)) + } + + /// Sample a single value uniformly from a range with inclusive lower bound + /// and inclusive upper bound `[low, high]`. + /// + /// By default this is implemented using + /// `UniformSampler::new_inclusive(low, high).sample(rng)`. However, for + /// some types more optimal implementations for single usage may be provided + /// via this method. + /// Results may not be identical. + fn sample_single_inclusive( + low: B1, + high: B2, + rng: &mut R, + ) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let uniform: Self = UniformSampler::new_inclusive(low, high)?; + Ok(uniform.sample(rng)) + } +} + +impl TryFrom> for Uniform { + type Error = Error; + + fn try_from(r: Range) -> Result, Error> { + Uniform::new(r.start, r.end) + } +} + +impl TryFrom> for Uniform { + type Error = Error; + + fn try_from(r: ::core::ops::RangeInclusive) -> Result, Error> { + Uniform::new_inclusive(r.start(), r.end()) + } +} + +/// Helper trait similar to [`Borrow`] but implemented +/// only for [`SampleUniform`] and references to [`SampleUniform`] +/// in order to resolve ambiguity issues. +/// +/// [`Borrow`]: std::borrow::Borrow +pub trait SampleBorrow { + /// Immutably borrows from an owned value. See [`Borrow::borrow`] + /// + /// [`Borrow::borrow`]: std::borrow::Borrow::borrow + fn borrow(&self) -> &Borrowed; +} +impl SampleBorrow for Borrowed +where + Borrowed: SampleUniform, +{ + #[inline(always)] + fn borrow(&self) -> &Borrowed { + self + } +} +impl SampleBorrow for &Borrowed +where + Borrowed: SampleUniform, +{ + #[inline(always)] + fn borrow(&self) -> &Borrowed { + self + } +} + +/// Range that supports generating a single sample efficiently. +/// +/// Any type implementing this trait can be used to specify the sampled range +/// for `Rng::random_range`. +pub trait SampleRange { + /// Generate a sample from the given range. + fn sample_single(self, rng: &mut R) -> Result; + + /// Check whether the range is empty. + fn is_empty(&self) -> bool; +} + +impl SampleRange for Range { + #[inline] + fn sample_single(self, rng: &mut R) -> Result { + T::Sampler::sample_single(self.start, self.end, rng) + } + + #[inline] + fn is_empty(&self) -> bool { + !(self.start < self.end) + } +} + +impl SampleRange for RangeInclusive { + #[inline] + fn sample_single(self, rng: &mut R) -> Result { + T::Sampler::sample_single_inclusive(self.start(), self.end(), rng) + } + + #[inline] + fn is_empty(&self) -> bool { + !(self.start() <= self.end()) + } +} + +macro_rules! impl_sample_range_u { + ($t:ty) => { + impl SampleRange<$t> for RangeTo<$t> { + #[inline] + fn sample_single(self, rng: &mut R) -> Result<$t, Error> { + <$t as SampleUniform>::Sampler::sample_single(0, self.end, rng) + } + + #[inline] + fn is_empty(&self) -> bool { + 0 == self.end + } + } + + impl SampleRange<$t> for RangeToInclusive<$t> { + #[inline] + fn sample_single(self, rng: &mut R) -> Result<$t, Error> { + <$t as SampleUniform>::Sampler::sample_single_inclusive(0, self.end, rng) + } + + #[inline] + fn is_empty(&self) -> bool { + false + } + } + }; +} + +impl_sample_range_u!(u8); +impl_sample_range_u!(u16); +impl_sample_range_u!(u32); +impl_sample_range_u!(u64); +impl_sample_range_u!(u128); +impl_sample_range_u!(usize); + +#[cfg(test)] +mod tests { + use super::*; + use core::time::Duration; + + #[test] + #[cfg(feature = "serde")] + fn test_uniform_serialization() { + let unit_box: Uniform = Uniform::new(-1, 1).unwrap(); + let de_unit_box: Uniform = + bincode::deserialize(&bincode::serialize(&unit_box).unwrap()).unwrap(); + assert_eq!(unit_box.0, de_unit_box.0); + + let unit_box: Uniform = Uniform::new(-1., 1.).unwrap(); + let de_unit_box: Uniform = + bincode::deserialize(&bincode::serialize(&unit_box).unwrap()).unwrap(); + assert_eq!(unit_box.0, de_unit_box.0); + } + + #[test] + fn test_custom_uniform() { + use crate::distr::uniform::{SampleBorrow, SampleUniform, UniformFloat, UniformSampler}; + #[derive(Clone, Copy, PartialEq, PartialOrd)] + struct MyF32 { + x: f32, + } + #[derive(Clone, Copy, Debug)] + struct UniformMyF32(UniformFloat); + impl UniformSampler for UniformMyF32 { + type X = MyF32; + + fn new(low: B1, high: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + UniformFloat::::new(low.borrow().x, high.borrow().x).map(UniformMyF32) + } + + fn new_inclusive(low: B1, high: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + UniformSampler::new(low, high) + } + + fn sample(&self, rng: &mut R) -> Self::X { + MyF32 { + x: self.0.sample(rng), + } + } + } + impl SampleUniform for MyF32 { + type Sampler = UniformMyF32; + } + + let (low, high) = (MyF32 { x: 17.0f32 }, MyF32 { x: 22.0f32 }); + let uniform = Uniform::new(low, high).unwrap(); + let mut rng = crate::test::rng(804); + for _ in 0..100 { + let x: MyF32 = rng.sample(uniform); + assert!(low <= x && x < high); + } + } + + #[test] + fn value_stability() { + fn test_samples( + lb: T, + ub: T, + expected_single: &[T], + expected_multiple: &[T], + ) where + Uniform: Distribution, + { + let mut rng = crate::test::rng(897); + let mut buf = [lb; 3]; + + for x in &mut buf { + *x = T::Sampler::sample_single(lb, ub, &mut rng).unwrap(); + } + assert_eq!(&buf, expected_single); + + let distr = Uniform::new(lb, ub).unwrap(); + for x in &mut buf { + *x = rng.sample(&distr); + } + assert_eq!(&buf, expected_multiple); + } + + test_samples( + 0f32, + 1e-2f32, + &[0.0003070104, 0.0026630748, 0.00979833], + &[0.008194133, 0.00398172, 0.007428536], + ); + test_samples( + -1e10f64, + 1e10f64, + &[-4673848682.871551, 6388267422.932352, 4857075081.198343], + &[1173375212.1808167, 1917642852.109581, 2365076174.3153973], + ); + + test_samples( + Duration::new(2, 0), + Duration::new(4, 0), + &[ + Duration::new(2, 532615131), + Duration::new(3, 638826742), + Duration::new(3, 485707508), + ], + &[ + Duration::new(3, 117337521), + Duration::new(3, 191764285), + Duration::new(3, 236507617), + ], + ); + } + + #[test] + fn uniform_distributions_can_be_compared() { + assert_eq!( + Uniform::new(1.0, 2.0).unwrap(), + Uniform::new(1.0, 2.0).unwrap() + ); + + // To cover UniformInt + assert_eq!( + Uniform::new(1_u32, 2_u32).unwrap(), + Uniform::new(1_u32, 2_u32).unwrap() + ); + } +} diff --git a/src/distr/uniform_float.rs b/src/distr/uniform_float.rs new file mode 100644 index 00000000000..e9b0421aaf0 --- /dev/null +++ b/src/distr/uniform_float.rs @@ -0,0 +1,454 @@ +// Copyright 2018-2020 Developers of the Rand project. +// Copyright 2017 The Rust Project Developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! `UniformFloat` implementation + +use super::{Error, SampleBorrow, SampleUniform, UniformSampler}; +use crate::distr::float::IntoFloat; +use crate::distr::utils::{BoolAsSIMD, FloatAsSIMD, FloatSIMDUtils, IntAsSIMD}; +use crate::Rng; + +#[cfg(feature = "simd_support")] +use core::simd::prelude::*; +// #[cfg(feature = "simd_support")] +// use core::simd::{LaneCount, SupportedLaneCount}; + +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +/// The back-end implementing [`UniformSampler`] for floating-point types. +/// +/// Unless you are implementing [`UniformSampler`] for your own type, this type +/// should not be used directly, use [`Uniform`] instead. +/// +/// # Implementation notes +/// +/// `UniformFloat` implementations convert RNG output to a float in the range +/// `[1, 2)` via transmutation, map this to `[0, 1)`, then scale and translate +/// to the desired range. Values produced this way have what equals 23 bits of +/// random digits for an `f32` and 52 for an `f64`. +/// +/// # Bias and range errors +/// +/// Bias may be expected within the least-significant bit of the significand. +/// It is not guaranteed that exclusive limits of a range are respected; i.e. +/// when sampling the range `[a, b)` it is not guaranteed that `b` is never +/// sampled. +/// +/// [`new`]: UniformSampler::new +/// [`new_inclusive`]: UniformSampler::new_inclusive +/// [`StandardUniform`]: crate::distr::StandardUniform +/// [`Uniform`]: super::Uniform +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct UniformFloat { + low: X, + scale: X, +} + +macro_rules! uniform_float_impl { + ($($meta:meta)?, $ty:ty, $uty:ident, $f_scalar:ident, $u_scalar:ident, $bits_to_discard:expr) => { + $(#[cfg($meta)])? + impl UniformFloat<$ty> { + /// Construct, reducing `scale` as required to ensure that rounding + /// can never yield values greater than `high`. + /// + /// Note: though it may be tempting to use a variant of this method + /// to ensure that samples from `[low, high)` are always strictly + /// less than `high`, this approach may be very slow where + /// `scale.abs()` is much smaller than `high.abs()` + /// (example: `low=0.99999999997819644, high=1.`). + fn new_bounded(low: $ty, high: $ty, mut scale: $ty) -> Self { + let max_rand = <$ty>::splat(1.0 as $f_scalar - $f_scalar::EPSILON); + + loop { + let mask = (scale * max_rand + low).gt_mask(high); + if !mask.any() { + break; + } + scale = scale.decrease_masked(mask); + } + + debug_assert!(<$ty>::splat(0.0).all_le(scale)); + + UniformFloat { low, scale } + } + } + + $(#[cfg($meta)])? + impl SampleUniform for $ty { + type Sampler = UniformFloat<$ty>; + } + + $(#[cfg($meta)])? + impl UniformSampler for UniformFloat<$ty> { + type X = $ty; + + fn new(low_b: B1, high_b: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + #[cfg(debug_assertions)] + if !(low.all_finite()) || !(high.all_finite()) { + return Err(Error::NonFinite); + } + if !(low.all_lt(high)) { + return Err(Error::EmptyRange); + } + + let scale = high - low; + if !(scale.all_finite()) { + return Err(Error::NonFinite); + } + + Ok(Self::new_bounded(low, high, scale)) + } + + fn new_inclusive(low_b: B1, high_b: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + #[cfg(debug_assertions)] + if !(low.all_finite()) || !(high.all_finite()) { + return Err(Error::NonFinite); + } + if !low.all_le(high) { + return Err(Error::EmptyRange); + } + + let max_rand = <$ty>::splat(1.0 as $f_scalar - $f_scalar::EPSILON); + let scale = (high - low) / max_rand; + if !scale.all_finite() { + return Err(Error::NonFinite); + } + + Ok(Self::new_bounded(low, high, scale)) + } + + fn sample(&self, rng: &mut R) -> Self::X { + // Generate a value in the range [1, 2) + let value1_2 = (rng.random::<$uty>() >> $uty::splat($bits_to_discard)).into_float_with_exponent(0); + + // Get a value in the range [0, 1) to avoid overflow when multiplying by scale + let value0_1 = value1_2 - <$ty>::splat(1.0); + + // We don't use `f64::mul_add`, because it is not available with + // `no_std`. Furthermore, it is slower for some targets (but + // faster for others). However, the order of multiplication and + // addition is important, because on some platforms (e.g. ARM) + // it will be optimized to a single (non-FMA) instruction. + value0_1 * self.scale + self.low + } + + #[inline] + fn sample_single(low_b: B1, high_b: B2, rng: &mut R) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + Self::sample_single_inclusive(low_b, high_b, rng) + } + + #[inline] + fn sample_single_inclusive(low_b: B1, high_b: B2, rng: &mut R) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + #[cfg(debug_assertions)] + if !low.all_finite() || !high.all_finite() { + return Err(Error::NonFinite); + } + if !low.all_le(high) { + return Err(Error::EmptyRange); + } + let scale = high - low; + if !scale.all_finite() { + return Err(Error::NonFinite); + } + + // Generate a value in the range [1, 2) + let value1_2 = + (rng.random::<$uty>() >> $uty::splat($bits_to_discard)).into_float_with_exponent(0); + + // Get a value in the range [0, 1) to avoid overflow when multiplying by scale + let value0_1 = value1_2 - <$ty>::splat(1.0); + + // Doing multiply before addition allows some architectures + // to use a single instruction. + Ok(value0_1 * scale + low) + } + } + }; +} + +uniform_float_impl! { , f32, u32, f32, u32, 32 - 23 } +uniform_float_impl! { , f64, u64, f64, u64, 64 - 52 } + +#[cfg(feature = "simd_support")] +uniform_float_impl! { feature = "simd_support", f32x2, u32x2, f32, u32, 32 - 23 } +#[cfg(feature = "simd_support")] +uniform_float_impl! { feature = "simd_support", f32x4, u32x4, f32, u32, 32 - 23 } +#[cfg(feature = "simd_support")] +uniform_float_impl! { feature = "simd_support", f32x8, u32x8, f32, u32, 32 - 23 } +#[cfg(feature = "simd_support")] +uniform_float_impl! { feature = "simd_support", f32x16, u32x16, f32, u32, 32 - 23 } + +#[cfg(feature = "simd_support")] +uniform_float_impl! { feature = "simd_support", f64x2, u64x2, f64, u64, 64 - 52 } +#[cfg(feature = "simd_support")] +uniform_float_impl! { feature = "simd_support", f64x4, u64x4, f64, u64, 64 - 52 } +#[cfg(feature = "simd_support")] +uniform_float_impl! { feature = "simd_support", f64x8, u64x8, f64, u64, 64 - 52 } + +#[cfg(test)] +mod tests { + use super::*; + use crate::distr::{utils::FloatSIMDScalarUtils, Uniform}; + use crate::test::{const_rng, step_rng}; + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_floats() { + let mut rng = crate::test::rng(252); + let mut zero_rng = const_rng(0); + let mut max_rng = const_rng(0xffff_ffff_ffff_ffff); + macro_rules! t { + ($ty:ty, $f_scalar:ident, $bits_shifted:expr) => {{ + let v: &[($f_scalar, $f_scalar)] = &[ + (0.0, 100.0), + (-1e35, -1e25), + (1e-35, 1e-25), + (-1e35, 1e35), + (<$f_scalar>::from_bits(0), <$f_scalar>::from_bits(3)), + (-<$f_scalar>::from_bits(10), -<$f_scalar>::from_bits(1)), + (-<$f_scalar>::from_bits(5), 0.0), + (-<$f_scalar>::from_bits(7), -0.0), + (0.1 * $f_scalar::MAX, $f_scalar::MAX), + (-$f_scalar::MAX * 0.2, $f_scalar::MAX * 0.7), + ]; + for &(low_scalar, high_scalar) in v.iter() { + for lane in 0..<$ty>::LEN { + let low = <$ty>::splat(0.0 as $f_scalar).replace(lane, low_scalar); + let high = <$ty>::splat(1.0 as $f_scalar).replace(lane, high_scalar); + let my_uniform = Uniform::new(low, high).unwrap(); + let my_incl_uniform = Uniform::new_inclusive(low, high).unwrap(); + for _ in 0..100 { + let v = rng.sample(my_uniform).extract_lane(lane); + assert!(low_scalar <= v && v <= high_scalar); + let v = rng.sample(my_incl_uniform).extract_lane(lane); + assert!(low_scalar <= v && v <= high_scalar); + let v = + <$ty as SampleUniform>::Sampler::sample_single(low, high, &mut rng) + .unwrap() + .extract_lane(lane); + assert!(low_scalar <= v && v <= high_scalar); + let v = <$ty as SampleUniform>::Sampler::sample_single_inclusive( + low, high, &mut rng, + ) + .unwrap() + .extract_lane(lane); + assert!(low_scalar <= v && v <= high_scalar); + } + + assert_eq!( + rng.sample(Uniform::new_inclusive(low, low).unwrap()) + .extract_lane(lane), + low_scalar + ); + + assert_eq!(zero_rng.sample(my_uniform).extract_lane(lane), low_scalar); + assert_eq!( + zero_rng.sample(my_incl_uniform).extract_lane(lane), + low_scalar + ); + assert_eq!( + <$ty as SampleUniform>::Sampler::sample_single( + low, + high, + &mut zero_rng + ) + .unwrap() + .extract_lane(lane), + low_scalar + ); + assert_eq!( + <$ty as SampleUniform>::Sampler::sample_single_inclusive( + low, + high, + &mut zero_rng + ) + .unwrap() + .extract_lane(lane), + low_scalar + ); + + assert!(max_rng.sample(my_uniform).extract_lane(lane) <= high_scalar); + assert!(max_rng.sample(my_incl_uniform).extract_lane(lane) <= high_scalar); + // sample_single cannot cope with max_rng: + // assert!(<$ty as SampleUniform>::Sampler + // ::sample_single(low, high, &mut max_rng).unwrap() + // .extract(lane) <= high_scalar); + assert!( + <$ty as SampleUniform>::Sampler::sample_single_inclusive( + low, + high, + &mut max_rng + ) + .unwrap() + .extract_lane(lane) + <= high_scalar + ); + + // Don't run this test for really tiny differences between high and low + // since for those rounding might result in selecting high for a very + // long time. + if (high_scalar - low_scalar) > 0.0001 { + let mut lowering_max_rng = + step_rng(0xffff_ffff_ffff_ffff, (-1i64 << $bits_shifted) as u64); + assert!( + <$ty as SampleUniform>::Sampler::sample_single( + low, + high, + &mut lowering_max_rng + ) + .unwrap() + .extract_lane(lane) + <= high_scalar + ); + } + } + } + + assert_eq!( + rng.sample(Uniform::new_inclusive($f_scalar::MAX, $f_scalar::MAX).unwrap()), + $f_scalar::MAX + ); + assert_eq!( + rng.sample(Uniform::new_inclusive(-$f_scalar::MAX, -$f_scalar::MAX).unwrap()), + -$f_scalar::MAX + ); + }}; + } + + t!(f32, f32, 32 - 23); + t!(f64, f64, 64 - 52); + #[cfg(feature = "simd_support")] + { + t!(f32x2, f32, 32 - 23); + t!(f32x4, f32, 32 - 23); + t!(f32x8, f32, 32 - 23); + t!(f32x16, f32, 32 - 23); + t!(f64x2, f64, 64 - 52); + t!(f64x4, f64, 64 - 52); + t!(f64x8, f64, 64 - 52); + } + } + + #[test] + fn test_float_overflow() { + assert_eq!(Uniform::try_from(f64::MIN..f64::MAX), Err(Error::NonFinite)); + } + + #[test] + #[should_panic] + fn test_float_overflow_single() { + let mut rng = crate::test::rng(252); + rng.random_range(f64::MIN..f64::MAX); + } + + #[test] + #[cfg(all(feature = "std", panic = "unwind"))] + fn test_float_assertions() { + use super::SampleUniform; + fn range(low: T, high: T) -> Result { + let mut rng = crate::test::rng(253); + T::Sampler::sample_single(low, high, &mut rng) + } + + macro_rules! t { + ($ty:ident, $f_scalar:ident) => {{ + let v: &[($f_scalar, $f_scalar)] = &[ + ($f_scalar::NAN, 0.0), + (1.0, $f_scalar::NAN), + ($f_scalar::NAN, $f_scalar::NAN), + (1.0, 0.5), + ($f_scalar::MAX, -$f_scalar::MAX), + ($f_scalar::INFINITY, $f_scalar::INFINITY), + ($f_scalar::NEG_INFINITY, $f_scalar::NEG_INFINITY), + ($f_scalar::NEG_INFINITY, 5.0), + (5.0, $f_scalar::INFINITY), + ($f_scalar::NAN, $f_scalar::INFINITY), + ($f_scalar::NEG_INFINITY, $f_scalar::NAN), + ($f_scalar::NEG_INFINITY, $f_scalar::INFINITY), + ]; + for &(low_scalar, high_scalar) in v.iter() { + for lane in 0..<$ty>::LEN { + let low = <$ty>::splat(0.0 as $f_scalar).replace(lane, low_scalar); + let high = <$ty>::splat(1.0 as $f_scalar).replace(lane, high_scalar); + assert!(range(low, high).is_err()); + assert!(Uniform::new(low, high).is_err()); + assert!(Uniform::new_inclusive(low, high).is_err()); + assert!(Uniform::new(low, low).is_err()); + } + } + }}; + } + + t!(f32, f32); + t!(f64, f64); + #[cfg(feature = "simd_support")] + { + t!(f32x2, f32); + t!(f32x4, f32); + t!(f32x8, f32); + t!(f32x16, f32); + t!(f64x2, f64); + t!(f64x4, f64); + t!(f64x8, f64); + } + } + + #[test] + fn test_uniform_from_std_range() { + let r = Uniform::try_from(2.0f64..7.0).unwrap(); + assert_eq!(r.0.low, 2.0); + assert_eq!(r.0.scale, 5.0); + } + + #[test] + fn test_uniform_from_std_range_bad_limits() { + #![allow(clippy::reversed_empty_ranges)] + assert!(Uniform::try_from(100.0..10.0).is_err()); + assert!(Uniform::try_from(100.0..100.0).is_err()); + } + + #[test] + fn test_uniform_from_std_range_inclusive() { + let r = Uniform::try_from(2.0f64..=7.0).unwrap(); + assert_eq!(r.0.low, 2.0); + assert!(r.0.scale > 5.0); + assert!(r.0.scale < 5.0 + 1e-14); + } + + #[test] + fn test_uniform_from_std_range_inclusive_bad_limits() { + #![allow(clippy::reversed_empty_ranges)] + assert!(Uniform::try_from(100.0..=10.0).is_err()); + assert!(Uniform::try_from(100.0..=99.0).is_err()); + } +} diff --git a/src/distr/uniform_int.rs b/src/distr/uniform_int.rs new file mode 100644 index 00000000000..d5c56b02a0b --- /dev/null +++ b/src/distr/uniform_int.rs @@ -0,0 +1,796 @@ +// Copyright 2018-2020 Developers of the Rand project. +// Copyright 2017 The Rust Project Developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! `UniformInt` implementation + +use super::{Error, SampleBorrow, SampleUniform, UniformSampler}; +use crate::distr::utils::WideningMultiply; +#[cfg(feature = "simd_support")] +use crate::distr::{Distribution, StandardUniform}; +use crate::Rng; + +#[cfg(feature = "simd_support")] +use core::simd::prelude::*; +#[cfg(feature = "simd_support")] +use core::simd::{LaneCount, SupportedLaneCount}; + +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +/// The back-end implementing [`UniformSampler`] for integer types. +/// +/// Unless you are implementing [`UniformSampler`] for your own type, this type +/// should not be used directly, use [`Uniform`] instead. +/// +/// # Implementation notes +/// +/// For simplicity, we use the same generic struct `UniformInt` for all +/// integer types `X`. This gives us only one field type, `X`; to store unsigned +/// values of this size, we take use the fact that these conversions are no-ops. +/// +/// For a closed range, the number of possible numbers we should generate is +/// `range = (high - low + 1)`. To avoid bias, we must ensure that the size of +/// our sample space, `zone`, is a multiple of `range`; other values must be +/// rejected (by replacing with a new random sample). +/// +/// As a special case, we use `range = 0` to represent the full range of the +/// result type (i.e. for `new_inclusive($ty::MIN, $ty::MAX)`). +/// +/// The optimum `zone` is the largest product of `range` which fits in our +/// (unsigned) target type. We calculate this by calculating how many numbers we +/// must reject: `reject = (MAX + 1) % range = (MAX - range + 1) % range`. Any (large) +/// product of `range` will suffice, thus in `sample_single` we multiply by a +/// power of 2 via bit-shifting (faster but may cause more rejections). +/// +/// The smallest integer PRNGs generate is `u32`. For 8- and 16-bit outputs we +/// use `u32` for our `zone` and samples (because it's not slower and because +/// it reduces the chance of having to reject a sample). In this case we cannot +/// store `zone` in the target type since it is too large, however we know +/// `ints_to_reject < range <= $uty::MAX`. +/// +/// An alternative to using a modulus is widening multiply: After a widening +/// multiply by `range`, the result is in the high word. Then comparing the low +/// word against `zone` makes sure our distribution is uniform. +/// +/// # Bias +/// +/// Unless the `unbiased` feature flag is used, outputs may have a small bias. +/// In the worst case, bias affects 1 in `2^n` samples where n is +/// 56 (`i8` and `u8`), 48 (`i16` and `u16`), 96 (`i32` and `u32`), 64 (`i64` +/// and `u64`), 128 (`i128` and `u128`). +/// +/// [`Uniform`]: super::Uniform +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct UniformInt { + pub(super) low: X, + pub(super) range: X, + thresh: X, // effectively 2.pow(max(64, uty_bits)) % range +} + +macro_rules! uniform_int_impl { + ($ty:ty, $uty:ty, $sample_ty:ident) => { + impl SampleUniform for $ty { + type Sampler = UniformInt<$ty>; + } + + impl UniformSampler for UniformInt<$ty> { + // We play free and fast with unsigned vs signed here + // (when $ty is signed), but that's fine, since the + // contract of this macro is for $ty and $uty to be + // "bit-equal", so casting between them is a no-op. + + type X = $ty; + + #[inline] // if the range is constant, this helps LLVM to do the + // calculations at compile-time. + fn new(low_b: B1, high_b: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low < high) { + return Err(Error::EmptyRange); + } + UniformSampler::new_inclusive(low, high - 1) + } + + #[inline] // if the range is constant, this helps LLVM to do the + // calculations at compile-time. + fn new_inclusive(low_b: B1, high_b: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low <= high) { + return Err(Error::EmptyRange); + } + + let range = high.wrapping_sub(low).wrapping_add(1) as $uty; + let thresh = if range > 0 { + let range = $sample_ty::from(range); + (range.wrapping_neg() % range) + } else { + 0 + }; + + Ok(UniformInt { + low, + range: range as $ty, // type: $uty + thresh: thresh as $uty as $ty, // type: $sample_ty + }) + } + + /// Sample from distribution, Lemire's method, unbiased + #[inline] + fn sample(&self, rng: &mut R) -> Self::X { + let range = self.range as $uty as $sample_ty; + if range == 0 { + return rng.random(); + } + + let thresh = self.thresh as $uty as $sample_ty; + let hi = loop { + let (hi, lo) = rng.random::<$sample_ty>().wmul(range); + if lo >= thresh { + break hi; + } + }; + self.low.wrapping_add(hi as $ty) + } + + #[inline] + fn sample_single( + low_b: B1, + high_b: B2, + rng: &mut R, + ) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low < high) { + return Err(Error::EmptyRange); + } + Self::sample_single_inclusive(low, high - 1, rng) + } + + /// Sample single value, Canon's method, biased + /// + /// In the worst case, bias affects 1 in `2^n` samples where n is + /// 56 (`i8`), 48 (`i16`), 96 (`i32`), 64 (`i64`), 128 (`i128`). + #[cfg(not(feature = "unbiased"))] + #[inline] + fn sample_single_inclusive( + low_b: B1, + high_b: B2, + rng: &mut R, + ) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low <= high) { + return Err(Error::EmptyRange); + } + let range = high.wrapping_sub(low).wrapping_add(1) as $uty as $sample_ty; + if range == 0 { + // Range is MAX+1 (unrepresentable), so we need a special case + return Ok(rng.random()); + } + + // generate a sample using a sensible integer type + let (mut result, lo_order) = rng.random::<$sample_ty>().wmul(range); + + // if the sample is biased... + if lo_order > range.wrapping_neg() { + // ...generate a new sample to reduce bias... + let (new_hi_order, _) = (rng.random::<$sample_ty>()).wmul(range as $sample_ty); + // ... incrementing result on overflow + let is_overflow = lo_order.checked_add(new_hi_order as $sample_ty).is_none(); + result += is_overflow as $sample_ty; + } + + Ok(low.wrapping_add(result as $ty)) + } + + /// Sample single value, Canon's method, unbiased + #[cfg(feature = "unbiased")] + #[inline] + fn sample_single_inclusive( + low_b: B1, + high_b: B2, + rng: &mut R, + ) -> Result + where + B1: SampleBorrow<$ty> + Sized, + B2: SampleBorrow<$ty> + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low <= high) { + return Err(Error::EmptyRange); + } + let range = high.wrapping_sub(low).wrapping_add(1) as $uty as $sample_ty; + if range == 0 { + // Range is MAX+1 (unrepresentable), so we need a special case + return Ok(rng.random()); + } + + let (mut result, mut lo) = rng.random::<$sample_ty>().wmul(range); + + // In contrast to the biased sampler, we use a loop: + while lo > range.wrapping_neg() { + let (new_hi, new_lo) = (rng.random::<$sample_ty>()).wmul(range); + match lo.checked_add(new_hi) { + Some(x) if x < $sample_ty::MAX => { + // Anything less than MAX: last term is 0 + break; + } + None => { + // Overflow: last term is 1 + result += 1; + break; + } + _ => { + // Unlikely case: must check next sample + lo = new_lo; + continue; + } + } + } + + Ok(low.wrapping_add(result as $ty)) + } + } + }; +} + +uniform_int_impl! { i8, u8, u32 } +uniform_int_impl! { i16, u16, u32 } +uniform_int_impl! { i32, u32, u32 } +uniform_int_impl! { i64, u64, u64 } +uniform_int_impl! { i128, u128, u128 } +uniform_int_impl! { u8, u8, u32 } +uniform_int_impl! { u16, u16, u32 } +uniform_int_impl! { u32, u32, u32 } +uniform_int_impl! { u64, u64, u64 } +uniform_int_impl! { u128, u128, u128 } + +#[cfg(feature = "simd_support")] +macro_rules! uniform_simd_int_impl { + ($ty:ident, $unsigned:ident) => { + // The "pick the largest zone that can fit in an `u32`" optimization + // is less useful here. Multiple lanes complicate things, we don't + // know the PRNG's minimal output size, and casting to a larger vector + // is generally a bad idea for SIMD performance. The user can still + // implement it manually. + + #[cfg(feature = "simd_support")] + impl SampleUniform for Simd<$ty, LANES> + where + LaneCount: SupportedLaneCount, + Simd<$unsigned, LANES>: + WideningMultiply, Simd<$unsigned, LANES>)>, + StandardUniform: Distribution>, + { + type Sampler = UniformInt>; + } + + #[cfg(feature = "simd_support")] + impl UniformSampler for UniformInt> + where + LaneCount: SupportedLaneCount, + Simd<$unsigned, LANES>: + WideningMultiply, Simd<$unsigned, LANES>)>, + StandardUniform: Distribution>, + { + type X = Simd<$ty, LANES>; + + #[inline] // if the range is constant, this helps LLVM to do the + // calculations at compile-time. + fn new(low_b: B1, high_b: B2) -> Result + where B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low.simd_lt(high).all()) { + return Err(Error::EmptyRange); + } + UniformSampler::new_inclusive(low, high - Simd::splat(1)) + } + + #[inline] // if the range is constant, this helps LLVM to do the + // calculations at compile-time. + fn new_inclusive(low_b: B1, high_b: B2) -> Result + where B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low.simd_le(high).all()) { + return Err(Error::EmptyRange); + } + + // NOTE: all `Simd` operations are inherently wrapping, + // see https://doc.rust-lang.org/std/simd/struct.Simd.html + let range: Simd<$unsigned, LANES> = ((high - low) + Simd::splat(1)).cast(); + + // We must avoid divide-by-zero by using 0 % 1 == 0. + let not_full_range = range.simd_gt(Simd::splat(0)); + let modulo = not_full_range.select(range, Simd::splat(1)); + let ints_to_reject = range.wrapping_neg() % modulo; + + Ok(UniformInt { + low, + // These are really $unsigned values, but store as $ty: + range: range.cast(), + thresh: ints_to_reject.cast(), + }) + } + + fn sample(&self, rng: &mut R) -> Self::X { + let range: Simd<$unsigned, LANES> = self.range.cast(); + let thresh: Simd<$unsigned, LANES> = self.thresh.cast(); + + // This might seem very slow, generating a whole new + // SIMD vector for every sample rejection. For most uses + // though, the chance of rejection is small and provides good + // general performance. With multiple lanes, that chance is + // multiplied. To mitigate this, we replace only the lanes of + // the vector which fail, iteratively reducing the chance of + // rejection. The replacement method does however add a little + // overhead. Benchmarking or calculating probabilities might + // reveal contexts where this replacement method is slower. + let mut v: Simd<$unsigned, LANES> = rng.random(); + loop { + let (hi, lo) = v.wmul(range); + let mask = lo.simd_ge(thresh); + if mask.all() { + let hi: Simd<$ty, LANES> = hi.cast(); + // wrapping addition + let result = self.low + hi; + // `select` here compiles to a blend operation + // When `range.eq(0).none()` the compare and blend + // operations are avoided. + let v: Simd<$ty, LANES> = v.cast(); + return range.simd_gt(Simd::splat(0)).select(result, v); + } + // Replace only the failing lanes + v = mask.select(v, rng.random()); + } + } + } + }; + + // bulk implementation + ($(($unsigned:ident, $signed:ident)),+) => { + $( + uniform_simd_int_impl!($unsigned, $unsigned); + uniform_simd_int_impl!($signed, $unsigned); + )+ + }; +} + +#[cfg(feature = "simd_support")] +uniform_simd_int_impl! { (u8, i8), (u16, i16), (u32, i32), (u64, i64) } + +/// The back-end implementing [`UniformSampler`] for `usize`. +/// +/// # Implementation notes +/// +/// Sampling a `usize` value is usually used in relation to the length of an +/// array or other memory structure, thus it is reasonable to assume that the +/// vast majority of use-cases will have a maximum size under [`u32::MAX`]. +/// In part to optimise for this use-case, but mostly to ensure that results +/// are portable across 32-bit and 64-bit architectures (as far as is possible), +/// this implementation will use 32-bit sampling when possible. +#[cfg(any(target_pointer_width = "32", target_pointer_width = "64"))] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct UniformUsize { + low: usize, + range: usize, + thresh: usize, + #[cfg(target_pointer_width = "64")] + mode64: bool, +} + +#[cfg(any(target_pointer_width = "32", target_pointer_width = "64"))] +impl SampleUniform for usize { + type Sampler = UniformUsize; +} + +#[cfg(any(target_pointer_width = "32", target_pointer_width = "64"))] +impl UniformSampler for UniformUsize { + type X = usize; + + #[inline] // if the range is constant, this helps LLVM to do the + // calculations at compile-time. + fn new(low_b: B1, high_b: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low < high) { + return Err(Error::EmptyRange); + } + + UniformSampler::new_inclusive(low, high - 1) + } + + #[inline] // if the range is constant, this helps LLVM to do the + // calculations at compile-time. + fn new_inclusive(low_b: B1, high_b: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low <= high) { + return Err(Error::EmptyRange); + } + + #[cfg(target_pointer_width = "64")] + let mode64 = high > (u32::MAX as usize); + #[cfg(target_pointer_width = "32")] + let mode64 = false; + + let (range, thresh); + if cfg!(target_pointer_width = "64") && !mode64 { + let range32 = (high as u32).wrapping_sub(low as u32).wrapping_add(1); + range = range32 as usize; + thresh = if range32 > 0 { + (range32.wrapping_neg() % range32) as usize + } else { + 0 + }; + } else { + range = high.wrapping_sub(low).wrapping_add(1); + thresh = if range > 0 { + range.wrapping_neg() % range + } else { + 0 + }; + } + + Ok(UniformUsize { + low, + range, + thresh, + #[cfg(target_pointer_width = "64")] + mode64, + }) + } + + #[inline] + fn sample(&self, rng: &mut R) -> usize { + #[cfg(target_pointer_width = "32")] + let mode32 = true; + #[cfg(target_pointer_width = "64")] + let mode32 = !self.mode64; + + if mode32 { + let range = self.range as u32; + if range == 0 { + return rng.random::() as usize; + } + + let thresh = self.thresh as u32; + let hi = loop { + let (hi, lo) = rng.random::().wmul(range); + if lo >= thresh { + break hi; + } + }; + self.low.wrapping_add(hi as usize) + } else { + let range = self.range as u64; + if range == 0 { + return rng.random::() as usize; + } + + let thresh = self.thresh as u64; + let hi = loop { + let (hi, lo) = rng.random::().wmul(range); + if lo >= thresh { + break hi; + } + }; + self.low.wrapping_add(hi as usize) + } + } + + #[inline] + fn sample_single( + low_b: B1, + high_b: B2, + rng: &mut R, + ) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low < high) { + return Err(Error::EmptyRange); + } + + if cfg!(target_pointer_width = "64") && high > (u32::MAX as usize) { + return UniformInt::::sample_single(low as u64, high as u64, rng) + .map(|x| x as usize); + } + + UniformInt::::sample_single(low as u32, high as u32, rng).map(|x| x as usize) + } + + #[inline] + fn sample_single_inclusive( + low_b: B1, + high_b: B2, + rng: &mut R, + ) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low <= high) { + return Err(Error::EmptyRange); + } + + if cfg!(target_pointer_width = "64") && high > (u32::MAX as usize) { + return UniformInt::::sample_single_inclusive(low as u64, high as u64, rng) + .map(|x| x as usize); + } + + UniformInt::::sample_single_inclusive(low as u32, high as u32, rng).map(|x| x as usize) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::distr::{Distribution, Uniform}; + use core::fmt::Debug; + use core::ops::Add; + + #[test] + fn test_uniform_bad_limits_equal_int() { + assert_eq!(Uniform::new(10, 10), Err(Error::EmptyRange)); + } + + #[test] + fn test_uniform_good_limits_equal_int() { + let mut rng = crate::test::rng(804); + let dist = Uniform::new_inclusive(10, 10).unwrap(); + for _ in 0..20 { + assert_eq!(rng.sample(dist), 10); + } + } + + #[test] + fn test_uniform_bad_limits_flipped_int() { + assert_eq!(Uniform::new(10, 5), Err(Error::EmptyRange)); + } + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_integers() { + let mut rng = crate::test::rng(251); + macro_rules! t { + ($ty:ident, $v:expr, $le:expr, $lt:expr) => {{ + for &(low, high) in $v.iter() { + let my_uniform = Uniform::new(low, high).unwrap(); + for _ in 0..1000 { + let v: $ty = rng.sample(my_uniform); + assert!($le(low, v) && $lt(v, high)); + } + + let my_uniform = Uniform::new_inclusive(low, high).unwrap(); + for _ in 0..1000 { + let v: $ty = rng.sample(my_uniform); + assert!($le(low, v) && $le(v, high)); + } + + let my_uniform = Uniform::new(&low, high).unwrap(); + for _ in 0..1000 { + let v: $ty = rng.sample(my_uniform); + assert!($le(low, v) && $lt(v, high)); + } + + let my_uniform = Uniform::new_inclusive(&low, &high).unwrap(); + for _ in 0..1000 { + let v: $ty = rng.sample(my_uniform); + assert!($le(low, v) && $le(v, high)); + } + + for _ in 0..1000 { + let v = <$ty as SampleUniform>::Sampler::sample_single(low, high, &mut rng).unwrap(); + assert!($le(low, v) && $lt(v, high)); + } + + for _ in 0..1000 { + let v = <$ty as SampleUniform>::Sampler::sample_single_inclusive(low, high, &mut rng).unwrap(); + assert!($le(low, v) && $le(v, high)); + } + } + }}; + + // scalar bulk + ($($ty:ident),*) => {{ + $(t!( + $ty, + [(0, 10), (10, 127), ($ty::MIN, $ty::MAX)], + |x, y| x <= y, + |x, y| x < y + );)* + }}; + + // simd bulk + ($($ty:ident),* => $scalar:ident) => {{ + $(t!( + $ty, + [ + ($ty::splat(0), $ty::splat(10)), + ($ty::splat(10), $ty::splat(127)), + ($ty::splat($scalar::MIN), $ty::splat($scalar::MAX)), + ], + |x: $ty, y| x.simd_le(y).all(), + |x: $ty, y| x.simd_lt(y).all() + );)* + }}; + } + t!(i8, i16, i32, i64, i128, u8, u16, u32, u64, usize, u128); + + #[cfg(feature = "simd_support")] + { + t!(u8x4, u8x8, u8x16, u8x32, u8x64 => u8); + t!(i8x4, i8x8, i8x16, i8x32, i8x64 => i8); + t!(u16x2, u16x4, u16x8, u16x16, u16x32 => u16); + t!(i16x2, i16x4, i16x8, i16x16, i16x32 => i16); + t!(u32x2, u32x4, u32x8, u32x16 => u32); + t!(i32x2, i32x4, i32x8, i32x16 => i32); + t!(u64x2, u64x4, u64x8 => u64); + t!(i64x2, i64x4, i64x8 => i64); + } + } + + #[test] + fn test_uniform_from_std_range() { + let r = Uniform::try_from(2u32..7).unwrap(); + assert_eq!(r.0.low, 2); + assert_eq!(r.0.range, 5); + } + + #[test] + fn test_uniform_from_std_range_bad_limits() { + #![allow(clippy::reversed_empty_ranges)] + assert!(Uniform::try_from(100..10).is_err()); + assert!(Uniform::try_from(100..100).is_err()); + } + + #[test] + fn test_uniform_from_std_range_inclusive() { + let r = Uniform::try_from(2u32..=6).unwrap(); + assert_eq!(r.0.low, 2); + assert_eq!(r.0.range, 5); + } + + #[test] + fn test_uniform_from_std_range_inclusive_bad_limits() { + #![allow(clippy::reversed_empty_ranges)] + assert!(Uniform::try_from(100..=10).is_err()); + assert!(Uniform::try_from(100..=99).is_err()); + } + + #[test] + fn value_stability() { + fn test_samples>( + lb: T, + ub: T, + ub_excl: T, + expected: &[T], + ) where + Uniform: Distribution, + { + let mut rng = crate::test::rng(897); + let mut buf = [lb; 6]; + + for x in &mut buf[0..3] { + *x = T::Sampler::sample_single_inclusive(lb, ub, &mut rng).unwrap(); + } + + let distr = Uniform::new_inclusive(lb, ub).unwrap(); + for x in &mut buf[3..6] { + *x = rng.sample(&distr); + } + assert_eq!(&buf, expected); + + let mut rng = crate::test::rng(897); + + for x in &mut buf[0..3] { + *x = T::Sampler::sample_single(lb, ub_excl, &mut rng).unwrap(); + } + + let distr = Uniform::new(lb, ub_excl).unwrap(); + for x in &mut buf[3..6] { + *x = rng.sample(&distr); + } + assert_eq!(&buf, expected); + } + + test_samples(-105i8, 111, 112, &[-99, -48, 107, 72, -19, 56]); + test_samples(2i16, 1352, 1353, &[43, 361, 1325, 1109, 539, 1005]); + test_samples( + -313853i32, + 13513, + 13514, + &[-303803, -226673, 6912, -45605, -183505, -70668], + ); + test_samples( + 131521i64, + 6542165, + 6542166, + &[1838724, 5384489, 4893692, 3712948, 3951509, 4094926], + ); + test_samples( + -0x8000_0000_0000_0000_0000_0000_0000_0000i128, + -1, + 0, + &[ + -30725222750250982319765550926688025855, + -75088619368053423329503924805178012357, + -64950748766625548510467638647674468829, + -41794017901603587121582892414659436495, + -63623852319608406524605295913876414006, + -17404679390297612013597359206379189023, + ], + ); + test_samples(11u8, 218, 219, &[17, 66, 214, 181, 93, 165]); + test_samples(11u16, 218, 219, &[17, 66, 214, 181, 93, 165]); + test_samples(11u32, 218, 219, &[17, 66, 214, 181, 93, 165]); + test_samples(11u64, 218, 219, &[66, 181, 165, 127, 134, 139]); + test_samples(11u128, 218, 219, &[181, 127, 139, 167, 141, 197]); + test_samples(11usize, 218, 219, &[17, 66, 214, 181, 93, 165]); + + #[cfg(feature = "simd_support")] + { + let lb = Simd::from([11u8, 0, 128, 127]); + let ub = Simd::from([218, 254, 254, 254]); + let ub_excl = ub + Simd::splat(1); + test_samples( + lb, + ub, + ub_excl, + &[ + Simd::from([13, 5, 237, 130]), + Simd::from([126, 186, 149, 161]), + Simd::from([103, 86, 234, 252]), + Simd::from([35, 18, 225, 231]), + Simd::from([106, 153, 246, 177]), + Simd::from([195, 168, 149, 222]), + ], + ); + } + } +} diff --git a/src/distr/uniform_other.rs b/src/distr/uniform_other.rs new file mode 100644 index 00000000000..03533debcd8 --- /dev/null +++ b/src/distr/uniform_other.rs @@ -0,0 +1,319 @@ +// Copyright 2018-2020 Developers of the Rand project. +// Copyright 2017 The Rust Project Developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! `UniformChar`, `UniformDuration` implementations + +use super::{Error, SampleBorrow, SampleUniform, Uniform, UniformInt, UniformSampler}; +use crate::distr::Distribution; +use crate::Rng; +use core::time::Duration; + +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +impl SampleUniform for char { + type Sampler = UniformChar; +} + +/// The back-end implementing [`UniformSampler`] for `char`. +/// +/// Unless you are implementing [`UniformSampler`] for your own type, this type +/// should not be used directly, use [`Uniform`] instead. +/// +/// This differs from integer range sampling since the range `0xD800..=0xDFFF` +/// are used for surrogate pairs in UCS and UTF-16, and consequently are not +/// valid Unicode code points. We must therefore avoid sampling values in this +/// range. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct UniformChar { + sampler: UniformInt, +} + +/// UTF-16 surrogate range start +const CHAR_SURROGATE_START: u32 = 0xD800; +/// UTF-16 surrogate range size +const CHAR_SURROGATE_LEN: u32 = 0xE000 - CHAR_SURROGATE_START; + +/// Convert `char` to compressed `u32` +fn char_to_comp_u32(c: char) -> u32 { + match c as u32 { + c if c >= CHAR_SURROGATE_START => c - CHAR_SURROGATE_LEN, + c => c, + } +} + +impl UniformSampler for UniformChar { + type X = char; + + #[inline] // if the range is constant, this helps LLVM to do the + // calculations at compile-time. + fn new(low_b: B1, high_b: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = char_to_comp_u32(*low_b.borrow()); + let high = char_to_comp_u32(*high_b.borrow()); + let sampler = UniformInt::::new(low, high); + sampler.map(|sampler| UniformChar { sampler }) + } + + #[inline] // if the range is constant, this helps LLVM to do the + // calculations at compile-time. + fn new_inclusive(low_b: B1, high_b: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = char_to_comp_u32(*low_b.borrow()); + let high = char_to_comp_u32(*high_b.borrow()); + let sampler = UniformInt::::new_inclusive(low, high); + sampler.map(|sampler| UniformChar { sampler }) + } + + fn sample(&self, rng: &mut R) -> Self::X { + let mut x = self.sampler.sample(rng); + if x >= CHAR_SURROGATE_START { + x += CHAR_SURROGATE_LEN; + } + // SAFETY: x must not be in surrogate range or greater than char::MAX. + // This relies on range constructors which accept char arguments. + // Validity of input char values is assumed. + unsafe { core::char::from_u32_unchecked(x) } + } +} + +#[cfg(feature = "alloc")] +impl crate::distr::SampleString for Uniform { + fn append_string( + &self, + rng: &mut R, + string: &mut alloc::string::String, + len: usize, + ) { + // Getting the hi value to assume the required length to reserve in string. + let mut hi = self.0.sampler.low + self.0.sampler.range - 1; + if hi >= CHAR_SURROGATE_START { + hi += CHAR_SURROGATE_LEN; + } + // Get the utf8 length of hi to minimize extra space. + let max_char_len = char::from_u32(hi).map(char::len_utf8).unwrap_or(4); + string.reserve(max_char_len * len); + string.extend(self.sample_iter(rng).take(len)) + } +} + +/// The back-end implementing [`UniformSampler`] for `Duration`. +/// +/// Unless you are implementing [`UniformSampler`] for your own types, this type +/// should not be used directly, use [`Uniform`] instead. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct UniformDuration { + mode: UniformDurationMode, + offset: u32, +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +enum UniformDurationMode { + Small { + secs: u64, + nanos: Uniform, + }, + Medium { + nanos: Uniform, + }, + Large { + max_secs: u64, + max_nanos: u32, + secs: Uniform, + }, +} + +impl SampleUniform for Duration { + type Sampler = UniformDuration; +} + +impl UniformSampler for UniformDuration { + type X = Duration; + + #[inline] + fn new(low_b: B1, high_b: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low < high) { + return Err(Error::EmptyRange); + } + UniformDuration::new_inclusive(low, high - Duration::new(0, 1)) + } + + #[inline] + fn new_inclusive(low_b: B1, high_b: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low <= high) { + return Err(Error::EmptyRange); + } + + let low_s = low.as_secs(); + let low_n = low.subsec_nanos(); + let mut high_s = high.as_secs(); + let mut high_n = high.subsec_nanos(); + + if high_n < low_n { + high_s -= 1; + high_n += 1_000_000_000; + } + + let mode = if low_s == high_s { + UniformDurationMode::Small { + secs: low_s, + nanos: Uniform::new_inclusive(low_n, high_n)?, + } + } else { + let max = high_s + .checked_mul(1_000_000_000) + .and_then(|n| n.checked_add(u64::from(high_n))); + + if let Some(higher_bound) = max { + let lower_bound = low_s * 1_000_000_000 + u64::from(low_n); + UniformDurationMode::Medium { + nanos: Uniform::new_inclusive(lower_bound, higher_bound)?, + } + } else { + // An offset is applied to simplify generation of nanoseconds + let max_nanos = high_n - low_n; + UniformDurationMode::Large { + max_secs: high_s, + max_nanos, + secs: Uniform::new_inclusive(low_s, high_s)?, + } + } + }; + Ok(UniformDuration { + mode, + offset: low_n, + }) + } + + #[inline] + fn sample(&self, rng: &mut R) -> Duration { + match self.mode { + UniformDurationMode::Small { secs, nanos } => { + let n = nanos.sample(rng); + Duration::new(secs, n) + } + UniformDurationMode::Medium { nanos } => { + let nanos = nanos.sample(rng); + Duration::new(nanos / 1_000_000_000, (nanos % 1_000_000_000) as u32) + } + UniformDurationMode::Large { + max_secs, + max_nanos, + secs, + } => { + // constant folding means this is at least as fast as `Rng::sample(Range)` + let nano_range = Uniform::new(0, 1_000_000_000).unwrap(); + loop { + let s = secs.sample(rng); + let n = nano_range.sample(rng); + if !(s == max_secs && n > max_nanos) { + let sum = n + self.offset; + break Duration::new(s, sum); + } + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + #[cfg(feature = "serde")] + fn test_serialization_uniform_duration() { + let distr = UniformDuration::new(Duration::from_secs(10), Duration::from_secs(60)).unwrap(); + let de_distr: UniformDuration = + bincode::deserialize(&bincode::serialize(&distr).unwrap()).unwrap(); + assert_eq!(distr, de_distr); + } + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_char() { + let mut rng = crate::test::rng(891); + let mut max = core::char::from_u32(0).unwrap(); + for _ in 0..100 { + let c = rng.random_range('A'..='Z'); + assert!(c.is_ascii_uppercase()); + max = max.max(c); + } + assert_eq!(max, 'Z'); + let d = Uniform::new( + core::char::from_u32(0xD7F0).unwrap(), + core::char::from_u32(0xE010).unwrap(), + ) + .unwrap(); + for _ in 0..100 { + let c = d.sample(&mut rng); + assert!((c as u32) < 0xD800 || (c as u32) > 0xDFFF); + } + #[cfg(feature = "alloc")] + { + use crate::distr::SampleString; + let string1 = d.sample_string(&mut rng, 100); + assert_eq!(string1.capacity(), 300); + let string2 = Uniform::new( + core::char::from_u32(0x0000).unwrap(), + core::char::from_u32(0x0080).unwrap(), + ) + .unwrap() + .sample_string(&mut rng, 100); + assert_eq!(string2.capacity(), 100); + let string3 = Uniform::new_inclusive( + core::char::from_u32(0x0000).unwrap(), + core::char::from_u32(0x0080).unwrap(), + ) + .unwrap() + .sample_string(&mut rng, 100); + assert_eq!(string3.capacity(), 200); + } + } + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_durations() { + let mut rng = crate::test::rng(253); + + let v = &[ + (Duration::new(10, 50000), Duration::new(100, 1234)), + (Duration::new(0, 100), Duration::new(1, 50)), + (Duration::new(0, 0), Duration::new(u64::MAX, 999_999_999)), + ]; + for &(low, high) in v.iter() { + let my_uniform = Uniform::new(low, high).unwrap(); + for _ in 0..1000 { + let v = rng.sample(my_uniform); + assert!(low <= v && v < high); + } + } + } +} diff --git a/src/distributions/utils.rs b/src/distr/utils.rs similarity index 91% rename from src/distributions/utils.rs rename to src/distr/utils.rs index e3ef5bcdb8b..784534f48b0 100644 --- a/src/distributions/utils.rs +++ b/src/distr/utils.rs @@ -8,9 +8,10 @@ //! Math helper functions -#[cfg(feature = "simd_support")] use core::simd::prelude::*; -#[cfg(feature = "simd_support")] use core::simd::{LaneCount, SimdElement, SupportedLaneCount}; - +#[cfg(feature = "simd_support")] +use core::simd::prelude::*; +#[cfg(feature = "simd_support")] +use core::simd::{LaneCount, SimdElement, SupportedLaneCount}; pub(crate) trait WideningMultiply { type Output; @@ -146,8 +147,10 @@ wmul_impl_usize! { u64 } #[cfg(feature = "simd_support")] mod simd_wmul { use super::*; - #[cfg(target_arch = "x86")] use core::arch::x86::*; - #[cfg(target_arch = "x86_64")] use core::arch::x86_64::*; + #[cfg(target_arch = "x86")] + use core::arch::x86::*; + #[cfg(target_arch = "x86_64")] + use core::arch::x86_64::*; wmul_impl! { (u8x4, u16x4), @@ -215,9 +218,7 @@ pub(crate) trait FloatSIMDUtils { fn all_finite(self) -> bool; type Mask; - fn finite_mask(self) -> Self::Mask; fn gt_mask(self, other: Self) -> Self::Mask; - fn ge_mask(self, other: Self) -> Self::Mask; // Decrease all lanes where the mask is `true` to the next lower value // representable by the floating-point type. At least one of the lanes @@ -235,12 +236,14 @@ pub(crate) trait FloatSIMDScalarUtils: FloatSIMDUtils { type Scalar; fn replace(self, index: usize, new_value: Self::Scalar) -> Self; - fn extract(self, index: usize) -> Self::Scalar; + fn extract_lane(self, index: usize) -> Self::Scalar; } /// Implement functions on f32/f64 to give them APIs similar to SIMD types pub(crate) trait FloatAsSIMD: Sized { + #[cfg(test)] const LEN: usize = 1; + #[inline(always)] fn splat(scalar: Self) -> Self { scalar @@ -289,21 +292,11 @@ macro_rules! scalar_float_impl { self.is_finite() } - #[inline(always)] - fn finite_mask(self) -> Self::Mask { - self.is_finite() - } - #[inline(always)] fn gt_mask(self, other: Self) -> Self::Mask { self > other } - #[inline(always)] - fn ge_mask(self, other: Self) -> Self::Mask { - self >= other - } - #[inline(always)] fn decrease_masked(self, mask: Self::Mask) -> Self { debug_assert!(mask, "At least one lane must be set"); @@ -327,7 +320,7 @@ macro_rules! scalar_float_impl { } #[inline] - fn extract(self, index: usize) -> Self::Scalar { + fn extract_lane(self, index: usize) -> Self::Scalar { debug_assert_eq!(index, 0); self } @@ -340,12 +333,12 @@ macro_rules! scalar_float_impl { scalar_float_impl!(f32, u32); scalar_float_impl!(f64, u64); - #[cfg(feature = "simd_support")] macro_rules! simd_impl { ($fty:ident, $uty:ident) => { impl FloatSIMDUtils for Simd<$fty, LANES> - where LaneCount: SupportedLaneCount + where + LaneCount: SupportedLaneCount, { type Mask = Mask<<$fty as SimdElement>::Mask, LANES>; type UInt = Simd<$uty, LANES>; @@ -365,21 +358,11 @@ macro_rules! simd_impl { self.is_finite().all() } - #[inline(always)] - fn finite_mask(self) -> Self::Mask { - self.is_finite() - } - #[inline(always)] fn gt_mask(self, other: Self) -> Self::Mask { self.simd_gt(other) } - #[inline(always)] - fn ge_mask(self, other: Self) -> Self::Mask { - self.simd_ge(other) - } - #[inline(always)] fn decrease_masked(self, mask: Self::Mask) -> Self { // Casting a mask into ints will produce all bits set for @@ -412,7 +395,7 @@ macro_rules! simd_impl { } #[inline] - fn extract(self, index: usize) -> Self::Scalar { + fn extract_lane(self, index: usize) -> Self::Scalar { self.as_array()[index] } } diff --git a/src/distr/weighted/mod.rs b/src/distr/weighted/mod.rs new file mode 100644 index 00000000000..368c5b0703d --- /dev/null +++ b/src/distr/weighted/mod.rs @@ -0,0 +1,115 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Weighted (index) sampling +//! +//! Primarily, this module houses the [`WeightedIndex`] distribution. +//! See also [`rand_distr::weighted`] for alternative implementations supporting +//! potentially-faster sampling or a more easily modifiable tree structure. +//! +//! [`rand_distr::weighted`]: https://docs.rs/rand_distr/latest/rand_distr/weighted/index.html + +use core::fmt; +mod weighted_index; + +pub use weighted_index::WeightedIndex; + +/// Bounds on a weight +/// +/// See usage in [`WeightedIndex`]. +pub trait Weight: Clone { + /// Representation of 0 + const ZERO: Self; + + /// Checked addition + /// + /// - `Result::Ok`: On success, `v` is added to `self` + /// - `Result::Err`: Returns an error when `Self` cannot represent the + /// result of `self + v` (i.e. overflow). The value of `self` should be + /// discarded. + #[allow(clippy::result_unit_err)] + fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()>; +} + +macro_rules! impl_weight_int { + ($t:ty) => { + impl Weight for $t { + const ZERO: Self = 0; + fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()> { + match self.checked_add(*v) { + Some(sum) => { + *self = sum; + Ok(()) + } + None => Err(()), + } + } + } + }; + ($t:ty, $($tt:ty),*) => { + impl_weight_int!($t); + impl_weight_int!($($tt),*); + } +} +impl_weight_int!(i8, i16, i32, i64, i128, isize); +impl_weight_int!(u8, u16, u32, u64, u128, usize); + +macro_rules! impl_weight_float { + ($t:ty) => { + impl Weight for $t { + const ZERO: Self = 0.0; + + fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()> { + // Floats have an explicit representation for overflow + *self += *v; + Ok(()) + } + } + }; +} +impl_weight_float!(f32); +impl_weight_float!(f64); + +/// Invalid weight errors +/// +/// This type represents errors from [`WeightedIndex::new`], +/// [`WeightedIndex::update_weights`] and other weighted distributions. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +// Marked non_exhaustive to allow a new error code in the solution to #1476. +#[non_exhaustive] +pub enum Error { + /// The input weight sequence is empty, too long, or wrongly ordered + InvalidInput, + + /// A weight is negative, too large for the distribution, or not a valid number + InvalidWeight, + + /// Not enough non-zero weights are available to sample values + /// + /// When attempting to sample a single value this implies that all weights + /// are zero. When attempting to sample `amount` values this implies that + /// less than `amount` weights are greater than zero. + InsufficientNonZero, + + /// Overflow when calculating the sum of weights + Overflow, +} + +#[cfg(feature = "std")] +impl std::error::Error for Error {} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str(match *self { + Error::InvalidInput => "Weights sequence is empty/too long/unordered", + Error::InvalidWeight => "A weight is negative, too large or not a valid number", + Error::InsufficientNonZero => "Not enough weights > zero", + Error::Overflow => "Overflow when summing weights", + }) + } +} diff --git a/src/distributions/weighted_index.rs b/src/distr/weighted/weighted_index.rs similarity index 58% rename from src/distributions/weighted_index.rs rename to src/distr/weighted/weighted_index.rs index 49cb02d6ade..4bb9d141fb3 100644 --- a/src/distributions/weighted_index.rs +++ b/src/distr/weighted/weighted_index.rs @@ -6,21 +6,19 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! Weighted index sampling - -use crate::distributions::uniform::{SampleBorrow, SampleUniform, UniformSampler}; -use crate::distributions::Distribution; +use super::{Error, Weight}; +use crate::distr::uniform::{SampleBorrow, SampleUniform, UniformSampler}; +use crate::distr::Distribution; use crate::Rng; -use core::cmp::PartialOrd; -use core::fmt; // Note that this whole module is only imported if feature="alloc" is enabled. use alloc::vec::Vec; +use core::fmt::{self, Debug}; -#[cfg(feature = "serde1")] +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -/// A distribution using weighted sampling of discrete items +/// A distribution using weighted sampling of discrete items. /// /// Sampling a `WeightedIndex` distribution returns the index of a randomly /// selected element from the iterator used when the `WeightedIndex` was @@ -33,12 +31,9 @@ use serde::{Deserialize, Serialize}; /// # Performance /// /// Time complexity of sampling from `WeightedIndex` is `O(log N)` where -/// `N` is the number of weights. There are two alternative implementations with -/// different runtimes characteristics: -/// * [`rand_distr::weighted_alias`](https://docs.rs/rand_distr/*/rand_distr/weighted_alias/index.html) -/// supports `O(1)` sampling, but with much higher initialisation cost. -/// * [`rand_distr::weighted_tree`](https://docs.rs/rand_distr/*/rand_distr/weighted_tree/index.html) -/// keeps the weights in a tree structure where sampling and updating is `O(log N)`. +/// `N` is the number of weights. +/// See also [`rand_distr::weighted`] for alternative implementations supporting +/// potentially-faster sampling or a more easily modifiable tree structure. /// /// A `WeightedIndex` contains a `Vec` and a [`Uniform`] and so its /// size is the sum of the size of those objects, possibly plus some alignment. @@ -59,12 +54,12 @@ use serde::{Deserialize, Serialize}; /// /// ``` /// use rand::prelude::*; -/// use rand::distributions::WeightedIndex; +/// use rand::distr::weighted::WeightedIndex; /// /// let choices = ['a', 'b', 'c']; /// let weights = [2, 1, 1]; /// let dist = WeightedIndex::new(&weights).unwrap(); -/// let mut rng = thread_rng(); +/// let mut rng = rand::rng(); /// for _ in 0..100 { /// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' /// println!("{}", choices[dist.sample(&mut rng)]); @@ -78,11 +73,11 @@ use serde::{Deserialize, Serialize}; /// } /// ``` /// -/// [`Uniform`]: crate::distributions::Uniform +/// [`Uniform`]: crate::distr::Uniform /// [`RngCore`]: crate::RngCore +/// [`rand_distr::weighted`]: https://docs.rs/rand_distr/latest/rand_distr/weighted/index.html #[derive(Debug, Clone, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct WeightedIndex { cumulative_weights: Vec, total_weight: X, @@ -95,24 +90,24 @@ impl WeightedIndex { /// implementation of [`Uniform`] exists. /// /// Error cases: - /// - [`WeightError::InvalidInput`] when the iterator `weights` is empty. - /// - [`WeightError::InvalidWeight`] when a weight is not-a-number or negative. - /// - [`WeightError::InsufficientNonZero`] when the sum of all weights is zero. - /// - [`WeightError::Overflow`] when the sum of all weights overflows. + /// - [`Error::InvalidInput`] when the iterator `weights` is empty. + /// - [`Error::InvalidWeight`] when a weight is not-a-number or negative. + /// - [`Error::InsufficientNonZero`] when the sum of all weights is zero. + /// - [`Error::Overflow`] when the sum of all weights overflows. /// - /// [`Uniform`]: crate::distributions::uniform::Uniform - pub fn new(weights: I) -> Result, WeightError> + /// [`Uniform`]: crate::distr::uniform::Uniform + pub fn new(weights: I) -> Result, Error> where I: IntoIterator, I::Item: SampleBorrow, X: Weight, { let mut iter = weights.into_iter(); - let mut total_weight: X = iter.next().ok_or(WeightError::InvalidInput)?.borrow().clone(); + let mut total_weight: X = iter.next().ok_or(Error::InvalidInput)?.borrow().clone(); let zero = X::ZERO; if !(total_weight >= zero) { - return Err(WeightError::InvalidWeight); + return Err(Error::InvalidWeight); } let mut weights = Vec::::with_capacity(iter.size_hint().0); @@ -120,17 +115,17 @@ impl WeightedIndex { // Note that `!(w >= x)` is not equivalent to `w < x` for partially // ordered types due to NaNs which are equal to nothing. if !(w.borrow() >= &zero) { - return Err(WeightError::InvalidWeight); + return Err(Error::InvalidWeight); } weights.push(total_weight.clone()); if let Err(()) = total_weight.checked_add_assign(w.borrow()) { - return Err(WeightError::Overflow); + return Err(Error::Overflow); } } if total_weight == zero { - return Err(WeightError::InsufficientNonZero); + return Err(Error::InsufficientNonZero); } let distr = X::Sampler::new(zero, total_weight.clone()).unwrap(); @@ -150,10 +145,10 @@ impl WeightedIndex { /// allocation internally. /// /// In case of error, `self` is not modified. Error cases: - /// - [`WeightError::InvalidInput`] when `new_weights` are not ordered by + /// - [`Error::InvalidInput`] when `new_weights` are not ordered by /// index or an index is too large. - /// - [`WeightError::InvalidWeight`] when a weight is not-a-number or negative. - /// - [`WeightError::InsufficientNonZero`] when the sum of all weights is zero. + /// - [`Error::InvalidWeight`] when a weight is not-a-number or negative. + /// - [`Error::InsufficientNonZero`] when the sum of all weights is zero. /// Note that due to floating-point loss of precision, this case is not /// always correctly detected; usage of a fixed-point weight type may be /// preferred. @@ -161,10 +156,10 @@ impl WeightedIndex { /// Updates take `O(N)` time. If you need to frequently update weights, consider /// [`rand_distr::weighted_tree`](https://docs.rs/rand_distr/*/rand_distr/weighted_tree/index.html) /// as an alternative where an update is `O(log N)`. - pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), WeightError> + pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), Error> where - X: for<'a> ::core::ops::AddAssign<&'a X> - + for<'a> ::core::ops::SubAssign<&'a X> + X: for<'a> core::ops::AddAssign<&'a X> + + for<'a> core::ops::SubAssign<&'a X> + Clone + Default, { @@ -182,14 +177,14 @@ impl WeightedIndex { for &(i, w) in new_weights { if let Some(old_i) = prev_i { if old_i >= i { - return Err(WeightError::InvalidInput); + return Err(Error::InvalidInput); } } if !(*w >= zero) { - return Err(WeightError::InvalidWeight); + return Err(Error::InvalidWeight); } if i > self.cumulative_weights.len() { - return Err(WeightError::InvalidInput); + return Err(Error::InvalidInput); } let mut old_w = if i < self.cumulative_weights.len() { @@ -206,7 +201,7 @@ impl WeightedIndex { prev_i = Some(i); } if total_weight <= zero { - return Err(WeightError::InsufficientNonZero); + return Err(Error::InsufficientNonZero); } // Update the weights. Because we checked all the preconditions in the @@ -244,80 +239,142 @@ impl WeightedIndex { } } -impl Distribution for WeightedIndex +/// A lazy-loading iterator over the weights of a `WeightedIndex` distribution. +/// This is returned by [`WeightedIndex::weights`]. +pub struct WeightedIndexIter<'a, X: SampleUniform + PartialOrd> { + weighted_index: &'a WeightedIndex, + index: usize, +} + +impl Debug for WeightedIndexIter<'_, X> where - X: SampleUniform + PartialOrd, + X: SampleUniform + PartialOrd + Debug, + X::Sampler: Debug, { - fn sample(&self, rng: &mut R) -> usize { - let chosen_weight = self.weight_distribution.sample(rng); - // Find the first item which has a weight *higher* than the chosen weight. - self.cumulative_weights - .partition_point(|w| w <= &chosen_weight) + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("WeightedIndexIter") + .field("weighted_index", &self.weighted_index) + .field("index", &self.index) + .finish() } } -/// Bounds on a weight -/// -/// See usage in [`WeightedIndex`]. -pub trait Weight: Clone { - /// Representation of 0 - const ZERO: Self; - - /// Checked addition - /// - /// - `Result::Ok`: On success, `v` is added to `self` - /// - `Result::Err`: Returns an error when `Self` cannot represent the - /// result of `self + v` (i.e. overflow). The value of `self` should be - /// discarded. - fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()>; +impl Clone for WeightedIndexIter<'_, X> +where + X: SampleUniform + PartialOrd, +{ + fn clone(&self) -> Self { + WeightedIndexIter { + weighted_index: self.weighted_index, + index: self.index, + } + } } -macro_rules! impl_weight_int { - ($t:ty) => { - impl Weight for $t { - const ZERO: Self = 0; - fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()> { - match self.checked_add(*v) { - Some(sum) => { - *self = sum; - Ok(()) - } - None => Err(()), - } +impl Iterator for WeightedIndexIter<'_, X> +where + X: for<'b> core::ops::SubAssign<&'b X> + SampleUniform + PartialOrd + Clone, +{ + type Item = X; + + fn next(&mut self) -> Option { + match self.weighted_index.weight(self.index) { + None => None, + Some(weight) => { + self.index += 1; + Some(weight) } } - }; - ($t:ty, $($tt:ty),*) => { - impl_weight_int!($t); - impl_weight_int!($($tt),*); } } -impl_weight_int!(i8, i16, i32, i64, i128, isize); -impl_weight_int!(u8, u16, u32, u64, u128, usize); - -macro_rules! impl_weight_float { - ($t:ty) => { - impl Weight for $t { - const ZERO: Self = 0.0; - fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()> { - // Floats have an explicit representation for overflow - *self += *v; - Ok(()) - } + +impl WeightedIndex { + /// Returns the weight at the given index, if it exists. + /// + /// If the index is out of bounds, this will return `None`. + /// + /// # Example + /// + /// ``` + /// use rand::distr::weighted::WeightedIndex; + /// + /// let weights = [0, 1, 2]; + /// let dist = WeightedIndex::new(&weights).unwrap(); + /// assert_eq!(dist.weight(0), Some(0)); + /// assert_eq!(dist.weight(1), Some(1)); + /// assert_eq!(dist.weight(2), Some(2)); + /// assert_eq!(dist.weight(3), None); + /// ``` + pub fn weight(&self, index: usize) -> Option + where + X: for<'a> core::ops::SubAssign<&'a X>, + { + use core::cmp::Ordering::*; + + let mut weight = match index.cmp(&self.cumulative_weights.len()) { + Less => self.cumulative_weights[index].clone(), + Equal => self.total_weight.clone(), + Greater => return None, + }; + + if index > 0 { + weight -= &self.cumulative_weights[index - 1]; + } + Some(weight) + } + + /// Returns a lazy-loading iterator containing the current weights of this distribution. + /// + /// If this distribution has not been updated since its creation, this will return the + /// same weights as were passed to `new`. + /// + /// # Example + /// + /// ``` + /// use rand::distr::weighted::WeightedIndex; + /// + /// let weights = [1, 2, 3]; + /// let mut dist = WeightedIndex::new(&weights).unwrap(); + /// assert_eq!(dist.weights().collect::>(), vec![1, 2, 3]); + /// dist.update_weights(&[(0, &2)]).unwrap(); + /// assert_eq!(dist.weights().collect::>(), vec![2, 2, 3]); + /// ``` + pub fn weights(&self) -> WeightedIndexIter<'_, X> + where + X: for<'a> core::ops::SubAssign<&'a X>, + { + WeightedIndexIter { + weighted_index: self, + index: 0, } - }; + } + + /// Returns the sum of all weights in this distribution. + pub fn total_weight(&self) -> X { + self.total_weight.clone() + } +} + +impl Distribution for WeightedIndex +where + X: SampleUniform + PartialOrd, +{ + fn sample(&self, rng: &mut R) -> usize { + let chosen_weight = self.weight_distribution.sample(rng); + // Find the first item which has a weight *higher* than the chosen weight. + self.cumulative_weights + .partition_point(|w| w <= &chosen_weight) + } } -impl_weight_float!(f32); -impl_weight_float!(f64); #[cfg(test)] mod test { use super::*; - #[cfg(feature = "serde1")] + #[cfg(feature = "serde")] #[test] - fn test_weightedindex_serde1() { - let weighted_index = WeightedIndex::new(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).unwrap(); + fn test_weightedindex_serde() { + let weighted_index = WeightedIndex::new([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).unwrap(); let ser_weighted_index = bincode::serialize(&weighted_index).unwrap(); let de_weighted_index: WeightedIndex = @@ -333,24 +390,24 @@ mod test { #[test] fn test_accepting_nan() { assert_eq!( - WeightedIndex::new(&[core::f32::NAN, 0.5]).unwrap_err(), - WeightError::InvalidWeight, + WeightedIndex::new([f32::NAN, 0.5]).unwrap_err(), + Error::InvalidWeight, ); assert_eq!( - WeightedIndex::new(&[core::f32::NAN]).unwrap_err(), - WeightError::InvalidWeight, + WeightedIndex::new([f32::NAN]).unwrap_err(), + Error::InvalidWeight, ); assert_eq!( - WeightedIndex::new(&[0.5, core::f32::NAN]).unwrap_err(), - WeightError::InvalidWeight, + WeightedIndex::new([0.5, f32::NAN]).unwrap_err(), + Error::InvalidWeight, ); assert_eq!( - WeightedIndex::new(&[0.5, 7.0]) + WeightedIndex::new([0.5, 7.0]) .unwrap() - .update_weights(&[(0, &core::f32::NAN)]) + .update_weights(&[(0, &f32::NAN)]) .unwrap_err(), - WeightError::InvalidWeight, + Error::InvalidWeight, ) } @@ -398,10 +455,10 @@ mod test { verify(chosen); for _ in 0..5 { - assert_eq!(WeightedIndex::new(&[0, 1]).unwrap().sample(&mut r), 1); - assert_eq!(WeightedIndex::new(&[1, 0]).unwrap().sample(&mut r), 0); + assert_eq!(WeightedIndex::new([0, 1]).unwrap().sample(&mut r), 1); + assert_eq!(WeightedIndex::new([1, 0]).unwrap().sample(&mut r), 0); assert_eq!( - WeightedIndex::new(&[0, 0, 0, 0, 10, 0]) + WeightedIndex::new([0, 0, 0, 0, 10, 0]) .unwrap() .sample(&mut r), 4 @@ -410,24 +467,21 @@ mod test { assert_eq!( WeightedIndex::new(&[10][0..0]).unwrap_err(), - WeightError::InvalidInput + Error::InvalidInput ); assert_eq!( - WeightedIndex::new(&[0]).unwrap_err(), - WeightError::InsufficientNonZero + WeightedIndex::new([0]).unwrap_err(), + Error::InsufficientNonZero ); assert_eq!( - WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(), - WeightError::InvalidWeight + WeightedIndex::new([10, 20, -1, 30]).unwrap_err(), + Error::InvalidWeight ); assert_eq!( - WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(), - WeightError::InvalidWeight - ); - assert_eq!( - WeightedIndex::new(&[-10]).unwrap_err(), - WeightError::InvalidWeight + WeightedIndex::new([-10, 20, 1, 30]).unwrap_err(), + Error::InvalidWeight ); + assert_eq!(WeightedIndex::new([-10]).unwrap_err(), Error::InvalidWeight); } #[test] @@ -459,10 +513,81 @@ mod test { } } + #[test] + fn test_update_weights_errors() { + let data = [ + ( + &[1i32, 0, 0][..], + &[(0, &0)][..], + Error::InsufficientNonZero, + ), + ( + &[10, 10, 10, 10][..], + &[(1, &-11)][..], + Error::InvalidWeight, // A weight is negative + ), + ( + &[1, 2, 3, 4, 5][..], + &[(1, &5), (0, &5)][..], // Wrong order + Error::InvalidInput, + ), + ( + &[1][..], + &[(1, &1)][..], // Index too large + Error::InvalidInput, + ), + ]; + + for (weights, update, err) in data.iter() { + let total_weight = weights.iter().sum::(); + let mut distr = WeightedIndex::new(weights.to_vec()).unwrap(); + assert_eq!(distr.total_weight, total_weight); + match distr.update_weights(update) { + Ok(_) => panic!("Expected update_weights to fail, but it succeeded"), + Err(e) => assert_eq!(e, *err), + } + } + } + + #[test] + fn test_weight_at() { + let data = [ + &[1][..], + &[10, 2, 3, 4][..], + &[1, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..], + &[u32::MAX][..], + ]; + + for weights in data.iter() { + let distr = WeightedIndex::new(weights.to_vec()).unwrap(); + for (i, weight) in weights.iter().enumerate() { + assert_eq!(distr.weight(i), Some(*weight)); + } + assert_eq!(distr.weight(weights.len()), None); + } + } + + #[test] + fn test_weights() { + let data = [ + &[1][..], + &[10, 2, 3, 4][..], + &[1, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..], + &[u32::MAX][..], + ]; + + for weights in data.iter() { + let distr = WeightedIndex::new(weights.to_vec()).unwrap(); + assert_eq!(distr.weights().collect::>(), weights.to_vec()); + } + } + #[test] fn value_stability() { fn test_samples( - weights: I, buf: &mut [usize], expected: &[usize], + weights: I, + buf: &mut [usize], + expected: &[usize], ) where I: IntoIterator, I::Item: SampleBorrow, @@ -478,17 +603,17 @@ mod test { let mut buf = [0; 10]; test_samples( - &[1i32, 1, 1, 1, 1, 1, 1, 1, 1], + [1i32, 1, 1, 1, 1, 1, 1, 1, 1], &mut buf, &[0, 6, 2, 6, 3, 4, 7, 8, 2, 5], ); test_samples( - &[0.7f32, 0.1, 0.1, 0.1], + [0.7f32, 0.1, 0.1, 0.1], &mut buf, &[0, 0, 0, 1, 0, 0, 2, 3, 0, 0], ); test_samples( - &[1.0f64, 0.999, 0.998, 0.997], + [1.0f64, 0.999, 0.998, 0.997], &mut buf, &[2, 2, 1, 3, 2, 1, 3, 3, 2, 1], ); @@ -496,49 +621,11 @@ mod test { #[test] fn weighted_index_distributions_can_be_compared() { - assert_eq!(WeightedIndex::new(&[1, 2]), WeightedIndex::new(&[1, 2])); + assert_eq!(WeightedIndex::new([1, 2]), WeightedIndex::new([1, 2])); } #[test] fn overflow() { - assert_eq!( - WeightedIndex::new([2, usize::MAX]), - Err(WeightError::Overflow) - ); - } -} - -/// Errors returned by weighted distributions -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum WeightError { - /// The input weight sequence is empty, too long, or wrongly ordered - InvalidInput, - - /// A weight is negative, too large for the distribution, or not a valid number - InvalidWeight, - - /// Not enough non-zero weights are available to sample values - /// - /// When attempting to sample a single value this implies that all weights - /// are zero. When attempting to sample `amount` values this implies that - /// less than `amount` weights are greater than zero. - InsufficientNonZero, - - /// Overflow when calculating the sum of weights - Overflow, -} - -#[cfg(feature = "std")] -impl std::error::Error for WeightError {} - -impl fmt::Display for WeightError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.write_str(match *self { - WeightError::InvalidInput => "Weights sequence is empty/too long/unordered", - WeightError::InvalidWeight => "A weight is negative, too large or not a valid number", - WeightError::InsufficientNonZero => "Not enough weights > zero", - WeightError::Overflow => "Overflow when summing weights", - }) + assert_eq!(WeightedIndex::new([2, usize::MAX]), Err(Error::Overflow)); } } diff --git a/src/distributions/integer.rs b/src/distributions/integer.rs deleted file mode 100644 index e5d320b2550..00000000000 --- a/src/distributions/integer.rs +++ /dev/null @@ -1,283 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! The implementations of the `Standard` distribution for integer types. - -use crate::distributions::{Distribution, Standard}; -use crate::Rng; -#[cfg(all(target_arch = "x86", feature = "simd_support"))] -use core::arch::x86::__m512i; -#[cfg(target_arch = "x86")] -use core::arch::x86::{__m128i, __m256i}; -#[cfg(all(target_arch = "x86_64", feature = "simd_support"))] -use core::arch::x86_64::__m512i; -#[cfg(target_arch = "x86_64")] -use core::arch::x86_64::{__m128i, __m256i}; -use core::num::{ - NonZeroU16, NonZeroU32, NonZeroU64, NonZeroU8, NonZeroUsize,NonZeroU128, - NonZeroI16, NonZeroI32, NonZeroI64, NonZeroI8, NonZeroIsize,NonZeroI128 -}; -#[cfg(feature = "simd_support")] use core::simd::*; -use core::mem; - -impl Distribution for Standard { - #[inline] - fn sample(&self, rng: &mut R) -> u8 { - rng.next_u32() as u8 - } -} - -impl Distribution for Standard { - #[inline] - fn sample(&self, rng: &mut R) -> u16 { - rng.next_u32() as u16 - } -} - -impl Distribution for Standard { - #[inline] - fn sample(&self, rng: &mut R) -> u32 { - rng.next_u32() - } -} - -impl Distribution for Standard { - #[inline] - fn sample(&self, rng: &mut R) -> u64 { - rng.next_u64() - } -} - -impl Distribution for Standard { - #[inline] - fn sample(&self, rng: &mut R) -> u128 { - // Use LE; we explicitly generate one value before the next. - let x = u128::from(rng.next_u64()); - let y = u128::from(rng.next_u64()); - (y << 64) | x - } -} - -impl Distribution for Standard { - #[inline] - #[cfg(any(target_pointer_width = "32", target_pointer_width = "16"))] - fn sample(&self, rng: &mut R) -> usize { - rng.next_u32() as usize - } - - #[inline] - #[cfg(target_pointer_width = "64")] - fn sample(&self, rng: &mut R) -> usize { - rng.next_u64() as usize - } -} - -macro_rules! impl_int_from_uint { - ($ty:ty, $uty:ty) => { - impl Distribution<$ty> for Standard { - #[inline] - fn sample(&self, rng: &mut R) -> $ty { - rng.gen::<$uty>() as $ty - } - } - }; -} - -impl_int_from_uint! { i8, u8 } -impl_int_from_uint! { i16, u16 } -impl_int_from_uint! { i32, u32 } -impl_int_from_uint! { i64, u64 } -impl_int_from_uint! { i128, u128 } -impl_int_from_uint! { isize, usize } - -macro_rules! impl_nzint { - ($ty:ty, $new:path) => { - impl Distribution<$ty> for Standard { - fn sample(&self, rng: &mut R) -> $ty { - loop { - if let Some(nz) = $new(rng.gen()) { - break nz; - } - } - } - } - }; -} - -impl_nzint!(NonZeroU8, NonZeroU8::new); -impl_nzint!(NonZeroU16, NonZeroU16::new); -impl_nzint!(NonZeroU32, NonZeroU32::new); -impl_nzint!(NonZeroU64, NonZeroU64::new); -impl_nzint!(NonZeroU128, NonZeroU128::new); -impl_nzint!(NonZeroUsize, NonZeroUsize::new); - -impl_nzint!(NonZeroI8, NonZeroI8::new); -impl_nzint!(NonZeroI16, NonZeroI16::new); -impl_nzint!(NonZeroI32, NonZeroI32::new); -impl_nzint!(NonZeroI64, NonZeroI64::new); -impl_nzint!(NonZeroI128, NonZeroI128::new); -impl_nzint!(NonZeroIsize, NonZeroIsize::new); - -macro_rules! x86_intrinsic_impl { - ($($intrinsic:ident),+) => {$( - /// Available only on x86/64 platforms - impl Distribution<$intrinsic> for Standard { - #[inline] - fn sample(&self, rng: &mut R) -> $intrinsic { - // On proper hardware, this should compile to SIMD instructions - // Verified on x86 Haswell with __m128i, __m256i - let mut buf = [0_u8; mem::size_of::<$intrinsic>()]; - rng.fill_bytes(&mut buf); - // x86 is little endian so no need for conversion - zerocopy::transmute!(buf) - } - } - )+}; -} - -#[cfg(feature = "simd_support")] -macro_rules! simd_impl { - ($($ty:ty),+) => {$( - /// Requires nightly Rust and the [`simd_support`] feature - /// - /// [`simd_support`]: https://github.com/rust-random/rand#crate-features - impl Distribution> for Standard - where - LaneCount: SupportedLaneCount, - { - #[inline] - fn sample(&self, rng: &mut R) -> Simd<$ty, LANES> { - let mut vec = Simd::default(); - rng.fill(vec.as_mut_array().as_mut_slice()); - vec - } - } - )+}; -} - -#[cfg(feature = "simd_support")] -simd_impl!(u8, i8, u16, i16, u32, i32, u64, i64, usize, isize); - -#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -x86_intrinsic_impl!(__m128i, __m256i); -#[cfg(all( - any(target_arch = "x86", target_arch = "x86_64"), - feature = "simd_support" -))] -x86_intrinsic_impl!(__m512i); - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_integers() { - let mut rng = crate::test::rng(806); - - rng.sample::(Standard); - rng.sample::(Standard); - rng.sample::(Standard); - rng.sample::(Standard); - rng.sample::(Standard); - rng.sample::(Standard); - - rng.sample::(Standard); - rng.sample::(Standard); - rng.sample::(Standard); - rng.sample::(Standard); - rng.sample::(Standard); - rng.sample::(Standard); - } - - #[test] - fn value_stability() { - fn test_samples(zero: T, expected: &[T]) - where Standard: Distribution { - let mut rng = crate::test::rng(807); - let mut buf = [zero; 3]; - for x in &mut buf { - *x = rng.sample(Standard); - } - assert_eq!(&buf, expected); - } - - test_samples(0u8, &[9, 247, 111]); - test_samples(0u16, &[32265, 42999, 38255]); - test_samples(0u32, &[2220326409, 2575017975, 2018088303]); - test_samples(0u64, &[ - 11059617991457472009, - 16096616328739788143, - 1487364411147516184, - ]); - test_samples(0u128, &[ - 296930161868957086625409848350820761097, - 145644820879247630242265036535529306392, - 111087889832015897993126088499035356354, - ]); - #[cfg(any(target_pointer_width = "32", target_pointer_width = "16"))] - test_samples(0usize, &[2220326409, 2575017975, 2018088303]); - #[cfg(target_pointer_width = "64")] - test_samples(0usize, &[ - 11059617991457472009, - 16096616328739788143, - 1487364411147516184, - ]); - - test_samples(0i8, &[9, -9, 111]); - // Skip further i* types: they are simple reinterpretation of u* samples - - #[cfg(feature = "simd_support")] - { - // We only test a sub-set of types here and make assumptions about the rest. - - test_samples(u8x4::default(), &[ - u8x4::from([9, 126, 87, 132]), - u8x4::from([247, 167, 123, 153]), - u8x4::from([111, 149, 73, 120]), - ]); - test_samples(u8x8::default(), &[ - u8x8::from([9, 126, 87, 132, 247, 167, 123, 153]), - u8x8::from([111, 149, 73, 120, 68, 171, 98, 223]), - u8x8::from([24, 121, 1, 50, 13, 46, 164, 20]), - ]); - - test_samples(i64x8::default(), &[ - i64x8::from([ - -7387126082252079607, - -2350127744969763473, - 1487364411147516184, - 7895421560427121838, - 602190064936008898, - 6022086574635100741, - -5080089175222015595, - -4066367846667249123, - ]), - i64x8::from([ - 9180885022207963908, - 3095981199532211089, - 6586075293021332726, - 419343203796414657, - 3186951873057035255, - 5287129228749947252, - 444726432079249540, - -1587028029513790706, - ]), - i64x8::from([ - 6075236523189346388, - 1351763722368165432, - -6192309979959753740, - -7697775502176768592, - -4482022114172078123, - 7522501477800909500, - -1837258847956201231, - -586926753024886735, - ]), - ]); - } - } -} diff --git a/src/distributions/uniform.rs b/src/distributions/uniform.rs deleted file mode 100644 index 5e6b6ae3f9e..00000000000 --- a/src/distributions/uniform.rs +++ /dev/null @@ -1,1765 +0,0 @@ -// Copyright 2018-2020 Developers of the Rand project. -// Copyright 2017 The Rust Project Developers. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! A distribution uniformly sampling numbers within a given range. -//! -//! [`Uniform`] is the standard distribution to sample uniformly from a range; -//! e.g. `Uniform::new_inclusive(1, 6).unwrap()` can sample integers from 1 to 6, like a -//! standard die. [`Rng::gen_range`] supports any type supported by [`Uniform`]. -//! -//! This distribution is provided with support for several primitive types -//! (all integer and floating-point types) as well as [`std::time::Duration`], -//! and supports extension to user-defined types via a type-specific *back-end* -//! implementation. -//! -//! The types [`UniformInt`], [`UniformFloat`] and [`UniformDuration`] are the -//! back-ends supporting sampling from primitive integer and floating-point -//! ranges as well as from [`std::time::Duration`]; these types do not normally -//! need to be used directly (unless implementing a derived back-end). -//! -//! # Example usage -//! -//! ``` -//! use rand::{Rng, thread_rng}; -//! use rand::distributions::Uniform; -//! -//! let mut rng = thread_rng(); -//! let side = Uniform::new(-10.0, 10.0).unwrap(); -//! -//! // sample between 1 and 10 points -//! for _ in 0..rng.gen_range(1..=10) { -//! // sample a point from the square with sides -10 - 10 in two dimensions -//! let (x, y) = (rng.sample(side), rng.sample(side)); -//! println!("Point: {}, {}", x, y); -//! } -//! ``` -//! -//! # Extending `Uniform` to support a custom type -//! -//! To extend [`Uniform`] to support your own types, write a back-end which -//! implements the [`UniformSampler`] trait, then implement the [`SampleUniform`] -//! helper trait to "register" your back-end. See the `MyF32` example below. -//! -//! At a minimum, the back-end needs to store any parameters needed for sampling -//! (e.g. the target range) and implement `new`, `new_inclusive` and `sample`. -//! Those methods should include an assertion to check the range is valid (i.e. -//! `low < high`). The example below merely wraps another back-end. -//! -//! The `new`, `new_inclusive` and `sample_single` functions use arguments of -//! type `SampleBorrow` to support passing in values by reference or -//! by value. In the implementation of these functions, you can choose to -//! simply use the reference returned by [`SampleBorrow::borrow`], or you can choose -//! to copy or clone the value, whatever is appropriate for your type. -//! -//! ``` -//! use rand::prelude::*; -//! use rand::distributions::uniform::{Uniform, SampleUniform, -//! UniformSampler, UniformFloat, SampleBorrow, Error}; -//! -//! struct MyF32(f32); -//! -//! #[derive(Clone, Copy, Debug)] -//! struct UniformMyF32(UniformFloat); -//! -//! impl UniformSampler for UniformMyF32 { -//! type X = MyF32; -//! -//! fn new(low: B1, high: B2) -> Result -//! where B1: SampleBorrow + Sized, -//! B2: SampleBorrow + Sized -//! { -//! UniformFloat::::new(low.borrow().0, high.borrow().0).map(UniformMyF32) -//! } -//! fn new_inclusive(low: B1, high: B2) -> Result -//! where B1: SampleBorrow + Sized, -//! B2: SampleBorrow + Sized -//! { -//! UniformFloat::::new_inclusive(low.borrow().0, high.borrow().0).map(UniformMyF32) -//! } -//! fn sample(&self, rng: &mut R) -> Self::X { -//! MyF32(self.0.sample(rng)) -//! } -//! } -//! -//! impl SampleUniform for MyF32 { -//! type Sampler = UniformMyF32; -//! } -//! -//! let (low, high) = (MyF32(17.0f32), MyF32(22.0f32)); -//! let uniform = Uniform::new(low, high).unwrap(); -//! let x = uniform.sample(&mut thread_rng()); -//! ``` -//! -//! [`SampleUniform`]: crate::distributions::uniform::SampleUniform -//! [`UniformSampler`]: crate::distributions::uniform::UniformSampler -//! [`UniformInt`]: crate::distributions::uniform::UniformInt -//! [`UniformFloat`]: crate::distributions::uniform::UniformFloat -//! [`UniformDuration`]: crate::distributions::uniform::UniformDuration -//! [`SampleBorrow::borrow`]: crate::distributions::uniform::SampleBorrow::borrow - -use core::fmt; -use core::time::Duration; -use core::ops::{Range, RangeInclusive}; -use core::convert::TryFrom; - -use crate::distributions::float::IntoFloat; -use crate::distributions::utils::{BoolAsSIMD, FloatAsSIMD, FloatSIMDUtils, IntAsSIMD, WideningMultiply}; -use crate::distributions::Distribution; -#[cfg(feature = "simd_support")] -use crate::distributions::Standard; -use crate::{Rng, RngCore}; - -#[cfg(feature = "simd_support")] use core::simd::prelude::*; -#[cfg(feature = "simd_support")] use core::simd::{LaneCount, SupportedLaneCount}; - -/// Error type returned from [`Uniform::new`] and `new_inclusive`. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum Error { - /// `low > high`, or equal in case of exclusive range. - EmptyRange, - /// Input or range `high - low` is non-finite. Not relevant to integer types. - NonFinite, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Error::EmptyRange => "low > high (or equal if exclusive) in uniform distribution", - Error::NonFinite => "Non-finite range in uniform distribution", - }) - } -} - -#[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] -impl std::error::Error for Error {} - -#[cfg(feature = "serde1")] -use serde::{Serialize, Deserialize}; - -/// Sample values uniformly between two bounds. -/// -/// [`Uniform::new`] and [`Uniform::new_inclusive`] construct a uniform -/// distribution sampling from the given range; these functions may do extra -/// work up front to make sampling of multiple values faster. If only one sample -/// from the range is required, [`Rng::gen_range`] can be more efficient. -/// -/// When sampling from a constant range, many calculations can happen at -/// compile-time and all methods should be fast; for floating-point ranges and -/// the full range of integer types, this should have comparable performance to -/// the `Standard` distribution. -/// -/// Steps are taken to avoid bias, which might be present in naive -/// implementations; for example `rng.gen::() % 170` samples from the range -/// `[0, 169]` but is twice as likely to select numbers less than 85 than other -/// values. Further, the implementations here give more weight to the high-bits -/// generated by the RNG than the low bits, since with some RNGs the low-bits -/// are of lower quality than the high bits. -/// -/// Implementations must sample in `[low, high)` range for -/// `Uniform::new(low, high)`, i.e., excluding `high`. In particular, care must -/// be taken to ensure that rounding never results values `< low` or `>= high`. -/// -/// # Example -/// -/// ``` -/// use rand::distributions::{Distribution, Uniform}; -/// -/// let between = Uniform::try_from(10..10000).unwrap(); -/// let mut rng = rand::thread_rng(); -/// let mut sum = 0; -/// for _ in 0..1000 { -/// sum += between.sample(&mut rng); -/// } -/// println!("{}", sum); -/// ``` -/// -/// For a single sample, [`Rng::gen_range`] may be preferred: -/// -/// ``` -/// use rand::Rng; -/// -/// let mut rng = rand::thread_rng(); -/// println!("{}", rng.gen_range(0..10)); -/// ``` -/// -/// [`new`]: Uniform::new -/// [`new_inclusive`]: Uniform::new_inclusive -/// [`Rng::gen_range`]: Rng::gen_range -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "serde1", serde(bound(serialize = "X::Sampler: Serialize")))] -#[cfg_attr(feature = "serde1", serde(bound(deserialize = "X::Sampler: Deserialize<'de>")))] -pub struct Uniform(X::Sampler); - -impl Uniform { - /// Create a new `Uniform` instance, which samples uniformly from the half - /// open range `[low, high)` (excluding `high`). - /// - /// Fails if `low >= high`, or if `low`, `high` or the range `high - low` is - /// non-finite. In release mode, only the range is checked. - pub fn new(low: B1, high: B2) -> Result, Error> - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - X::Sampler::new(low, high).map(Uniform) - } - - /// Create a new `Uniform` instance, which samples uniformly from the closed - /// range `[low, high]` (inclusive). - /// - /// Fails if `low > high`, or if `low`, `high` or the range `high - low` is - /// non-finite. In release mode, only the range is checked. - pub fn new_inclusive(low: B1, high: B2) -> Result, Error> - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - X::Sampler::new_inclusive(low, high).map(Uniform) - } -} - -impl Distribution for Uniform { - fn sample(&self, rng: &mut R) -> X { - self.0.sample(rng) - } -} - -/// Helper trait for creating objects using the correct implementation of -/// [`UniformSampler`] for the sampling type. -/// -/// See the [module documentation] on how to implement [`Uniform`] range -/// sampling for a custom type. -/// -/// [module documentation]: crate::distributions::uniform -pub trait SampleUniform: Sized { - /// The `UniformSampler` implementation supporting type `X`. - type Sampler: UniformSampler; -} - -/// Helper trait handling actual uniform sampling. -/// -/// See the [module documentation] on how to implement [`Uniform`] range -/// sampling for a custom type. -/// -/// Implementation of [`sample_single`] is optional, and is only useful when -/// the implementation can be faster than `Self::new(low, high).sample(rng)`. -/// -/// [module documentation]: crate::distributions::uniform -/// [`sample_single`]: UniformSampler::sample_single -pub trait UniformSampler: Sized { - /// The type sampled by this implementation. - type X; - - /// Construct self, with inclusive lower bound and exclusive upper bound `[low, high)`. - /// - /// Usually users should not call this directly but prefer to use - /// [`Uniform::new`]. - fn new(low: B1, high: B2) -> Result - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized; - - /// Construct self, with inclusive bounds `[low, high]`. - /// - /// Usually users should not call this directly but prefer to use - /// [`Uniform::new_inclusive`]. - fn new_inclusive(low: B1, high: B2) -> Result - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized; - - /// Sample a value. - fn sample(&self, rng: &mut R) -> Self::X; - - /// Sample a single value uniformly from a range with inclusive lower bound - /// and exclusive upper bound `[low, high)`. - /// - /// By default this is implemented using - /// `UniformSampler::new(low, high).sample(rng)`. However, for some types - /// more optimal implementations for single usage may be provided via this - /// method (which is the case for integers and floats). - /// Results may not be identical. - /// - /// Note that to use this method in a generic context, the type needs to be - /// retrieved via `SampleUniform::Sampler` as follows: - /// ``` - /// use rand::{thread_rng, distributions::uniform::{SampleUniform, UniformSampler}}; - /// # #[allow(unused)] - /// fn sample_from_range(lb: T, ub: T) -> T { - /// let mut rng = thread_rng(); - /// ::Sampler::sample_single(lb, ub, &mut rng).unwrap() - /// } - /// ``` - fn sample_single(low: B1, high: B2, rng: &mut R) -> Result - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - let uniform: Self = UniformSampler::new(low, high)?; - Ok(uniform.sample(rng)) - } - - /// Sample a single value uniformly from a range with inclusive lower bound - /// and inclusive upper bound `[low, high]`. - /// - /// By default this is implemented using - /// `UniformSampler::new_inclusive(low, high).sample(rng)`. However, for - /// some types more optimal implementations for single usage may be provided - /// via this method. - /// Results may not be identical. - fn sample_single_inclusive(low: B1, high: B2, rng: &mut R) - -> Result - where B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized - { - let uniform: Self = UniformSampler::new_inclusive(low, high)?; - Ok(uniform.sample(rng)) - } -} - -impl TryFrom> for Uniform { - type Error = Error; - - fn try_from(r: ::core::ops::Range) -> Result, Error> { - Uniform::new(r.start, r.end) - } -} - -impl TryFrom> for Uniform { - type Error = Error; - - fn try_from(r: ::core::ops::RangeInclusive) -> Result, Error> { - Uniform::new_inclusive(r.start(), r.end()) - } -} - - -/// Helper trait similar to [`Borrow`] but implemented -/// only for SampleUniform and references to SampleUniform in -/// order to resolve ambiguity issues. -/// -/// [`Borrow`]: std::borrow::Borrow -pub trait SampleBorrow { - /// Immutably borrows from an owned value. See [`Borrow::borrow`] - /// - /// [`Borrow::borrow`]: std::borrow::Borrow::borrow - fn borrow(&self) -> &Borrowed; -} -impl SampleBorrow for Borrowed -where Borrowed: SampleUniform -{ - #[inline(always)] - fn borrow(&self) -> &Borrowed { - self - } -} -impl<'a, Borrowed> SampleBorrow for &'a Borrowed -where Borrowed: SampleUniform -{ - #[inline(always)] - fn borrow(&self) -> &Borrowed { - self - } -} - -/// Range that supports generating a single sample efficiently. -/// -/// Any type implementing this trait can be used to specify the sampled range -/// for `Rng::gen_range`. -pub trait SampleRange { - /// Generate a sample from the given range. - fn sample_single(self, rng: &mut R) -> Result; - - /// Check whether the range is empty. - fn is_empty(&self) -> bool; -} - -impl SampleRange for Range { - #[inline] - fn sample_single(self, rng: &mut R) -> Result { - T::Sampler::sample_single(self.start, self.end, rng) - } - - #[inline] - fn is_empty(&self) -> bool { - !(self.start < self.end) - } -} - -impl SampleRange for RangeInclusive { - #[inline] - fn sample_single(self, rng: &mut R) -> Result { - T::Sampler::sample_single_inclusive(self.start(), self.end(), rng) - } - - #[inline] - fn is_empty(&self) -> bool { - !(self.start() <= self.end()) - } -} - - -//////////////////////////////////////////////////////////////////////////////// - -// What follows are all back-ends. - - -/// The back-end implementing [`UniformSampler`] for integer types. -/// -/// Unless you are implementing [`UniformSampler`] for your own type, this type -/// should not be used directly, use [`Uniform`] instead. -/// -/// # Implementation notes -/// -/// For simplicity, we use the same generic struct `UniformInt` for all -/// integer types `X`. This gives us only one field type, `X`; to store unsigned -/// values of this size, we take use the fact that these conversions are no-ops. -/// -/// For a closed range, the number of possible numbers we should generate is -/// `range = (high - low + 1)`. To avoid bias, we must ensure that the size of -/// our sample space, `zone`, is a multiple of `range`; other values must be -/// rejected (by replacing with a new random sample). -/// -/// As a special case, we use `range = 0` to represent the full range of the -/// result type (i.e. for `new_inclusive($ty::MIN, $ty::MAX)`). -/// -/// The optimum `zone` is the largest product of `range` which fits in our -/// (unsigned) target type. We calculate this by calculating how many numbers we -/// must reject: `reject = (MAX + 1) % range = (MAX - range + 1) % range`. Any (large) -/// product of `range` will suffice, thus in `sample_single` we multiply by a -/// power of 2 via bit-shifting (faster but may cause more rejections). -/// -/// The smallest integer PRNGs generate is `u32`. For 8- and 16-bit outputs we -/// use `u32` for our `zone` and samples (because it's not slower and because -/// it reduces the chance of having to reject a sample). In this case we cannot -/// store `zone` in the target type since it is too large, however we know -/// `ints_to_reject < range <= $uty::MAX`. -/// -/// An alternative to using a modulus is widening multiply: After a widening -/// multiply by `range`, the result is in the high word. Then comparing the low -/// word against `zone` makes sure our distribution is uniform. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub struct UniformInt { - low: X, - range: X, - thresh: X, // effectively 2.pow(max(64, uty_bits)) % range -} - -macro_rules! uniform_int_impl { - ($ty:ty, $uty:ty, $sample_ty:ident) => { - impl SampleUniform for $ty { - type Sampler = UniformInt<$ty>; - } - - impl UniformSampler for UniformInt<$ty> { - // We play free and fast with unsigned vs signed here - // (when $ty is signed), but that's fine, since the - // contract of this macro is for $ty and $uty to be - // "bit-equal", so casting between them is a no-op. - - type X = $ty; - - #[inline] // if the range is constant, this helps LLVM to do the - // calculations at compile-time. - fn new(low_b: B1, high_b: B2) -> Result - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - let low = *low_b.borrow(); - let high = *high_b.borrow(); - if !(low < high) { - return Err(Error::EmptyRange); - } - UniformSampler::new_inclusive(low, high - 1) - } - - #[inline] // if the range is constant, this helps LLVM to do the - // calculations at compile-time. - fn new_inclusive(low_b: B1, high_b: B2) -> Result - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - let low = *low_b.borrow(); - let high = *high_b.borrow(); - if !(low <= high) { - return Err(Error::EmptyRange); - } - - let range = high.wrapping_sub(low).wrapping_add(1) as $uty; - let thresh = if range > 0 { - let range = $sample_ty::from(range); - (range.wrapping_neg() % range) - } else { - 0 - }; - - Ok(UniformInt { - low, - range: range as $ty, // type: $uty - thresh: thresh as $uty as $ty, // type: $sample_ty - }) - } - - /// Sample from distribution, Lemire's method, unbiased - #[inline] - fn sample(&self, rng: &mut R) -> Self::X { - let range = self.range as $uty as $sample_ty; - if range == 0 { - return rng.gen(); - } - - let thresh = self.thresh as $uty as $sample_ty; - let hi = loop { - let (hi, lo) = rng.gen::<$sample_ty>().wmul(range); - if lo >= thresh { - break hi; - } - }; - self.low.wrapping_add(hi as $ty) - } - - #[inline] - fn sample_single(low_b: B1, high_b: B2, rng: &mut R) -> Result - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - let low = *low_b.borrow(); - let high = *high_b.borrow(); - if !(low < high) { - return Err(Error::EmptyRange); - } - Self::sample_single_inclusive(low, high - 1, rng) - } - - /// Sample single value, Canon's method, biased - /// - /// In the worst case, bias affects 1 in `2^n` samples where n is - /// 56 (`i8`), 48 (`i16`), 96 (`i32`), 64 (`i64`), 128 (`i128`). - #[cfg(not(feature = "unbiased"))] - #[inline] - fn sample_single_inclusive( - low_b: B1, high_b: B2, rng: &mut R, - ) -> Result - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - let low = *low_b.borrow(); - let high = *high_b.borrow(); - if !(low <= high) { - return Err(Error::EmptyRange); - } - let range = high.wrapping_sub(low).wrapping_add(1) as $uty as $sample_ty; - if range == 0 { - // Range is MAX+1 (unrepresentable), so we need a special case - return Ok(rng.gen()); - } - - // generate a sample using a sensible integer type - let (mut result, lo_order) = rng.gen::<$sample_ty>().wmul(range); - - // if the sample is biased... - if lo_order > range.wrapping_neg() { - // ...generate a new sample to reduce bias... - let (new_hi_order, _) = (rng.gen::<$sample_ty>()).wmul(range as $sample_ty); - // ... incrementing result on overflow - let is_overflow = lo_order.checked_add(new_hi_order as $sample_ty).is_none(); - result += is_overflow as $sample_ty; - } - - Ok(low.wrapping_add(result as $ty)) - } - - /// Sample single value, Canon's method, unbiased - #[cfg(feature = "unbiased")] - #[inline] - fn sample_single_inclusive( - low_b: B1, high_b: B2, rng: &mut R, - ) -> Result - where - B1: SampleBorrow<$ty> + Sized, - B2: SampleBorrow<$ty> + Sized, - { - let low = *low_b.borrow(); - let high = *high_b.borrow(); - if !(low <= high) { - return Err(Error::EmptyRange); - } - let range = high.wrapping_sub(low).wrapping_add(1) as $uty as $sample_ty; - if range == 0 { - // Range is MAX+1 (unrepresentable), so we need a special case - return Ok(rng.gen()); - } - - let (mut result, mut lo) = rng.gen::<$sample_ty>().wmul(range); - - // In constrast to the biased sampler, we use a loop: - while lo > range.wrapping_neg() { - let (new_hi, new_lo) = (rng.gen::<$sample_ty>()).wmul(range); - match lo.checked_add(new_hi) { - Some(x) if x < $sample_ty::MAX => { - // Anything less than MAX: last term is 0 - break; - } - None => { - // Overflow: last term is 1 - result += 1; - break; - } - _ => { - // Unlikely case: must check next sample - lo = new_lo; - continue; - } - } - } - - Ok(low.wrapping_add(result as $ty)) - } - } - }; -} - -uniform_int_impl! { i8, u8, u32 } -uniform_int_impl! { i16, u16, u32 } -uniform_int_impl! { i32, u32, u32 } -uniform_int_impl! { i64, u64, u64 } -uniform_int_impl! { i128, u128, u128 } -uniform_int_impl! { isize, usize, usize } -uniform_int_impl! { u8, u8, u32 } -uniform_int_impl! { u16, u16, u32 } -uniform_int_impl! { u32, u32, u32 } -uniform_int_impl! { u64, u64, u64 } -uniform_int_impl! { usize, usize, usize } -uniform_int_impl! { u128, u128, u128 } - -#[cfg(feature = "simd_support")] -macro_rules! uniform_simd_int_impl { - ($ty:ident, $unsigned:ident) => { - // The "pick the largest zone that can fit in an `u32`" optimization - // is less useful here. Multiple lanes complicate things, we don't - // know the PRNG's minimal output size, and casting to a larger vector - // is generally a bad idea for SIMD performance. The user can still - // implement it manually. - impl SampleUniform for Simd<$ty, LANES> - where - LaneCount: SupportedLaneCount, - Simd<$unsigned, LANES>: - WideningMultiply, Simd<$unsigned, LANES>)>, - Standard: Distribution>, - { - type Sampler = UniformInt>; - } - - impl UniformSampler for UniformInt> - where - LaneCount: SupportedLaneCount, - Simd<$unsigned, LANES>: - WideningMultiply, Simd<$unsigned, LANES>)>, - Standard: Distribution>, - { - type X = Simd<$ty, LANES>; - - #[inline] // if the range is constant, this helps LLVM to do the - // calculations at compile-time. - fn new(low_b: B1, high_b: B2) -> Result - where B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized - { - let low = *low_b.borrow(); - let high = *high_b.borrow(); - if !(low.simd_lt(high).all()) { - return Err(Error::EmptyRange); - } - UniformSampler::new_inclusive(low, high - Simd::splat(1)) - } - - #[inline] // if the range is constant, this helps LLVM to do the - // calculations at compile-time. - fn new_inclusive(low_b: B1, high_b: B2) -> Result - where B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized - { - let low = *low_b.borrow(); - let high = *high_b.borrow(); - if !(low.simd_le(high).all()) { - return Err(Error::EmptyRange); - } - let unsigned_max = Simd::splat(::core::$unsigned::MAX); - - // NOTE: all `Simd` operations are inherently wrapping, - // see https://doc.rust-lang.org/std/simd/struct.Simd.html - let range: Simd<$unsigned, LANES> = ((high - low) + Simd::splat(1)).cast(); - // `% 0` will panic at runtime. - let not_full_range = range.simd_gt(Simd::splat(0)); - // replacing 0 with `unsigned_max` allows a faster `select` - // with bitwise OR - let modulo = not_full_range.select(range, unsigned_max); - // wrapping addition - // TODO: replace with `range.wrapping_neg() % module` when Simd supports this. - let ints_to_reject = (Simd::splat(0) - range) % modulo; - // When `range` is 0, `lo` of `v.wmul(range)` will always be - // zero which means only one sample is needed. - - Ok(UniformInt { - low, - // These are really $unsigned values, but store as $ty: - range: range.cast(), - thresh: ints_to_reject.cast(), - }) - } - - fn sample(&self, rng: &mut R) -> Self::X { - let range: Simd<$unsigned, LANES> = self.range.cast(); - let thresh: Simd<$unsigned, LANES> = self.thresh.cast(); - - // This might seem very slow, generating a whole new - // SIMD vector for every sample rejection. For most uses - // though, the chance of rejection is small and provides good - // general performance. With multiple lanes, that chance is - // multiplied. To mitigate this, we replace only the lanes of - // the vector which fail, iteratively reducing the chance of - // rejection. The replacement method does however add a little - // overhead. Benchmarking or calculating probabilities might - // reveal contexts where this replacement method is slower. - let mut v: Simd<$unsigned, LANES> = rng.gen(); - loop { - let (hi, lo) = v.wmul(range); - let mask = lo.simd_ge(thresh); - if mask.all() { - let hi: Simd<$ty, LANES> = hi.cast(); - // wrapping addition - let result = self.low + hi; - // `select` here compiles to a blend operation - // When `range.eq(0).none()` the compare and blend - // operations are avoided. - let v: Simd<$ty, LANES> = v.cast(); - return range.simd_gt(Simd::splat(0)).select(result, v); - } - // Replace only the failing lanes - v = mask.select(v, rng.gen()); - } - } - } - }; - - // bulk implementation - ($(($unsigned:ident, $signed:ident)),+) => { - $( - uniform_simd_int_impl!($unsigned, $unsigned); - uniform_simd_int_impl!($signed, $unsigned); - )+ - }; -} - -#[cfg(feature = "simd_support")] -uniform_simd_int_impl! { (u8, i8), (u16, i16), (u32, i32), (u64, i64) } - -impl SampleUniform for char { - type Sampler = UniformChar; -} - -/// The back-end implementing [`UniformSampler`] for `char`. -/// -/// Unless you are implementing [`UniformSampler`] for your own type, this type -/// should not be used directly, use [`Uniform`] instead. -/// -/// This differs from integer range sampling since the range `0xD800..=0xDFFF` -/// are used for surrogate pairs in UCS and UTF-16, and consequently are not -/// valid Unicode code points. We must therefore avoid sampling values in this -/// range. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub struct UniformChar { - sampler: UniformInt, -} - -/// UTF-16 surrogate range start -const CHAR_SURROGATE_START: u32 = 0xD800; -/// UTF-16 surrogate range size -const CHAR_SURROGATE_LEN: u32 = 0xE000 - CHAR_SURROGATE_START; - -/// Convert `char` to compressed `u32` -fn char_to_comp_u32(c: char) -> u32 { - match c as u32 { - c if c >= CHAR_SURROGATE_START => c - CHAR_SURROGATE_LEN, - c => c, - } -} - -impl UniformSampler for UniformChar { - type X = char; - - #[inline] // if the range is constant, this helps LLVM to do the - // calculations at compile-time. - fn new(low_b: B1, high_b: B2) -> Result - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - let low = char_to_comp_u32(*low_b.borrow()); - let high = char_to_comp_u32(*high_b.borrow()); - let sampler = UniformInt::::new(low, high); - sampler.map(|sampler| UniformChar { sampler }) - } - - #[inline] // if the range is constant, this helps LLVM to do the - // calculations at compile-time. - fn new_inclusive(low_b: B1, high_b: B2) -> Result - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - let low = char_to_comp_u32(*low_b.borrow()); - let high = char_to_comp_u32(*high_b.borrow()); - let sampler = UniformInt::::new_inclusive(low, high); - sampler.map(|sampler| UniformChar { sampler }) - } - - fn sample(&self, rng: &mut R) -> Self::X { - let mut x = self.sampler.sample(rng); - if x >= CHAR_SURROGATE_START { - x += CHAR_SURROGATE_LEN; - } - // SAFETY: x must not be in surrogate range or greater than char::MAX. - // This relies on range constructors which accept char arguments. - // Validity of input char values is assumed. - unsafe { core::char::from_u32_unchecked(x) } - } -} - -/// Note: the `String` is potentially left with excess capacity if the range -/// includes non ascii chars; optionally the user may call -/// `string.shrink_to_fit()` afterwards. -#[cfg(feature = "alloc")] -impl super::DistString for Uniform{ - fn append_string(&self, rng: &mut R, string: &mut alloc::string::String, len: usize) { - // Getting the hi value to assume the required length to reserve in string. - let mut hi = self.0.sampler.low + self.0.sampler.range - 1; - if hi >= CHAR_SURROGATE_START { - hi += CHAR_SURROGATE_LEN; - } - // Get the utf8 length of hi to minimize extra space. - let max_char_len = char::from_u32(hi).map(char::len_utf8).unwrap_or(4); - string.reserve(max_char_len * len); - string.extend(self.sample_iter(rng).take(len)) - } -} - -/// The back-end implementing [`UniformSampler`] for floating-point types. -/// -/// Unless you are implementing [`UniformSampler`] for your own type, this type -/// should not be used directly, use [`Uniform`] instead. -/// -/// # Implementation notes -/// -/// Instead of generating a float in the `[0, 1)` range using [`Standard`], the -/// `UniformFloat` implementation converts the output of an PRNG itself. This -/// way one or two steps can be optimized out. -/// -/// The floats are first converted to a value in the `[1, 2)` interval using a -/// transmute-based method, and then mapped to the expected range with a -/// multiply and addition. Values produced this way have what equals 23 bits of -/// random digits for an `f32`, and 52 for an `f64`. -/// -/// [`new`]: UniformSampler::new -/// [`new_inclusive`]: UniformSampler::new_inclusive -/// [`Standard`]: crate::distributions::Standard -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub struct UniformFloat { - low: X, - scale: X, -} - -macro_rules! uniform_float_impl { - ($ty:ty, $uty:ident, $f_scalar:ident, $u_scalar:ident, $bits_to_discard:expr) => { - impl SampleUniform for $ty { - type Sampler = UniformFloat<$ty>; - } - - impl UniformSampler for UniformFloat<$ty> { - type X = $ty; - - fn new(low_b: B1, high_b: B2) -> Result - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - let low = *low_b.borrow(); - let high = *high_b.borrow(); - #[cfg(debug_assertions)] - if !(low.all_finite()) || !(high.all_finite()) { - return Err(Error::NonFinite); - } - if !(low.all_lt(high)) { - return Err(Error::EmptyRange); - } - let max_rand = <$ty>::splat( - (::core::$u_scalar::MAX >> $bits_to_discard).into_float_with_exponent(0) - 1.0, - ); - - let mut scale = high - low; - if !(scale.all_finite()) { - return Err(Error::NonFinite); - } - - loop { - let mask = (scale * max_rand + low).ge_mask(high); - if !mask.any() { - break; - } - scale = scale.decrease_masked(mask); - } - - debug_assert!(<$ty>::splat(0.0).all_le(scale)); - - Ok(UniformFloat { low, scale }) - } - - fn new_inclusive(low_b: B1, high_b: B2) -> Result - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - let low = *low_b.borrow(); - let high = *high_b.borrow(); - #[cfg(debug_assertions)] - if !(low.all_finite()) || !(high.all_finite()) { - return Err(Error::NonFinite); - } - if !low.all_le(high) { - return Err(Error::EmptyRange); - } - let max_rand = <$ty>::splat( - (::core::$u_scalar::MAX >> $bits_to_discard).into_float_with_exponent(0) - 1.0, - ); - - let mut scale = (high - low) / max_rand; - if !scale.all_finite() { - return Err(Error::NonFinite); - } - - loop { - let mask = (scale * max_rand + low).gt_mask(high); - if !mask.any() { - break; - } - scale = scale.decrease_masked(mask); - } - - debug_assert!(<$ty>::splat(0.0).all_le(scale)); - - Ok(UniformFloat { low, scale }) - } - - fn sample(&self, rng: &mut R) -> Self::X { - // Generate a value in the range [1, 2) - let value1_2 = (rng.gen::<$uty>() >> $uty::splat($bits_to_discard)).into_float_with_exponent(0); - - // Get a value in the range [0, 1) to avoid overflow when multiplying by scale - let value0_1 = value1_2 - <$ty>::splat(1.0); - - // We don't use `f64::mul_add`, because it is not available with - // `no_std`. Furthermore, it is slower for some targets (but - // faster for others). However, the order of multiplication and - // addition is important, because on some platforms (e.g. ARM) - // it will be optimized to a single (non-FMA) instruction. - value0_1 * self.scale + self.low - } - - #[inline] - fn sample_single(low_b: B1, high_b: B2, rng: &mut R) -> Result - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - let low = *low_b.borrow(); - let high = *high_b.borrow(); - #[cfg(debug_assertions)] - if !low.all_finite() || !high.all_finite() { - return Err(Error::NonFinite); - } - if !low.all_lt(high) { - return Err(Error::EmptyRange); - } - let mut scale = high - low; - if !scale.all_finite() { - return Err(Error::NonFinite); - } - - loop { - // Generate a value in the range [1, 2) - let value1_2 = - (rng.gen::<$uty>() >> $uty::splat($bits_to_discard)).into_float_with_exponent(0); - - // Get a value in the range [0, 1) to avoid overflow when multiplying by scale - let value0_1 = value1_2 - <$ty>::splat(1.0); - - // Doing multiply before addition allows some architectures - // to use a single instruction. - let res = value0_1 * scale + low; - - debug_assert!(low.all_le(res) || !scale.all_finite()); - if res.all_lt(high) { - return Ok(res); - } - - // This handles a number of edge cases. - // * `low` or `high` is NaN. In this case `scale` and - // `res` are going to end up as NaN. - // * `low` is negative infinity and `high` is finite. - // `scale` is going to be infinite and `res` will be - // NaN. - // * `high` is positive infinity and `low` is finite. - // `scale` is going to be infinite and `res` will - // be infinite or NaN (if value0_1 is 0). - // * `low` is negative infinity and `high` is positive - // infinity. `scale` will be infinite and `res` will - // be NaN. - // * `low` and `high` are finite, but `high - low` - // overflows to infinite. `scale` will be infinite - // and `res` will be infinite or NaN (if value0_1 is 0). - // So if `high` or `low` are non-finite, we are guaranteed - // to fail the `res < high` check above and end up here. - // - // While we technically should check for non-finite `low` - // and `high` before entering the loop, by doing the checks - // here instead, we allow the common case to avoid these - // checks. But we are still guaranteed that if `low` or - // `high` are non-finite we'll end up here and can do the - // appropriate checks. - // - // Likewise, `high - low` overflowing to infinity is also - // rare, so handle it here after the common case. - let mask = !scale.finite_mask(); - if mask.any() { - if !(low.all_finite() && high.all_finite()) { - return Err(Error::NonFinite); - } - scale = scale.decrease_masked(mask); - } - } - } - - #[inline] - fn sample_single_inclusive(low_b: B1, high_b: B2, rng: &mut R) -> Result - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - let low = *low_b.borrow(); - let high = *high_b.borrow(); - #[cfg(debug_assertions)] - if !low.all_finite() || !high.all_finite() { - return Err(Error::NonFinite); - } - if !low.all_le(high) { - return Err(Error::EmptyRange); - } - let scale = high - low; - if !scale.all_finite() { - return Err(Error::NonFinite); - } - - // Generate a value in the range [1, 2) - let value1_2 = - (rng.gen::<$uty>() >> $uty::splat($bits_to_discard)).into_float_with_exponent(0); - - // Get a value in the range [0, 1) to avoid overflow when multiplying by scale - let value0_1 = value1_2 - <$ty>::splat(1.0); - - // Doing multiply before addition allows some architectures - // to use a single instruction. - Ok(value0_1 * scale + low) - } - } - }; -} - -uniform_float_impl! { f32, u32, f32, u32, 32 - 23 } -uniform_float_impl! { f64, u64, f64, u64, 64 - 52 } - -#[cfg(feature = "simd_support")] -uniform_float_impl! { f32x2, u32x2, f32, u32, 32 - 23 } -#[cfg(feature = "simd_support")] -uniform_float_impl! { f32x4, u32x4, f32, u32, 32 - 23 } -#[cfg(feature = "simd_support")] -uniform_float_impl! { f32x8, u32x8, f32, u32, 32 - 23 } -#[cfg(feature = "simd_support")] -uniform_float_impl! { f32x16, u32x16, f32, u32, 32 - 23 } - -#[cfg(feature = "simd_support")] -uniform_float_impl! { f64x2, u64x2, f64, u64, 64 - 52 } -#[cfg(feature = "simd_support")] -uniform_float_impl! { f64x4, u64x4, f64, u64, 64 - 52 } -#[cfg(feature = "simd_support")] -uniform_float_impl! { f64x8, u64x8, f64, u64, 64 - 52 } - - -/// The back-end implementing [`UniformSampler`] for `Duration`. -/// -/// Unless you are implementing [`UniformSampler`] for your own types, this type -/// should not be used directly, use [`Uniform`] instead. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub struct UniformDuration { - mode: UniformDurationMode, - offset: u32, -} - -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -enum UniformDurationMode { - Small { - secs: u64, - nanos: Uniform, - }, - Medium { - nanos: Uniform, - }, - Large { - max_secs: u64, - max_nanos: u32, - secs: Uniform, - }, -} - -impl SampleUniform for Duration { - type Sampler = UniformDuration; -} - -impl UniformSampler for UniformDuration { - type X = Duration; - - #[inline] - fn new(low_b: B1, high_b: B2) -> Result - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - let low = *low_b.borrow(); - let high = *high_b.borrow(); - if !(low < high) { - return Err(Error::EmptyRange); - } - UniformDuration::new_inclusive(low, high - Duration::new(0, 1)) - } - - #[inline] - fn new_inclusive(low_b: B1, high_b: B2) -> Result - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - let low = *low_b.borrow(); - let high = *high_b.borrow(); - if !(low <= high) { - return Err(Error::EmptyRange); - } - - let low_s = low.as_secs(); - let low_n = low.subsec_nanos(); - let mut high_s = high.as_secs(); - let mut high_n = high.subsec_nanos(); - - if high_n < low_n { - high_s -= 1; - high_n += 1_000_000_000; - } - - let mode = if low_s == high_s { - UniformDurationMode::Small { - secs: low_s, - nanos: Uniform::new_inclusive(low_n, high_n)?, - } - } else { - let max = high_s - .checked_mul(1_000_000_000) - .and_then(|n| n.checked_add(u64::from(high_n))); - - if let Some(higher_bound) = max { - let lower_bound = low_s * 1_000_000_000 + u64::from(low_n); - UniformDurationMode::Medium { - nanos: Uniform::new_inclusive(lower_bound, higher_bound)?, - } - } else { - // An offset is applied to simplify generation of nanoseconds - let max_nanos = high_n - low_n; - UniformDurationMode::Large { - max_secs: high_s, - max_nanos, - secs: Uniform::new_inclusive(low_s, high_s)?, - } - } - }; - Ok(UniformDuration { - mode, - offset: low_n, - }) - } - - #[inline] - fn sample(&self, rng: &mut R) -> Duration { - match self.mode { - UniformDurationMode::Small { secs, nanos } => { - let n = nanos.sample(rng); - Duration::new(secs, n) - } - UniformDurationMode::Medium { nanos } => { - let nanos = nanos.sample(rng); - Duration::new(nanos / 1_000_000_000, (nanos % 1_000_000_000) as u32) - } - UniformDurationMode::Large { - max_secs, - max_nanos, - secs, - } => { - // constant folding means this is at least as fast as `Rng::sample(Range)` - let nano_range = Uniform::new(0, 1_000_000_000).unwrap(); - loop { - let s = secs.sample(rng); - let n = nano_range.sample(rng); - if !(s == max_secs && n > max_nanos) { - let sum = n + self.offset; - break Duration::new(s, sum); - } - } - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::rngs::mock::StepRng; - use crate::distributions::utils::FloatSIMDScalarUtils; - - #[test] - #[cfg(feature = "serde1")] - fn test_serialization_uniform_duration() { - let distr = UniformDuration::new(Duration::from_secs(10), Duration::from_secs(60)).unwrap(); - let de_distr: UniformDuration = bincode::deserialize(&bincode::serialize(&distr).unwrap()).unwrap(); - assert_eq!(distr, de_distr); - } - - #[test] - #[cfg(feature = "serde1")] - fn test_uniform_serialization() { - let unit_box: Uniform = Uniform::new(-1, 1).unwrap(); - let de_unit_box: Uniform = bincode::deserialize(&bincode::serialize(&unit_box).unwrap()).unwrap(); - assert_eq!(unit_box.0, de_unit_box.0); - - let unit_box: Uniform = Uniform::new(-1., 1.).unwrap(); - let de_unit_box: Uniform = bincode::deserialize(&bincode::serialize(&unit_box).unwrap()).unwrap(); - assert_eq!(unit_box.0, de_unit_box.0); - } - - #[test] - fn test_uniform_bad_limits_equal_int() { - assert_eq!(Uniform::new(10, 10), Err(Error::EmptyRange)); - } - - #[test] - fn test_uniform_good_limits_equal_int() { - let mut rng = crate::test::rng(804); - let dist = Uniform::new_inclusive(10, 10).unwrap(); - for _ in 0..20 { - assert_eq!(rng.sample(dist), 10); - } - } - - #[test] - fn test_uniform_bad_limits_flipped_int() { - assert_eq!(Uniform::new(10, 5), Err(Error::EmptyRange)); - } - - #[test] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_integers() { - use core::{i128, u128}; - use core::{i16, i32, i64, i8, isize}; - use core::{u16, u32, u64, u8, usize}; - - let mut rng = crate::test::rng(251); - macro_rules! t { - ($ty:ident, $v:expr, $le:expr, $lt:expr) => {{ - for &(low, high) in $v.iter() { - let my_uniform = Uniform::new(low, high).unwrap(); - for _ in 0..1000 { - let v: $ty = rng.sample(my_uniform); - assert!($le(low, v) && $lt(v, high)); - } - - let my_uniform = Uniform::new_inclusive(low, high).unwrap(); - for _ in 0..1000 { - let v: $ty = rng.sample(my_uniform); - assert!($le(low, v) && $le(v, high)); - } - - let my_uniform = Uniform::new(&low, high).unwrap(); - for _ in 0..1000 { - let v: $ty = rng.sample(my_uniform); - assert!($le(low, v) && $lt(v, high)); - } - - let my_uniform = Uniform::new_inclusive(&low, &high).unwrap(); - for _ in 0..1000 { - let v: $ty = rng.sample(my_uniform); - assert!($le(low, v) && $le(v, high)); - } - - for _ in 0..1000 { - let v = <$ty as SampleUniform>::Sampler::sample_single(low, high, &mut rng).unwrap(); - assert!($le(low, v) && $lt(v, high)); - } - - for _ in 0..1000 { - let v = <$ty as SampleUniform>::Sampler::sample_single_inclusive(low, high, &mut rng).unwrap(); - assert!($le(low, v) && $le(v, high)); - } - } - }}; - - // scalar bulk - ($($ty:ident),*) => {{ - $(t!( - $ty, - [(0, 10), (10, 127), ($ty::MIN, $ty::MAX)], - |x, y| x <= y, - |x, y| x < y - );)* - }}; - - // simd bulk - ($($ty:ident),* => $scalar:ident) => {{ - $(t!( - $ty, - [ - ($ty::splat(0), $ty::splat(10)), - ($ty::splat(10), $ty::splat(127)), - ($ty::splat($scalar::MIN), $ty::splat($scalar::MAX)), - ], - |x: $ty, y| x.simd_le(y).all(), - |x: $ty, y| x.simd_lt(y).all() - );)* - }}; - } - t!(i8, i16, i32, i64, isize, u8, u16, u32, u64, usize, i128, u128); - - #[cfg(feature = "simd_support")] - { - t!(u8x4, u8x8, u8x16, u8x32, u8x64 => u8); - t!(i8x4, i8x8, i8x16, i8x32, i8x64 => i8); - t!(u16x2, u16x4, u16x8, u16x16, u16x32 => u16); - t!(i16x2, i16x4, i16x8, i16x16, i16x32 => i16); - t!(u32x2, u32x4, u32x8, u32x16 => u32); - t!(i32x2, i32x4, i32x8, i32x16 => i32); - t!(u64x2, u64x4, u64x8 => u64); - t!(i64x2, i64x4, i64x8 => i64); - } - } - - #[test] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_char() { - let mut rng = crate::test::rng(891); - let mut max = core::char::from_u32(0).unwrap(); - for _ in 0..100 { - let c = rng.gen_range('A'..='Z'); - assert!(('A'..='Z').contains(&c)); - max = max.max(c); - } - assert_eq!(max, 'Z'); - let d = Uniform::new( - core::char::from_u32(0xD7F0).unwrap(), - core::char::from_u32(0xE010).unwrap(), - ).unwrap(); - for _ in 0..100 { - let c = d.sample(&mut rng); - assert!((c as u32) < 0xD800 || (c as u32) > 0xDFFF); - } - #[cfg(feature = "alloc")] - { - use crate::distributions::DistString; - let string1 = d.sample_string(&mut rng, 100); - assert_eq!(string1.capacity(), 300); - let string2 = Uniform::new( - core::char::from_u32(0x0000).unwrap(), - core::char::from_u32(0x0080).unwrap(), - ).unwrap().sample_string(&mut rng, 100); - assert_eq!(string2.capacity(), 100); - let string3 = Uniform::new_inclusive( - core::char::from_u32(0x0000).unwrap(), - core::char::from_u32(0x0080).unwrap(), - ).unwrap().sample_string(&mut rng, 100); - assert_eq!(string3.capacity(), 200); - } - } - - #[test] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_floats() { - let mut rng = crate::test::rng(252); - let mut zero_rng = StepRng::new(0, 0); - let mut max_rng = StepRng::new(0xffff_ffff_ffff_ffff, 0); - macro_rules! t { - ($ty:ty, $f_scalar:ident, $bits_shifted:expr) => {{ - let v: &[($f_scalar, $f_scalar)] = &[ - (0.0, 100.0), - (-1e35, -1e25), - (1e-35, 1e-25), - (-1e35, 1e35), - (<$f_scalar>::from_bits(0), <$f_scalar>::from_bits(3)), - (-<$f_scalar>::from_bits(10), -<$f_scalar>::from_bits(1)), - (-<$f_scalar>::from_bits(5), 0.0), - (-<$f_scalar>::from_bits(7), -0.0), - (0.1 * ::core::$f_scalar::MAX, ::core::$f_scalar::MAX), - (-::core::$f_scalar::MAX * 0.2, ::core::$f_scalar::MAX * 0.7), - ]; - for &(low_scalar, high_scalar) in v.iter() { - for lane in 0..<$ty>::LEN { - let low = <$ty>::splat(0.0 as $f_scalar).replace(lane, low_scalar); - let high = <$ty>::splat(1.0 as $f_scalar).replace(lane, high_scalar); - let my_uniform = Uniform::new(low, high).unwrap(); - let my_incl_uniform = Uniform::new_inclusive(low, high).unwrap(); - for _ in 0..100 { - let v = rng.sample(my_uniform).extract(lane); - assert!(low_scalar <= v && v < high_scalar); - let v = rng.sample(my_incl_uniform).extract(lane); - assert!(low_scalar <= v && v <= high_scalar); - let v = <$ty as SampleUniform>::Sampler - ::sample_single(low, high, &mut rng).unwrap().extract(lane); - assert!(low_scalar <= v && v < high_scalar); - let v = <$ty as SampleUniform>::Sampler - ::sample_single_inclusive(low, high, &mut rng).unwrap().extract(lane); - assert!(low_scalar <= v && v <= high_scalar); - } - - assert_eq!( - rng.sample(Uniform::new_inclusive(low, low).unwrap()).extract(lane), - low_scalar - ); - - assert_eq!(zero_rng.sample(my_uniform).extract(lane), low_scalar); - assert_eq!(zero_rng.sample(my_incl_uniform).extract(lane), low_scalar); - assert_eq!(<$ty as SampleUniform>::Sampler - ::sample_single(low, high, &mut zero_rng).unwrap() - .extract(lane), low_scalar); - assert_eq!(<$ty as SampleUniform>::Sampler - ::sample_single_inclusive(low, high, &mut zero_rng).unwrap() - .extract(lane), low_scalar); - - assert!(max_rng.sample(my_uniform).extract(lane) < high_scalar); - assert!(max_rng.sample(my_incl_uniform).extract(lane) <= high_scalar); - // sample_single cannot cope with max_rng: - // assert!(<$ty as SampleUniform>::Sampler - // ::sample_single(low, high, &mut max_rng).unwrap() - // .extract(lane) < high_scalar); - assert!(<$ty as SampleUniform>::Sampler - ::sample_single_inclusive(low, high, &mut max_rng).unwrap() - .extract(lane) <= high_scalar); - - // Don't run this test for really tiny differences between high and low - // since for those rounding might result in selecting high for a very - // long time. - if (high_scalar - low_scalar) > 0.0001 { - let mut lowering_max_rng = StepRng::new( - 0xffff_ffff_ffff_ffff, - (-1i64 << $bits_shifted) as u64, - ); - assert!( - <$ty as SampleUniform>::Sampler - ::sample_single(low, high, &mut lowering_max_rng).unwrap() - .extract(lane) < high_scalar - ); - } - } - } - - assert_eq!( - rng.sample(Uniform::new_inclusive( - ::core::$f_scalar::MAX, - ::core::$f_scalar::MAX - ).unwrap()), - ::core::$f_scalar::MAX - ); - assert_eq!( - rng.sample(Uniform::new_inclusive( - -::core::$f_scalar::MAX, - -::core::$f_scalar::MAX - ).unwrap()), - -::core::$f_scalar::MAX - ); - }}; - } - - t!(f32, f32, 32 - 23); - t!(f64, f64, 64 - 52); - #[cfg(feature = "simd_support")] - { - t!(f32x2, f32, 32 - 23); - t!(f32x4, f32, 32 - 23); - t!(f32x8, f32, 32 - 23); - t!(f32x16, f32, 32 - 23); - t!(f64x2, f64, 64 - 52); - t!(f64x4, f64, 64 - 52); - t!(f64x8, f64, 64 - 52); - } - } - - #[test] - fn test_float_overflow() { - assert_eq!(Uniform::try_from(::core::f64::MIN..::core::f64::MAX), Err(Error::NonFinite)); - } - - #[test] - #[should_panic] - fn test_float_overflow_single() { - let mut rng = crate::test::rng(252); - rng.gen_range(::core::f64::MIN..::core::f64::MAX); - } - - #[test] - #[cfg(all(feature = "std", panic = "unwind"))] - fn test_float_assertions() { - use super::SampleUniform; - use std::panic::catch_unwind; - fn range(low: T, high: T) { - let mut rng = crate::test::rng(253); - T::Sampler::sample_single(low, high, &mut rng).unwrap(); - } - - macro_rules! t { - ($ty:ident, $f_scalar:ident) => {{ - let v: &[($f_scalar, $f_scalar)] = &[ - (::std::$f_scalar::NAN, 0.0), - (1.0, ::std::$f_scalar::NAN), - (::std::$f_scalar::NAN, ::std::$f_scalar::NAN), - (1.0, 0.5), - (::std::$f_scalar::MAX, -::std::$f_scalar::MAX), - (::std::$f_scalar::INFINITY, ::std::$f_scalar::INFINITY), - ( - ::std::$f_scalar::NEG_INFINITY, - ::std::$f_scalar::NEG_INFINITY, - ), - (::std::$f_scalar::NEG_INFINITY, 5.0), - (5.0, ::std::$f_scalar::INFINITY), - (::std::$f_scalar::NAN, ::std::$f_scalar::INFINITY), - (::std::$f_scalar::NEG_INFINITY, ::std::$f_scalar::NAN), - (::std::$f_scalar::NEG_INFINITY, ::std::$f_scalar::INFINITY), - ]; - for &(low_scalar, high_scalar) in v.iter() { - for lane in 0..<$ty>::LEN { - let low = <$ty>::splat(0.0 as $f_scalar).replace(lane, low_scalar); - let high = <$ty>::splat(1.0 as $f_scalar).replace(lane, high_scalar); - assert!(catch_unwind(|| range(low, high)).is_err()); - assert!(Uniform::new(low, high).is_err()); - assert!(Uniform::new_inclusive(low, high).is_err()); - assert!(catch_unwind(|| range(low, low)).is_err()); - assert!(Uniform::new(low, low).is_err()); - } - } - }}; - } - - t!(f32, f32); - t!(f64, f64); - #[cfg(feature = "simd_support")] - { - t!(f32x2, f32); - t!(f32x4, f32); - t!(f32x8, f32); - t!(f32x16, f32); - t!(f64x2, f64); - t!(f64x4, f64); - t!(f64x8, f64); - } - } - - - #[test] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_durations() { - let mut rng = crate::test::rng(253); - - let v = &[ - (Duration::new(10, 50000), Duration::new(100, 1234)), - (Duration::new(0, 100), Duration::new(1, 50)), - ( - Duration::new(0, 0), - Duration::new(u64::max_value(), 999_999_999), - ), - ]; - for &(low, high) in v.iter() { - let my_uniform = Uniform::new(low, high).unwrap(); - for _ in 0..1000 { - let v = rng.sample(my_uniform); - assert!(low <= v && v < high); - } - } - } - - #[test] - fn test_custom_uniform() { - use crate::distributions::uniform::{ - SampleBorrow, SampleUniform, UniformFloat, UniformSampler, - }; - #[derive(Clone, Copy, PartialEq, PartialOrd)] - struct MyF32 { - x: f32, - } - #[derive(Clone, Copy, Debug)] - struct UniformMyF32(UniformFloat); - impl UniformSampler for UniformMyF32 { - type X = MyF32; - - fn new(low: B1, high: B2) -> Result - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - UniformFloat::::new(low.borrow().x, high.borrow().x).map(UniformMyF32) - } - - fn new_inclusive(low: B1, high: B2) -> Result - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - UniformSampler::new(low, high) - } - - fn sample(&self, rng: &mut R) -> Self::X { - MyF32 { - x: self.0.sample(rng), - } - } - } - impl SampleUniform for MyF32 { - type Sampler = UniformMyF32; - } - - let (low, high) = (MyF32 { x: 17.0f32 }, MyF32 { x: 22.0f32 }); - let uniform = Uniform::new(low, high).unwrap(); - let mut rng = crate::test::rng(804); - for _ in 0..100 { - let x: MyF32 = rng.sample(uniform); - assert!(low <= x && x < high); - } - } - - #[test] - fn test_uniform_from_std_range() { - let r = Uniform::try_from(2u32..7).unwrap(); - assert_eq!(r.0.low, 2); - assert_eq!(r.0.range, 5); - let r = Uniform::try_from(2.0f64..7.0).unwrap(); - assert_eq!(r.0.low, 2.0); - assert_eq!(r.0.scale, 5.0); - } - - #[test] - fn test_uniform_from_std_range_bad_limits() { - #![allow(clippy::reversed_empty_ranges)] - assert!(Uniform::try_from(100..10).is_err()); - assert!(Uniform::try_from(100..100).is_err()); - assert!(Uniform::try_from(100.0..10.0).is_err()); - assert!(Uniform::try_from(100.0..100.0).is_err()); - } - - #[test] - fn test_uniform_from_std_range_inclusive() { - let r = Uniform::try_from(2u32..=6).unwrap(); - assert_eq!(r.0.low, 2); - assert_eq!(r.0.range, 5); - let r = Uniform::try_from(2.0f64..=7.0).unwrap(); - assert_eq!(r.0.low, 2.0); - assert!(r.0.scale > 5.0); - assert!(r.0.scale < 5.0 + 1e-14); - } - - #[test] - fn test_uniform_from_std_range_inclusive_bad_limits() { - #![allow(clippy::reversed_empty_ranges)] - assert!(Uniform::try_from(100..=10).is_err()); - assert!(Uniform::try_from(100..=99).is_err()); - assert!(Uniform::try_from(100.0..=10.0).is_err()); - assert!(Uniform::try_from(100.0..=99.0).is_err()); - } - - #[test] - fn value_stability() { - fn test_samples( - lb: T, ub: T, expected_single: &[T], expected_multiple: &[T], - ) where Uniform: Distribution { - let mut rng = crate::test::rng(897); - let mut buf = [lb; 3]; - - for x in &mut buf { - *x = T::Sampler::sample_single(lb, ub, &mut rng).unwrap(); - } - assert_eq!(&buf, expected_single); - - let distr = Uniform::new(lb, ub).unwrap(); - for x in &mut buf { - *x = rng.sample(&distr); - } - assert_eq!(&buf, expected_multiple); - } - - // We test on a sub-set of types; possibly we should do more. - // TODO: SIMD types - - test_samples(11u8, 219, &[17, 66, 214], &[181, 93, 165]); - test_samples(11u32, 219, &[17, 66, 214], &[181, 93, 165]); - - test_samples(0f32, 1e-2f32, &[0.0003070104, 0.0026630748, 0.00979833], &[ - 0.008194133, - 0.00398172, - 0.007428536, - ]); - test_samples( - -1e10f64, - 1e10f64, - &[-4673848682.871551, 6388267422.932352, 4857075081.198343], - &[1173375212.1808167, 1917642852.109581, 2365076174.3153973], - ); - - test_samples( - Duration::new(2, 0), - Duration::new(4, 0), - &[ - Duration::new(2, 532615131), - Duration::new(3, 638826742), - Duration::new(3, 485707508), - ], - &[ - Duration::new(3, 117337521), - Duration::new(3, 191764285), - Duration::new(3, 236507617), - ], - ); - } - - #[test] - fn uniform_distributions_can_be_compared() { - assert_eq!(Uniform::new(1.0, 2.0).unwrap(), Uniform::new(1.0, 2.0).unwrap()); - - // To cover UniformInt - assert_eq!(Uniform::new(1_u32, 2_u32).unwrap(), Uniform::new(1_u32, 2_u32).unwrap()); - } -} diff --git a/src/lib.rs b/src/lib.rs index dc9e29d6277..9187c9cc16a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,25 +14,24 @@ //! //! # Quick Start //! -//! To get you started quickly, the easiest and highest-level way to get -//! a random value is to use [`random()`]; alternatively you can use -//! [`thread_rng()`]. The [`Rng`] trait provides a useful API on all RNGs, while -//! the [`distributions`] and [`seq`] modules provide further -//! functionality on top of RNGs. -//! //! ``` +//! // The prelude import enables methods we use below, specifically +//! // Rng::random, Rng::sample, SliceRandom::shuffle and IndexedRandom::choose. //! use rand::prelude::*; //! -//! if rand::random() { // generates a boolean -//! // Try printing a random unicode code point (probably a bad idea)! -//! println!("char: {}", rand::random::()); -//! } +//! // Get an RNG: +//! let mut rng = rand::rng(); //! -//! let mut rng = rand::thread_rng(); -//! let y: f64 = rng.gen(); // generates a float between 0 and 1 +//! // Try printing a random unicode code point (probably a bad idea)! +//! println!("char: '{}'", rng.random::()); +//! // Try printing a random alphanumeric value instead! +//! println!("alpha: '{}'", rng.sample(rand::distr::Alphanumeric) as char); //! +//! // Generate and shuffle a sequence: //! let mut nums: Vec = (1..100).collect(); //! nums.shuffle(&mut rng); +//! // And take a random pick (yes, we didn't need to shuffle first!): +//! let _ = nums.choose(&mut rng); //! ``` //! //! # The Book @@ -50,15 +49,22 @@ #![doc(test(attr(allow(unused_variables), deny(warnings))))] #![no_std] #![cfg_attr(feature = "simd_support", feature(portable_simd))] -#![cfg_attr(doc_cfg, feature(doc_cfg))] +#![cfg_attr( + all(feature = "simd_support", target_feature = "avx512bw"), + feature(stdarch_x86_avx512) +)] +#![cfg_attr(docsrs, feature(doc_auto_cfg))] #![allow( clippy::float_cmp, clippy::neg_cmp_op_on_partial_ord, clippy::nonminimal_bool )] +#![deny(clippy::undocumented_unsafe_blocks)] -#[cfg(feature = "std")] extern crate std; -#[cfg(feature = "alloc")] extern crate alloc; +#[cfg(feature = "alloc")] +extern crate alloc; +#[cfg(feature = "std")] +extern crate std; #[allow(unused)] macro_rules! trace { ($($x:tt)*) => ( @@ -91,30 +97,44 @@ macro_rules! error { ($($x:tt)*) => ( } ) } +// Re-export rand_core itself +pub use rand_core; + // Re-exports from rand_core -pub use rand_core::{CryptoRng, Error, RngCore, SeedableRng}; +pub use rand_core::{CryptoRng, RngCore, SeedableRng, TryCryptoRng, TryRngCore}; // Public modules -pub mod distributions; +pub mod distr; pub mod prelude; mod rng; pub mod rngs; pub mod seq; // Public exports -#[cfg(all(feature = "std", feature = "std_rng", feature = "getrandom"))] -pub use crate::rngs::thread::thread_rng; +#[cfg(feature = "thread_rng")] +pub use crate::rngs::thread::rng; + +/// Access the thread-local generator +/// +/// Use [`rand::rng()`](rng()) instead. +#[cfg(feature = "thread_rng")] +#[deprecated(since = "0.9.0", note = "Renamed to `rng`")] +#[inline] +pub fn thread_rng() -> crate::rngs::ThreadRng { + rng() +} + pub use rng::{Fill, Rng}; -#[cfg(all(feature = "std", feature = "std_rng", feature = "getrandom"))] -use crate::distributions::{Distribution, Standard}; +#[cfg(feature = "thread_rng")] +use crate::distr::{Distribution, StandardUniform}; -/// Generates a random value using the thread-local random number generator. +/// Generate a random value using the thread-local random number generator. /// -/// This function is simply a shortcut for `thread_rng().gen()`: +/// This function is shorthand for [rng()].[random()](Rng::random): /// /// - See [`ThreadRng`] for documentation of the generator and security -/// - See [`Standard`] for documentation of supported types and distributions +/// - See [`StandardUniform`] for documentation of supported types and distributions /// /// # Examples /// @@ -130,35 +150,151 @@ use crate::distributions::{Distribution, Standard}; /// } /// ``` /// -/// If you're calling `random()` in a loop, caching the generator as in the -/// following example can increase performance. +/// If you're calling `random()` repeatedly, consider using a local `rng` +/// handle to save an initialization-check on each usage: /// /// ``` -/// use rand::Rng; +/// use rand::Rng; // provides the `random` method +/// +/// let mut rng = rand::rng(); // a local handle to the generator /// /// let mut v = vec![1, 2, 3]; /// /// for x in v.iter_mut() { -/// *x = rand::random() +/// *x = rng.random(); /// } +/// ``` /// -/// // can be made faster by caching thread_rng +/// [`StandardUniform`]: distr::StandardUniform +/// [`ThreadRng`]: rngs::ThreadRng +#[cfg(feature = "thread_rng")] +#[inline] +pub fn random() -> T +where + StandardUniform: Distribution, +{ + rng().random() +} + +/// Return an iterator over [`random()`] variates /// -/// let mut rng = rand::thread_rng(); +/// This function is shorthand for +/// [rng()].[random_iter](Rng::random_iter)(). /// -/// for x in v.iter_mut() { -/// *x = rng.gen(); -/// } +/// # Example +/// +/// ``` +/// let v: Vec = rand::random_iter().take(5).collect(); +/// println!("{v:?}"); /// ``` +#[cfg(feature = "thread_rng")] +#[inline] +pub fn random_iter() -> distr::Iter +where + StandardUniform: Distribution, +{ + rng().random_iter() +} + +/// Generate a random value in the given range using the thread-local random number generator. /// -/// [`Standard`]: distributions::Standard -/// [`ThreadRng`]: rngs::ThreadRng -#[cfg(all(feature = "std", feature = "std_rng", feature = "getrandom"))] -#[cfg_attr(doc_cfg, doc(cfg(all(feature = "std", feature = "std_rng", feature = "getrandom"))))] +/// This function is shorthand for +/// [rng()].[random_range](Rng::random_range)(range). +/// +/// # Example +/// +/// ``` +/// let y: f32 = rand::random_range(0.0..=1e9); +/// println!("{}", y); +/// +/// let words: Vec<&str> = "Mary had a little lamb".split(' ').collect(); +/// println!("{}", words[rand::random_range(..words.len())]); +/// ``` +/// Note that the first example can also be achieved (without `collect`'ing +/// to a `Vec`) using [`seq::IteratorRandom::choose`]. +#[cfg(feature = "thread_rng")] #[inline] -pub fn random() -> T -where Standard: Distribution { - thread_rng().gen() +pub fn random_range(range: R) -> T +where + T: distr::uniform::SampleUniform, + R: distr::uniform::SampleRange, +{ + rng().random_range(range) +} + +/// Return a bool with a probability `p` of being true. +/// +/// This function is shorthand for +/// [rng()].[random_bool](Rng::random_bool)(p). +/// +/// # Example +/// +/// ``` +/// println!("{}", rand::random_bool(1.0 / 3.0)); +/// ``` +/// +/// # Panics +/// +/// If `p < 0` or `p > 1`. +#[cfg(feature = "thread_rng")] +#[inline] +#[track_caller] +pub fn random_bool(p: f64) -> bool { + rng().random_bool(p) +} + +/// Return a bool with a probability of `numerator/denominator` of being +/// true. +/// +/// That is, `random_ratio(2, 3)` has chance of 2 in 3, or about 67%, of +/// returning true. If `numerator == denominator`, then the returned value +/// is guaranteed to be `true`. If `numerator == 0`, then the returned +/// value is guaranteed to be `false`. +/// +/// See also the [`Bernoulli`] distribution, which may be faster if +/// sampling from the same `numerator` and `denominator` repeatedly. +/// +/// This function is shorthand for +/// [rng()].[random_ratio](Rng::random_ratio)(numerator, denominator). +/// +/// # Panics +/// +/// If `denominator == 0` or `numerator > denominator`. +/// +/// # Example +/// +/// ``` +/// println!("{}", rand::random_ratio(2, 3)); +/// ``` +/// +/// [`Bernoulli`]: distr::Bernoulli +#[cfg(feature = "thread_rng")] +#[inline] +#[track_caller] +pub fn random_ratio(numerator: u32, denominator: u32) -> bool { + rng().random_ratio(numerator, denominator) +} + +/// Fill any type implementing [`Fill`] with random data +/// +/// This function is shorthand for +/// [rng()].[fill](Rng::fill)(dest). +/// +/// # Example +/// +/// ``` +/// let mut arr = [0i8; 20]; +/// rand::fill(&mut arr[..]); +/// ``` +/// +/// Note that you can instead use [`random()`] to generate an array of random +/// data, though this is slower for small elements (smaller than the RNG word +/// size). +#[cfg(feature = "thread_rng")] +#[inline] +#[track_caller] +pub fn fill(dest: &mut T) { + dest.fill(&mut rng()) } #[cfg(test)] @@ -173,18 +309,52 @@ mod test { rand_pcg::Pcg32::new(seed, INC) } + /// Construct a generator yielding a constant value + pub fn const_rng(x: u64) -> StepRng { + StepRng(x, 0) + } + + /// Construct a generator yielding an arithmetic sequence + pub fn step_rng(x: u64, increment: u64) -> StepRng { + StepRng(x, increment) + } + + #[derive(Clone)] + pub struct StepRng(u64, u64); + impl RngCore for StepRng { + fn next_u32(&mut self) -> u32 { + self.next_u64() as u32 + } + + fn next_u64(&mut self) -> u64 { + let res = self.0; + self.0 = self.0.wrapping_add(self.1); + res + } + + fn fill_bytes(&mut self, dst: &mut [u8]) { + rand_core::impls::fill_bytes_via_next(self, dst) + } + } + #[test] - #[cfg(all(feature = "std", feature = "std_rng", feature = "getrandom"))] + #[cfg(feature = "thread_rng")] fn test_random() { - let _n: usize = random(); + let _n: u64 = random(); let _f: f32 = random(); - let _o: Option> = random(); #[allow(clippy::type_complexity)] let _many: ( (), - (usize, isize, Option<(u32, (bool,))>), + [(u32, bool); 3], (u8, i8, u16, i16, u32, i32, u64, i64), (f32, (f64, (f64,))), ) = random(); } + + #[test] + #[cfg(feature = "thread_rng")] + fn test_range() { + let _n: usize = random_range(42..=43); + let _f: f32 = random_range(42.0..43.0); + } } diff --git a/src/prelude.rs b/src/prelude.rs index 35fee3d73fd..b0f563ad5fc 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -14,22 +14,22 @@ //! //! ``` //! use rand::prelude::*; -//! # let mut r = StdRng::from_rng(thread_rng()).unwrap(); -//! # let _: f32 = r.gen(); +//! # let mut r = StdRng::from_rng(&mut rand::rng()); +//! # let _: f32 = r.random(); //! ``` -#[doc(no_inline)] pub use crate::distributions::Distribution; +#[doc(no_inline)] +pub use crate::distr::Distribution; #[cfg(feature = "small_rng")] #[doc(no_inline)] pub use crate::rngs::SmallRng; #[cfg(feature = "std_rng")] -#[doc(no_inline)] pub use crate::rngs::StdRng; #[doc(no_inline)] -#[cfg(all(feature = "std", feature = "std_rng", feature = "getrandom"))] +pub use crate::rngs::StdRng; +#[doc(no_inline)] +#[cfg(feature = "thread_rng")] pub use crate::rngs::ThreadRng; #[doc(no_inline)] pub use crate::seq::{IndexedMutRandom, IndexedRandom, IteratorRandom, SliceRandom}; #[doc(no_inline)] -#[cfg(all(feature = "std", feature = "std_rng", feature = "getrandom"))] -pub use crate::{random, thread_rng}; -#[doc(no_inline)] pub use crate::{CryptoRng, Rng, RngCore, SeedableRng}; +pub use crate::{CryptoRng, Rng, RngCore, SeedableRng}; diff --git a/src/rng.rs b/src/rng.rs index 206275f8d76..c502e1ba476 100644 --- a/src/rng.rs +++ b/src/rng.rs @@ -9,16 +9,20 @@ //! [`Rng`] trait -use rand_core::{Error, RngCore}; -use crate::distributions::uniform::{SampleRange, SampleUniform}; -use crate::distributions::{self, Distribution, Standard}; +use crate::distr::uniform::{SampleRange, SampleUniform}; +use crate::distr::{self, Distribution, StandardUniform}; use core::num::Wrapping; use core::{mem, slice}; +use rand_core::RngCore; -/// An automatically-implemented extension trait on [`RngCore`] providing high-level -/// generic methods for sampling values and other convenience methods. +/// User-level interface for RNGs /// -/// This is the primary trait to use when generating random values. +/// [`RngCore`] is the `dyn`-safe implementation-level interface for Random +/// (Number) Generators. This trait, `Rng`, provides a user-level interface on +/// RNGs. It is implemented automatically for any `R: RngCore`. +/// +/// This trait must usually be brought into scope via `use rand::Rng;` or +/// `use rand::prelude::*;`. /// /// # Generic usage /// @@ -33,64 +37,92 @@ use core::{mem, slice}; /// /// An alternative pattern is possible: `fn foo(rng: R)`. This has some /// trade-offs. It allows the argument to be consumed directly without a `&mut` -/// (which is how `from_rng(thread_rng())` works); also it still works directly +/// (which is how `from_rng(rand::rng())` works); also it still works directly /// on references (including type-erased references). Unfortunately within the /// function `foo` it is not known whether `rng` is a reference type or not, /// hence many uses of `rng` require an extra reference, either explicitly -/// (`distr.sample(&mut rng)`) or implicitly (`rng.gen()`); one may hope the +/// (`distr.sample(&mut rng)`) or implicitly (`rng.random()`); one may hope the /// optimiser can remove redundant references later. /// /// Example: /// /// ``` -/// # use rand::thread_rng; /// use rand::Rng; /// /// fn foo(rng: &mut R) -> f32 { -/// rng.gen() +/// rng.random() /// } /// -/// # let v = foo(&mut thread_rng()); +/// # let v = foo(&mut rand::rng()); /// ``` pub trait Rng: RngCore { - /// Return a random value via the [`Standard`] distribution. + /// Return a random value via the [`StandardUniform`] distribution. /// /// # Example /// /// ``` - /// use rand::{thread_rng, Rng}; + /// use rand::Rng; /// - /// let mut rng = thread_rng(); - /// let x: u32 = rng.gen(); + /// let mut rng = rand::rng(); + /// let x: u32 = rng.random(); /// println!("{}", x); - /// println!("{:?}", rng.gen::<(f64, bool)>()); + /// println!("{:?}", rng.random::<(f64, bool)>()); /// ``` /// /// # Arrays and tuples /// - /// The `rng.gen()` method is able to generate arrays + /// The `rng.random()` method is able to generate arrays /// and tuples (up to 12 elements), so long as all element types can be /// generated. /// /// For arrays of integers, especially for those with small element types - /// (< 64 bit), it will likely be faster to instead use [`Rng::fill`]. + /// (< 64 bit), it will likely be faster to instead use [`Rng::fill`], + /// though note that generated values will differ. /// /// ``` - /// use rand::{thread_rng, Rng}; + /// use rand::Rng; /// - /// let mut rng = thread_rng(); - /// let tuple: (u8, i32, char) = rng.gen(); // arbitrary tuple support + /// let mut rng = rand::rng(); + /// let tuple: (u8, i32, char) = rng.random(); // arbitrary tuple support /// - /// let arr1: [f32; 32] = rng.gen(); // array construction + /// let arr1: [f32; 32] = rng.random(); // array construction /// let mut arr2 = [0u8; 128]; /// rng.fill(&mut arr2); // array fill /// ``` /// - /// [`Standard`]: distributions::Standard + /// [`StandardUniform`]: distr::StandardUniform #[inline] - fn gen(&mut self) -> T - where Standard: Distribution { - Standard.sample(self) + fn random(&mut self) -> T + where + StandardUniform: Distribution, + { + StandardUniform.sample(self) + } + + /// Return an iterator over [`random`](Self::random) variates + /// + /// This is a just a wrapper over [`Rng::sample_iter`] using + /// [`distr::StandardUniform`]. + /// + /// Note: this method consumes its argument. Use + /// `(&mut rng).random_iter()` to avoid consuming the RNG. + /// + /// # Example + /// + /// ``` + /// use rand::{rngs::SmallRng, Rng, SeedableRng}; + /// + /// let rng = SmallRng::seed_from_u64(0); + /// let v: Vec = rng.random_iter().take(5).collect(); + /// assert_eq!(v.len(), 5); + /// ``` + #[inline] + fn random_iter(self) -> distr::Iter + where + Self: Sized, + StandardUniform: Distribution, + { + StandardUniform.sample_iter(self) } /// Generate a random value in the given range. @@ -99,7 +131,8 @@ pub trait Rng: RngCore { /// made from the given range. See also the [`Uniform`] distribution /// type which may be faster if sampling from the same range repeatedly. /// - /// Only `gen_range(low..high)` and `gen_range(low..=high)` are supported. + /// All types support `low..high_exclusive` and `low..=high` range syntax. + /// Unsigned integer types also support `..high_exclusive` and `..=high` syntax. /// /// # Panics /// @@ -108,55 +141,95 @@ pub trait Rng: RngCore { /// # Example /// /// ``` - /// use rand::{thread_rng, Rng}; + /// use rand::Rng; /// - /// let mut rng = thread_rng(); + /// let mut rng = rand::rng(); /// /// // Exclusive range - /// let n: u32 = rng.gen_range(0..10); + /// let n: u32 = rng.random_range(..10); /// println!("{}", n); - /// let m: f64 = rng.gen_range(-40.0..1.3e5); + /// let m: f64 = rng.random_range(-40.0..1.3e5); /// println!("{}", m); /// /// // Inclusive range - /// let n: u32 = rng.gen_range(0..=10); + /// let n: u32 = rng.random_range(..=10); /// println!("{}", n); /// ``` /// - /// [`Uniform`]: distributions::uniform::Uniform - fn gen_range(&mut self, range: R) -> T + /// [`Uniform`]: distr::uniform::Uniform + #[track_caller] + fn random_range(&mut self, range: R) -> T where T: SampleUniform, - R: SampleRange + R: SampleRange, { assert!(!range.is_empty(), "cannot sample empty range"); range.sample_single(self).unwrap() } - /// Generate values via an iterator + /// Return a bool with a probability `p` of being true. /// - /// This is a just a wrapper over [`Rng::sample_iter`] using - /// [`distributions::Standard`]. + /// See also the [`Bernoulli`] distribution, which may be faster if + /// sampling from the same probability repeatedly. /// - /// Note: this method consumes its argument. Use - /// `(&mut rng).gen_iter()` to avoid consuming the RNG. + /// # Example + /// + /// ``` + /// use rand::Rng; + /// + /// let mut rng = rand::rng(); + /// println!("{}", rng.random_bool(1.0 / 3.0)); + /// ``` + /// + /// # Panics + /// + /// If `p < 0` or `p > 1`. + /// + /// [`Bernoulli`]: distr::Bernoulli + #[inline] + #[track_caller] + fn random_bool(&mut self, p: f64) -> bool { + match distr::Bernoulli::new(p) { + Ok(d) => self.sample(d), + Err(_) => panic!("p={:?} is outside range [0.0, 1.0]", p), + } + } + + /// Return a bool with a probability of `numerator/denominator` of being + /// true. + /// + /// That is, `random_ratio(2, 3)` has chance of 2 in 3, or about 67%, of + /// returning true. If `numerator == denominator`, then the returned value + /// is guaranteed to be `true`. If `numerator == 0`, then the returned + /// value is guaranteed to be `false`. + /// + /// See also the [`Bernoulli`] distribution, which may be faster if + /// sampling from the same `numerator` and `denominator` repeatedly. + /// + /// # Panics + /// + /// If `denominator == 0` or `numerator > denominator`. /// /// # Example /// /// ``` - /// use rand::{rngs::mock::StepRng, Rng}; + /// use rand::Rng; /// - /// let rng = StepRng::new(1, 1); - /// let v: Vec = rng.gen_iter().take(5).collect(); - /// assert_eq!(&v, &[1, 2, 3, 4, 5]); + /// let mut rng = rand::rng(); + /// println!("{}", rng.random_ratio(2, 3)); /// ``` + /// + /// [`Bernoulli`]: distr::Bernoulli #[inline] - fn gen_iter(self) -> distributions::DistIter - where - Self: Sized, - Standard: Distribution, - { - Standard.sample_iter(self) + #[track_caller] + fn random_ratio(&mut self, numerator: u32, denominator: u32) -> bool { + match distr::Bernoulli::from_ratio(numerator, denominator) { + Ok(d) => self.sample(d), + Err(_) => panic!( + "p={}/{} is outside range [0.0, 1.0]", + numerator, denominator + ), + } } /// Sample a new value, using the given distribution. @@ -164,10 +237,10 @@ pub trait Rng: RngCore { /// ### Example /// /// ``` - /// use rand::{thread_rng, Rng}; - /// use rand::distributions::Uniform; + /// use rand::Rng; + /// use rand::distr::Uniform; /// - /// let mut rng = thread_rng(); + /// let mut rng = rand::rng(); /// let x = rng.sample(Uniform::new(10u32, 15).unwrap()); /// // Type annotation requires two types, the type and distribution; the /// // distribution can be inferred. @@ -185,13 +258,13 @@ pub trait Rng: RngCore { /// # Example /// /// ``` - /// use rand::{thread_rng, Rng}; - /// use rand::distributions::{Alphanumeric, Uniform, Standard}; + /// use rand::Rng; + /// use rand::distr::{Alphanumeric, Uniform, StandardUniform}; /// - /// let mut rng = thread_rng(); + /// let mut rng = rand::rng(); /// /// // Vec of 16 x f32: - /// let v: Vec = (&mut rng).sample_iter(Standard).take(16).collect(); + /// let v: Vec = (&mut rng).sample_iter(StandardUniform).take(16).collect(); /// /// // String: /// let s: String = (&mut rng).sample_iter(Alphanumeric) @@ -200,7 +273,7 @@ pub trait Rng: RngCore { /// .collect(); /// /// // Combined values - /// println!("{:?}", (&mut rng).sample_iter(Standard).take(5) + /// println!("{:?}", (&mut rng).sample_iter(StandardUniform).take(5) /// .collect::>()); /// /// // Dice-rolling: @@ -210,7 +283,7 @@ pub trait Rng: RngCore { /// println!("Not a 6; rolling again!"); /// } /// ``` - fn sample_iter(self, distr: D) -> distributions::DistIter + fn sample_iter(self, distr: D) -> distr::Iter where D: Distribution, Self: Sized, @@ -220,106 +293,64 @@ pub trait Rng: RngCore { /// Fill any type implementing [`Fill`] with random data /// + /// This method is implemented for types which may be safely reinterpreted + /// as an (aligned) `[u8]` slice then filled with random data. It is often + /// faster than using [`Rng::random`] but not value-equivalent. + /// /// The distribution is expected to be uniform with portable results, but /// this cannot be guaranteed for third-party implementations. /// - /// This is identical to [`try_fill`] except that it panics on error. - /// /// # Example /// /// ``` - /// use rand::{thread_rng, Rng}; + /// use rand::Rng; /// /// let mut arr = [0i8; 20]; - /// thread_rng().fill(&mut arr[..]); + /// rand::rng().fill(&mut arr[..]); /// ``` /// /// [`fill_bytes`]: RngCore::fill_bytes - /// [`try_fill`]: Rng::try_fill + #[track_caller] fn fill(&mut self, dest: &mut T) { - dest.try_fill(self).unwrap_or_else(|_| panic!("Rng::fill failed")) + dest.fill(self) } - /// Fill any type implementing [`Fill`] with random data - /// - /// The distribution is expected to be uniform with portable results, but - /// this cannot be guaranteed for third-party implementations. - /// - /// This is identical to [`fill`] except that it forwards errors. - /// - /// # Example - /// - /// ``` - /// # use rand::Error; - /// use rand::{thread_rng, Rng}; - /// - /// # fn try_inner() -> Result<(), Error> { - /// let mut arr = [0u64; 4]; - /// thread_rng().try_fill(&mut arr[..])?; - /// # Ok(()) - /// # } - /// - /// # try_inner().unwrap() - /// ``` - /// - /// [`try_fill_bytes`]: RngCore::try_fill_bytes - /// [`fill`]: Rng::fill - fn try_fill(&mut self, dest: &mut T) -> Result<(), Error> { - dest.try_fill(self) + /// Alias for [`Rng::random`]. + #[inline] + #[deprecated( + since = "0.9.0", + note = "Renamed to `random` to avoid conflict with the new `gen` keyword in Rust 2024." + )] + fn r#gen(&mut self) -> T + where + StandardUniform: Distribution, + { + self.random() } - /// Return a bool with a probability `p` of being true. - /// - /// See also the [`Bernoulli`] distribution, which may be faster if - /// sampling from the same probability repeatedly. - /// - /// # Example - /// - /// ``` - /// use rand::{thread_rng, Rng}; - /// - /// let mut rng = thread_rng(); - /// println!("{}", rng.gen_bool(1.0 / 3.0)); - /// ``` - /// - /// # Panics - /// - /// If `p < 0` or `p > 1`. - /// - /// [`Bernoulli`]: distributions::Bernoulli + /// Alias for [`Rng::random_range`]. #[inline] + #[deprecated(since = "0.9.0", note = "Renamed to `random_range`")] + fn gen_range(&mut self, range: R) -> T + where + T: SampleUniform, + R: SampleRange, + { + self.random_range(range) + } + + /// Alias for [`Rng::random_bool`]. + #[inline] + #[deprecated(since = "0.9.0", note = "Renamed to `random_bool`")] fn gen_bool(&mut self, p: f64) -> bool { - let d = distributions::Bernoulli::new(p).unwrap(); - self.sample(d) + self.random_bool(p) } - /// Return a bool with a probability of `numerator/denominator` of being - /// true. I.e. `gen_ratio(2, 3)` has chance of 2 in 3, or about 67%, of - /// returning true. If `numerator == denominator`, then the returned value - /// is guaranteed to be `true`. If `numerator == 0`, then the returned - /// value is guaranteed to be `false`. - /// - /// See also the [`Bernoulli`] distribution, which may be faster if - /// sampling from the same `numerator` and `denominator` repeatedly. - /// - /// # Panics - /// - /// If `denominator == 0` or `numerator > denominator`. - /// - /// # Example - /// - /// ``` - /// use rand::{thread_rng, Rng}; - /// - /// let mut rng = thread_rng(); - /// println!("{}", rng.gen_ratio(2, 3)); - /// ``` - /// - /// [`Bernoulli`]: distributions::Bernoulli + /// Alias for [`Rng::random_ratio`]. #[inline] + #[deprecated(since = "0.9.0", note = "Renamed to `random_ratio`")] fn gen_ratio(&mut self, numerator: u32, denominator: u32) -> bool { - let d = distributions::Bernoulli::from_ratio(numerator, denominator).unwrap(); - self.sample(d) + self.random_ratio(numerator, denominator) } } @@ -334,18 +365,17 @@ impl Rng for R {} /// [Chapter on Portability](https://rust-random.github.io/book/portability.html)). pub trait Fill { /// Fill self with random data - fn try_fill(&mut self, rng: &mut R) -> Result<(), Error>; + fn fill(&mut self, rng: &mut R); } macro_rules! impl_fill_each { () => {}; ($t:ty) => { impl Fill for [$t] { - fn try_fill(&mut self, rng: &mut R) -> Result<(), Error> { + fn fill(&mut self, rng: &mut R) { for elt in self.iter_mut() { - *elt = rng.gen(); + *elt = rng.random(); } - Ok(()) } } }; @@ -358,79 +388,104 @@ macro_rules! impl_fill_each { impl_fill_each!(bool, char, f32, f64,); impl Fill for [u8] { - fn try_fill(&mut self, rng: &mut R) -> Result<(), Error> { - rng.try_fill_bytes(self) + fn fill(&mut self, rng: &mut R) { + rng.fill_bytes(self) } } +/// Call target for unsafe macros +const unsafe fn __unsafe() {} + +/// Implement `Fill` for given type `$t`. +/// +/// # Safety +/// All bit patterns of `[u8; size_of::<$t>()]` must represent values of `$t`. macro_rules! impl_fill { () => {}; - ($t:ty) => { + ($t:ty) => {{ + // Force caller to wrap with an `unsafe` block + __unsafe(); + impl Fill for [$t] { - #[inline(never)] // in micro benchmarks, this improves performance - fn try_fill(&mut self, rng: &mut R) -> Result<(), Error> { + fn fill(&mut self, rng: &mut R) { if self.len() > 0 { - rng.try_fill_bytes(unsafe { - slice::from_raw_parts_mut(self.as_mut_ptr() - as *mut u8, - mem::size_of_val(self) - ) - })?; + let size = mem::size_of_val(self); + rng.fill_bytes( + // SAFETY: `self` non-null and valid for reads and writes within its `size` + // bytes. `self` meets the alignment requirements of `&mut [u8]`. + // The contents of `self` are initialized. Both `[u8]` and `[$t]` are valid + // for all bit-patterns of their contents (note that the SAFETY requirement + // on callers of this macro). `self` is not borrowed. + unsafe { + slice::from_raw_parts_mut(self.as_mut_ptr() + as *mut u8, + size + ) + } + ); for x in self { *x = x.to_le(); } } - Ok(()) } } impl Fill for [Wrapping<$t>] { - #[inline(never)] - fn try_fill(&mut self, rng: &mut R) -> Result<(), Error> { + fn fill(&mut self, rng: &mut R) { if self.len() > 0 { - rng.try_fill_bytes(unsafe { - slice::from_raw_parts_mut(self.as_mut_ptr() - as *mut u8, - self.len() * mem::size_of::<$t>() - ) - })?; + let size = self.len() * mem::size_of::<$t>(); + rng.fill_bytes( + // SAFETY: `self` non-null and valid for reads and writes within its `size` + // bytes. `self` meets the alignment requirements of `&mut [u8]`. + // The contents of `self` are initialized. Both `[u8]` and `[$t]` are valid + // for all bit-patterns of their contents (note that the SAFETY requirement + // on callers of this macro). `self` is not borrowed. + unsafe { + slice::from_raw_parts_mut(self.as_mut_ptr() + as *mut u8, + size + ) + } + ); for x in self { - *x = Wrapping(x.0.to_le()); + *x = Wrapping(x.0.to_le()); } } - Ok(()) } - } + }} }; - ($t:ty, $($tt:ty,)*) => { + ($t:ty, $($tt:ty,)*) => {{ impl_fill!($t); // TODO: this could replace above impl once Rust #32463 is fixed // impl_fill!(Wrapping<$t>); impl_fill!($($tt,)*); - } + }} } -impl_fill!(u16, u32, u64, usize, u128,); -impl_fill!(i8, i16, i32, i64, isize, i128,); +// SAFETY: All bit patterns of `[u8; size_of::<$t>()]` represent values of `u*`. +const _: () = unsafe { impl_fill!(u16, u32, u64, u128,) }; +// SAFETY: All bit patterns of `[u8; size_of::<$t>()]` represent values of `i*`. +const _: () = unsafe { impl_fill!(i8, i16, i32, i64, i128,) }; impl Fill for [T; N] -where [T]: Fill +where + [T]: Fill, { - fn try_fill(&mut self, rng: &mut R) -> Result<(), Error> { - self[..].try_fill(rng) + fn fill(&mut self, rng: &mut R) { + <[T] as Fill>::fill(self, rng) } } #[cfg(test)] mod test { use super::*; - use crate::test::rng; - use crate::rngs::mock::StepRng; - #[cfg(feature = "alloc")] use alloc::boxed::Box; + use crate::test::{const_rng, rng}; + #[cfg(feature = "alloc")] + use alloc::boxed::Box; #[test] fn test_fill_bytes_default() { - let mut r = StepRng::new(0x11_22_33_44_55_66_77_88, 0); + let mut r = const_rng(0x11_22_33_44_55_66_77_88); // check every remainder mod 8, both in small and big vectors. let lengths = [0, 1, 2, 3, 4, 5, 6, 7, 80, 81, 82, 83, 84, 85, 86, 87]; @@ -451,7 +506,7 @@ mod test { #[test] fn test_fill() { let x = 9041086907909331047; // a random u64 - let mut rng = StepRng::new(x, 0); + let mut rng = const_rng(x); // Convert to byte sequence and back to u64; byte-swap twice if BE. let mut array = [0u64; 2]; @@ -474,102 +529,111 @@ mod test { // Check equivalence for generated floats let mut array = [0f32; 2]; rng.fill(&mut array); - let gen: [f32; 2] = rng.gen(); - assert_eq!(array, gen); + let arr2: [f32; 2] = rng.random(); + assert_eq!(array, arr2); } #[test] fn test_fill_empty() { let mut array = [0u32; 0]; - let mut rng = StepRng::new(0, 1); + let mut rng = rng(1); rng.fill(&mut array); rng.fill(&mut array[..]); } #[test] - fn test_gen_range_int() { + fn test_random_range_int() { let mut r = rng(101); for _ in 0..1000 { - let a = r.gen_range(-4711..17); + let a = r.random_range(-4711..17); assert!((-4711..17).contains(&a)); - let a: i8 = r.gen_range(-3..42); + let a: i8 = r.random_range(-3..42); assert!((-3..42).contains(&a)); - let a: u16 = r.gen_range(10..99); + let a: u16 = r.random_range(10..99); assert!((10..99).contains(&a)); - let a: i32 = r.gen_range(-100..2000); + let a: i32 = r.random_range(-100..2000); assert!((-100..2000).contains(&a)); - let a: u32 = r.gen_range(12..=24); + let a: u32 = r.random_range(12..=24); assert!((12..=24).contains(&a)); - assert_eq!(r.gen_range(0u32..1), 0u32); - assert_eq!(r.gen_range(-12i64..-11), -12i64); - assert_eq!(r.gen_range(3_000_000..3_000_001), 3_000_000); + assert_eq!(r.random_range(..1u32), 0u32); + assert_eq!(r.random_range(-12i64..-11), -12i64); + assert_eq!(r.random_range(3_000_000..3_000_001), 3_000_000); } } #[test] - fn test_gen_range_float() { + fn test_random_range_float() { let mut r = rng(101); for _ in 0..1000 { - let a = r.gen_range(-4.5..1.7); + let a = r.random_range(-4.5..1.7); assert!((-4.5..1.7).contains(&a)); - let a = r.gen_range(-1.1..=-0.3); + let a = r.random_range(-1.1..=-0.3); assert!((-1.1..=-0.3).contains(&a)); - assert_eq!(r.gen_range(0.0f32..=0.0), 0.); - assert_eq!(r.gen_range(-11.0..=-11.0), -11.); - assert_eq!(r.gen_range(3_000_000.0..=3_000_000.0), 3_000_000.); + assert_eq!(r.random_range(0.0f32..=0.0), 0.); + assert_eq!(r.random_range(-11.0..=-11.0), -11.); + assert_eq!(r.random_range(3_000_000.0..=3_000_000.0), 3_000_000.); } } #[test] #[should_panic] - fn test_gen_range_panic_int() { - #![allow(clippy::reversed_empty_ranges)] + #[allow(clippy::reversed_empty_ranges)] + fn test_random_range_panic_int() { let mut r = rng(102); - r.gen_range(5..-2); + r.random_range(5..-2); } #[test] #[should_panic] - fn test_gen_range_panic_usize() { - #![allow(clippy::reversed_empty_ranges)] + #[allow(clippy::reversed_empty_ranges)] + fn test_random_range_panic_usize() { let mut r = rng(103); - r.gen_range(5..2); + r.random_range(5..2); } #[test] - fn test_gen_bool() { - #![allow(clippy::bool_assert_comparison)] - + #[allow(clippy::bool_assert_comparison)] + fn test_random_bool() { let mut r = rng(105); for _ in 0..5 { - assert_eq!(r.gen_bool(0.0), false); - assert_eq!(r.gen_bool(1.0), true); + assert_eq!(r.random_bool(0.0), false); + assert_eq!(r.random_bool(1.0), true); + } + } + + #[test] + fn test_rng_mut_ref() { + fn use_rng(mut r: impl Rng) { + let _ = r.next_u32(); } + + let mut rng = rng(109); + use_rng(&mut rng); } #[test] fn test_rng_trait_object() { - use crate::distributions::{Distribution, Standard}; + use crate::distr::{Distribution, StandardUniform}; let mut rng = rng(109); let mut r = &mut rng as &mut dyn RngCore; r.next_u32(); - r.gen::(); - assert_eq!(r.gen_range(0..1), 0); - let _c: u8 = Standard.sample(&mut r); + r.random::(); + assert_eq!(r.random_range(0..1), 0); + let _c: u8 = StandardUniform.sample(&mut r); } #[test] #[cfg(feature = "alloc")] fn test_rng_boxed_trait() { - use crate::distributions::{Distribution, Standard}; + use crate::distr::{Distribution, StandardUniform}; let rng = rng(110); let mut r = Box::new(rng) as Box; r.next_u32(); - r.gen::(); - assert_eq!(r.gen_range(0..1), 0); - let _c: u8 = Standard.sample(&mut r); + r.random::(); + assert_eq!(r.random_range(0..1), 0); + let _c: u8 = StandardUniform.sample(&mut r); } #[test] @@ -582,7 +646,7 @@ mod test { let mut sum: u32 = 0; let mut rng = rng(111); for _ in 0..N { - if rng.gen_ratio(NUM, DENOM) { + if rng.random_ratio(NUM, DENOM) { sum += 1; } } diff --git a/src/rngs/adapter/mod.rs b/src/rngs/adapter/mod.rs deleted file mode 100644 index bd1d2943233..00000000000 --- a/src/rngs/adapter/mod.rs +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! Wrappers / adapters forming RNGs - -mod read; -mod reseeding; - -#[allow(deprecated)] -pub use self::read::{ReadError, ReadRng}; -pub use self::reseeding::ReseedingRng; diff --git a/src/rngs/adapter/read.rs b/src/rngs/adapter/read.rs deleted file mode 100644 index 25a9ca7fca4..00000000000 --- a/src/rngs/adapter/read.rs +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// Copyright 2013 The Rust Project Developers. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! A wrapper around any Read to treat it as an RNG. - -#![allow(deprecated)] - -use std::fmt; -use std::io::Read; - -use rand_core::{impls, Error, RngCore}; - - -/// An RNG that reads random bytes straight from any type supporting -/// [`std::io::Read`], for example files. -/// -/// This will work best with an infinite reader, but that is not required. -/// -/// This can be used with `/dev/urandom` on Unix but it is recommended to use -/// [`OsRng`] instead. -/// -/// # Panics -/// -/// `ReadRng` uses [`std::io::Read::read_exact`], which retries on interrupts. -/// All other errors from the underlying reader, including when it does not -/// have enough data, will only be reported through [`try_fill_bytes`]. -/// The other [`RngCore`] methods will panic in case of an error. -/// -/// [`OsRng`]: crate::rngs::OsRng -/// [`try_fill_bytes`]: RngCore::try_fill_bytes -#[derive(Debug)] -#[deprecated(since="0.8.4", note="removal due to lack of usage")] -pub struct ReadRng { - reader: R, -} - -impl ReadRng { - /// Create a new `ReadRng` from a `Read`. - pub fn new(r: R) -> ReadRng { - ReadRng { reader: r } - } -} - -impl RngCore for ReadRng { - fn next_u32(&mut self) -> u32 { - impls::next_u32_via_fill(self) - } - - fn next_u64(&mut self) -> u64 { - impls::next_u64_via_fill(self) - } - - fn fill_bytes(&mut self, dest: &mut [u8]) { - self.try_fill_bytes(dest).unwrap_or_else(|err| { - panic!( - "reading random bytes from Read implementation failed; error: {}", - err - ) - }); - } - - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - if dest.is_empty() { - return Ok(()); - } - // Use `std::io::read_exact`, which retries on `ErrorKind::Interrupted`. - self.reader - .read_exact(dest) - .map_err(|e| Error::new(ReadError(e))) - } -} - -/// `ReadRng` error type -#[derive(Debug)] -#[deprecated(since="0.8.4")] -pub struct ReadError(std::io::Error); - -impl fmt::Display for ReadError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "ReadError: {}", self.0) - } -} - -impl std::error::Error for ReadError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - Some(&self.0) - } -} - - -#[cfg(test)] -mod test { - use std::println; - - use super::ReadRng; - use crate::RngCore; - - #[test] - fn test_reader_rng_u64() { - // transmute from the target to avoid endianness concerns. - #[rustfmt::skip] - let v = [0u8, 0, 0, 0, 0, 0, 0, 1, - 0, 4, 0, 0, 3, 0, 0, 2, - 5, 0, 0, 0, 0, 0, 0, 0]; - let mut rng = ReadRng::new(&v[..]); - - assert_eq!(rng.next_u64(), 1 << 56); - assert_eq!(rng.next_u64(), (2 << 56) + (3 << 32) + (4 << 8)); - assert_eq!(rng.next_u64(), 5); - } - - #[test] - fn test_reader_rng_u32() { - let v = [0u8, 0, 0, 1, 0, 0, 2, 0, 3, 0, 0, 0]; - let mut rng = ReadRng::new(&v[..]); - - assert_eq!(rng.next_u32(), 1 << 24); - assert_eq!(rng.next_u32(), 2 << 16); - assert_eq!(rng.next_u32(), 3); - } - - #[test] - fn test_reader_rng_fill_bytes() { - let v = [1u8, 2, 3, 4, 5, 6, 7, 8]; - let mut w = [0u8; 8]; - - let mut rng = ReadRng::new(&v[..]); - rng.fill_bytes(&mut w); - - assert!(v == w); - } - - #[test] - fn test_reader_rng_insufficient_bytes() { - let v = [1u8, 2, 3, 4, 5, 6, 7, 8]; - let mut w = [0u8; 9]; - - let mut rng = ReadRng::new(&v[..]); - - let result = rng.try_fill_bytes(&mut w); - assert!(result.is_err()); - println!("Error: {}", result.unwrap_err()); - } -} diff --git a/src/rngs/mock.rs b/src/rngs/mock.rs index 0d9e0f905c9..5b6a2253b18 100644 --- a/src/rngs/mock.rs +++ b/src/rngs/mock.rs @@ -8,10 +8,12 @@ //! Mock random number generator -use rand_core::{impls, Error, RngCore}; +#![allow(deprecated)] -#[cfg(feature = "serde1")] -use serde::{Serialize, Deserialize}; +use rand_core::{impls, RngCore}; + +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; /// A mock generator yielding very predictable output /// @@ -22,7 +24,7 @@ use serde::{Serialize, Deserialize}; /// Other integer types (64-bit and smaller) are produced via cast from `u64`. /// /// Other types are produced via their implementation of [`Rng`](crate::Rng) or -/// [`Distribution`](crate::distributions::Distribution). +/// [`Distribution`](crate::distr::Distribution). /// Output values may not be intuitive and may change in future releases but /// are considered /// [portable](https://rust-random.github.io/book/portability.html). @@ -31,15 +33,17 @@ use serde::{Serialize, Deserialize}; /// # Example /// /// ``` +/// # #![allow(deprecated)] /// use rand::Rng; /// use rand::rngs::mock::StepRng; /// /// let mut my_rng = StepRng::new(2, 1); -/// let sample: [u64; 3] = my_rng.gen(); +/// let sample: [u64; 3] = my_rng.random(); /// assert_eq!(sample, [2, 3, 4]); /// ``` #[derive(Debug, Clone, PartialEq, Eq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[deprecated(since = "0.9.2", note = "Deprecated without replacement")] pub struct StepRng { v: u64, a: u64, @@ -64,48 +68,13 @@ impl RngCore for StepRng { #[inline] fn next_u64(&mut self) -> u64 { - let result = self.v; + let res = self.v; self.v = self.v.wrapping_add(self.a); - result - } - - #[inline] - fn fill_bytes(&mut self, dest: &mut [u8]) { - impls::fill_bytes_via_next(self, dest); + res } #[inline] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.fill_bytes(dest); - Ok(()) - } -} - -#[cfg(test)] -mod tests { - #[cfg(any(feature = "alloc", feature = "serde1"))] - use super::StepRng; - - #[test] - #[cfg(feature = "serde1")] - fn test_serialization_step_rng() { - let some_rng = StepRng::new(42, 7); - let de_some_rng: StepRng = - bincode::deserialize(&bincode::serialize(&some_rng).unwrap()).unwrap(); - assert_eq!(some_rng.v, de_some_rng.v); - assert_eq!(some_rng.a, de_some_rng.a); - - } - - #[test] - #[cfg(feature = "alloc")] - fn test_bool() { - use crate::{Rng, distributions::Standard}; - - // If this result ever changes, update doc on StepRng! - let rng = StepRng::new(0, 1 << 31); - let result: alloc::vec::Vec = - rng.sample_iter(Standard).take(6).collect(); - assert_eq!(&result, &[false, true, false, true, false, true]); + fn fill_bytes(&mut self, dst: &mut [u8]) { + impls::fill_bytes_via_next(self, dst) } } diff --git a/src/rngs/mod.rs b/src/rngs/mod.rs index 9013c57d10c..8ce25759a26 100644 --- a/src/rngs/mod.rs +++ b/src/rngs/mod.rs @@ -8,112 +8,103 @@ //! Random number generators and adapters //! -//! ## Background: Random number generators (RNGs) +//! This crate provides a small selection of non-[portable] generators. +//! See also [Types of generators] and [Our RNGs] in the book. //! -//! Computers cannot produce random numbers from nowhere. We classify -//! random number generators as follows: +//! ## Generators //! -//! - "True" random number generators (TRNGs) use hard-to-predict data sources -//! (e.g. the high-resolution parts of event timings and sensor jitter) to -//! harvest random bit-sequences, apply algorithms to remove bias and -//! estimate available entropy, then combine these bits into a byte-sequence -//! or an entropy pool. This job is usually done by the operating system or -//! a hardware generator (HRNG). -//! - "Pseudo"-random number generators (PRNGs) use algorithms to transform a -//! seed into a sequence of pseudo-random numbers. These generators can be -//! fast and produce well-distributed unpredictable random numbers (or not). -//! They are usually deterministic: given algorithm and seed, the output -//! sequence can be reproduced. They have finite period and eventually loop; -//! with many algorithms this period is fixed and can be proven sufficiently -//! long, while others are chaotic and the period depends on the seed. -//! - "Cryptographically secure" pseudo-random number generators (CSPRNGs) -//! are the sub-set of PRNGs which are secure. Security of the generator -//! relies both on hiding the internal state and using a strong algorithm. +//! This crate provides a small selection of non-[portable] random number generators: //! -//! ## Traits and functionality -//! -//! All RNGs implement the [`RngCore`] trait, as a consequence of which the -//! [`Rng`] extension trait is automatically implemented. Secure RNGs may -//! additionally implement the [`CryptoRng`] trait. -//! -//! All PRNGs require a seed to produce their random number sequence. The -//! [`SeedableRng`] trait provides three ways of constructing PRNGs: -//! -//! - `from_seed` accepts a type specific to the PRNG -//! - `from_rng` allows a PRNG to be seeded from any other RNG -//! - `seed_from_u64` allows any PRNG to be seeded from a `u64` insecurely -//! - `from_entropy` securely seeds a PRNG from fresh entropy -//! -//! Use the [`rand_core`] crate when implementing your own RNGs. -//! -//! ## Our generators -//! -//! This crate provides several random number generators: -//! -//! - [`OsRng`] is an interface to the operating system's random number -//! source. Typically the operating system uses a CSPRNG with entropy -//! provided by a TRNG and some type of on-going re-seeding. -//! - [`ThreadRng`], provided by the [`thread_rng`] function, is a handle to a -//! thread-local CSPRNG with periodic seeding from [`OsRng`]. Because this +//! - [`OsRng`] is a stateless interface over the operating system's random number +//! source. This is typically secure with some form of periodic re-seeding. +//! - [`ThreadRng`], provided by [`crate::rng()`], is a handle to a +//! thread-local generator with periodic seeding from [`OsRng`]. Because this //! is local, it is typically much faster than [`OsRng`]. It should be -//! secure, though the paranoid may prefer [`OsRng`]. +//! secure, but see documentation on [`ThreadRng`]. //! - [`StdRng`] is a CSPRNG chosen for good performance and trust of security //! (based on reviews, maturity and usage). The current algorithm is ChaCha12, //! which is well established and rigorously analysed. -//! [`StdRng`] provides the algorithm used by [`ThreadRng`] but without -//! periodic reseeding. -//! - [`SmallRng`] is an **insecure** PRNG designed to be fast, simple, require -//! little memory, and have good output quality. +//! [`StdRng`] is the deterministic generator used by [`ThreadRng`] but +//! without the periodic reseeding or thread-local management. +//! - [`SmallRng`] is a relatively simple, insecure generator designed to be +//! fast, use little memory, and pass various statistical tests of +//! randomness quality. //! //! The algorithms selected for [`StdRng`] and [`SmallRng`] may change in any -//! release and may be platform-dependent, therefore they should be considered -//! **not reproducible**. +//! release and may be platform-dependent, therefore they are not +//! [reproducible][portable]. //! -//! ## Additional generators +//! ### Additional generators //! -//! **TRNGs**: The [`rdrand`] crate provides an interface to the RDRAND and -//! RDSEED instructions available in modern Intel and AMD CPUs. -//! The [`rand_jitter`] crate provides a user-space implementation of -//! entropy harvesting from CPU timer jitter, but is very slow and has -//! [security issues](https://github.com/rust-random/rand/issues/699). +//! - The [`rdrand`] crate provides an interface to the RDRAND and RDSEED +//! instructions available in modern Intel and AMD CPUs. +//! - The [`rand_jitter`] crate provides a user-space implementation of +//! entropy harvesting from CPU timer jitter, but is very slow and has +//! [security issues](https://github.com/rust-random/rand/issues/699). +//! - The [`rand_chacha`] crate provides [portable] implementations of +//! generators derived from the [ChaCha] family of stream ciphers +//! - The [`rand_pcg`] crate provides [portable] implementations of a subset +//! of the [PCG] family of small, insecure generators +//! - The [`rand_xoshiro`] crate provides [portable] implementations of the +//! [xoshiro] family of small, insecure generators //! -//! **PRNGs**: Several companion crates are available, providing individual or -//! families of PRNG algorithms. These provide the implementations behind -//! [`StdRng`] and [`SmallRng`] but can also be used directly, indeed *should* -//! be used directly when **reproducibility** matters. -//! Some suggestions are: [`rand_chacha`], [`rand_pcg`], [`rand_xoshiro`]. -//! A full list can be found by searching for crates with the [`rng` tag]. +//! For more, search [crates with the `rng` tag]. //! +//! ## Traits and functionality +//! +//! All generators implement [`RngCore`] and thus also [`Rng`][crate::Rng]. +//! See also the [Random Values] chapter in the book. +//! +//! Secure RNGs may additionally implement the [`CryptoRng`] trait. +//! +//! Use the [`rand_core`] crate when implementing your own RNGs. +//! +//! [portable]: https://rust-random.github.io/book/crate-reprod.html +//! [Types of generators]: https://rust-random.github.io/book/guide-gen.html +//! [Our RNGs]: https://rust-random.github.io/book/guide-rngs.html +//! [Random Values]: https://rust-random.github.io/book/guide-values.html //! [`Rng`]: crate::Rng //! [`RngCore`]: crate::RngCore //! [`CryptoRng`]: crate::CryptoRng //! [`SeedableRng`]: crate::SeedableRng -//! [`thread_rng`]: crate::thread_rng //! [`rdrand`]: https://crates.io/crates/rdrand //! [`rand_jitter`]: https://crates.io/crates/rand_jitter //! [`rand_chacha`]: https://crates.io/crates/rand_chacha //! [`rand_pcg`]: https://crates.io/crates/rand_pcg //! [`rand_xoshiro`]: https://crates.io/crates/rand_xoshiro -//! [`rng` tag]: https://crates.io/keywords/rng +//! [crates with the `rng` tag]: https://crates.io/keywords/rng +//! [chacha]: https://cr.yp.to/chacha.html +//! [PCG]: https://www.pcg-random.org/ +//! [xoshiro]: https://prng.di.unimi.it/ -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] -#[cfg(feature = "std")] pub mod adapter; +mod reseeding; +pub use reseeding::ReseedingRng; +#[deprecated(since = "0.9.2")] pub mod mock; // Public so we don't export `StepRng` directly, making it a bit // more clear it is intended for testing. +#[cfg(feature = "small_rng")] +mod small; +#[cfg(all( + feature = "small_rng", + any(target_pointer_width = "32", target_pointer_width = "16") +))] +mod xoshiro128plusplus; #[cfg(all(feature = "small_rng", target_pointer_width = "64"))] mod xoshiro256plusplus; -#[cfg(all(feature = "small_rng", not(target_pointer_width = "64")))] -mod xoshiro128plusplus; -#[cfg(feature = "small_rng")] mod small; -#[cfg(feature = "std_rng")] mod std; -#[cfg(all(feature = "std", feature = "std_rng", feature = "getrandom"))] pub(crate) mod thread; +#[cfg(feature = "std_rng")] +mod std; +#[cfg(feature = "thread_rng")] +pub(crate) mod thread; -#[cfg(feature = "small_rng")] pub use self::small::SmallRng; -#[cfg(feature = "std_rng")] pub use self::std::StdRng; -#[cfg(all(feature = "std", feature = "std_rng", feature = "getrandom"))] pub use self::thread::ThreadRng; +#[cfg(feature = "small_rng")] +pub use self::small::SmallRng; +#[cfg(feature = "std_rng")] +pub use self::std::StdRng; +#[cfg(feature = "thread_rng")] +pub use self::thread::ThreadRng; -#[cfg_attr(doc_cfg, doc(cfg(feature = "getrandom")))] -#[cfg(feature = "getrandom")] pub use rand_core::OsRng; +#[cfg(feature = "os_rng")] +pub use rand_core::OsRng; diff --git a/src/rngs/adapter/reseeding.rs b/src/rngs/reseeding.rs similarity index 56% rename from src/rngs/adapter/reseeding.rs rename to src/rngs/reseeding.rs index 39ddfd71047..69b9045e0de 100644 --- a/src/rngs/adapter/reseeding.rs +++ b/src/rngs/reseeding.rs @@ -13,7 +13,7 @@ use core::mem::size_of_val; use rand_core::block::{BlockRng, BlockRngCore, CryptoBlockRng}; -use rand_core::{CryptoRng, Error, RngCore, SeedableRng}; +use rand_core::{CryptoRng, RngCore, SeedableRng, TryCryptoRng, TryRngCore}; /// A wrapper around any PRNG that implements [`BlockRngCore`], that adds the /// ability to reseed it. @@ -22,10 +22,6 @@ use rand_core::{CryptoRng, Error, RngCore, SeedableRng}; /// /// - On a manual call to [`reseed()`]. /// - After `clone()`, the clone will be reseeded on first use. -/// - When a process is forked on UNIX, the RNGs in both the parent and child -/// processes will be reseeded just before the next call to -/// [`BlockRngCore::generate`], i.e. "soon". For ChaCha and Hc128, this is a -/// maximum of 63 and 15, respectively, `u32` values before reseeding. /// - After the PRNG has generated a configurable number of random bytes. /// /// # When should reseeding after a fixed number of generated bytes be used? @@ -43,12 +39,6 @@ use rand_core::{CryptoRng, Error, RngCore, SeedableRng}; /// Use [`ReseedingRng::new`] with a `threshold` of `0` to disable reseeding /// after a fixed number of generated bytes. /// -/// # Limitations -/// -/// It is recommended that a `ReseedingRng` (including `ThreadRng`) not be used -/// from a fork handler. -/// Use `OsRng` or `getrandom`, or defer your use of the RNG until later. -/// /// # Error handling /// /// Although unlikely, reseeding the wrapped PRNG can fail. `ReseedingRng` will @@ -67,15 +57,14 @@ use rand_core::{CryptoRng, Error, RngCore, SeedableRng}; /// use rand_chacha::ChaCha20Core; // Internal part of ChaChaRng that /// // implements BlockRngCore /// use rand::rngs::OsRng; -/// use rand::rngs::adapter::ReseedingRng; +/// use rand::rngs::ReseedingRng; /// -/// let prng = ChaCha20Core::from_entropy(); -/// let mut reseeding_rng = ReseedingRng::new(prng, 0, OsRng); +/// let mut reseeding_rng = ReseedingRng::::new(0, OsRng).unwrap(); /// -/// println!("{}", reseeding_rng.gen::()); +/// println!("{}", reseeding_rng.random::()); /// /// let mut cloned_rng = reseeding_rng.clone(); -/// assert!(reseeding_rng.gen::() != cloned_rng.gen::()); +/// assert!(reseeding_rng.random::() != cloned_rng.random::()); /// ``` /// /// [`BlockRngCore`]: rand_core::block::BlockRngCore @@ -85,12 +74,12 @@ use rand_core::{CryptoRng, Error, RngCore, SeedableRng}; pub struct ReseedingRng(BlockRng>) where R: BlockRngCore + SeedableRng, - Rsdr: RngCore; + Rsdr: TryRngCore; impl ReseedingRng where R: BlockRngCore + SeedableRng, - Rsdr: RngCore, + Rsdr: TryRngCore, { /// Create a new `ReseedingRng` from an existing PRNG, combined with a RNG /// to use as reseeder. @@ -98,21 +87,27 @@ where /// `threshold` sets the number of generated bytes after which to reseed the /// PRNG. Set it to zero to never reseed based on the number of generated /// values. - pub fn new(rng: R, threshold: u64, reseeder: Rsdr) -> Self { - ReseedingRng(BlockRng::new(ReseedingCore::new(rng, threshold, reseeder))) + pub fn new(threshold: u64, reseeder: Rsdr) -> Result { + Ok(ReseedingRng(BlockRng::new(ReseedingCore::new( + threshold, reseeder, + )?))) } - /// Reseed the internal PRNG. - pub fn reseed(&mut self) -> Result<(), Error> { + /// Immediately reseed the generator + /// + /// This discards any remaining random data in the cache. + pub fn reseed(&mut self) -> Result<(), Rsdr::Error> { + self.0.reset(); self.0.core.reseed() } } // TODO: this should be implemented for any type where the inner type // implements RngCore, but we can't specify that because ReseedingCore is private -impl RngCore for ReseedingRng +impl RngCore for ReseedingRng where R: BlockRngCore + SeedableRng, + Rsdr: TryRngCore, { #[inline(always)] fn next_u32(&mut self) -> u32 { @@ -127,16 +122,12 @@ where fn fill_bytes(&mut self, dest: &mut [u8]) { self.0.fill_bytes(dest) } - - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.0.try_fill_bytes(dest) - } } impl Clone for ReseedingRng where R: BlockRngCore + SeedableRng + Clone, - Rsdr: RngCore + Clone, + Rsdr: TryRngCore + Clone, { fn clone(&self) -> ReseedingRng { // Recreating `BlockRng` seems easier than cloning it and resetting @@ -148,7 +139,7 @@ where impl CryptoRng for ReseedingRng where R: BlockRngCore + SeedableRng + CryptoBlockRng, - Rsdr: CryptoRng, + Rsdr: TryCryptoRng, { } @@ -158,24 +149,22 @@ struct ReseedingCore { reseeder: Rsdr, threshold: i64, bytes_until_reseed: i64, - fork_counter: usize, } impl BlockRngCore for ReseedingCore where R: BlockRngCore + SeedableRng, - Rsdr: RngCore, + Rsdr: TryRngCore, { type Item = ::Item; type Results = ::Results; fn generate(&mut self, results: &mut Self::Results) { - let global_fork_counter = fork::get_fork_counter(); - if self.bytes_until_reseed <= 0 || self.is_forked(global_fork_counter) { + if self.bytes_until_reseed <= 0 { // We get better performance by not calling only `reseed` here // and continuing with the rest of the function, but by directly // returning from a non-inlined function. - return self.reseed_and_generate(results, global_fork_counter); + return self.reseed_and_generate(results); } let num_bytes = size_of_val(results.as_ref()); self.bytes_until_reseed -= num_bytes as i64; @@ -186,66 +175,46 @@ where impl ReseedingCore where R: BlockRngCore + SeedableRng, - Rsdr: RngCore, + Rsdr: TryRngCore, { /// Create a new `ReseedingCore`. - fn new(rng: R, threshold: u64, reseeder: Rsdr) -> Self { - use ::core::i64::MAX; - fork::register_fork_handler(); - + /// + /// `threshold` is the maximum number of bytes produced by + /// [`BlockRngCore::generate`] before attempting reseeding. + fn new(threshold: u64, mut reseeder: Rsdr) -> Result { // Because generating more values than `i64::MAX` takes centuries on // current hardware, we just clamp to that value. // Also we set a threshold of 0, which indicates no limit, to that // value. let threshold = if threshold == 0 { - MAX - } else if threshold <= MAX as u64 { + i64::MAX + } else if threshold <= i64::MAX as u64 { threshold as i64 } else { - MAX + i64::MAX }; - ReseedingCore { - inner: rng, + let inner = R::try_from_rng(&mut reseeder)?; + + Ok(ReseedingCore { + inner, reseeder, threshold, bytes_until_reseed: threshold, - fork_counter: 0, - } + }) } /// Reseed the internal PRNG. - fn reseed(&mut self) -> Result<(), Error> { - R::from_rng(&mut self.reseeder).map(|result| { + fn reseed(&mut self) -> Result<(), Rsdr::Error> { + R::try_from_rng(&mut self.reseeder).map(|result| { self.bytes_until_reseed = self.threshold; self.inner = result }) } - fn is_forked(&self, global_fork_counter: usize) -> bool { - // In theory, on 32-bit platforms, it is possible for - // `global_fork_counter` to wrap around after ~4e9 forks. - // - // This check will detect a fork in the normal case where - // `fork_counter < global_fork_counter`, and also when the difference - // between both is greater than `isize::MAX` (wrapped around). - // - // It will still fail to detect a fork if there have been more than - // `isize::MAX` forks, without any reseed in between. Seems unlikely - // enough. - (self.fork_counter.wrapping_sub(global_fork_counter) as isize) < 0 - } - #[inline(never)] - fn reseed_and_generate( - &mut self, results: &mut ::Results, global_fork_counter: usize, - ) { - #![allow(clippy::if_same_then_else)] // false positive - if self.is_forked(global_fork_counter) { - info!("Fork detected, reseeding RNG"); - } else { - trace!("Reseeding RNG (periodic reseed)"); - } + fn reseed_and_generate(&mut self, results: &mut ::Results) { + trace!("Reseeding RNG (periodic reseed)"); let num_bytes = size_of_val(results.as_ref()); @@ -253,7 +222,6 @@ where warn!("Reseeding RNG failed: {}", e); let _ = e; } - self.fork_counter = global_fork_counter; self.bytes_until_reseed = self.threshold - num_bytes as i64; self.inner.generate(results); @@ -263,7 +231,7 @@ where impl Clone for ReseedingCore where R: BlockRngCore + SeedableRng + Clone, - Rsdr: RngCore + Clone, + Rsdr: TryRngCore + Clone, { fn clone(&self) -> ReseedingCore { ReseedingCore { @@ -271,7 +239,6 @@ where reseeder: self.reseeder.clone(), threshold: self.threshold, bytes_until_reseed: 0, // reseed clone on first use - fork_counter: self.fork_counter, } } } @@ -279,79 +246,24 @@ where impl CryptoBlockRng for ReseedingCore where R: BlockRngCore + SeedableRng + CryptoBlockRng, - Rsdr: CryptoRng, + Rsdr: TryCryptoRng, { } - -#[cfg(all(unix, not(target_os = "emscripten")))] -mod fork { - use core::sync::atomic::{AtomicUsize, Ordering}; - use std::sync::Once; - - // Fork protection - // - // We implement fork protection on Unix using `pthread_atfork`. - // When the process is forked, we increment `RESEEDING_RNG_FORK_COUNTER`. - // Every `ReseedingRng` stores the last known value of the static in - // `fork_counter`. If the cached `fork_counter` is less than - // `RESEEDING_RNG_FORK_COUNTER`, it is time to reseed this RNG. - // - // If reseeding fails, we don't deal with this by setting a delay, but just - // don't update `fork_counter`, so a reseed is attempted as soon as - // possible. - - static RESEEDING_RNG_FORK_COUNTER: AtomicUsize = AtomicUsize::new(0); - - pub fn get_fork_counter() -> usize { - RESEEDING_RNG_FORK_COUNTER.load(Ordering::Relaxed) - } - - extern "C" fn fork_handler() { - // Note: fetch_add is defined to wrap on overflow - // (which is what we want). - RESEEDING_RNG_FORK_COUNTER.fetch_add(1, Ordering::Relaxed); - } - - pub fn register_fork_handler() { - static REGISTER: Once = Once::new(); - REGISTER.call_once(|| { - // Bump the counter before and after forking (see #1169): - let ret = unsafe { libc::pthread_atfork( - Some(fork_handler), - Some(fork_handler), - Some(fork_handler), - ) }; - if ret != 0 { - panic!("libc::pthread_atfork failed with code {}", ret); - } - }); - } -} - -#[cfg(not(all(unix, not(target_os = "emscripten"))))] -mod fork { - pub fn get_fork_counter() -> usize { - 0 - } - pub fn register_fork_handler() {} -} - - #[cfg(feature = "std_rng")] #[cfg(test)] mod test { - use super::ReseedingRng; - use crate::rngs::mock::StepRng; use crate::rngs::std::Core; - use crate::{Rng, SeedableRng}; + use crate::test::const_rng; + use crate::Rng; + + use super::ReseedingRng; #[test] fn test_reseeding() { - let mut zero = StepRng::new(0, 0); - let rng = Core::from_rng(&mut zero).unwrap(); + let zero = const_rng(0); let thresh = 1; // reseed every time the buffer is exhausted - let mut reseeding = ReseedingRng::new(rng, thresh, zero); + let mut reseeding = ReseedingRng::::new(thresh, zero).unwrap(); // RNG buffer size is [u32; 64] // Debug is only implemented up to length 32 so use two arrays @@ -367,19 +279,17 @@ mod test { } #[test] + #[allow(clippy::redundant_clone)] fn test_clone_reseeding() { - #![allow(clippy::redundant_clone)] - - let mut zero = StepRng::new(0, 0); - let rng = Core::from_rng(&mut zero).unwrap(); - let mut rng1 = ReseedingRng::new(rng, 32 * 4, zero); + let zero = const_rng(0); + let mut rng1 = ReseedingRng::::new(32 * 4, zero).unwrap(); - let first: u32 = rng1.gen(); + let first: u32 = rng1.random(); for _ in 0..10 { - let _ = rng1.gen::(); + let _ = rng1.random::(); } let mut rng2 = rng1.clone(); - assert_eq!(first, rng2.gen::()); + assert_eq!(first, rng2.random::()); } } diff --git a/src/rngs/small.rs b/src/rngs/small.rs index 2841b0b5dd8..67e0d0544f4 100644 --- a/src/rngs/small.rs +++ b/src/rngs/small.rs @@ -8,109 +8,113 @@ //! A small fast RNG -use rand_core::{Error, RngCore, SeedableRng}; +use rand_core::{RngCore, SeedableRng}; +#[cfg(any(target_pointer_width = "32", target_pointer_width = "16"))] +type Rng = super::xoshiro128plusplus::Xoshiro128PlusPlus; #[cfg(target_pointer_width = "64")] type Rng = super::xoshiro256plusplus::Xoshiro256PlusPlus; -#[cfg(not(target_pointer_width = "64"))] -type Rng = super::xoshiro128plusplus::Xoshiro128PlusPlus; -/// A small-state, fast non-crypto PRNG +/// A small-state, fast, non-crypto, non-portable PRNG /// -/// `SmallRng` may be a good choice when a PRNG with small state, cheap -/// initialization, good statistical quality and good performance are required. -/// Note that depending on the application, [`StdRng`] may be faster on many -/// modern platforms while providing higher-quality randomness. Furthermore, -/// `SmallRng` is **not** a good choice when: +/// This is the "standard small" RNG, a generator with the following properties: /// -/// - Portability is required. Its implementation is not fixed. Use a named -/// generator from an external crate instead, for example [rand_xoshiro] or -/// [rand_chacha]. Refer also to -/// [The Book](https://rust-random.github.io/book/guide-rngs.html). -/// - Security against prediction is important. Use [`StdRng`] instead. +/// - Non-[portable]: any future library version may replace the algorithm +/// and results may be platform-dependent. +/// (For a small portable generator, use the [rand_pcg] or [rand_xoshiro] crate.) +/// - Non-cryptographic: output is easy to predict (insecure) +/// - [Quality]: statistically good quality +/// - Fast: the RNG is fast for both bulk generation and single values, with +/// consistent cost of method calls +/// - Fast initialization +/// - Small state: little memory usage (current state size is 16-32 bytes +/// depending on platform) /// -/// The PRNG algorithm in `SmallRng` is chosen to be efficient on the current -/// platform, without consideration for cryptography or security. The size of -/// its state is much smaller than [`StdRng`]. The current algorithm is +/// The current algorithm is /// `Xoshiro256PlusPlus` on 64-bit platforms and `Xoshiro128PlusPlus` on 32-bit /// platforms. Both are also implemented by the [rand_xoshiro] crate. /// +/// ## Seeding (construction) +/// +/// This generator implements the [`SeedableRng`] trait. All methods are +/// suitable for seeding, but note that, even with a fixed seed, output is not +/// [portable]. Some suggestions: +/// +/// 1. To automatically seed with a unique seed, use [`SeedableRng::from_rng`]: +/// ``` +/// use rand::SeedableRng; +/// use rand::rngs::SmallRng; +/// let rng = SmallRng::from_rng(&mut rand::rng()); +/// # let _: SmallRng = rng; +/// ``` +/// or [`SeedableRng::from_os_rng`]: +/// ``` +/// # use rand::SeedableRng; +/// # use rand::rngs::SmallRng; +/// let rng = SmallRng::from_os_rng(); +/// # let _: SmallRng = rng; +/// ``` +/// 2. To use a deterministic integral seed, use `seed_from_u64`. This uses a +/// hash function internally to yield a (typically) good seed from any +/// input. +/// ``` +/// # use rand::{SeedableRng, rngs::SmallRng}; +/// let rng = SmallRng::seed_from_u64(1); +/// # let _: SmallRng = rng; +/// ``` +/// 3. To seed deterministically from text or other input, use [`rand_seeder`]. +/// +/// See also [Seeding RNGs] in the book. +/// +/// ## Generation +/// +/// The generators implements [`RngCore`] and thus also [`Rng`][crate::Rng]. +/// See also the [Random Values] chapter in the book. +/// +/// [portable]: https://rust-random.github.io/book/crate-reprod.html +/// [Seeding RNGs]: https://rust-random.github.io/book/guide-seeding.html +/// [Random Values]: https://rust-random.github.io/book/guide-values.html +/// [Quality]: https://rust-random.github.io/book/guide-rngs.html#quality /// [`StdRng`]: crate::rngs::StdRng -/// [rand_chacha]: https://crates.io/crates/rand_chacha +/// [rand_pcg]: https://crates.io/crates/rand_pcg /// [rand_xoshiro]: https://crates.io/crates/rand_xoshiro -#[cfg_attr(doc_cfg, doc(cfg(feature = "small_rng")))] +/// [`rand_chacha::ChaCha8Rng`]: https://docs.rs/rand_chacha/latest/rand_chacha/struct.ChaCha8Rng.html +/// [`rand_seeder`]: https://docs.rs/rand_seeder/latest/rand_seeder/ #[derive(Clone, Debug, PartialEq, Eq)] pub struct SmallRng(Rng); -impl RngCore for SmallRng { - #[inline(always)] - fn next_u32(&mut self) -> u32 { - self.0.next_u32() - } +impl SeedableRng for SmallRng { + // Fix to 256 bits. Changing this is a breaking change! + type Seed = [u8; 32]; #[inline(always)] - fn next_u64(&mut self) -> u64 { - self.0.next_u64() - } - - #[inline(always)] - fn fill_bytes(&mut self, dest: &mut [u8]) { - self.0.fill_bytes(dest); + fn from_seed(seed: Self::Seed) -> Self { + // This is for compatibility with 32-bit platforms where Rng::Seed has a different seed size + // With MSRV >= 1.77: let seed = *seed.first_chunk().unwrap() + const LEN: usize = core::mem::size_of::<::Seed>(); + let seed = (&seed[..LEN]).try_into().unwrap(); + SmallRng(Rng::from_seed(seed)) } #[inline(always)] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.0.try_fill_bytes(dest) + fn seed_from_u64(state: u64) -> Self { + SmallRng(Rng::seed_from_u64(state)) } } -impl SmallRng { - /// Construct an instance seeded from another `Rng` - /// - /// We recommend that the source (master) RNG uses a different algorithm - /// (i.e. is not `SmallRng`) to avoid correlations between the child PRNGs. - /// - /// # Example - /// ``` - /// # use rand::rngs::SmallRng; - /// let rng = SmallRng::from_rng(rand::thread_rng()); - /// ``` +impl RngCore for SmallRng { #[inline(always)] - pub fn from_rng(rng: R) -> Result { - Rng::from_rng(rng).map(SmallRng) + fn next_u32(&mut self) -> u32 { + self.0.next_u32() } - /// Construct an instance seeded from the thread-local RNG - /// - /// # Panics - /// - /// This method panics only if [`thread_rng`](crate::thread_rng) fails to - /// initialize. - #[cfg(all(feature = "std", feature = "std_rng", feature = "getrandom"))] #[inline(always)] - pub fn from_thread_rng() -> Self { - let mut seed = ::Seed::default(); - crate::thread_rng().fill_bytes(seed.as_mut()); - SmallRng(Rng::from_seed(seed)) + fn next_u64(&mut self) -> u64 { + self.0.next_u64() } - /// Construct an instance from a `u64` seed - /// - /// This provides a convenient method of seeding a `SmallRng` from a simple - /// number by use of another algorithm to mutate and expand the input. - /// This is suitable for use with low Hamming Weight numbers like 0 and 1. - /// - /// **Warning:** the implementation is deterministic but not portable: - /// output values may differ according to platform and may be changed by a - /// future version of the library. - /// - /// # Example - /// ``` - /// # use rand::rngs::SmallRng; - /// let rng = SmallRng::seed_from_u64(1); - /// ``` #[inline(always)] - pub fn seed_from_u64(state: u64) -> Self { - SmallRng(Rng::seed_from_u64(state)) + fn fill_bytes(&mut self, dest: &mut [u8]) { + self.0.fill_bytes(dest) } } diff --git a/src/rngs/std.rs b/src/rngs/std.rs index 31b20a2dc5d..6e1658e7453 100644 --- a/src/rngs/std.rs +++ b/src/rngs/std.rs @@ -8,29 +8,64 @@ //! The standard RNG -use crate::{CryptoRng, Error, RngCore, SeedableRng}; +use rand_core::{CryptoRng, RngCore, SeedableRng}; -#[cfg(feature = "getrandom")] +#[cfg(any(test, feature = "os_rng"))] pub(crate) use rand_chacha::ChaCha12Core as Core; use rand_chacha::ChaCha12Rng as Rng; -/// The standard RNG. The PRNG algorithm in `StdRng` is chosen to be efficient -/// on the current platform, to be statistically strong and unpredictable -/// (meaning a cryptographically secure PRNG). +/// A strong, fast (amortized), non-portable RNG +/// +/// This is the "standard" RNG, a generator with the following properties: +/// +/// - Non-[portable]: any future library version may replace the algorithm +/// and results may be platform-dependent. +/// (For a portable version, use the [rand_chacha] crate directly.) +/// - [CSPRNG]: statistically good quality of randomness and [unpredictable] +/// - Fast ([amortized](https://en.wikipedia.org/wiki/Amortized_analysis)): +/// the RNG is fast for bulk generation, but the cost of method calls is not +/// consistent due to usage of an output buffer. /// /// The current algorithm used is the ChaCha block cipher with 12 rounds. Please -/// see this relevant [rand issue] for the discussion. This may change as new +/// see this relevant [rand issue] for the discussion. This may change as new /// evidence of cipher security and performance becomes available. /// -/// The algorithm is deterministic but should not be considered reproducible -/// due to dependence on configuration and possible replacement in future -/// library versions. For a secure reproducible generator, we recommend use of -/// the [rand_chacha] crate directly. +/// ## Seeding (construction) +/// +/// This generator implements the [`SeedableRng`] trait. Any method may be used, +/// but note that `seed_from_u64` is not suitable for usage where security is +/// important. Also note that, even with a fixed seed, output is not [portable]. +/// +/// Using a fresh seed **direct from the OS** is the most secure option: +/// ``` +/// # use rand::{SeedableRng, rngs::StdRng}; +/// let rng = StdRng::from_os_rng(); +/// # let _: StdRng = rng; +/// ``` +/// +/// Seeding via [`rand::rng()`](crate::rng()) may be faster: +/// ``` +/// # use rand::{SeedableRng, rngs::StdRng}; +/// let rng = StdRng::from_rng(&mut rand::rng()); +/// # let _: StdRng = rng; +/// ``` +/// +/// Any [`SeedableRng`] method may be used, but note that `seed_from_u64` is not +/// suitable where security is required. See also [Seeding RNGs] in the book. +/// +/// ## Generation +/// +/// The generators implements [`RngCore`] and thus also [`Rng`][crate::Rng]. +/// See also the [Random Values] chapter in the book. /// +/// [portable]: https://rust-random.github.io/book/crate-reprod.html +/// [Seeding RNGs]: https://rust-random.github.io/book/guide-seeding.html +/// [unpredictable]: https://rust-random.github.io/book/guide-rngs.html#security +/// [Random Values]: https://rust-random.github.io/book/guide-values.html +/// [CSPRNG]: https://rust-random.github.io/book/guide-gen.html#cryptographically-secure-pseudo-random-number-generator /// [rand_chacha]: https://crates.io/crates/rand_chacha /// [rand issue]: https://github.com/rust-random/rand/issues/932 -#[cfg_attr(doc_cfg, doc(cfg(feature = "std_rng")))] #[derive(Clone, Debug, PartialEq, Eq)] pub struct StdRng(Rng); @@ -46,13 +81,8 @@ impl RngCore for StdRng { } #[inline(always)] - fn fill_bytes(&mut self, dest: &mut [u8]) { - self.0.fill_bytes(dest); - } - - #[inline(always)] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.0.try_fill_bytes(dest) + fn fill_bytes(&mut self, dst: &mut [u8]) { + self.0.fill_bytes(dst) } } @@ -64,16 +94,10 @@ impl SeedableRng for StdRng { fn from_seed(seed: Self::Seed) -> Self { StdRng(Rng::from_seed(seed)) } - - #[inline(always)] - fn from_rng(rng: R) -> Result { - Rng::from_rng(rng).map(StdRng) - } } impl CryptoRng for StdRng {} - #[cfg(test)] mod test { use crate::rngs::StdRng; @@ -92,7 +116,7 @@ mod test { let mut rng0 = StdRng::from_seed(seed); let x0 = rng0.next_u64(); - let mut rng1 = StdRng::from_rng(rng0).unwrap(); + let mut rng1 = StdRng::from_rng(&mut rng0); let x1 = rng1.next_u64(); assert_eq!([x0, x1], target); diff --git a/src/rngs/thread.rs b/src/rngs/thread.rs index 6c8d83c02eb..7e5203214a4 100644 --- a/src/rngs/thread.rs +++ b/src/rngs/thread.rs @@ -9,14 +9,15 @@ //! Thread-local random number generator use core::cell::UnsafeCell; +use std::fmt; use std::rc::Rc; use std::thread_local; -use std::fmt; + +use rand_core::{CryptoRng, RngCore}; use super::std::Core; -use crate::rngs::adapter::ReseedingRng; use crate::rngs::OsRng; -use crate::{CryptoRng, Error, RngCore, SeedableRng}; +use crate::rngs::ReseedingRng; // Rationale for using `UnsafeCell` in `ThreadRng`: // @@ -32,7 +33,6 @@ use crate::{CryptoRng, Error, RngCore, SeedableRng}; // `ThreadRng` internally, which is nonsensical anyway. We should also never run // `ThreadRng` in destructors of its implementation, which is also nonsensical. - // Number of generated bytes after which to reseed `ThreadRng`. // According to benchmarks, reseeding has a noticeable impact with thresholds // of 32 kB and less. We choose 64 kB to avoid significant overhead. @@ -41,35 +41,73 @@ const THREAD_RNG_RESEED_THRESHOLD: u64 = 1024 * 64; /// A reference to the thread-local generator /// /// This type is a reference to a lazily-initialized thread-local generator. -/// An instance can be obtained via [`thread_rng`] or via `ThreadRng::default()`. -/// This handle is safe to use everywhere (including thread-local destructors), -/// though it is recommended not to use inside a fork handler. +/// An instance can be obtained via [`rand::rng()`][crate::rng()] or via +/// [`ThreadRng::default()`]. /// The handle cannot be passed between threads (is not `Send` or `Sync`). /// -/// `ThreadRng` uses the same CSPRNG as [`StdRng`], ChaCha12. As with -/// [`StdRng`], the algorithm may be changed, subject to reasonable expectations -/// of security and performance. -/// -/// `ThreadRng` is automatically seeded from [`OsRng`] with periodic reseeding -/// (every 64 kiB, as well as "soon" after a fork on Unix — see [`ReseedingRng`] -/// documentation for details). +/// # Security /// /// Security must be considered relative to a threat model and validation -/// requirements. `ThreadRng` attempts to meet basic security considerations -/// for producing unpredictable random numbers: use a CSPRNG, use a -/// recommended platform-specific seed ([`OsRng`]), and avoid -/// leaking internal secrets e.g. via [`Debug`] implementation or serialization. -/// Memory is not zeroized on drop. +/// requirements. The Rand project can provide no guarantee of fitness for +/// purpose. The design criteria for `ThreadRng` are as follows: +/// +/// - Automatic seeding via [`OsRng`] and periodically thereafter (see +/// ([`ReseedingRng`] documentation). Limitation: there is no automatic +/// reseeding on process fork (see [below](#fork)). +/// - A rigorusly analyzed, unpredictable (cryptographic) pseudo-random generator +/// (see [the book on security](https://rust-random.github.io/book/guide-rngs.html#security)). +/// The currently selected algorithm is ChaCha (12-rounds). +/// See also [`StdRng`] documentation. +/// - Not to leak internal state through [`Debug`] or serialization +/// implementations. +/// - No further protections exist to in-memory state. In particular, the +/// implementation is not required to zero memory on exit (of the process or +/// thread). (This may change in the future.) +/// - Be fast enough for general-purpose usage. Note in particular that +/// `ThreadRng` is designed to be a "fast, reasonably secure generator" +/// (where "reasonably secure" implies the above criteria). +/// +/// We leave it to the user to determine whether this generator meets their +/// security requirements. For an alternative, see [`OsRng`]. +/// +/// # Fork /// -/// [`ReseedingRng`]: crate::rngs::adapter::ReseedingRng +/// `ThreadRng` is not automatically reseeded on fork. It is recommended to +/// explicitly call [`ThreadRng::reseed`] immediately after a fork, for example: +/// ```ignore +/// fn do_fork() { +/// let pid = unsafe { libc::fork() }; +/// if pid == 0 { +/// // Reseed ThreadRng in child processes: +/// rand::rng().reseed(); +/// } +/// } +/// ``` +/// +/// Methods on `ThreadRng` are not reentrant-safe and thus should not be called +/// from an interrupt (e.g. a fork handler) unless it can be guaranteed that no +/// other method on the same `ThreadRng` is currently executing. +/// +/// [`ReseedingRng`]: crate::rngs::ReseedingRng /// [`StdRng`]: crate::rngs::StdRng -#[cfg_attr(doc_cfg, doc(cfg(all(feature = "std", feature = "std_rng", feature = "getrandom"))))] #[derive(Clone)] pub struct ThreadRng { // Rc is explicitly !Send and !Sync rng: Rc>>, } +impl ThreadRng { + /// Immediately reseed the generator + /// + /// This discards any remaining random data in the cache. + pub fn reseed(&mut self) -> Result<(), rand_core::OsError> { + // SAFETY: We must make sure to stop using `rng` before anyone else + // creates another mutable reference + let rng = unsafe { &mut *self.rng.get() }; + rng.reseed() + } +} + /// Debug implementation does not leak internal state impl fmt::Debug for ThreadRng { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { @@ -78,44 +116,52 @@ impl fmt::Debug for ThreadRng { } thread_local!( - // We require Rc<..> to avoid premature freeing when thread_rng is used + // We require Rc<..> to avoid premature freeing when ThreadRng is used // within thread-local destructors. See #968. static THREAD_RNG_KEY: Rc>> = { - let r = Core::from_rng(OsRng).unwrap_or_else(|err| - panic!("could not initialize thread_rng: {}", err)); - let rng = ReseedingRng::new(r, - THREAD_RNG_RESEED_THRESHOLD, - OsRng); + let rng = ReseedingRng::new(THREAD_RNG_RESEED_THRESHOLD, + OsRng).unwrap_or_else(|err| + panic!("could not initialize ThreadRng: {}", err)); Rc::new(UnsafeCell::new(rng)) } ); -/// Access the thread-local generator +/// Access a fast, pre-initialized generator +/// +/// This is a handle to the local [`ThreadRng`]. +/// +/// See also [`crate::rngs`] for alternatives. /// -/// Returns a reference to the local [`ThreadRng`], initializing the generator -/// on the first call on each thread. +/// # Example /// -/// Example usage: /// ``` -/// use rand::Rng; +/// use rand::prelude::*; /// /// # fn main() { -/// // rand::random() may be used instead of rand::thread_rng().gen(): -/// println!("A random boolean: {}", rand::random::()); /// -/// let mut rng = rand::thread_rng(); -/// println!("A simulated die roll: {}", rng.gen_range(1..=6)); +/// let mut numbers = [1, 2, 3, 4, 5]; +/// numbers.shuffle(&mut rand::rng()); +/// println!("Numbers: {numbers:?}"); +/// +/// // Using a local binding avoids an initialization-check on each usage: +/// let mut rng = rand::rng(); +/// +/// println!("True or false: {}", rng.random::()); +/// println!("A simulated die roll: {}", rng.random_range(1..=6)); /// # } /// ``` -#[cfg_attr(doc_cfg, doc(cfg(all(feature = "std", feature = "std_rng", feature = "getrandom"))))] -pub fn thread_rng() -> ThreadRng { +/// +/// # Security +/// +/// Refer to [`ThreadRng#Security`]. +pub fn rng() -> ThreadRng { let rng = THREAD_RNG_KEY.with(|t| t.clone()); ThreadRng { rng } } impl Default for ThreadRng { fn default() -> ThreadRng { - thread_rng() + rng() } } @@ -136,38 +182,31 @@ impl RngCore for ThreadRng { rng.next_u64() } + #[inline(always)] fn fill_bytes(&mut self, dest: &mut [u8]) { // SAFETY: We must make sure to stop using `rng` before anyone else // creates another mutable reference let rng = unsafe { &mut *self.rng.get() }; rng.fill_bytes(dest) } - - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - // SAFETY: We must make sure to stop using `rng` before anyone else - // creates another mutable reference - let rng = unsafe { &mut *self.rng.get() }; - rng.try_fill_bytes(dest) - } } impl CryptoRng for ThreadRng {} - #[cfg(test)] mod test { #[test] fn test_thread_rng() { use crate::Rng; - let mut r = crate::thread_rng(); - r.gen::(); - assert_eq!(r.gen_range(0..1), 0); + let mut r = crate::rng(); + r.random::(); + assert_eq!(r.random_range(0..1), 0); } #[test] fn test_debug_output() { // We don't care about the exact output here, but it must not include // private CSPRNG state or the cache stored by BlockRng! - assert_eq!(std::format!("{:?}", crate::thread_rng()), "ThreadRng { .. }"); + assert_eq!(std::format!("{:?}", crate::rng()), "ThreadRng { .. }"); } } diff --git a/src/rngs/xoshiro128plusplus.rs b/src/rngs/xoshiro128plusplus.rs index ece98fafd6a..69fe7ca9202 100644 --- a/src/rngs/xoshiro128plusplus.rs +++ b/src/rngs/xoshiro128plusplus.rs @@ -6,10 +6,11 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -#[cfg(feature="serde1")] use serde::{Serialize, Deserialize}; -use rand_core::impls::{next_u64_via_u32, fill_bytes_via_next}; +use rand_core::impls::{fill_bytes_via_next, next_u64_via_u32}; use rand_core::le::read_u32_into; -use rand_core::{SeedableRng, RngCore, Error}; +use rand_core::{RngCore, SeedableRng}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; /// A xoshiro128++ random number generator. /// @@ -20,7 +21,7 @@ use rand_core::{SeedableRng, RngCore, Error}; /// reference source code](http://xoshiro.di.unimi.it/xoshiro128plusplus.c) by /// David Blackman and Sebastiano Vigna. #[derive(Debug, Clone, PartialEq, Eq)] -#[cfg_attr(feature="serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Xoshiro128PlusPlus { s: [u32; 4], } @@ -32,36 +33,43 @@ impl SeedableRng for Xoshiro128PlusPlus { /// mapped to a different seed. #[inline] fn from_seed(seed: [u8; 16]) -> Xoshiro128PlusPlus { - if seed.iter().all(|&x| x == 0) { - return Self::seed_from_u64(0); - } let mut state = [0; 4]; read_u32_into(&seed, &mut state); + // Check for zero on aligned integers for better code generation. + // Furtermore, seed_from_u64(0) will expand to a constant when optimized. + if state.iter().all(|&x| x == 0) { + return Self::seed_from_u64(0); + } Xoshiro128PlusPlus { s: state } } /// Create a new `Xoshiro128PlusPlus` from a `u64` seed. /// /// This uses the SplitMix64 generator internally. + #[inline] fn seed_from_u64(mut state: u64) -> Self { const PHI: u64 = 0x9e3779b97f4a7c15; - let mut seed = Self::Seed::default(); - for chunk in seed.as_mut().chunks_mut(8) { + let mut s = [0; 4]; + for i in s.chunks_exact_mut(2) { state = state.wrapping_add(PHI); let mut z = state; z = (z ^ (z >> 30)).wrapping_mul(0xbf58476d1ce4e5b9); z = (z ^ (z >> 27)).wrapping_mul(0x94d049bb133111eb); z = z ^ (z >> 31); - chunk.copy_from_slice(&z.to_le_bytes()); + i[0] = z as u32; + i[1] = (z >> 32) as u32; } - Self::from_seed(seed) + // By using a non-zero PHI we are guaranteed to generate a non-zero state + // Thus preventing a recursion between from_seed and seed_from_u64. + debug_assert_ne!(s, [0; 4]); + Xoshiro128PlusPlus { s } } } impl RngCore for Xoshiro128PlusPlus { #[inline] fn next_u32(&mut self) -> u32 { - let result_starstar = self.s[0] + let res = self.s[0] .wrapping_add(self.s[3]) .rotate_left(7) .wrapping_add(self.s[0]); @@ -77,7 +85,7 @@ impl RngCore for Xoshiro128PlusPlus { self.s[3] = self.s[3].rotate_left(11); - result_starstar + res } #[inline] @@ -86,30 +94,39 @@ impl RngCore for Xoshiro128PlusPlus { } #[inline] - fn fill_bytes(&mut self, dest: &mut [u8]) { - fill_bytes_via_next(self, dest); - } - - #[inline] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.fill_bytes(dest); - Ok(()) + fn fill_bytes(&mut self, dst: &mut [u8]) { + fill_bytes_via_next(self, dst) } } #[cfg(test)] mod tests { - use super::*; + use super::Xoshiro128PlusPlus; + use rand_core::{RngCore, SeedableRng}; #[test] fn reference() { - let mut rng = Xoshiro128PlusPlus::from_seed( - [1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0]); + let mut rng = + Xoshiro128PlusPlus::from_seed([1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0]); // These values were produced with the reference implementation: // http://xoshiro.di.unimi.it/xoshiro128plusplus.c let expected = [ - 641, 1573767, 3222811527, 3517856514, 836907274, 4247214768, - 3867114732, 1355841295, 495546011, 621204420, + 641, 1573767, 3222811527, 3517856514, 836907274, 4247214768, 3867114732, 1355841295, + 495546011, 621204420, + ]; + for &e in &expected { + assert_eq!(rng.next_u32(), e); + } + } + + #[test] + fn stable_seed_from_u64() { + // We don't guarantee value-stability for SmallRng but this + // could influence keeping stability whenever possible (e.g. after optimizations). + let mut rng = Xoshiro128PlusPlus::seed_from_u64(0); + let expected = [ + 1179900579, 1938959192, 3089844957, 3657088315, 1015453891, 479942911, 3433842246, + 669252886, 3985671746, 2737205563, ]; for &e in &expected { assert_eq!(rng.next_u32(), e); diff --git a/src/rngs/xoshiro256plusplus.rs b/src/rngs/xoshiro256plusplus.rs index 8ffb18b8033..7b39c6109a7 100644 --- a/src/rngs/xoshiro256plusplus.rs +++ b/src/rngs/xoshiro256plusplus.rs @@ -6,10 +6,11 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -#[cfg(feature="serde1")] use serde::{Serialize, Deserialize}; use rand_core::impls::fill_bytes_via_next; use rand_core::le::read_u64_into; -use rand_core::{SeedableRng, RngCore, Error}; +use rand_core::{RngCore, SeedableRng}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; /// A xoshiro256++ random number generator. /// @@ -20,7 +21,7 @@ use rand_core::{SeedableRng, RngCore, Error}; /// reference source code](http://xoshiro.di.unimi.it/xoshiro256plusplus.c) by /// David Blackman and Sebastiano Vigna. #[derive(Debug, Clone, PartialEq, Eq)] -#[cfg_attr(feature="serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Xoshiro256PlusPlus { s: [u64; 4], } @@ -32,29 +33,35 @@ impl SeedableRng for Xoshiro256PlusPlus { /// mapped to a different seed. #[inline] fn from_seed(seed: [u8; 32]) -> Xoshiro256PlusPlus { - if seed.iter().all(|&x| x == 0) { - return Self::seed_from_u64(0); - } let mut state = [0; 4]; read_u64_into(&seed, &mut state); + // Check for zero on aligned integers for better code generation. + // Furtermore, seed_from_u64(0) will expand to a constant when optimized. + if state.iter().all(|&x| x == 0) { + return Self::seed_from_u64(0); + } Xoshiro256PlusPlus { s: state } } /// Create a new `Xoshiro256PlusPlus` from a `u64` seed. /// /// This uses the SplitMix64 generator internally. + #[inline] fn seed_from_u64(mut state: u64) -> Self { const PHI: u64 = 0x9e3779b97f4a7c15; - let mut seed = Self::Seed::default(); - for chunk in seed.as_mut().chunks_mut(8) { + let mut s = [0; 4]; + for i in s.iter_mut() { state = state.wrapping_add(PHI); let mut z = state; z = (z ^ (z >> 30)).wrapping_mul(0xbf58476d1ce4e5b9); z = (z ^ (z >> 27)).wrapping_mul(0x94d049bb133111eb); z = z ^ (z >> 31); - chunk.copy_from_slice(&z.to_le_bytes()); + *i = z; } - Self::from_seed(seed) + // By using a non-zero PHI we are guaranteed to generate a non-zero state + // Thus preventing a recursion between from_seed and seed_from_u64. + debug_assert_ne!(s, [0; 4]); + Xoshiro256PlusPlus { s } } } @@ -63,12 +70,13 @@ impl RngCore for Xoshiro256PlusPlus { fn next_u32(&mut self) -> u32 { // The lowest bits have some linear dependencies, so we use the // upper bits instead. - (self.next_u64() >> 32) as u32 + let val = self.next_u64(); + (val >> 32) as u32 } #[inline] fn next_u64(&mut self) -> u64 { - let result_plusplus = self.s[0] + let res = self.s[0] .wrapping_add(self.s[3]) .rotate_left(23) .wrapping_add(self.s[0]); @@ -84,36 +92,61 @@ impl RngCore for Xoshiro256PlusPlus { self.s[3] = self.s[3].rotate_left(45); - result_plusplus - } - - #[inline] - fn fill_bytes(&mut self, dest: &mut [u8]) { - fill_bytes_via_next(self, dest); + res } #[inline] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.fill_bytes(dest); - Ok(()) + fn fill_bytes(&mut self, dst: &mut [u8]) { + fill_bytes_via_next(self, dst) } } #[cfg(test)] mod tests { - use super::*; + use super::Xoshiro256PlusPlus; + use rand_core::{RngCore, SeedableRng}; #[test] fn reference() { - let mut rng = Xoshiro256PlusPlus::from_seed( - [1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, - 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0]); + let mut rng = Xoshiro256PlusPlus::from_seed([ + 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, + 0, 0, 0, + ]); // These values were produced with the reference implementation: // http://xoshiro.di.unimi.it/xoshiro256plusplus.c let expected = [ - 41943041, 58720359, 3588806011781223, 3591011842654386, - 9228616714210784205, 9973669472204895162, 14011001112246962877, - 12406186145184390807, 15849039046786891736, 10450023813501588000, + 41943041, + 58720359, + 3588806011781223, + 3591011842654386, + 9228616714210784205, + 9973669472204895162, + 14011001112246962877, + 12406186145184390807, + 15849039046786891736, + 10450023813501588000, + ]; + for &e in &expected { + assert_eq!(rng.next_u64(), e); + } + } + + #[test] + fn stable_seed_from_u64() { + // We don't guarantee value-stability for SmallRng but this + // could influence keeping stability whenever possible (e.g. after optimizations). + let mut rng = Xoshiro256PlusPlus::seed_from_u64(0); + let expected = [ + 5987356902031041503, + 7051070477665621255, + 6633766593972829180, + 211316841551650330, + 9136120204379184874, + 379361710973160858, + 15813423377499357806, + 15596884590815070553, + 5439680534584881407, + 1369371744833522710, ]; for &e in &expected { assert_eq!(rng.next_u64(), e); diff --git a/src/seq/coin_flipper.rs b/src/seq/coin_flipper.rs index 7a97fa8aaf0..7e8f53116ce 100644 --- a/src/seq/coin_flipper.rs +++ b/src/seq/coin_flipper.rs @@ -10,7 +10,7 @@ use crate::RngCore; pub(crate) struct CoinFlipper { pub rng: R, - chunk: u32, //TODO(opt): this should depend on RNG word size + chunk: u32, // TODO(opt): this should depend on RNG word size chunk_remaining: u32, } @@ -27,17 +27,17 @@ impl CoinFlipper { /// Returns true with a probability of 1 / d /// Uses an expected two bits of randomness /// Panics if d == 0 - pub fn gen_ratio_one_over(&mut self, d: usize) -> bool { + pub fn random_ratio_one_over(&mut self, d: usize) -> bool { debug_assert_ne!(d, 0); - // This uses the same logic as `gen_ratio` but is optimized for the case that + // This uses the same logic as `random_ratio` but is optimized for the case that // the starting numerator is one (which it always is for `Sequence::Choose()`) - // In this case (but not `gen_ratio`), this way of calculating c is always accurate + // In this case (but not `random_ratio`), this way of calculating c is always accurate let c = (usize::BITS - 1 - d.leading_zeros()).min(32); if self.flip_c_heads(c) { let numerator = 1 << c; - self.gen_ratio(numerator, d) + self.random_ratio(numerator, d) } else { false } @@ -46,7 +46,7 @@ impl CoinFlipper { #[inline] /// Returns true with a probability of n / d /// Uses an expected two bits of randomness - fn gen_ratio(&mut self, mut n: usize, d: usize) -> bool { + fn random_ratio(&mut self, mut n: usize, d: usize) -> bool { // Explanation: // We are trying to return true with a probability of n / d // If n >= d, we can just return true @@ -92,7 +92,7 @@ impl CoinFlipper { // If n * 2^c > `usize::MAX` we always return `true` anyway n = n.saturating_mul(2_usize.pow(c)); } else { - //At least one tail + // At least one tail if c == 1 { // Calculate 2n - d. // We need to use wrapping as 2n might be greater than `usize::MAX` diff --git a/src/seq/increasing_uniform.rs b/src/seq/increasing_uniform.rs index 3208c656fb5..10dd48a652a 100644 --- a/src/seq/increasing_uniform.rs +++ b/src/seq/increasing_uniform.rs @@ -41,7 +41,7 @@ impl IncreasingUniform { let next_n = self.n + 1; // There's room for further optimisation here: - // gen_range uses rejection sampling (or other method; see #1196) to avoid bias. + // random_range uses rejection sampling (or other method; see #1196) to avoid bias. // When the initial sample is biased for range 0..bound // it may still be viable to use for a smaller bound // (especially if small biases are considered acceptable). @@ -50,7 +50,7 @@ impl IncreasingUniform { // If the chunk is empty, generate a new chunk let (bound, remaining) = calculate_bound_u32(next_n); // bound = (n + 1) * (n + 2) *..* (n + remaining) - self.chunk = self.rng.gen_range(0..bound); + self.chunk = self.rng.random_range(..bound); // Chunk is a random number in // [0, (n + 1) * (n + 2) *..* (n + remaining) ) diff --git a/src/seq/index.rs b/src/seq/index.rs index e98b7ec1061..7dd0513850c 100644 --- a/src/seq/index.rs +++ b/src/seq/index.rs @@ -7,52 +7,56 @@ // except according to those terms. //! Low-level API for sampling indices - -#[cfg(feature = "alloc")] use core::slice; - -#[cfg(feature = "alloc")] use alloc::vec::{self, Vec}; +use alloc::vec::{self, Vec}; +use core::slice; +use core::{hash::Hash, ops::AddAssign}; // BTreeMap is not as fast in tests, but better than nothing. -#[cfg(all(feature = "alloc", not(feature = "std")))] -use alloc::collections::BTreeSet; -#[cfg(feature = "std")] use std::collections::HashSet; - #[cfg(feature = "std")] use super::WeightError; +use crate::distr::uniform::SampleUniform; +use crate::distr::{Distribution, Uniform}; +use crate::Rng; +#[cfg(not(feature = "std"))] +use alloc::collections::BTreeSet; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; +#[cfg(feature = "std")] +use std::collections::HashSet; -#[cfg(feature = "alloc")] -use crate::{Rng, distributions::{uniform::SampleUniform, Distribution, Uniform}}; - -#[cfg(feature = "serde1")] -use serde::{Serialize, Deserialize}; +#[cfg(not(any(target_pointer_width = "32", target_pointer_width = "64")))] +compile_error!("unsupported pointer width"); /// A vector of indices. /// /// Multiple internal representations are possible. #[derive(Clone, Debug)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum IndexVec { #[doc(hidden)] U32(Vec), + #[cfg(target_pointer_width = "64")] #[doc(hidden)] - USize(Vec), + U64(Vec), } impl IndexVec { /// Returns the number of indices #[inline] pub fn len(&self) -> usize { - match *self { - IndexVec::U32(ref v) => v.len(), - IndexVec::USize(ref v) => v.len(), + match self { + IndexVec::U32(v) => v.len(), + #[cfg(target_pointer_width = "64")] + IndexVec::U64(v) => v.len(), } } /// Returns `true` if the length is 0. #[inline] pub fn is_empty(&self) -> bool { - match *self { - IndexVec::U32(ref v) => v.is_empty(), - IndexVec::USize(ref v) => v.is_empty(), + match self { + IndexVec::U32(v) => v.is_empty(), + #[cfg(target_pointer_width = "64")] + IndexVec::U64(v) => v.is_empty(), } } @@ -62,9 +66,10 @@ impl IndexVec { /// restrictions.) #[inline] pub fn index(&self, index: usize) -> usize { - match *self { - IndexVec::U32(ref v) => v[index] as usize, - IndexVec::USize(ref v) => v[index], + match self { + IndexVec::U32(v) => v[index] as usize, + #[cfg(target_pointer_width = "64")] + IndexVec::U64(v) => v[index] as usize, } } @@ -73,30 +78,33 @@ impl IndexVec { pub fn into_vec(self) -> Vec { match self { IndexVec::U32(v) => v.into_iter().map(|i| i as usize).collect(), - IndexVec::USize(v) => v, + #[cfg(target_pointer_width = "64")] + IndexVec::U64(v) => v.into_iter().map(|i| i as usize).collect(), } } /// Iterate over the indices as a sequence of `usize` values #[inline] pub fn iter(&self) -> IndexVecIter<'_> { - match *self { - IndexVec::U32(ref v) => IndexVecIter::U32(v.iter()), - IndexVec::USize(ref v) => IndexVecIter::USize(v.iter()), + match self { + IndexVec::U32(v) => IndexVecIter::U32(v.iter()), + #[cfg(target_pointer_width = "64")] + IndexVec::U64(v) => IndexVecIter::U64(v.iter()), } } } impl IntoIterator for IndexVec { - type Item = usize; type IntoIter = IndexVecIntoIter; + type Item = usize; /// Convert into an iterator over the indices as a sequence of `usize` values #[inline] fn into_iter(self) -> IndexVecIntoIter { match self { IndexVec::U32(v) => IndexVecIntoIter::U32(v.into_iter()), - IndexVec::USize(v) => IndexVecIntoIter::USize(v.into_iter()), + #[cfg(target_pointer_width = "64")] + IndexVec::U64(v) => IndexVecIntoIter::U64(v.into_iter()), } } } @@ -106,12 +114,15 @@ impl PartialEq for IndexVec { use self::IndexVec::*; match (self, other) { (U32(v1), U32(v2)) => v1 == v2, - (USize(v1), USize(v2)) => v1 == v2, - (U32(v1), USize(v2)) => { - (v1.len() == v2.len()) && (v1.iter().zip(v2.iter()).all(|(x, y)| *x as usize == *y)) + #[cfg(target_pointer_width = "64")] + (U64(v1), U64(v2)) => v1 == v2, + #[cfg(target_pointer_width = "64")] + (U32(v1), U64(v2)) => { + (v1.len() == v2.len()) && (v1.iter().zip(v2.iter()).all(|(x, y)| *x as u64 == *y)) } - (USize(v1), U32(v2)) => { - (v1.len() == v2.len()) && (v1.iter().zip(v2.iter()).all(|(x, y)| *x == *y as usize)) + #[cfg(target_pointer_width = "64")] + (U64(v1), U32(v2)) => { + (v1.len() == v2.len()) && (v1.iter().zip(v2.iter()).all(|(x, y)| *x == *y as u64)) } } } @@ -124,10 +135,11 @@ impl From> for IndexVec { } } -impl From> for IndexVec { +#[cfg(target_pointer_width = "64")] +impl From> for IndexVec { #[inline] - fn from(v: Vec) -> Self { - IndexVec::USize(v) + fn from(v: Vec) -> Self { + IndexVec::U64(v) } } @@ -136,40 +148,44 @@ impl From> for IndexVec { pub enum IndexVecIter<'a> { #[doc(hidden)] U32(slice::Iter<'a, u32>), + #[cfg(target_pointer_width = "64")] #[doc(hidden)] - USize(slice::Iter<'a, usize>), + U64(slice::Iter<'a, u64>), } -impl<'a> Iterator for IndexVecIter<'a> { +impl Iterator for IndexVecIter<'_> { type Item = usize; #[inline] fn next(&mut self) -> Option { use self::IndexVecIter::*; - match *self { - U32(ref mut iter) => iter.next().map(|i| *i as usize), - USize(ref mut iter) => iter.next().cloned(), + match self { + U32(iter) => iter.next().map(|i| *i as usize), + #[cfg(target_pointer_width = "64")] + U64(iter) => iter.next().map(|i| *i as usize), } } #[inline] fn size_hint(&self) -> (usize, Option) { - match *self { - IndexVecIter::U32(ref v) => v.size_hint(), - IndexVecIter::USize(ref v) => v.size_hint(), + match self { + IndexVecIter::U32(v) => v.size_hint(), + #[cfg(target_pointer_width = "64")] + IndexVecIter::U64(v) => v.size_hint(), } } } -impl<'a> ExactSizeIterator for IndexVecIter<'a> {} +impl ExactSizeIterator for IndexVecIter<'_> {} /// Return type of `IndexVec::into_iter`. #[derive(Clone, Debug)] pub enum IndexVecIntoIter { #[doc(hidden)] U32(vec::IntoIter), + #[cfg(target_pointer_width = "64")] #[doc(hidden)] - USize(vec::IntoIter), + U64(vec::IntoIter), } impl Iterator for IndexVecIntoIter { @@ -178,25 +194,26 @@ impl Iterator for IndexVecIntoIter { #[inline] fn next(&mut self) -> Option { use self::IndexVecIntoIter::*; - match *self { - U32(ref mut v) => v.next().map(|i| i as usize), - USize(ref mut v) => v.next(), + match self { + U32(v) => v.next().map(|i| i as usize), + #[cfg(target_pointer_width = "64")] + U64(v) => v.next().map(|i| i as usize), } } #[inline] fn size_hint(&self) -> (usize, Option) { use self::IndexVecIntoIter::*; - match *self { - U32(ref v) => v.size_hint(), - USize(ref v) => v.size_hint(), + match self { + U32(v) => v.size_hint(), + #[cfg(target_pointer_width = "64")] + U64(v) => v.size_hint(), } } } impl ExactSizeIterator for IndexVecIntoIter {} - /// Randomly sample exactly `amount` distinct indices from `0..length`, and /// return them in random order (fully shuffled). /// @@ -219,15 +236,22 @@ impl ExactSizeIterator for IndexVecIntoIter {} /// to adapt the internal `sample_floyd` implementation. /// /// Panics if `amount > length`. +#[track_caller] pub fn sample(rng: &mut R, length: usize, amount: usize) -> IndexVec -where R: Rng + ?Sized { +where + R: Rng + ?Sized, +{ if amount > length { panic!("`amount` of samples must be less than or equal to `length`"); } - if length > (::core::u32::MAX as usize) { + if length > (u32::MAX as usize) { + #[cfg(target_pointer_width = "32")] + unreachable!(); + // We never want to use inplace here, but could use floyd's alg // Lazy version: always use the cache alg. - return sample_rejection(rng, length, amount); + #[cfg(target_pointer_width = "64")] + return sample_rejection(rng, length as u64, amount as u64); } let amount = amount as u32; let length = length as u32; @@ -258,10 +282,13 @@ where R: Rng + ?Sized { } } -/// Randomly sample exactly `amount` distinct indices from `0..length`, and -/// return them in an arbitrary order (there is no guarantee of shuffling or -/// ordering). The weights are to be provided by the input function `weights`, -/// which will be called once for each index. +/// Randomly sample `amount` distinct indices from `0..length` +/// +/// The result may contain less than `amount` indices if insufficient non-zero +/// weights are available. Results are returned in an arbitrary order (there is +/// no guarantee of shuffling or ordering). +/// +/// Function `weight` is called once for each index to provide weights. /// /// This method is used internally by the slice sampling methods, but it can /// sometimes be useful to have the indices themselves so this is provided as @@ -269,45 +296,58 @@ where R: Rng + ?Sized { /// /// Error cases: /// - [`WeightError::InvalidWeight`] when a weight is not-a-number or negative. -/// - [`WeightError::InsufficientNonZero`] when fewer than `amount` weights are positive. /// /// This implementation uses `O(length + amount)` space and `O(length)` time. #[cfg(feature = "std")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] pub fn sample_weighted( - rng: &mut R, length: usize, weight: F, amount: usize, + rng: &mut R, + length: usize, + weight: F, + amount: usize, ) -> Result where R: Rng + ?Sized, F: Fn(usize) -> X, X: Into, { - if length > (core::u32::MAX as usize) { - sample_efraimidis_spirakis(rng, length, weight, amount) + if length > (u32::MAX as usize) { + #[cfg(target_pointer_width = "32")] + unreachable!(); + + #[cfg(target_pointer_width = "64")] + { + let amount = amount as u64; + let length = length as u64; + sample_efraimidis_spirakis(rng, length, weight, amount) + } } else { - assert!(amount <= core::u32::MAX as usize); + assert!(amount <= u32::MAX as usize); let amount = amount as u32; let length = length as u32; sample_efraimidis_spirakis(rng, length, weight, amount) } } - -/// Randomly sample exactly `amount` distinct indices from `0..length`, and -/// return them in an arbitrary order (there is no guarantee of shuffling or -/// ordering). The weights are to be provided by the input function `weights`, -/// which will be called once for each index. +/// Randomly sample `amount` distinct indices from `0..length` +/// +/// The result may contain less than `amount` indices if insufficient non-zero +/// weights are available. Results are returned in an arbitrary order (there is +/// no guarantee of shuffling or ordering). /// -/// This implementation uses the algorithm described by Efraimidis and Spirakis -/// in this paper: https://doi.org/10.1016/j.ipl.2005.11.003 +/// Function `weight` is called once for each index to provide weights. +/// +/// This implementation is based on the algorithm A-ExpJ as found in +/// [Efraimidis and Spirakis, 2005](https://doi.org/10.1016/j.ipl.2005.11.003). /// It uses `O(length + amount)` space and `O(length)` time. /// /// Error cases: /// - [`WeightError::InvalidWeight`] when a weight is not-a-number or negative. -/// - [`WeightError::InsufficientNonZero`] when fewer than `amount` weights are positive. #[cfg(feature = "std")] fn sample_efraimidis_spirakis( - rng: &mut R, length: N, weight: F, amount: N, + rng: &mut R, + length: N, + weight: F, + amount: N, ) -> Result where R: Rng + ?Sized, @@ -316,6 +356,8 @@ where N: UInt, IndexVec: From>, { + use std::{cmp::Ordering, collections::BinaryHeap}; + if amount == N::zero() { return Ok(IndexVec::U32(Vec::new())); } @@ -324,31 +366,37 @@ where index: N, key: f64, } + impl PartialOrd for Element { - fn partial_cmp(&self, other: &Self) -> Option { - self.key.partial_cmp(&other.key) + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) } } + impl Ord for Element { - fn cmp(&self, other: &Self) -> core::cmp::Ordering { - // partial_cmp will always produce a value, - // because we check that the weights are not nan - self.partial_cmp(other).unwrap() + fn cmp(&self, other: &Self) -> Ordering { + // unwrap() should not panic since weights should not be NaN + // We reverse so that BinaryHeap::peek shows the smallest item + self.key.partial_cmp(&other.key).unwrap().reverse() } } + impl PartialEq for Element { fn eq(&self, other: &Self) -> bool { self.key == other.key } } + impl Eq for Element {} - let mut candidates = Vec::with_capacity(length.as_usize()); + let mut candidates = BinaryHeap::with_capacity(amount.as_usize()); let mut index = N::zero(); - while index < length { + while index < length && candidates.len() < amount.as_usize() { let weight = weight(index.as_usize()).into(); if weight > 0.0 { - let key = rng.gen::().powf(1.0 / weight); + // We use the log of the key used in A-ExpJ to improve precision + // for small weights: + let key = rng.random::().ln() / weight; candidates.push(Element { index, key }); } else if !(weight >= 0.0) { return Err(WeightError::InvalidWeight); @@ -357,24 +405,31 @@ where index += N::one(); } - let avail = candidates.len(); - if avail < amount.as_usize() { - return Err(WeightError::InsufficientNonZero); - } - - // Partially sort the array to find the `amount` elements with the greatest - // keys. Do this by using `select_nth_unstable` to put the elements with - // the *smallest* keys at the beginning of the list in `O(n)` time, which - // provides equivalent information about the elements with the *greatest* keys. - let (_, mid, greater) - = candidates.select_nth_unstable(avail - amount.as_usize()); + if index < length { + let mut x = rng.random::().ln() / candidates.peek().unwrap().key; + while index < length { + let weight = weight(index.as_usize()).into(); + if weight > 0.0 { + x -= weight; + if x <= 0.0 { + let min_candidate = candidates.pop().unwrap(); + let t = (min_candidate.key * weight).exp(); + let key = rng.random_range(t..1.0).ln() / weight; + candidates.push(Element { index, key }); + + x = rng.random::().ln() / candidates.peek().unwrap().key; + } + } else if !(weight >= 0.0) { + return Err(WeightError::InvalidWeight); + } - let mut result: Vec = Vec::with_capacity(amount.as_usize()); - result.push(mid.index); - for element in greater { - result.push(element.index); + index += N::one(); + } } - Ok(IndexVec::from(result)) + + Ok(IndexVec::from( + candidates.iter().map(|elt| elt.index).collect(), + )) } /// Randomly sample exactly `amount` indices from `0..length`, using Floyd's @@ -384,14 +439,16 @@ where /// /// This implementation uses `O(amount)` memory and `O(amount^2)` time. fn sample_floyd(rng: &mut R, length: u32, amount: u32) -> IndexVec -where R: Rng + ?Sized { - // Note that the values returned by `rng.gen_range()` can be +where + R: Rng + ?Sized, +{ + // Note that the values returned by `rng.random_range()` can be // inferred from the returned vector by working backwards from // the last entry. This bijection proves the algorithm fair. debug_assert!(amount <= length); let mut indices = Vec::with_capacity(amount as usize); for j in length - amount..length { - let t = rng.gen_range(0..=j); + let t = rng.random_range(..=j); if let Some(pos) = indices.iter().position(|&x| x == t) { indices[pos] = j; } @@ -413,12 +470,14 @@ where R: Rng + ?Sized { /// /// Set-up is `O(length)` time and memory and shuffling is `O(amount)` time. fn sample_inplace(rng: &mut R, length: u32, amount: u32) -> IndexVec -where R: Rng + ?Sized { +where + R: Rng + ?Sized, +{ debug_assert!(amount <= length); let mut indices: Vec = Vec::with_capacity(length as usize); indices.extend(0..length); for i in 0..amount { - let j: u32 = rng.gen_range(i..length); + let j: u32 = rng.random_range(i..length); indices.swap(i as usize, j as usize); } indices.truncate(amount as usize); @@ -426,12 +485,13 @@ where R: Rng + ?Sized { IndexVec::from(indices) } -trait UInt: Copy + PartialOrd + Ord + PartialEq + Eq + SampleUniform - + core::hash::Hash + core::ops::AddAssign { +trait UInt: Copy + PartialOrd + Ord + PartialEq + Eq + SampleUniform + Hash + AddAssign { fn zero() -> Self; + #[cfg_attr(feature = "alloc", allow(dead_code))] fn one() -> Self; fn as_usize(self) -> usize; } + impl UInt for u32 { #[inline] fn zero() -> Self { @@ -448,7 +508,9 @@ impl UInt for u32 { self as usize } } -impl UInt for usize { + +#[cfg(target_pointer_width = "64")] +impl UInt for u64 { #[inline] fn zero() -> Self { 0 @@ -461,7 +523,7 @@ impl UInt for usize { #[inline] fn as_usize(self) -> usize { - self + self as usize } } @@ -501,25 +563,17 @@ where #[cfg(test)] mod test { use super::*; + use alloc::vec; #[test] - #[cfg(feature = "serde1")] + #[cfg(feature = "serde")] fn test_serialization_index_vec() { - let some_index_vec = IndexVec::from(vec![254_usize, 234, 2, 1]); - let de_some_index_vec: IndexVec = bincode::deserialize(&bincode::serialize(&some_index_vec).unwrap()).unwrap(); - match (some_index_vec, de_some_index_vec) { - (IndexVec::U32(a), IndexVec::U32(b)) => { - assert_eq!(a, b); - }, - (IndexVec::USize(a), IndexVec::USize(b)) => { - assert_eq!(a, b); - }, - _ => {panic!("failed to seralize/deserialize `IndexVec`")} - } + let some_index_vec = IndexVec::from(vec![254_u32, 234, 2, 1]); + let de_some_index_vec: IndexVec = + bincode::deserialize(&bincode::serialize(&some_index_vec).unwrap()).unwrap(); + assert_eq!(some_index_vec, de_some_index_vec); } - #[cfg(feature = "alloc")] use alloc::vec; - #[test] fn test_sample_boundaries() { let mut r = crate::test::rng(404); @@ -592,13 +646,14 @@ mod test { for &i in &indices { assert!((i as usize) < len); } - }, - IndexVec::USize(_) => panic!("expected `IndexVec::U32`"), + } + #[cfg(target_pointer_width = "64")] + _ => panic!("expected `IndexVec::U32`"), } } let r = sample_weighted(&mut seed_rng(423), 10, |i| i as f64, 10); - assert_eq!(r.unwrap_err(), WeightError::InsufficientNonZero); + assert_eq!(r.unwrap().len(), 9); } #[test] @@ -627,11 +682,15 @@ mod test { do_test(300, 80, &[31, 289, 248, 154, 221, 243, 7, 192]); // inplace do_test(300, 180, &[31, 289, 248, 154, 221, 243, 7, 192]); // inplace - do_test(1_000_000, 8, &[ - 103717, 963485, 826422, 509101, 736394, 807035, 5327, 632573, - ]); // floyd - do_test(1_000_000, 180, &[ - 103718, 963490, 826426, 509103, 736396, 807036, 5327, 632573, - ]); // rejection + do_test( + 1_000_000, + 8, + &[103717, 963485, 826422, 509101, 736394, 807035, 5327, 632573], + ); // floyd + do_test( + 1_000_000, + 180, + &[103718, 963490, 826426, 509103, 736396, 807036, 5327, 632573], + ); // rejection } } diff --git a/src/seq/iterator.rs b/src/seq/iterator.rs new file mode 100644 index 00000000000..a9a9e56155c --- /dev/null +++ b/src/seq/iterator.rs @@ -0,0 +1,668 @@ +// Copyright 2018-2024 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! `IteratorRandom` + +use super::coin_flipper::CoinFlipper; +#[allow(unused)] +use super::IndexedRandom; +use crate::Rng; +#[cfg(feature = "alloc")] +use alloc::vec::Vec; + +/// Extension trait on iterators, providing random sampling methods. +/// +/// This trait is implemented on all iterators `I` where `I: Iterator + Sized` +/// and provides methods for +/// choosing one or more elements. You must `use` this trait: +/// +/// ``` +/// use rand::seq::IteratorRandom; +/// +/// let faces = "😀😎😐😕😠😢"; +/// println!("I am {}!", faces.chars().choose(&mut rand::rng()).unwrap()); +/// ``` +/// Example output (non-deterministic): +/// ```none +/// I am 😀! +/// ``` +pub trait IteratorRandom: Iterator + Sized { + /// Uniformly sample one element + /// + /// Assuming that the [`Iterator::size_hint`] is correct, this method + /// returns one uniformly-sampled random element of the slice, or `None` + /// only if the slice is empty. Incorrect bounds on the `size_hint` may + /// cause this method to incorrectly return `None` if fewer elements than + /// the advertised `lower` bound are present and may prevent sampling of + /// elements beyond an advertised `upper` bound (i.e. incorrect `size_hint` + /// is memory-safe, but may result in unexpected `None` result and + /// non-uniform distribution). + /// + /// With an accurate [`Iterator::size_hint`] and where [`Iterator::nth`] is + /// a constant-time operation, this method can offer `O(1)` performance. + /// Where no size hint is + /// available, complexity is `O(n)` where `n` is the iterator length. + /// Partial hints (where `lower > 0`) also improve performance. + /// + /// Note further that [`Iterator::size_hint`] may affect the number of RNG + /// samples used as well as the result (while remaining uniform sampling). + /// Consider instead using [`IteratorRandom::choose_stable`] to avoid + /// [`Iterator`] combinators which only change size hints from affecting the + /// results. + /// + /// # Example + /// + /// ``` + /// use rand::seq::IteratorRandom; + /// + /// let words = "Mary had a little lamb".split(' '); + /// println!("{}", words.choose(&mut rand::rng()).unwrap()); + /// ``` + fn choose(mut self, rng: &mut R) -> Option + where + R: Rng + ?Sized, + { + let (mut lower, mut upper) = self.size_hint(); + let mut result = None; + + // Handling for this condition outside the loop allows the optimizer to eliminate the loop + // when the Iterator is an ExactSizeIterator. This has a large performance impact on e.g. + // seq_iter_choose_from_1000. + if upper == Some(lower) { + return match lower { + 0 => None, + 1 => self.next(), + _ => self.nth(rng.random_range(..lower)), + }; + } + + let mut coin_flipper = CoinFlipper::new(rng); + let mut consumed = 0; + + // Continue until the iterator is exhausted + loop { + if lower > 1 { + let ix = coin_flipper.rng.random_range(..lower + consumed); + let skip = if ix < lower { + result = self.nth(ix); + lower - (ix + 1) + } else { + lower + }; + if upper == Some(lower) { + return result; + } + consumed += lower; + if skip > 0 { + self.nth(skip - 1); + } + } else { + let elem = self.next(); + if elem.is_none() { + return result; + } + consumed += 1; + if coin_flipper.random_ratio_one_over(consumed) { + result = elem; + } + } + + let hint = self.size_hint(); + lower = hint.0; + upper = hint.1; + } + } + + /// Uniformly sample one element (stable) + /// + /// This method is very similar to [`choose`] except that the result + /// only depends on the length of the iterator and the values produced by + /// `rng`. Notably for any iterator of a given length this will make the + /// same requests to `rng` and if the same sequence of values are produced + /// the same index will be selected from `self`. This may be useful if you + /// need consistent results no matter what type of iterator you are working + /// with. If you do not need this stability prefer [`choose`]. + /// + /// Note that this method still uses [`Iterator::size_hint`] to skip + /// constructing elements where possible, however the selection and `rng` + /// calls are the same in the face of this optimization. If you want to + /// force every element to be created regardless call `.inspect(|e| ())`. + /// + /// [`choose`]: IteratorRandom::choose + // + // Clippy is wrong here: we need to iterate over all entries with the RNG to + // ensure that choosing is *stable*. + #[allow(clippy::double_ended_iterator_last)] + fn choose_stable(mut self, rng: &mut R) -> Option + where + R: Rng + ?Sized, + { + let mut consumed = 0; + let mut result = None; + let mut coin_flipper = CoinFlipper::new(rng); + + loop { + // Currently the only way to skip elements is `nth()`. So we need to + // store what index to access next here. + // This should be replaced by `advance_by()` once it is stable: + // https://github.com/rust-lang/rust/issues/77404 + let mut next = 0; + + let (lower, _) = self.size_hint(); + if lower >= 2 { + let highest_selected = (0..lower) + .filter(|ix| coin_flipper.random_ratio_one_over(consumed + ix + 1)) + .last(); + + consumed += lower; + next = lower; + + if let Some(ix) = highest_selected { + result = self.nth(ix); + next -= ix + 1; + debug_assert!(result.is_some(), "iterator shorter than size_hint().0"); + } + } + + let elem = self.nth(next); + if elem.is_none() { + return result; + } + + if coin_flipper.random_ratio_one_over(consumed + 1) { + result = elem; + } + consumed += 1; + } + } + + /// Uniformly sample `amount` distinct elements into a buffer + /// + /// Collects values at random from the iterator into a supplied buffer + /// until that buffer is filled. + /// + /// Although the elements are selected randomly, the order of elements in + /// the buffer is neither stable nor fully random. If random ordering is + /// desired, shuffle the result. + /// + /// Returns the number of elements added to the buffer. This equals the length + /// of the buffer unless the iterator contains insufficient elements, in which + /// case this equals the number of elements available. + /// + /// Complexity is `O(n)` where `n` is the length of the iterator. + /// For slices, prefer [`IndexedRandom::choose_multiple`]. + fn choose_multiple_fill(mut self, rng: &mut R, buf: &mut [Self::Item]) -> usize + where + R: Rng + ?Sized, + { + let amount = buf.len(); + let mut len = 0; + while len < amount { + if let Some(elem) = self.next() { + buf[len] = elem; + len += 1; + } else { + // Iterator exhausted; stop early + return len; + } + } + + // Continue, since the iterator was not exhausted + for (i, elem) in self.enumerate() { + let k = rng.random_range(..i + 1 + amount); + if let Some(slot) = buf.get_mut(k) { + *slot = elem; + } + } + len + } + + /// Uniformly sample `amount` distinct elements into a [`Vec`] + /// + /// This is equivalent to `choose_multiple_fill` except for the result type. + /// + /// Although the elements are selected randomly, the order of elements in + /// the buffer is neither stable nor fully random. If random ordering is + /// desired, shuffle the result. + /// + /// The length of the returned vector equals `amount` unless the iterator + /// contains insufficient elements, in which case it equals the number of + /// elements available. + /// + /// Complexity is `O(n)` where `n` is the length of the iterator. + /// For slices, prefer [`IndexedRandom::choose_multiple`]. + #[cfg(feature = "alloc")] + fn choose_multiple(mut self, rng: &mut R, amount: usize) -> Vec + where + R: Rng + ?Sized, + { + let mut reservoir = Vec::with_capacity(amount); + reservoir.extend(self.by_ref().take(amount)); + + // Continue unless the iterator was exhausted + // + // note: this prevents iterators that "restart" from causing problems. + // If the iterator stops once, then so do we. + if reservoir.len() == amount { + for (i, elem) in self.enumerate() { + let k = rng.random_range(..i + 1 + amount); + if let Some(slot) = reservoir.get_mut(k) { + *slot = elem; + } + } + } else { + // Don't hang onto extra memory. There is a corner case where + // `amount` was much less than `self.len()`. + reservoir.shrink_to_fit(); + } + reservoir + } +} + +impl IteratorRandom for I where I: Iterator + Sized {} + +#[cfg(test)] +mod test { + use super::*; + #[cfg(all(feature = "alloc", not(feature = "std")))] + use alloc::vec::Vec; + + #[derive(Clone)] + struct UnhintedIterator { + iter: I, + } + impl Iterator for UnhintedIterator { + type Item = I::Item; + + fn next(&mut self) -> Option { + self.iter.next() + } + } + + #[derive(Clone)] + struct ChunkHintedIterator { + iter: I, + chunk_remaining: usize, + chunk_size: usize, + hint_total_size: bool, + } + impl Iterator for ChunkHintedIterator { + type Item = I::Item; + + fn next(&mut self) -> Option { + if self.chunk_remaining == 0 { + self.chunk_remaining = core::cmp::min(self.chunk_size, self.iter.len()); + } + self.chunk_remaining = self.chunk_remaining.saturating_sub(1); + + self.iter.next() + } + + fn size_hint(&self) -> (usize, Option) { + ( + self.chunk_remaining, + if self.hint_total_size { + Some(self.iter.len()) + } else { + None + }, + ) + } + } + + #[derive(Clone)] + struct WindowHintedIterator { + iter: I, + window_size: usize, + hint_total_size: bool, + } + impl Iterator for WindowHintedIterator { + type Item = I::Item; + + fn next(&mut self) -> Option { + self.iter.next() + } + + fn size_hint(&self) -> (usize, Option) { + ( + core::cmp::min(self.iter.len(), self.window_size), + if self.hint_total_size { + Some(self.iter.len()) + } else { + None + }, + ) + } + } + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_iterator_choose() { + let r = &mut crate::test::rng(109); + fn test_iter + Clone>(r: &mut R, iter: Iter) { + let mut chosen = [0i32; 9]; + for _ in 0..1000 { + let picked = iter.clone().choose(r).unwrap(); + chosen[picked] += 1; + } + for count in chosen.iter() { + // Samples should follow Binomial(1000, 1/9) + // Octave: binopdf(x, 1000, 1/9) gives the prob of *count == x + // Note: have seen 153, which is unlikely but not impossible. + assert!( + 72 < *count && *count < 154, + "count not close to 1000/9: {}", + count + ); + } + } + + test_iter(r, 0..9); + test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()); + #[cfg(feature = "alloc")] + test_iter(r, (0..9).collect::>().into_iter()); + test_iter(r, UnhintedIterator { iter: 0..9 }); + test_iter( + r, + ChunkHintedIterator { + iter: 0..9, + chunk_size: 4, + chunk_remaining: 4, + hint_total_size: false, + }, + ); + test_iter( + r, + ChunkHintedIterator { + iter: 0..9, + chunk_size: 4, + chunk_remaining: 4, + hint_total_size: true, + }, + ); + test_iter( + r, + WindowHintedIterator { + iter: 0..9, + window_size: 2, + hint_total_size: false, + }, + ); + test_iter( + r, + WindowHintedIterator { + iter: 0..9, + window_size: 2, + hint_total_size: true, + }, + ); + + assert_eq!((0..0).choose(r), None); + assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None); + } + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_iterator_choose_stable() { + let r = &mut crate::test::rng(109); + fn test_iter + Clone>(r: &mut R, iter: Iter) { + let mut chosen = [0i32; 9]; + for _ in 0..1000 { + let picked = iter.clone().choose_stable(r).unwrap(); + chosen[picked] += 1; + } + for count in chosen.iter() { + // Samples should follow Binomial(1000, 1/9) + // Octave: binopdf(x, 1000, 1/9) gives the prob of *count == x + // Note: have seen 153, which is unlikely but not impossible. + assert!( + 72 < *count && *count < 154, + "count not close to 1000/9: {}", + count + ); + } + } + + test_iter(r, 0..9); + test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()); + #[cfg(feature = "alloc")] + test_iter(r, (0..9).collect::>().into_iter()); + test_iter(r, UnhintedIterator { iter: 0..9 }); + test_iter( + r, + ChunkHintedIterator { + iter: 0..9, + chunk_size: 4, + chunk_remaining: 4, + hint_total_size: false, + }, + ); + test_iter( + r, + ChunkHintedIterator { + iter: 0..9, + chunk_size: 4, + chunk_remaining: 4, + hint_total_size: true, + }, + ); + test_iter( + r, + WindowHintedIterator { + iter: 0..9, + window_size: 2, + hint_total_size: false, + }, + ); + test_iter( + r, + WindowHintedIterator { + iter: 0..9, + window_size: 2, + hint_total_size: true, + }, + ); + + assert_eq!((0..0).choose(r), None); + assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None); + } + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_iterator_choose_stable_stability() { + fn test_iter(iter: impl Iterator + Clone) -> [i32; 9] { + let r = &mut crate::test::rng(109); + let mut chosen = [0i32; 9]; + for _ in 0..1000 { + let picked = iter.clone().choose_stable(r).unwrap(); + chosen[picked] += 1; + } + chosen + } + + let reference = test_iter(0..9); + assert_eq!( + test_iter([0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()), + reference + ); + + #[cfg(feature = "alloc")] + assert_eq!(test_iter((0..9).collect::>().into_iter()), reference); + assert_eq!(test_iter(UnhintedIterator { iter: 0..9 }), reference); + assert_eq!( + test_iter(ChunkHintedIterator { + iter: 0..9, + chunk_size: 4, + chunk_remaining: 4, + hint_total_size: false, + }), + reference + ); + assert_eq!( + test_iter(ChunkHintedIterator { + iter: 0..9, + chunk_size: 4, + chunk_remaining: 4, + hint_total_size: true, + }), + reference + ); + assert_eq!( + test_iter(WindowHintedIterator { + iter: 0..9, + window_size: 2, + hint_total_size: false, + }), + reference + ); + assert_eq!( + test_iter(WindowHintedIterator { + iter: 0..9, + window_size: 2, + hint_total_size: true, + }), + reference + ); + } + + #[test] + #[cfg(feature = "alloc")] + fn test_sample_iter() { + let min_val = 1; + let max_val = 100; + + let mut r = crate::test::rng(401); + let vals = (min_val..max_val).collect::>(); + let small_sample = vals.iter().choose_multiple(&mut r, 5); + let large_sample = vals.iter().choose_multiple(&mut r, vals.len() + 5); + + assert_eq!(small_sample.len(), 5); + assert_eq!(large_sample.len(), vals.len()); + // no randomization happens when amount >= len + assert_eq!(large_sample, vals.iter().collect::>()); + + assert!(small_sample + .iter() + .all(|e| { **e >= min_val && **e <= max_val })); + } + + #[test] + fn value_stability_choose() { + fn choose>(iter: I) -> Option { + let mut rng = crate::test::rng(411); + iter.choose(&mut rng) + } + + assert_eq!(choose([].iter().cloned()), None); + assert_eq!(choose(0..100), Some(33)); + assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(27)); + assert_eq!( + choose(ChunkHintedIterator { + iter: 0..100, + chunk_size: 32, + chunk_remaining: 32, + hint_total_size: false, + }), + Some(91) + ); + assert_eq!( + choose(ChunkHintedIterator { + iter: 0..100, + chunk_size: 32, + chunk_remaining: 32, + hint_total_size: true, + }), + Some(91) + ); + assert_eq!( + choose(WindowHintedIterator { + iter: 0..100, + window_size: 32, + hint_total_size: false, + }), + Some(34) + ); + assert_eq!( + choose(WindowHintedIterator { + iter: 0..100, + window_size: 32, + hint_total_size: true, + }), + Some(34) + ); + } + + #[test] + fn value_stability_choose_stable() { + fn choose>(iter: I) -> Option { + let mut rng = crate::test::rng(411); + iter.choose_stable(&mut rng) + } + + assert_eq!(choose([].iter().cloned()), None); + assert_eq!(choose(0..100), Some(27)); + assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(27)); + assert_eq!( + choose(ChunkHintedIterator { + iter: 0..100, + chunk_size: 32, + chunk_remaining: 32, + hint_total_size: false, + }), + Some(27) + ); + assert_eq!( + choose(ChunkHintedIterator { + iter: 0..100, + chunk_size: 32, + chunk_remaining: 32, + hint_total_size: true, + }), + Some(27) + ); + assert_eq!( + choose(WindowHintedIterator { + iter: 0..100, + window_size: 32, + hint_total_size: false, + }), + Some(27) + ); + assert_eq!( + choose(WindowHintedIterator { + iter: 0..100, + window_size: 32, + hint_total_size: true, + }), + Some(27) + ); + } + + #[test] + fn value_stability_choose_multiple() { + fn do_test>(iter: I, v: &[u32]) { + let mut rng = crate::test::rng(412); + let mut buf = [0u32; 8]; + assert_eq!( + iter.clone().choose_multiple_fill(&mut rng, &mut buf), + v.len() + ); + assert_eq!(&buf[0..v.len()], v); + + #[cfg(feature = "alloc")] + { + let mut rng = crate::test::rng(412); + assert_eq!(iter.choose_multiple(&mut rng, v.len()), v); + } + } + + do_test(0..4, &[0, 1, 2, 3]); + do_test(0..8, &[0, 1, 2, 3, 4, 5, 6, 7]); + do_test(0..100, &[77, 95, 38, 23, 25, 8, 58, 40]); + } +} diff --git a/src/seq/mod.rs b/src/seq/mod.rs index f5cbc6008e9..91d634d865e 100644 --- a/src/seq/mod.rs +++ b/src/seq/mod.rs @@ -19,7 +19,7 @@ //! //! Also see: //! -//! * [`crate::distributions::WeightedIndex`] distribution which provides +//! * [`crate::distr::weighted::WeightedIndex`] distribution which provides //! weighted index sampling. //! //! In order to make results reproducible across 32-64 bit architectures, all @@ -27,1393 +27,54 @@ //! small performance boost in some cases). mod coin_flipper; -#[cfg(feature = "alloc")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] -pub mod index; - mod increasing_uniform; +mod iterator; +mod slice; #[cfg(feature = "alloc")] -#[doc(no_inline)] -pub use crate::distributions::WeightError; - -use core::ops::{Index, IndexMut}; +#[path = "index.rs"] +mod index_; #[cfg(feature = "alloc")] -use alloc::vec::Vec; - +#[doc(no_inline)] +pub use crate::distr::weighted::Error as WeightError; +pub use iterator::IteratorRandom; #[cfg(feature = "alloc")] -use crate::distributions::uniform::{SampleBorrow, SampleUniform}; -#[cfg(feature = "alloc")] use crate::distributions::Weight; -use crate::Rng; - -use self::coin_flipper::CoinFlipper; -use self::increasing_uniform::IncreasingUniform; - -/// Extension trait on indexable lists, providing random sampling methods. -/// -/// This trait is implemented on `[T]` slice types. Other types supporting -/// [`std::ops::Index`] may implement this (only [`Self::len`] must be -/// specified). -pub trait IndexedRandom: Index { - /// The length - fn len(&self) -> usize; - - /// True when the length is zero - #[inline] - fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// Uniformly sample one element - /// - /// Returns a reference to one uniformly-sampled random element of - /// the slice, or `None` if the slice is empty. - /// - /// For slices, complexity is `O(1)`. - /// - /// # Example - /// - /// ``` - /// use rand::thread_rng; - /// use rand::seq::IndexedRandom; - /// - /// let choices = [1, 2, 4, 8, 16, 32]; - /// let mut rng = thread_rng(); - /// println!("{:?}", choices.choose(&mut rng)); - /// assert_eq!(choices[..0].choose(&mut rng), None); - /// ``` - fn choose(&self, rng: &mut R) -> Option<&Self::Output> - where - R: Rng + ?Sized, - { - if self.is_empty() { - None - } else { - Some(&self[gen_index(rng, self.len())]) - } - } - - /// Uniformly sample `amount` distinct elements - /// - /// Chooses `amount` elements from the slice at random, without repetition, - /// and in random order. The returned iterator is appropriate both for - /// collection into a `Vec` and filling an existing buffer (see example). - /// - /// In case this API is not sufficiently flexible, use [`index::sample`]. - /// - /// For slices, complexity is the same as [`index::sample`]. - /// - /// # Example - /// ``` - /// use rand::seq::IndexedRandom; - /// - /// let mut rng = &mut rand::thread_rng(); - /// let sample = "Hello, audience!".as_bytes(); - /// - /// // collect the results into a vector: - /// let v: Vec = sample.choose_multiple(&mut rng, 3).cloned().collect(); - /// - /// // store in a buffer: - /// let mut buf = [0u8; 5]; - /// for (b, slot) in sample.choose_multiple(&mut rng, buf.len()).zip(buf.iter_mut()) { - /// *slot = *b; - /// } - /// ``` - #[cfg(feature = "alloc")] - #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] - fn choose_multiple(&self, rng: &mut R, amount: usize) -> SliceChooseIter - where - Self::Output: Sized, - R: Rng + ?Sized, - { - let amount = ::core::cmp::min(amount, self.len()); - SliceChooseIter { - slice: self, - _phantom: Default::default(), - indices: index::sample(rng, self.len(), amount).into_iter(), - } - } - - /// Biased sampling for one element - /// - /// Returns a reference to one element of the slice, sampled according - /// to the provided weights. Returns `None` only if the slice is empty. - /// - /// The specified function `weight` maps each item `x` to a relative - /// likelihood `weight(x)`. The probability of each item being selected is - /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`. - /// - /// For slices of length `n`, complexity is `O(n)`. - /// For more information about the underlying algorithm, - /// see [`distributions::WeightedIndex`]. - /// - /// See also [`choose_weighted_mut`]. - /// - /// # Example - /// - /// ``` - /// use rand::prelude::*; - /// - /// let choices = [('a', 2), ('b', 1), ('c', 1), ('d', 0)]; - /// let mut rng = thread_rng(); - /// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c', - /// // and 'd' will never be printed - /// println!("{:?}", choices.choose_weighted(&mut rng, |item| item.1).unwrap().0); - /// ``` - /// [`choose`]: IndexedRandom::choose - /// [`choose_weighted_mut`]: IndexedMutRandom::choose_weighted_mut - /// [`distributions::WeightedIndex`]: crate::distributions::WeightedIndex - #[cfg(feature = "alloc")] - #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] - fn choose_weighted( - &self, rng: &mut R, weight: F, - ) -> Result<&Self::Output, WeightError> - where - R: Rng + ?Sized, - F: Fn(&Self::Output) -> B, - B: SampleBorrow, - X: SampleUniform + Weight + ::core::cmp::PartialOrd, - { - use crate::distributions::{Distribution, WeightedIndex}; - let distr = WeightedIndex::new((0..self.len()).map(|idx| weight(&self[idx])))?; - Ok(&self[distr.sample(rng)]) - } +pub use slice::SliceChooseIter; +pub use slice::{IndexedMutRandom, IndexedRandom, SliceRandom}; - /// Biased sampling of `amount` distinct elements - /// - /// Similar to [`choose_multiple`], but where the likelihood of each element's - /// inclusion in the output may be specified. The elements are returned in an - /// arbitrary, unspecified order. - /// - /// The specified function `weight` maps each item `x` to a relative - /// likelihood `weight(x)`. The probability of each item being selected is - /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`. - /// - /// If all of the weights are equal, even if they are all zero, each element has - /// an equal likelihood of being selected. - /// - /// This implementation uses `O(length + amount)` space and `O(length)` time - /// if the "nightly" feature is enabled, or `O(length)` space and - /// `O(length + amount * log length)` time otherwise. - /// - /// # Example - /// - /// ``` - /// use rand::prelude::*; - /// - /// let choices = [('a', 2), ('b', 1), ('c', 1)]; - /// let mut rng = thread_rng(); - /// // First Draw * Second Draw = total odds - /// // ----------------------- - /// // (50% * 50%) + (25% * 67%) = 41.7% chance that the output is `['a', 'b']` in some order. - /// // (50% * 50%) + (25% * 67%) = 41.7% chance that the output is `['a', 'c']` in some order. - /// // (25% * 33%) + (25% * 33%) = 16.6% chance that the output is `['b', 'c']` in some order. - /// println!("{:?}", choices.choose_multiple_weighted(&mut rng, 2, |item| item.1).unwrap().collect::>()); - /// ``` - /// [`choose_multiple`]: IndexedRandom::choose_multiple - // - // Note: this is feature-gated on std due to usage of f64::powf. - // If necessary, we may use alloc+libm as an alternative (see PR #1089). - #[cfg(feature = "std")] - #[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] - fn choose_multiple_weighted( - &self, rng: &mut R, amount: usize, weight: F, - ) -> Result, WeightError> - where - Self::Output: Sized, - R: Rng + ?Sized, - F: Fn(&Self::Output) -> X, - X: Into, - { - let amount = ::core::cmp::min(amount, self.len()); - Ok(SliceChooseIter { - slice: self, - _phantom: Default::default(), - indices: index::sample_weighted( - rng, - self.len(), - |idx| weight(&self[idx]).into(), - amount, - )? - .into_iter(), - }) - } -} - -/// Extension trait on indexable lists, providing random sampling methods. -/// -/// This trait is implemented automatically for every type implementing -/// [`IndexedRandom`] and [`std::ops::IndexMut`]. -pub trait IndexedMutRandom: IndexedRandom + IndexMut { - /// Uniformly sample one element (mut) - /// - /// Returns a mutable reference to one uniformly-sampled random element of - /// the slice, or `None` if the slice is empty. - /// - /// For slices, complexity is `O(1)`. - fn choose_mut(&mut self, rng: &mut R) -> Option<&mut Self::Output> - where - R: Rng + ?Sized, - { - if self.is_empty() { - None - } else { - let len = self.len(); - Some(&mut self[gen_index(rng, len)]) - } - } +/// Low-level API for sampling indices +pub mod index { + use crate::Rng; - /// Biased sampling for one element (mut) - /// - /// Returns a mutable reference to one element of the slice, sampled according - /// to the provided weights. Returns `None` only if the slice is empty. - /// - /// The specified function `weight` maps each item `x` to a relative - /// likelihood `weight(x)`. The probability of each item being selected is - /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`. - /// - /// For slices of length `n`, complexity is `O(n)`. - /// For more information about the underlying algorithm, - /// see [`distributions::WeightedIndex`]. - /// - /// See also [`choose_weighted`]. - /// - /// [`choose_mut`]: IndexedMutRandom::choose_mut - /// [`choose_weighted`]: IndexedRandom::choose_weighted - /// [`distributions::WeightedIndex`]: crate::distributions::WeightedIndex #[cfg(feature = "alloc")] - #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] - fn choose_weighted_mut( - &mut self, rng: &mut R, weight: F, - ) -> Result<&mut Self::Output, WeightError> - where - R: Rng + ?Sized, - F: Fn(&Self::Output) -> B, - B: SampleBorrow, - X: SampleUniform + Weight + ::core::cmp::PartialOrd, - { - use crate::distributions::{Distribution, WeightedIndex}; - let distr = WeightedIndex::new((0..self.len()).map(|idx| weight(&self[idx])))?; - let index = distr.sample(rng); - Ok(&mut self[index]) - } -} - -/// Extension trait on slices, providing shuffling methods. -/// -/// This trait is implemented on all `[T]` slice types, providing several -/// methods for choosing and shuffling elements. You must `use` this trait: -/// -/// ``` -/// use rand::seq::SliceRandom; -/// -/// let mut rng = rand::thread_rng(); -/// let mut bytes = "Hello, random!".to_string().into_bytes(); -/// bytes.shuffle(&mut rng); -/// let str = String::from_utf8(bytes).unwrap(); -/// println!("{}", str); -/// ``` -/// Example output (non-deterministic): -/// ```none -/// l,nmroHado !le -/// ``` -pub trait SliceRandom: IndexedMutRandom { - /// Shuffle a mutable slice in place. - /// - /// For slices of length `n`, complexity is `O(n)`. - /// The resulting permutation is picked uniformly from the set of all possible permutations. - /// - /// # Example - /// - /// ``` - /// use rand::seq::SliceRandom; - /// use rand::thread_rng; - /// - /// let mut rng = thread_rng(); - /// let mut y = [1, 2, 3, 4, 5]; - /// println!("Unshuffled: {:?}", y); - /// y.shuffle(&mut rng); - /// println!("Shuffled: {:?}", y); - /// ``` - fn shuffle(&mut self, rng: &mut R) - where - R: Rng + ?Sized; - - /// Shuffle a slice in place, but exit early. - /// - /// Returns two mutable slices from the source slice. The first contains - /// `amount` elements randomly permuted. The second has the remaining - /// elements that are not fully shuffled. - /// - /// This is an efficient method to select `amount` elements at random from - /// the slice, provided the slice may be mutated. - /// - /// If you only need to choose elements randomly and `amount > self.len()/2` - /// then you may improve performance by taking - /// `amount = self.len() - amount` and using only the second slice. - /// - /// If `amount` is greater than the number of elements in the slice, this - /// will perform a full shuffle. - /// - /// For slices, complexity is `O(m)` where `m = amount`. - fn partial_shuffle( - &mut self, rng: &mut R, amount: usize, - ) -> (&mut [Self::Output], &mut [Self::Output]) - where - Self::Output: Sized, - R: Rng + ?Sized; -} - -/// Extension trait on iterators, providing random sampling methods. -/// -/// This trait is implemented on all iterators `I` where `I: Iterator + Sized` -/// and provides methods for -/// choosing one or more elements. You must `use` this trait: -/// -/// ``` -/// use rand::seq::IteratorRandom; -/// -/// let mut rng = rand::thread_rng(); -/// -/// let faces = "😀😎😐😕😠😢"; -/// println!("I am {}!", faces.chars().choose(&mut rng).unwrap()); -/// ``` -/// Example output (non-deterministic): -/// ```none -/// I am 😀! -/// ``` -pub trait IteratorRandom: Iterator + Sized { - /// Uniformly sample one element - /// - /// Assuming that the [`Iterator::size_hint`] is correct, this method - /// returns one uniformly-sampled random element of the slice, or `None` - /// only if the slice is empty. Incorrect bounds on the `size_hint` may - /// cause this method to incorrectly return `None` if fewer elements than - /// the advertised `lower` bound are present and may prevent sampling of - /// elements beyond an advertised `upper` bound (i.e. incorrect `size_hint` - /// is memory-safe, but may result in unexpected `None` result and - /// non-uniform distribution). - /// - /// With an accurate [`Iterator::size_hint`] and where [`Iterator::nth`] is - /// a constant-time operation, this method can offer `O(1)` performance. - /// Where no size hint is - /// available, complexity is `O(n)` where `n` is the iterator length. - /// Partial hints (where `lower > 0`) also improve performance. - /// - /// Note further that [`Iterator::size_hint`] may affect the number of RNG - /// samples used as well as the result (while remaining uniform sampling). - /// Consider instead using [`IteratorRandom::choose_stable`] to avoid - /// [`Iterator`] combinators which only change size hints from affecting the - /// results. - fn choose(mut self, rng: &mut R) -> Option - where - R: Rng + ?Sized, - { - let (mut lower, mut upper) = self.size_hint(); - let mut result = None; - - // Handling for this condition outside the loop allows the optimizer to eliminate the loop - // when the Iterator is an ExactSizeIterator. This has a large performance impact on e.g. - // seq_iter_choose_from_1000. - if upper == Some(lower) { - return match lower { - 0 => None, - 1 => self.next(), - _ => self.nth(gen_index(rng, lower)), - }; - } - - let mut coin_flipper = coin_flipper::CoinFlipper::new(rng); - let mut consumed = 0; - - // Continue until the iterator is exhausted - loop { - if lower > 1 { - let ix = gen_index(coin_flipper.rng, lower + consumed); - let skip = if ix < lower { - result = self.nth(ix); - lower - (ix + 1) - } else { - lower - }; - if upper == Some(lower) { - return result; - } - consumed += lower; - if skip > 0 { - self.nth(skip - 1); - } - } else { - let elem = self.next(); - if elem.is_none() { - return result; - } - consumed += 1; - if coin_flipper.gen_ratio_one_over(consumed) { - result = elem; - } - } - - let hint = self.size_hint(); - lower = hint.0; - upper = hint.1; - } - } - - /// Uniformly sample one element (stable) - /// - /// This method is very similar to [`choose`] except that the result - /// only depends on the length of the iterator and the values produced by - /// `rng`. Notably for any iterator of a given length this will make the - /// same requests to `rng` and if the same sequence of values are produced - /// the same index will be selected from `self`. This may be useful if you - /// need consistent results no matter what type of iterator you are working - /// with. If you do not need this stability prefer [`choose`]. - /// - /// Note that this method still uses [`Iterator::size_hint`] to skip - /// constructing elements where possible, however the selection and `rng` - /// calls are the same in the face of this optimization. If you want to - /// force every element to be created regardless call `.inspect(|e| ())`. - /// - /// [`choose`]: IteratorRandom::choose - fn choose_stable(mut self, rng: &mut R) -> Option - where - R: Rng + ?Sized, - { - let mut consumed = 0; - let mut result = None; - let mut coin_flipper = CoinFlipper::new(rng); - - loop { - // Currently the only way to skip elements is `nth()`. So we need to - // store what index to access next here. - // This should be replaced by `advance_by()` once it is stable: - // https://github.com/rust-lang/rust/issues/77404 - let mut next = 0; - - let (lower, _) = self.size_hint(); - if lower >= 2 { - let highest_selected = (0..lower) - .filter(|ix| coin_flipper.gen_ratio_one_over(consumed + ix + 1)) - .last(); + #[doc(inline)] + pub use super::index_::*; - consumed += lower; - next = lower; - - if let Some(ix) = highest_selected { - result = self.nth(ix); - next -= ix + 1; - debug_assert!(result.is_some(), "iterator shorter than size_hint().0"); - } - } - - let elem = self.nth(next); - if elem.is_none() { - return result; - } - - if coin_flipper.gen_ratio_one_over(consumed + 1) { - result = elem; - } - consumed += 1; - } - } - - /// Uniformly sample `amount` distinct elements into a buffer - /// - /// Collects values at random from the iterator into a supplied buffer - /// until that buffer is filled. + /// Randomly sample exactly `N` distinct indices from `0..len`, and + /// return them in random order (fully shuffled). /// - /// Although the elements are selected randomly, the order of elements in - /// the buffer is neither stable nor fully random. If random ordering is - /// desired, shuffle the result. + /// This is implemented via Floyd's algorithm. Time complexity is `O(N^2)` + /// and memory complexity is `O(N)`. /// - /// Returns the number of elements added to the buffer. This equals the length - /// of the buffer unless the iterator contains insufficient elements, in which - /// case this equals the number of elements available. - /// - /// Complexity is `O(n)` where `n` is the length of the iterator. - /// For slices, prefer [`IndexedRandom::choose_multiple`]. - fn choose_multiple_fill(mut self, rng: &mut R, buf: &mut [Self::Item]) -> usize + /// Returns `None` if (and only if) `N > len`. + pub fn sample_array(rng: &mut R, len: usize) -> Option<[usize; N]> where R: Rng + ?Sized, { - let amount = buf.len(); - let mut len = 0; - while len < amount { - if let Some(elem) = self.next() { - buf[len] = elem; - len += 1; - } else { - // Iterator exhausted; stop early - return len; - } + if N > len { + return None; } - // Continue, since the iterator was not exhausted - for (i, elem) in self.enumerate() { - let k = gen_index(rng, i + 1 + amount); - if let Some(slot) = buf.get_mut(k) { - *slot = elem; + // Floyd's algorithm + let mut indices = [0; N]; + for (i, j) in (len - N..len).enumerate() { + let t = rng.random_range(..j + 1); + if let Some(pos) = indices[0..i].iter().position(|&x| x == t) { + indices[pos] = j; } + indices[i] = t; } - len - } - - /// Uniformly sample `amount` distinct elements into a [`Vec`] - /// - /// This is equivalent to `choose_multiple_fill` except for the result type. - /// - /// Although the elements are selected randomly, the order of elements in - /// the buffer is neither stable nor fully random. If random ordering is - /// desired, shuffle the result. - /// - /// The length of the returned vector equals `amount` unless the iterator - /// contains insufficient elements, in which case it equals the number of - /// elements available. - /// - /// Complexity is `O(n)` where `n` is the length of the iterator. - /// For slices, prefer [`IndexedRandom::choose_multiple`]. - #[cfg(feature = "alloc")] - #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] - fn choose_multiple(mut self, rng: &mut R, amount: usize) -> Vec - where - R: Rng + ?Sized, - { - let mut reservoir = Vec::with_capacity(amount); - reservoir.extend(self.by_ref().take(amount)); - - // Continue unless the iterator was exhausted - // - // note: this prevents iterators that "restart" from causing problems. - // If the iterator stops once, then so do we. - if reservoir.len() == amount { - for (i, elem) in self.enumerate() { - let k = gen_index(rng, i + 1 + amount); - if let Some(slot) = reservoir.get_mut(k) { - *slot = elem; - } - } - } else { - // Don't hang onto extra memory. There is a corner case where - // `amount` was much less than `self.len()`. - reservoir.shrink_to_fit(); - } - reservoir - } -} - -impl IndexedRandom for [T] { - fn len(&self) -> usize { - self.len() - } -} - -impl + ?Sized> IndexedMutRandom for IR {} - -impl SliceRandom for [T] { - fn shuffle(&mut self, rng: &mut R) - where - R: Rng + ?Sized, - { - if self.len() <= 1 { - // There is no need to shuffle an empty or single element slice - return; - } - self.partial_shuffle(rng, self.len()); - } - - fn partial_shuffle( - &mut self, rng: &mut R, amount: usize, - ) -> (&mut [T], &mut [T]) - where - R: Rng + ?Sized, - { - let m = self.len().saturating_sub(amount); - - // The algorithm below is based on Durstenfeld's algorithm for the - // [Fisher–Yates shuffle](https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm) - // for an unbiased permutation. - // It ensures that the last `amount` elements of the slice - // are randomly selected from the whole slice. - - //`IncreasingUniform::next_index()` is faster than `gen_index` - //but only works for 32 bit integers - //So we must use the slow method if the slice is longer than that. - if self.len() < (u32::MAX as usize) { - let mut chooser = IncreasingUniform::new(rng, m as u32); - for i in m..self.len() { - let index = chooser.next_index(); - self.swap(i, index); - } - } else { - for i in m..self.len() { - let index = gen_index(rng, i + 1); - self.swap(i, index); - } - } - let r = self.split_at_mut(m); - (r.1, r.0) - } -} - -impl IteratorRandom for I where I: Iterator + Sized {} - -/// An iterator over multiple slice elements. -/// -/// This struct is created by -/// [`IndexedRandom::choose_multiple`](trait.IndexedRandom.html#tymethod.choose_multiple). -#[cfg(feature = "alloc")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] -#[derive(Debug)] -pub struct SliceChooseIter<'a, S: ?Sized + 'a, T: 'a> { - slice: &'a S, - _phantom: ::core::marker::PhantomData, - indices: index::IndexVecIntoIter, -} - -#[cfg(feature = "alloc")] -impl<'a, S: Index + ?Sized + 'a, T: 'a> Iterator for SliceChooseIter<'a, S, T> { - type Item = &'a T; - - fn next(&mut self) -> Option { - // TODO: investigate using SliceIndex::get_unchecked when stable - self.indices.next().map(|i| &self.slice[i]) - } - - fn size_hint(&self) -> (usize, Option) { - (self.indices.len(), Some(self.indices.len())) - } -} - -#[cfg(feature = "alloc")] -impl<'a, S: Index + ?Sized + 'a, T: 'a> ExactSizeIterator - for SliceChooseIter<'a, S, T> -{ - fn len(&self) -> usize { - self.indices.len() - } -} - -// Sample a number uniformly between 0 and `ubound`. Uses 32-bit sampling where -// possible, primarily in order to produce the same output on 32-bit and 64-bit -// platforms. -#[inline] -fn gen_index(rng: &mut R, ubound: usize) -> usize { - - if ubound <= (core::u32::MAX as usize) { - rng.gen_range(0..ubound as u32) as usize - } else { - rng.gen_range(0..ubound) - } -} - -#[cfg(test)] -mod test { - use super::*; - #[cfg(feature = "alloc")] - use crate::Rng; - #[cfg(all(feature = "alloc", not(feature = "std")))] - use alloc::vec::Vec; - - #[test] - fn test_slice_choose() { - let mut r = crate::test::rng(107); - let chars = [ - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', - ]; - let mut chosen = [0i32; 14]; - // The below all use a binomial distribution with n=1000, p=1/14. - // binocdf(40, 1000, 1/14) ~= 2e-5; 1-binocdf(106, ..) ~= 2e-5 - for _ in 0..1000 { - let picked = *chars.choose(&mut r).unwrap(); - chosen[(picked as usize) - ('a' as usize)] += 1; - } - for count in chosen.iter() { - assert!(40 < *count && *count < 106); - } - - chosen.iter_mut().for_each(|x| *x = 0); - for _ in 0..1000 { - *chosen.choose_mut(&mut r).unwrap() += 1; - } - for count in chosen.iter() { - assert!(40 < *count && *count < 106); - } - - let mut v: [isize; 0] = []; - assert_eq!(v.choose(&mut r), None); - assert_eq!(v.choose_mut(&mut r), None); - } - - #[test] - fn value_stability_slice() { - let mut r = crate::test::rng(413); - let chars = [ - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', - ]; - let mut nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; - - assert_eq!(chars.choose(&mut r), Some(&'l')); - assert_eq!(nums.choose_mut(&mut r), Some(&mut 3)); - - #[cfg(feature = "alloc")] - assert_eq!( - &chars - .choose_multiple(&mut r, 8) - .cloned() - .collect::>(), - &['f', 'i', 'd', 'b', 'c', 'm', 'j', 'k'] - ); - - #[cfg(feature = "alloc")] - assert_eq!(chars.choose_weighted(&mut r, |_| 1), Ok(&'l')); - #[cfg(feature = "alloc")] - assert_eq!(nums.choose_weighted_mut(&mut r, |_| 1), Ok(&mut 8)); - - let mut r = crate::test::rng(414); - nums.shuffle(&mut r); - assert_eq!(nums, [5, 11, 0, 8, 7, 12, 6, 4, 9, 3, 1, 2, 10]); - nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; - let res = nums.partial_shuffle(&mut r, 6); - assert_eq!(res.0, &mut [7, 12, 6, 8, 1, 9]); - assert_eq!(res.1, &mut [0, 11, 2, 3, 4, 5, 10]); - } - - #[derive(Clone)] - struct UnhintedIterator { - iter: I, - } - impl Iterator for UnhintedIterator { - type Item = I::Item; - - fn next(&mut self) -> Option { - self.iter.next() - } - } - - #[derive(Clone)] - struct ChunkHintedIterator { - iter: I, - chunk_remaining: usize, - chunk_size: usize, - hint_total_size: bool, - } - impl Iterator for ChunkHintedIterator { - type Item = I::Item; - - fn next(&mut self) -> Option { - if self.chunk_remaining == 0 { - self.chunk_remaining = ::core::cmp::min(self.chunk_size, self.iter.len()); - } - self.chunk_remaining = self.chunk_remaining.saturating_sub(1); - - self.iter.next() - } - - fn size_hint(&self) -> (usize, Option) { - ( - self.chunk_remaining, - if self.hint_total_size { - Some(self.iter.len()) - } else { - None - }, - ) - } - } - - #[derive(Clone)] - struct WindowHintedIterator { - iter: I, - window_size: usize, - hint_total_size: bool, - } - impl Iterator for WindowHintedIterator { - type Item = I::Item; - - fn next(&mut self) -> Option { - self.iter.next() - } - - fn size_hint(&self) -> (usize, Option) { - ( - ::core::cmp::min(self.iter.len(), self.window_size), - if self.hint_total_size { - Some(self.iter.len()) - } else { - None - }, - ) - } - } - - #[test] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_iterator_choose() { - let r = &mut crate::test::rng(109); - fn test_iter + Clone>(r: &mut R, iter: Iter) { - let mut chosen = [0i32; 9]; - for _ in 0..1000 { - let picked = iter.clone().choose(r).unwrap(); - chosen[picked] += 1; - } - for count in chosen.iter() { - // Samples should follow Binomial(1000, 1/9) - // Octave: binopdf(x, 1000, 1/9) gives the prob of *count == x - // Note: have seen 153, which is unlikely but not impossible. - assert!( - 72 < *count && *count < 154, - "count not close to 1000/9: {}", - count - ); - } - } - - test_iter(r, 0..9); - test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()); - #[cfg(feature = "alloc")] - test_iter(r, (0..9).collect::>().into_iter()); - test_iter(r, UnhintedIterator { iter: 0..9 }); - test_iter( - r, - ChunkHintedIterator { - iter: 0..9, - chunk_size: 4, - chunk_remaining: 4, - hint_total_size: false, - }, - ); - test_iter( - r, - ChunkHintedIterator { - iter: 0..9, - chunk_size: 4, - chunk_remaining: 4, - hint_total_size: true, - }, - ); - test_iter( - r, - WindowHintedIterator { - iter: 0..9, - window_size: 2, - hint_total_size: false, - }, - ); - test_iter( - r, - WindowHintedIterator { - iter: 0..9, - window_size: 2, - hint_total_size: true, - }, - ); - - assert_eq!((0..0).choose(r), None); - assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None); - } - - #[test] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_iterator_choose_stable() { - let r = &mut crate::test::rng(109); - fn test_iter + Clone>(r: &mut R, iter: Iter) { - let mut chosen = [0i32; 9]; - for _ in 0..1000 { - let picked = iter.clone().choose_stable(r).unwrap(); - chosen[picked] += 1; - } - for count in chosen.iter() { - // Samples should follow Binomial(1000, 1/9) - // Octave: binopdf(x, 1000, 1/9) gives the prob of *count == x - // Note: have seen 153, which is unlikely but not impossible. - assert!( - 72 < *count && *count < 154, - "count not close to 1000/9: {}", - count - ); - } - } - - test_iter(r, 0..9); - test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()); - #[cfg(feature = "alloc")] - test_iter(r, (0..9).collect::>().into_iter()); - test_iter(r, UnhintedIterator { iter: 0..9 }); - test_iter( - r, - ChunkHintedIterator { - iter: 0..9, - chunk_size: 4, - chunk_remaining: 4, - hint_total_size: false, - }, - ); - test_iter( - r, - ChunkHintedIterator { - iter: 0..9, - chunk_size: 4, - chunk_remaining: 4, - hint_total_size: true, - }, - ); - test_iter( - r, - WindowHintedIterator { - iter: 0..9, - window_size: 2, - hint_total_size: false, - }, - ); - test_iter( - r, - WindowHintedIterator { - iter: 0..9, - window_size: 2, - hint_total_size: true, - }, - ); - - assert_eq!((0..0).choose(r), None); - assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None); - } - - #[test] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_iterator_choose_stable_stability() { - fn test_iter(iter: impl Iterator + Clone) -> [i32; 9] { - let r = &mut crate::test::rng(109); - let mut chosen = [0i32; 9]; - for _ in 0..1000 { - let picked = iter.clone().choose_stable(r).unwrap(); - chosen[picked] += 1; - } - chosen - } - - let reference = test_iter(0..9); - assert_eq!( - test_iter([0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()), - reference - ); - - #[cfg(feature = "alloc")] - assert_eq!(test_iter((0..9).collect::>().into_iter()), reference); - assert_eq!(test_iter(UnhintedIterator { iter: 0..9 }), reference); - assert_eq!( - test_iter(ChunkHintedIterator { - iter: 0..9, - chunk_size: 4, - chunk_remaining: 4, - hint_total_size: false, - }), - reference - ); - assert_eq!( - test_iter(ChunkHintedIterator { - iter: 0..9, - chunk_size: 4, - chunk_remaining: 4, - hint_total_size: true, - }), - reference - ); - assert_eq!( - test_iter(WindowHintedIterator { - iter: 0..9, - window_size: 2, - hint_total_size: false, - }), - reference - ); - assert_eq!( - test_iter(WindowHintedIterator { - iter: 0..9, - window_size: 2, - hint_total_size: true, - }), - reference - ); - } - - #[test] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_shuffle() { - let mut r = crate::test::rng(108); - let empty: &mut [isize] = &mut []; - empty.shuffle(&mut r); - let mut one = [1]; - one.shuffle(&mut r); - let b: &[_] = &[1]; - assert_eq!(one, b); - - let mut two = [1, 2]; - two.shuffle(&mut r); - assert!(two == [1, 2] || two == [2, 1]); - - fn move_last(slice: &mut [usize], pos: usize) { - // use slice[pos..].rotate_left(1); once we can use that - let last_val = slice[pos]; - for i in pos..slice.len() - 1 { - slice[i] = slice[i + 1]; - } - *slice.last_mut().unwrap() = last_val; - } - let mut counts = [0i32; 24]; - for _ in 0..10000 { - let mut arr: [usize; 4] = [0, 1, 2, 3]; - arr.shuffle(&mut r); - let mut permutation = 0usize; - let mut pos_value = counts.len(); - for i in 0..4 { - pos_value /= 4 - i; - let pos = arr.iter().position(|&x| x == i).unwrap(); - assert!(pos < (4 - i)); - permutation += pos * pos_value; - move_last(&mut arr, pos); - assert_eq!(arr[3], i); - } - for (i, &a) in arr.iter().enumerate() { - assert_eq!(a, i); - } - counts[permutation] += 1; - } - for count in counts.iter() { - // Binomial(10000, 1/24) with average 416.667 - // Octave: binocdf(n, 10000, 1/24) - // 99.9% chance samples lie within this range: - assert!(352 <= *count && *count <= 483, "count: {}", count); - } - } - - #[test] - fn test_partial_shuffle() { - let mut r = crate::test::rng(118); - - let mut empty: [u32; 0] = []; - let res = empty.partial_shuffle(&mut r, 10); - assert_eq!((res.0.len(), res.1.len()), (0, 0)); - - let mut v = [1, 2, 3, 4, 5]; - let res = v.partial_shuffle(&mut r, 2); - assert_eq!((res.0.len(), res.1.len()), (2, 3)); - assert!(res.0[0] != res.0[1]); - // First elements are only modified if selected, so at least one isn't modified: - assert!(res.1[0] == 1 || res.1[1] == 2 || res.1[2] == 3); - } - - #[test] - #[cfg(feature = "alloc")] - fn test_sample_iter() { - let min_val = 1; - let max_val = 100; - - let mut r = crate::test::rng(401); - let vals = (min_val..max_val).collect::>(); - let small_sample = vals.iter().choose_multiple(&mut r, 5); - let large_sample = vals.iter().choose_multiple(&mut r, vals.len() + 5); - - assert_eq!(small_sample.len(), 5); - assert_eq!(large_sample.len(), vals.len()); - // no randomization happens when amount >= len - assert_eq!(large_sample, vals.iter().collect::>()); - - assert!(small_sample - .iter() - .all(|e| { **e >= min_val && **e <= max_val })); - } - - #[test] - #[cfg(feature = "alloc")] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_weighted() { - let mut r = crate::test::rng(406); - const N_REPS: u32 = 3000; - let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; - let total_weight = weights.iter().sum::() as f32; - - let verify = |result: [i32; 14]| { - for (i, count) in result.iter().enumerate() { - let exp = (weights[i] * N_REPS) as f32 / total_weight; - let mut err = (*count as f32 - exp).abs(); - if err != 0.0 { - err /= exp; - } - assert!(err <= 0.25); - } - }; - - // choose_weighted - fn get_weight(item: &(u32, T)) -> u32 { - item.0 - } - let mut chosen = [0i32; 14]; - let mut items = [(0u32, 0usize); 14]; // (weight, index) - for (i, item) in items.iter_mut().enumerate() { - *item = (weights[i], i); - } - for _ in 0..N_REPS { - let item = items.choose_weighted(&mut r, get_weight).unwrap(); - chosen[item.1] += 1; - } - verify(chosen); - - // choose_weighted_mut - let mut items = [(0u32, 0i32); 14]; // (weight, count) - for (i, item) in items.iter_mut().enumerate() { - *item = (weights[i], 0); - } - for _ in 0..N_REPS { - items.choose_weighted_mut(&mut r, get_weight).unwrap().1 += 1; - } - for (ch, item) in chosen.iter_mut().zip(items.iter()) { - *ch = item.1; - } - verify(chosen); - - // Check error cases - let empty_slice = &mut [10][0..0]; - assert_eq!( - empty_slice.choose_weighted(&mut r, |_| 1), - Err(WeightError::InvalidInput) - ); - assert_eq!( - empty_slice.choose_weighted_mut(&mut r, |_| 1), - Err(WeightError::InvalidInput) - ); - assert_eq!( - ['x'].choose_weighted_mut(&mut r, |_| 0), - Err(WeightError::InsufficientNonZero) - ); - assert_eq!( - [0, -1].choose_weighted_mut(&mut r, |x| *x), - Err(WeightError::InvalidWeight) - ); - assert_eq!( - [-1, 0].choose_weighted_mut(&mut r, |x| *x), - Err(WeightError::InvalidWeight) - ); - } - - #[test] - fn value_stability_choose() { - fn choose>(iter: I) -> Option { - let mut rng = crate::test::rng(411); - iter.choose(&mut rng) - } - - assert_eq!(choose([].iter().cloned()), None); - assert_eq!(choose(0..100), Some(33)); - assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(27)); - assert_eq!( - choose(ChunkHintedIterator { - iter: 0..100, - chunk_size: 32, - chunk_remaining: 32, - hint_total_size: false, - }), - Some(91) - ); - assert_eq!( - choose(ChunkHintedIterator { - iter: 0..100, - chunk_size: 32, - chunk_remaining: 32, - hint_total_size: true, - }), - Some(91) - ); - assert_eq!( - choose(WindowHintedIterator { - iter: 0..100, - window_size: 32, - hint_total_size: false, - }), - Some(34) - ); - assert_eq!( - choose(WindowHintedIterator { - iter: 0..100, - window_size: 32, - hint_total_size: true, - }), - Some(34) - ); - } - - #[test] - fn value_stability_choose_stable() { - fn choose>(iter: I) -> Option { - let mut rng = crate::test::rng(411); - iter.choose_stable(&mut rng) - } - - assert_eq!(choose([].iter().cloned()), None); - assert_eq!(choose(0..100), Some(27)); - assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(27)); - assert_eq!( - choose(ChunkHintedIterator { - iter: 0..100, - chunk_size: 32, - chunk_remaining: 32, - hint_total_size: false, - }), - Some(27) - ); - assert_eq!( - choose(ChunkHintedIterator { - iter: 0..100, - chunk_size: 32, - chunk_remaining: 32, - hint_total_size: true, - }), - Some(27) - ); - assert_eq!( - choose(WindowHintedIterator { - iter: 0..100, - window_size: 32, - hint_total_size: false, - }), - Some(27) - ); - assert_eq!( - choose(WindowHintedIterator { - iter: 0..100, - window_size: 32, - hint_total_size: true, - }), - Some(27) - ); - } - - #[test] - fn value_stability_choose_multiple() { - fn do_test>(iter: I, v: &[u32]) { - let mut rng = crate::test::rng(412); - let mut buf = [0u32; 8]; - assert_eq!(iter.clone().choose_multiple_fill(&mut rng, &mut buf), v.len()); - assert_eq!(&buf[0..v.len()], v); - - #[cfg(feature = "alloc")] - { - let mut rng = crate::test::rng(412); - assert_eq!(iter.choose_multiple(&mut rng, v.len()), v); - } - } - - do_test(0..4, &[0, 1, 2, 3]); - do_test(0..8, &[0, 1, 2, 3, 4, 5, 6, 7]); - do_test(0..100, &[77, 95, 38, 23, 25, 8, 58, 40]); - } - - #[test] - #[cfg(feature = "std")] - fn test_multiple_weighted_edge_cases() { - use super::*; - - let mut rng = crate::test::rng(413); - - // Case 1: One of the weights is 0 - let choices = [('a', 2), ('b', 1), ('c', 0)]; - for _ in 0..100 { - let result = choices - .choose_multiple_weighted(&mut rng, 2, |item| item.1) - .unwrap() - .collect::>(); - - assert_eq!(result.len(), 2); - assert!(!result.iter().any(|val| val.0 == 'c')); - } - - // Case 2: All of the weights are 0 - let choices = [('a', 0), ('b', 0), ('c', 0)]; - let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1); - assert_eq!(r.unwrap_err(), WeightError::InsufficientNonZero); - - // Case 3: Negative weights - let choices = [('a', -1), ('b', 1), ('c', 1)]; - let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1); - assert_eq!(r.unwrap_err(), WeightError::InvalidWeight); - - // Case 4: Empty list - let choices = []; - let r = choices.choose_multiple_weighted(&mut rng, 0, |_: &()| 0); - assert_eq!(r.unwrap().count(), 0); - - // Case 5: NaN weights - let choices = [('a', core::f64::NAN), ('b', 1.0), ('c', 1.0)]; - let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1); - assert_eq!(r.unwrap_err(), WeightError::InvalidWeight); - - // Case 6: +infinity weights - let choices = [('a', core::f64::INFINITY), ('b', 1.0), ('c', 1.0)]; - for _ in 0..100 { - let result = choices - .choose_multiple_weighted(&mut rng, 2, |item| item.1) - .unwrap() - .collect::>(); - assert_eq!(result.len(), 2); - assert!(result.iter().any(|val| val.0 == 'a')); - } - - // Case 7: -infinity weights - let choices = [('a', core::f64::NEG_INFINITY), ('b', 1.0), ('c', 1.0)]; - let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1); - assert_eq!(r.unwrap_err(), WeightError::InvalidWeight); - - // Case 8: -0 weights - let choices = [('a', -0.0), ('b', 1.0), ('c', 1.0)]; - let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1); - assert!(r.is_ok()); - } - - #[test] - #[cfg(feature = "std")] - fn test_multiple_weighted_distributions() { - use super::*; - - // The theoretical probabilities of the different outcomes are: - // AB: 0.5 * 0.5 = 0.250 - // AC: 0.5 * 0.5 = 0.250 - // BA: 0.25 * 0.67 = 0.167 - // BC: 0.25 * 0.33 = 0.082 - // CA: 0.25 * 0.67 = 0.167 - // CB: 0.25 * 0.33 = 0.082 - let choices = [('a', 2), ('b', 1), ('c', 1)]; - let mut rng = crate::test::rng(414); - - let mut results = [0i32; 3]; - let expected_results = [4167, 4167, 1666]; - for _ in 0..10000 { - let result = choices - .choose_multiple_weighted(&mut rng, 2, |item| item.1) - .unwrap() - .collect::>(); - - assert_eq!(result.len(), 2); - - match (result[0].0, result[1].0) { - ('a', 'b') | ('b', 'a') => { - results[0] += 1; - } - ('a', 'c') | ('c', 'a') => { - results[1] += 1; - } - ('b', 'c') | ('c', 'b') => { - results[2] += 1; - } - (_, _) => panic!("unexpected result"), - } - } - - let mut diffs = results - .iter() - .zip(&expected_results) - .map(|(a, b)| (a - b).abs()); - assert!(!diffs.any(|deviation| deviation > 100)); + Some(indices) } } diff --git a/src/seq/slice.rs b/src/seq/slice.rs new file mode 100644 index 00000000000..f909418bc48 --- /dev/null +++ b/src/seq/slice.rs @@ -0,0 +1,766 @@ +// Copyright 2018-2023 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! `IndexedRandom`, `IndexedMutRandom`, `SliceRandom` + +use super::increasing_uniform::IncreasingUniform; +use super::index; +#[cfg(feature = "alloc")] +use crate::distr::uniform::{SampleBorrow, SampleUniform}; +#[cfg(feature = "alloc")] +use crate::distr::weighted::{Error as WeightError, Weight}; +use crate::Rng; +use core::ops::{Index, IndexMut}; + +/// Extension trait on indexable lists, providing random sampling methods. +/// +/// This trait is implemented on `[T]` slice types. Other types supporting +/// [`std::ops::Index`] may implement this (only [`Self::len`] must be +/// specified). +pub trait IndexedRandom: Index { + /// The length + fn len(&self) -> usize; + + /// True when the length is zero + #[inline] + fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Uniformly sample one element + /// + /// Returns a reference to one uniformly-sampled random element of + /// the slice, or `None` if the slice is empty. + /// + /// For slices, complexity is `O(1)`. + /// + /// # Example + /// + /// ``` + /// use rand::seq::IndexedRandom; + /// + /// let choices = [1, 2, 4, 8, 16, 32]; + /// let mut rng = rand::rng(); + /// println!("{:?}", choices.choose(&mut rng)); + /// assert_eq!(choices[..0].choose(&mut rng), None); + /// ``` + fn choose(&self, rng: &mut R) -> Option<&Self::Output> + where + R: Rng + ?Sized, + { + if self.is_empty() { + None + } else { + Some(&self[rng.random_range(..self.len())]) + } + } + + /// Uniformly sample `amount` distinct elements from self + /// + /// Chooses `amount` elements from the slice at random, without repetition, + /// and in random order. The returned iterator is appropriate both for + /// collection into a `Vec` and filling an existing buffer (see example). + /// + /// In case this API is not sufficiently flexible, use [`index::sample`]. + /// + /// For slices, complexity is the same as [`index::sample`]. + /// + /// # Example + /// ``` + /// use rand::seq::IndexedRandom; + /// + /// let mut rng = &mut rand::rng(); + /// let sample = "Hello, audience!".as_bytes(); + /// + /// // collect the results into a vector: + /// let v: Vec = sample.choose_multiple(&mut rng, 3).cloned().collect(); + /// + /// // store in a buffer: + /// let mut buf = [0u8; 5]; + /// for (b, slot) in sample.choose_multiple(&mut rng, buf.len()).zip(buf.iter_mut()) { + /// *slot = *b; + /// } + /// ``` + #[cfg(feature = "alloc")] + fn choose_multiple(&self, rng: &mut R, amount: usize) -> SliceChooseIter + where + Self::Output: Sized, + R: Rng + ?Sized, + { + let amount = core::cmp::min(amount, self.len()); + SliceChooseIter { + slice: self, + _phantom: Default::default(), + indices: index::sample(rng, self.len(), amount).into_iter(), + } + } + + /// Uniformly sample a fixed-size array of distinct elements from self + /// + /// Chooses `N` elements from the slice at random, without repetition, + /// and in random order. + /// + /// For slices, complexity is the same as [`index::sample_array`]. + /// + /// # Example + /// ``` + /// use rand::seq::IndexedRandom; + /// + /// let mut rng = &mut rand::rng(); + /// let sample = "Hello, audience!".as_bytes(); + /// + /// let a: [u8; 3] = sample.choose_multiple_array(&mut rng).unwrap(); + /// ``` + fn choose_multiple_array(&self, rng: &mut R) -> Option<[Self::Output; N]> + where + Self::Output: Clone + Sized, + R: Rng + ?Sized, + { + let indices = index::sample_array(rng, self.len())?; + Some(indices.map(|index| self[index].clone())) + } + + /// Biased sampling for one element + /// + /// Returns a reference to one element of the slice, sampled according + /// to the provided weights. Returns `None` only if the slice is empty. + /// + /// The specified function `weight` maps each item `x` to a relative + /// likelihood `weight(x)`. The probability of each item being selected is + /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`. + /// + /// For slices of length `n`, complexity is `O(n)`. + /// For more information about the underlying algorithm, + /// see the [`WeightedIndex`] distribution. + /// + /// See also [`choose_weighted_mut`]. + /// + /// # Example + /// + /// ``` + /// use rand::prelude::*; + /// + /// let choices = [('a', 2), ('b', 1), ('c', 1), ('d', 0)]; + /// let mut rng = rand::rng(); + /// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c', + /// // and 'd' will never be printed + /// println!("{:?}", choices.choose_weighted(&mut rng, |item| item.1).unwrap().0); + /// ``` + /// [`choose`]: IndexedRandom::choose + /// [`choose_weighted_mut`]: IndexedMutRandom::choose_weighted_mut + /// [`WeightedIndex`]: crate::distr::weighted::WeightedIndex + #[cfg(feature = "alloc")] + fn choose_weighted( + &self, + rng: &mut R, + weight: F, + ) -> Result<&Self::Output, WeightError> + where + R: Rng + ?Sized, + F: Fn(&Self::Output) -> B, + B: SampleBorrow, + X: SampleUniform + Weight + PartialOrd, + { + use crate::distr::{weighted::WeightedIndex, Distribution}; + let distr = WeightedIndex::new((0..self.len()).map(|idx| weight(&self[idx])))?; + Ok(&self[distr.sample(rng)]) + } + + /// Biased sampling of `amount` distinct elements + /// + /// Similar to [`choose_multiple`], but where the likelihood of each + /// element's inclusion in the output may be specified. Zero-weighted + /// elements are never returned; the result may therefore contain fewer + /// elements than `amount` even when `self.len() >= amount`. The elements + /// are returned in an arbitrary, unspecified order. + /// + /// The specified function `weight` maps each item `x` to a relative + /// likelihood `weight(x)`. The probability of each item being selected is + /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`. + /// + /// This implementation uses `O(length + amount)` space and `O(length)` time. + /// See [`index::sample_weighted`] for details. + /// + /// # Example + /// + /// ``` + /// use rand::prelude::*; + /// + /// let choices = [('a', 2), ('b', 1), ('c', 1)]; + /// let mut rng = rand::rng(); + /// // First Draw * Second Draw = total odds + /// // ----------------------- + /// // (50% * 50%) + (25% * 67%) = 41.7% chance that the output is `['a', 'b']` in some order. + /// // (50% * 50%) + (25% * 67%) = 41.7% chance that the output is `['a', 'c']` in some order. + /// // (25% * 33%) + (25% * 33%) = 16.6% chance that the output is `['b', 'c']` in some order. + /// println!("{:?}", choices.choose_multiple_weighted(&mut rng, 2, |item| item.1).unwrap().collect::>()); + /// ``` + /// [`choose_multiple`]: IndexedRandom::choose_multiple + // Note: this is feature-gated on std due to usage of f64::powf. + // If necessary, we may use alloc+libm as an alternative (see PR #1089). + #[cfg(feature = "std")] + fn choose_multiple_weighted( + &self, + rng: &mut R, + amount: usize, + weight: F, + ) -> Result, WeightError> + where + Self::Output: Sized, + R: Rng + ?Sized, + F: Fn(&Self::Output) -> X, + X: Into, + { + let amount = core::cmp::min(amount, self.len()); + Ok(SliceChooseIter { + slice: self, + _phantom: Default::default(), + indices: index::sample_weighted( + rng, + self.len(), + |idx| weight(&self[idx]).into(), + amount, + )? + .into_iter(), + }) + } +} + +/// Extension trait on indexable lists, providing random sampling methods. +/// +/// This trait is implemented automatically for every type implementing +/// [`IndexedRandom`] and [`std::ops::IndexMut`]. +pub trait IndexedMutRandom: IndexedRandom + IndexMut { + /// Uniformly sample one element (mut) + /// + /// Returns a mutable reference to one uniformly-sampled random element of + /// the slice, or `None` if the slice is empty. + /// + /// For slices, complexity is `O(1)`. + fn choose_mut(&mut self, rng: &mut R) -> Option<&mut Self::Output> + where + R: Rng + ?Sized, + { + if self.is_empty() { + None + } else { + let len = self.len(); + Some(&mut self[rng.random_range(..len)]) + } + } + + /// Biased sampling for one element (mut) + /// + /// Returns a mutable reference to one element of the slice, sampled according + /// to the provided weights. Returns `None` only if the slice is empty. + /// + /// The specified function `weight` maps each item `x` to a relative + /// likelihood `weight(x)`. The probability of each item being selected is + /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`. + /// + /// For slices of length `n`, complexity is `O(n)`. + /// For more information about the underlying algorithm, + /// see the [`WeightedIndex`] distribution. + /// + /// See also [`choose_weighted`]. + /// + /// [`choose_mut`]: IndexedMutRandom::choose_mut + /// [`choose_weighted`]: IndexedRandom::choose_weighted + /// [`WeightedIndex`]: crate::distr::weighted::WeightedIndex + #[cfg(feature = "alloc")] + fn choose_weighted_mut( + &mut self, + rng: &mut R, + weight: F, + ) -> Result<&mut Self::Output, WeightError> + where + R: Rng + ?Sized, + F: Fn(&Self::Output) -> B, + B: SampleBorrow, + X: SampleUniform + Weight + PartialOrd, + { + use crate::distr::{weighted::WeightedIndex, Distribution}; + let distr = WeightedIndex::new((0..self.len()).map(|idx| weight(&self[idx])))?; + let index = distr.sample(rng); + Ok(&mut self[index]) + } +} + +/// Extension trait on slices, providing shuffling methods. +/// +/// This trait is implemented on all `[T]` slice types, providing several +/// methods for choosing and shuffling elements. You must `use` this trait: +/// +/// ``` +/// use rand::seq::SliceRandom; +/// +/// let mut rng = rand::rng(); +/// let mut bytes = "Hello, random!".to_string().into_bytes(); +/// bytes.shuffle(&mut rng); +/// let str = String::from_utf8(bytes).unwrap(); +/// println!("{}", str); +/// ``` +/// Example output (non-deterministic): +/// ```none +/// l,nmroHado !le +/// ``` +pub trait SliceRandom: IndexedMutRandom { + /// Shuffle a mutable slice in place. + /// + /// For slices of length `n`, complexity is `O(n)`. + /// The resulting permutation is picked uniformly from the set of all possible permutations. + /// + /// # Example + /// + /// ``` + /// use rand::seq::SliceRandom; + /// + /// let mut rng = rand::rng(); + /// let mut y = [1, 2, 3, 4, 5]; + /// println!("Unshuffled: {:?}", y); + /// y.shuffle(&mut rng); + /// println!("Shuffled: {:?}", y); + /// ``` + fn shuffle(&mut self, rng: &mut R) + where + R: Rng + ?Sized; + + /// Shuffle a slice in place, but exit early. + /// + /// Returns two mutable slices from the source slice. The first contains + /// `amount` elements randomly permuted. The second has the remaining + /// elements that are not fully shuffled. + /// + /// This is an efficient method to select `amount` elements at random from + /// the slice, provided the slice may be mutated. + /// + /// If you only need to choose elements randomly and `amount > self.len()/2` + /// then you may improve performance by taking + /// `amount = self.len() - amount` and using only the second slice. + /// + /// If `amount` is greater than the number of elements in the slice, this + /// will perform a full shuffle. + /// + /// For slices, complexity is `O(m)` where `m = amount`. + fn partial_shuffle( + &mut self, + rng: &mut R, + amount: usize, + ) -> (&mut [Self::Output], &mut [Self::Output]) + where + Self::Output: Sized, + R: Rng + ?Sized; +} + +impl IndexedRandom for [T] { + fn len(&self) -> usize { + self.len() + } +} + +impl + ?Sized> IndexedMutRandom for IR {} + +impl SliceRandom for [T] { + fn shuffle(&mut self, rng: &mut R) + where + R: Rng + ?Sized, + { + if self.len() <= 1 { + // There is no need to shuffle an empty or single element slice + return; + } + self.partial_shuffle(rng, self.len()); + } + + fn partial_shuffle(&mut self, rng: &mut R, amount: usize) -> (&mut [T], &mut [T]) + where + R: Rng + ?Sized, + { + let m = self.len().saturating_sub(amount); + + // The algorithm below is based on Durstenfeld's algorithm for the + // [Fisher–Yates shuffle](https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm) + // for an unbiased permutation. + // It ensures that the last `amount` elements of the slice + // are randomly selected from the whole slice. + + // `IncreasingUniform::next_index()` is faster than `Rng::random_range` + // but only works for 32 bit integers + // So we must use the slow method if the slice is longer than that. + if self.len() < (u32::MAX as usize) { + let mut chooser = IncreasingUniform::new(rng, m as u32); + for i in m..self.len() { + let index = chooser.next_index(); + self.swap(i, index); + } + } else { + for i in m..self.len() { + let index = rng.random_range(..i + 1); + self.swap(i, index); + } + } + let r = self.split_at_mut(m); + (r.1, r.0) + } +} + +/// An iterator over multiple slice elements. +/// +/// This struct is created by +/// [`IndexedRandom::choose_multiple`](trait.IndexedRandom.html#tymethod.choose_multiple). +#[cfg(feature = "alloc")] +#[derive(Debug)] +pub struct SliceChooseIter<'a, S: ?Sized + 'a, T: 'a> { + slice: &'a S, + _phantom: core::marker::PhantomData, + indices: index::IndexVecIntoIter, +} + +#[cfg(feature = "alloc")] +impl<'a, S: Index + ?Sized + 'a, T: 'a> Iterator for SliceChooseIter<'a, S, T> { + type Item = &'a T; + + fn next(&mut self) -> Option { + // TODO: investigate using SliceIndex::get_unchecked when stable + self.indices.next().map(|i| &self.slice[i]) + } + + fn size_hint(&self) -> (usize, Option) { + (self.indices.len(), Some(self.indices.len())) + } +} + +#[cfg(feature = "alloc")] +impl<'a, S: Index + ?Sized + 'a, T: 'a> ExactSizeIterator + for SliceChooseIter<'a, S, T> +{ + fn len(&self) -> usize { + self.indices.len() + } +} + +#[cfg(test)] +mod test { + use super::*; + #[cfg(feature = "alloc")] + use alloc::vec::Vec; + + #[test] + fn test_slice_choose() { + let mut r = crate::test::rng(107); + let chars = [ + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', + ]; + let mut chosen = [0i32; 14]; + // The below all use a binomial distribution with n=1000, p=1/14. + // binocdf(40, 1000, 1/14) ~= 2e-5; 1-binocdf(106, ..) ~= 2e-5 + for _ in 0..1000 { + let picked = *chars.choose(&mut r).unwrap(); + chosen[(picked as usize) - ('a' as usize)] += 1; + } + for count in chosen.iter() { + assert!(40 < *count && *count < 106); + } + + chosen.iter_mut().for_each(|x| *x = 0); + for _ in 0..1000 { + *chosen.choose_mut(&mut r).unwrap() += 1; + } + for count in chosen.iter() { + assert!(40 < *count && *count < 106); + } + + let mut v: [isize; 0] = []; + assert_eq!(v.choose(&mut r), None); + assert_eq!(v.choose_mut(&mut r), None); + } + + #[test] + fn value_stability_slice() { + let mut r = crate::test::rng(413); + let chars = [ + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', + ]; + let mut nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; + + assert_eq!(chars.choose(&mut r), Some(&'l')); + assert_eq!(nums.choose_mut(&mut r), Some(&mut 3)); + + assert_eq!( + &chars.choose_multiple_array(&mut r), + &Some(['f', 'i', 'd', 'b', 'c', 'm', 'j', 'k']) + ); + + #[cfg(feature = "alloc")] + assert_eq!( + &chars + .choose_multiple(&mut r, 8) + .cloned() + .collect::>(), + &['h', 'm', 'd', 'b', 'c', 'e', 'n', 'f'] + ); + + #[cfg(feature = "alloc")] + assert_eq!(chars.choose_weighted(&mut r, |_| 1), Ok(&'i')); + #[cfg(feature = "alloc")] + assert_eq!(nums.choose_weighted_mut(&mut r, |_| 1), Ok(&mut 2)); + + let mut r = crate::test::rng(414); + nums.shuffle(&mut r); + assert_eq!(nums, [5, 11, 0, 8, 7, 12, 6, 4, 9, 3, 1, 2, 10]); + nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; + let res = nums.partial_shuffle(&mut r, 6); + assert_eq!(res.0, &mut [7, 12, 6, 8, 1, 9]); + assert_eq!(res.1, &mut [0, 11, 2, 3, 4, 5, 10]); + } + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_shuffle() { + let mut r = crate::test::rng(108); + let empty: &mut [isize] = &mut []; + empty.shuffle(&mut r); + let mut one = [1]; + one.shuffle(&mut r); + let b: &[_] = &[1]; + assert_eq!(one, b); + + let mut two = [1, 2]; + two.shuffle(&mut r); + assert!(two == [1, 2] || two == [2, 1]); + + fn move_last(slice: &mut [usize], pos: usize) { + // use slice[pos..].rotate_left(1); once we can use that + let last_val = slice[pos]; + for i in pos..slice.len() - 1 { + slice[i] = slice[i + 1]; + } + *slice.last_mut().unwrap() = last_val; + } + let mut counts = [0i32; 24]; + for _ in 0..10000 { + let mut arr: [usize; 4] = [0, 1, 2, 3]; + arr.shuffle(&mut r); + let mut permutation = 0usize; + let mut pos_value = counts.len(); + for i in 0..4 { + pos_value /= 4 - i; + let pos = arr.iter().position(|&x| x == i).unwrap(); + assert!(pos < (4 - i)); + permutation += pos * pos_value; + move_last(&mut arr, pos); + assert_eq!(arr[3], i); + } + for (i, &a) in arr.iter().enumerate() { + assert_eq!(a, i); + } + counts[permutation] += 1; + } + for count in counts.iter() { + // Binomial(10000, 1/24) with average 416.667 + // Octave: binocdf(n, 10000, 1/24) + // 99.9% chance samples lie within this range: + assert!(352 <= *count && *count <= 483, "count: {}", count); + } + } + + #[test] + fn test_partial_shuffle() { + let mut r = crate::test::rng(118); + + let mut empty: [u32; 0] = []; + let res = empty.partial_shuffle(&mut r, 10); + assert_eq!((res.0.len(), res.1.len()), (0, 0)); + + let mut v = [1, 2, 3, 4, 5]; + let res = v.partial_shuffle(&mut r, 2); + assert_eq!((res.0.len(), res.1.len()), (2, 3)); + assert!(res.0[0] != res.0[1]); + // First elements are only modified if selected, so at least one isn't modified: + assert!(res.1[0] == 1 || res.1[1] == 2 || res.1[2] == 3); + } + + #[test] + #[cfg(feature = "alloc")] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_weighted() { + let mut r = crate::test::rng(406); + const N_REPS: u32 = 3000; + let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; + let total_weight = weights.iter().sum::() as f32; + + let verify = |result: [i32; 14]| { + for (i, count) in result.iter().enumerate() { + let exp = (weights[i] * N_REPS) as f32 / total_weight; + let mut err = (*count as f32 - exp).abs(); + if err != 0.0 { + err /= exp; + } + assert!(err <= 0.25); + } + }; + + // choose_weighted + fn get_weight(item: &(u32, T)) -> u32 { + item.0 + } + let mut chosen = [0i32; 14]; + let mut items = [(0u32, 0usize); 14]; // (weight, index) + for (i, item) in items.iter_mut().enumerate() { + *item = (weights[i], i); + } + for _ in 0..N_REPS { + let item = items.choose_weighted(&mut r, get_weight).unwrap(); + chosen[item.1] += 1; + } + verify(chosen); + + // choose_weighted_mut + let mut items = [(0u32, 0i32); 14]; // (weight, count) + for (i, item) in items.iter_mut().enumerate() { + *item = (weights[i], 0); + } + for _ in 0..N_REPS { + items.choose_weighted_mut(&mut r, get_weight).unwrap().1 += 1; + } + for (ch, item) in chosen.iter_mut().zip(items.iter()) { + *ch = item.1; + } + verify(chosen); + + // Check error cases + let empty_slice = &mut [10][0..0]; + assert_eq!( + empty_slice.choose_weighted(&mut r, |_| 1), + Err(WeightError::InvalidInput) + ); + assert_eq!( + empty_slice.choose_weighted_mut(&mut r, |_| 1), + Err(WeightError::InvalidInput) + ); + assert_eq!( + ['x'].choose_weighted_mut(&mut r, |_| 0), + Err(WeightError::InsufficientNonZero) + ); + assert_eq!( + [0, -1].choose_weighted_mut(&mut r, |x| *x), + Err(WeightError::InvalidWeight) + ); + assert_eq!( + [-1, 0].choose_weighted_mut(&mut r, |x| *x), + Err(WeightError::InvalidWeight) + ); + } + + #[test] + #[cfg(feature = "std")] + fn test_multiple_weighted_edge_cases() { + use super::*; + + let mut rng = crate::test::rng(413); + + // Case 1: One of the weights is 0 + let choices = [('a', 2), ('b', 1), ('c', 0)]; + for _ in 0..100 { + let result = choices + .choose_multiple_weighted(&mut rng, 2, |item| item.1) + .unwrap() + .collect::>(); + + assert_eq!(result.len(), 2); + assert!(!result.iter().any(|val| val.0 == 'c')); + } + + // Case 2: All of the weights are 0 + let choices = [('a', 0), ('b', 0), ('c', 0)]; + let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1); + assert_eq!(r.unwrap().len(), 0); + + // Case 3: Negative weights + let choices = [('a', -1), ('b', 1), ('c', 1)]; + let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1); + assert_eq!(r.unwrap_err(), WeightError::InvalidWeight); + + // Case 4: Empty list + let choices = []; + let r = choices.choose_multiple_weighted(&mut rng, 0, |_: &()| 0); + assert_eq!(r.unwrap().count(), 0); + + // Case 5: NaN weights + let choices = [('a', f64::NAN), ('b', 1.0), ('c', 1.0)]; + let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1); + assert_eq!(r.unwrap_err(), WeightError::InvalidWeight); + + // Case 6: +infinity weights + let choices = [('a', f64::INFINITY), ('b', 1.0), ('c', 1.0)]; + for _ in 0..100 { + let result = choices + .choose_multiple_weighted(&mut rng, 2, |item| item.1) + .unwrap() + .collect::>(); + assert_eq!(result.len(), 2); + assert!(result.iter().any(|val| val.0 == 'a')); + } + + // Case 7: -infinity weights + let choices = [('a', f64::NEG_INFINITY), ('b', 1.0), ('c', 1.0)]; + let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1); + assert_eq!(r.unwrap_err(), WeightError::InvalidWeight); + + // Case 8: -0 weights + let choices = [('a', -0.0), ('b', 1.0), ('c', 1.0)]; + let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1); + assert!(r.is_ok()); + } + + #[test] + #[cfg(feature = "std")] + fn test_multiple_weighted_distributions() { + use super::*; + + // The theoretical probabilities of the different outcomes are: + // AB: 0.5 * 0.667 = 0.3333 + // AC: 0.5 * 0.333 = 0.1667 + // BA: 0.333 * 0.75 = 0.25 + // BC: 0.333 * 0.25 = 0.0833 + // CA: 0.167 * 0.6 = 0.1 + // CB: 0.167 * 0.4 = 0.0667 + let choices = [('a', 3), ('b', 2), ('c', 1)]; + let mut rng = crate::test::rng(414); + + let mut results = [0i32; 3]; + let expected_results = [5833, 2667, 1500]; + for _ in 0..10000 { + let result = choices + .choose_multiple_weighted(&mut rng, 2, |item| item.1) + .unwrap() + .collect::>(); + + assert_eq!(result.len(), 2); + + match (result[0].0, result[1].0) { + ('a', 'b') | ('b', 'a') => { + results[0] += 1; + } + ('a', 'c') | ('c', 'a') => { + results[1] += 1; + } + ('b', 'c') | ('c', 'b') => { + results[2] += 1; + } + (_, _) => panic!("unexpected result"), + } + } + + let mut diffs = results + .iter() + .zip(&expected_results) + .map(|(a, b)| (a - b).abs()); + assert!(!diffs.any(|deviation| deviation > 100)); + } +} diff --git a/utils/ziggurat_tables.py b/utils/ziggurat_tables.py deleted file mode 100755 index 88cfdab6ba2..00000000000 --- a/utils/ziggurat_tables.py +++ /dev/null @@ -1,125 +0,0 @@ -#!/usr/bin/env python -# -# Copyright 2018 Developers of the Rand project. -# Copyright 2013 The Rust Project Developers. -# -# Licensed under the Apache License, Version 2.0 or the MIT license -# , at your -# option. This file may not be copied, modified, or distributed -# except according to those terms. - -# This creates the tables used for distributions implemented using the -# ziggurat algorithm in `rand::distributions;`. They are -# (basically) the tables as used in the ZIGNOR variant (Doornik 2005). -# They are changed rarely, so the generated file should be checked in -# to git. -# -# It creates 3 tables: X as in the paper, F which is f(x_i), and -# F_DIFF which is f(x_i) - f(x_{i-1}). The latter two are just cached -# values which is not done in that paper (but is done in other -# variants). Note that the adZigR table is unnecessary because of -# algebra. -# -# It is designed to be compatible with Python 2 and 3. - -from math import exp, sqrt, log, floor -import random - -# The order should match the return value of `tables` -TABLE_NAMES = ['X', 'F'] - -# The actual length of the table is 1 more, to stop -# index-out-of-bounds errors. This should match the bitwise operation -# to find `i` in `zigurrat` in `libstd/rand/mod.rs`. Also the *_R and -# *_V constants below depend on this value. -TABLE_LEN = 256 - -# equivalent to `zigNorInit` in Doornik2005, but generalised to any -# distribution. r = dR, v = dV, f = probability density function, -# f_inv = inverse of f -def tables(r, v, f, f_inv): - # compute the x_i - xvec = [0]*(TABLE_LEN+1) - - xvec[0] = v / f(r) - xvec[1] = r - - for i in range(2, TABLE_LEN): - last = xvec[i-1] - xvec[i] = f_inv(v / last + f(last)) - - # cache the f's - fvec = [0]*(TABLE_LEN+1) - for i in range(TABLE_LEN+1): - fvec[i] = f(xvec[i]) - - return xvec, fvec - -# Distributions -# N(0, 1) -def norm_f(x): - return exp(-x*x/2.0) -def norm_f_inv(y): - return sqrt(-2.0*log(y)) - -NORM_R = 3.6541528853610088 -NORM_V = 0.00492867323399 - -NORM = tables(NORM_R, NORM_V, - norm_f, norm_f_inv) - -# Exp(1) -def exp_f(x): - return exp(-x) -def exp_f_inv(y): - return -log(y) - -EXP_R = 7.69711747013104972 -EXP_V = 0.0039496598225815571993 - -EXP = tables(EXP_R, EXP_V, - exp_f, exp_f_inv) - - -# Output the tables/constants/types - -def render_static(name, type, value): - # no space or - return 'pub static %s: %s =%s;\n' % (name, type, value) - -# static `name`: [`type`, .. `len(values)`] = -# [values[0], ..., values[3], -# values[4], ..., values[7], -# ... ]; -def render_table(name, values): - rows = [] - # 4 values on each row - for i in range(0, len(values), 4): - row = values[i:i+4] - rows.append(', '.join('%.18f' % f for f in row)) - - rendered = '\n [%s]' % ',\n '.join(rows) - return render_static(name, '[f64, .. %d]' % len(values), rendered) - - -with open('ziggurat_tables.rs', 'w') as f: - f.write('''// Copyright 2018 Developers of the Rand project. -// Copyright 2013 The Rust Project Developers. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -// Tables for distributions which are sampled using the ziggurat -// algorithm. Autogenerated by `ziggurat_tables.py`. - -pub type ZigTable = &\'static [f64, .. %d]; -''' % (TABLE_LEN + 1)) - for name, tables, r in [('NORM', NORM, NORM_R), - ('EXP', EXP, EXP_R)]: - f.write(render_static('ZIG_%s_R' % name, 'f64', ' %.18f' % r)) - for (tabname, table) in zip(TABLE_NAMES, tables): - f.write(render_table('ZIG_%s_%s' % (name, tabname), table))