diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 00000000000..3e0dd6b4f5b --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1 @@ +github: dhardy diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 00000000000..093433dffb3 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,20 @@ +--- +name: Bug report +about: Something doesn't work as expected +title: '' +labels: X-bug +assignees: '' + +--- + +## Summary + +A clear and concise description of what the bug is. + +What behaviour is expected, and why? + +## Code sample + +```rust +// Code demonstrating the problem +``` diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 00000000000..02ac88f0673 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,7 @@ +- [ ] Added a `CHANGELOG.md` entry + +# Summary + +# Motivation + +# Details diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000000..22b1e8da2f5 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,11 @@ +version: 2 +updates: + - package-ecosystem: "cargo" + directory: "/" + schedule: + interval: "monthly" + open-pull-requests-limit: 10 + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "monthly" diff --git a/.github/workflows/benches.yml b/.github/workflows/benches.yml new file mode 100644 index 00000000000..22b4baa8dce --- /dev/null +++ b/.github/workflows/benches.yml @@ -0,0 +1,44 @@ +name: Benches + +on: + push: + branches: [ master ] + paths-ignore: + - "**.md" + - "distr_test/**" + - "examples/**" + pull_request: + branches: [ master ] + paths-ignore: + - "**.md" + - "distr_test/**" + - "examples/**" + +defaults: + run: + working-directory: ./benches + +jobs: + clippy-fmt: + name: "Benches: Check Clippy and rustfmt" + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: stable + components: clippy, rustfmt + - name: Rustfmt + run: cargo fmt -- --check + - name: Clippy + run: cargo clippy --all-targets -- -D warnings + benches: + name: "Benches: Test" + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: nightly + - name: Test + run: RUSTFLAGS=-Dwarnings cargo test --benches diff --git a/.github/workflows/distr_test.yml b/.github/workflows/distr_test.yml new file mode 100644 index 00000000000..f2b7f814c98 --- /dev/null +++ b/.github/workflows/distr_test.yml @@ -0,0 +1,43 @@ +name: distr_test + +on: + push: + branches: [ master ] + paths-ignore: + - "**.md" + - "benches/**" + - "examples/**" + pull_request: + branches: [ master ] + paths-ignore: + - "**.md" + - "benches/**" + - "examples/**" + +defaults: + run: + working-directory: ./distr_test + +jobs: + clippy-fmt: + name: "distr_test: Check Clippy and rustfmt" + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: stable + components: clippy, rustfmt + - name: Rustfmt + run: cargo fmt -- --check + - name: Clippy + run: cargo clippy --all-targets -- -D warnings + ks-tests: + name: "distr_test: Run Komogorov Smirnov tests" + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: nightly + - run: cargo test --release diff --git a/.github/workflows/gh-pages.yml b/.github/workflows/gh-pages.yml index 80c0ec3d965..1d83a77bd7f 100644 --- a/.github/workflows/gh-pages.yml +++ b/.github/workflows/gh-pages.yml @@ -1,5 +1,10 @@ name: gh-pages +permissions: + contents: read + pages: write + id-token: write + on: push: branches: @@ -9,23 +14,34 @@ jobs: deploy: name: GH-pages documentation runs-on: ubuntu-latest + environment: + name: github-pages + url: https://rust-random.github.io/rand/ + steps: - - uses: actions/checkout@v2 + - name: Checkout + uses: actions/checkout@v4 + - name: Install toolchain - uses: actions-rs/toolchain@v1 - with: - profile: minimal - toolchain: nightly - override: true - - name: doc (rand) + uses: dtolnay/rust-toolchain@nightly + + - name: Build docs env: RUSTDOCFLAGS: --cfg doc_cfg # --all builds all crates, but with default features for other crates (okay in this case) run: | - cargo doc --all --features nightly,serde1,getrandom,small_rng + cargo doc --all --all-features --no-deps cp utils/redirect.html target/doc/index.html - - name: Deploy - uses: peaceiris/actions-gh-pages@v3 + rm target/doc/.lock + + - name: Setup Pages + uses: actions/configure-pages@v5 + + - name: Upload artifact + uses: actions/upload-pages-artifact@v3 with: - github_token: ${{ secrets.GITHUB_TOKEN }} - publish_dir: ./target/doc + path: './target/doc' + + - name: Deploy to GitHub Pages + id: deployment + uses: actions/deploy-pages@v4 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 740cfc8b872..293d5f4942d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,29 +1,58 @@ -name: Tests +name: Main tests on: push: branches: [ master, '0.[0-9]+' ] + paths-ignore: + - "**.md" + - "benches/**" + - "distr_test/**" pull_request: branches: [ master, '0.[0-9]+' ] + paths-ignore: + - "**.md" + - "benches/**" + - "distr_test/**" + +permissions: + contents: read # to fetch code (actions/checkout) jobs: + clippy-fmt: + name: Check Clippy and rustfmt + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: stable + components: clippy, rustfmt + - name: Check Clippy + run: cargo clippy --workspace -- -D warnings + - name: Check rustfmt + run: cargo fmt --all -- --check + check-doc: name: Check doc runs-on: ubuntu-latest + env: + RUSTDOCFLAGS: "-Dwarnings --cfg docsrs -Zunstable-options --generate-link-to-definition" steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Install toolchain - uses: actions-rs/toolchain@v1 + uses: dtolnay/rust-toolchain@master with: - profile: minimal toolchain: nightly - override: true - - run: cargo install cargo-deadlinks - - name: doc (rand) - env: - RUSTDOCFLAGS: --cfg doc_cfg - # --all builds all crates, but with default features for other crates (okay in this case) - run: cargo deadlinks --ignore-fragments -- --all --features nightly,serde1,getrandom,small_rng + - name: rand + run: cargo doc --all-features --no-deps + - name: rand_core + run: cargo doc --all-features --package rand_core --no-deps + - name: rand_distr + run: cargo doc --all-features --package rand_distr --no-deps + - name: rand_chacha + run: cargo doc --all-features --package rand_chacha --no-deps + - name: rand_pcg + run: cargo doc --all-features --package rand_pcg --no-deps test: runs-on: ${{ matrix.os }} @@ -47,7 +76,8 @@ jobs: # Test both windows-gnu and windows-msvc; use beta rust on one - os: ubuntu-latest target: x86_64-unknown-linux-gnu - toolchain: 1.36.0 # MSRV + variant: MSRV + toolchain: 1.63.0 - os: ubuntu-latest deps: sudo apt-get update ; sudo apt install gcc-multilib target: i686-unknown-linux-gnu @@ -58,45 +88,49 @@ jobs: variant: minimal_versions steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 + - name: MSRV + if: ${{ matrix.variant == 'MSRV' }} + run: cp Cargo.lock.msrv Cargo.lock - name: Install toolchain - uses: actions-rs/toolchain@v1 + uses: dtolnay/rust-toolchain@master with: - profile: minimal target: ${{ matrix.target }} toolchain: ${{ matrix.toolchain }} - override: true - run: ${{ matrix.deps }} - name: Maybe minimal versions if: ${{ matrix.variant == 'minimal_versions' }} - run: cargo generate-lockfile -Z minimal-versions + run: | + cargo generate-lockfile -Z minimal-versions - name: Maybe nightly if: ${{ matrix.toolchain == 'nightly' }} run: | - cargo test --target ${{ matrix.target }} --tests --features=nightly + cargo test --target ${{ matrix.target }} --features=nightly cargo test --target ${{ matrix.target }} --all-features - cargo test --target ${{ matrix.target }} --benches --features=nightly - cargo test --target ${{ matrix.target }} --manifest-path rand_distr/Cargo.toml --benches + cargo test --target ${{ matrix.target }} --lib --tests --no-default-features - name: Test rand run: | - cargo test --target ${{ matrix.target }} --tests --no-default-features - cargo test --target ${{ matrix.target }} --tests --no-default-features --features=alloc,getrandom,small_rng - # all stable features: - cargo test --target ${{ matrix.target }} --features=serde1,log,small_rng + cargo test --target ${{ matrix.target }} --lib --tests --no-default-features + cargo build --target ${{ matrix.target }} --no-default-features --features alloc,os_rng,small_rng,unbiased + cargo test --target ${{ matrix.target }} --lib --tests --no-default-features --features=alloc,os_rng,small_rng cargo test --target ${{ matrix.target }} --examples + - name: Test rand (all stable features) + run: | + cargo test --target ${{ matrix.target }} --features=serde,log,small_rng - name: Test rand_core run: | cargo test --target ${{ matrix.target }} --manifest-path rand_core/Cargo.toml cargo test --target ${{ matrix.target }} --manifest-path rand_core/Cargo.toml --no-default-features - cargo test --target ${{ matrix.target }} --manifest-path rand_core/Cargo.toml --no-default-features --features=alloc,getrandom + cargo test --target ${{ matrix.target }} --manifest-path rand_core/Cargo.toml --no-default-features --features=os_rng - name: Test rand_distr - run: cargo test --target ${{ matrix.target }} --manifest-path rand_distr/Cargo.toml + run: | + cargo test --target ${{ matrix.target }} --manifest-path rand_distr/Cargo.toml --features=serde + cargo test --target ${{ matrix.target }} --manifest-path rand_distr/Cargo.toml --no-default-features + cargo test --target ${{ matrix.target }} --manifest-path rand_distr/Cargo.toml --no-default-features --features=std,std_math - name: Test rand_pcg - run: cargo test --target ${{ matrix.target }} --manifest-path rand_pcg/Cargo.toml --features=serde1 + run: cargo test --target ${{ matrix.target }} --manifest-path rand_pcg/Cargo.toml --features=serde - name: Test rand_chacha - run: cargo test --target ${{ matrix.target }} --manifest-path rand_chacha/Cargo.toml - - name: Test rand_hc - run: cargo test --target ${{ matrix.target }} --manifest-path rand_hc/Cargo.toml + run: cargo test --target ${{ matrix.target }} --manifest-path rand_chacha/Cargo.toml --features=serde test-cross: runs-on: ${{ matrix.os }} @@ -105,20 +139,18 @@ jobs: matrix: include: - os: ubuntu-latest - target: mips-unknown-linux-gnu + target: powerpc-unknown-linux-gnu toolchain: stable steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Install toolchain - uses: actions-rs/toolchain@v1 + uses: dtolnay/rust-toolchain@master with: - profile: minimal target: ${{ matrix.target }} toolchain: ${{ matrix.toolchain }} - override: true - name: Cache cargo plugins - uses: actions/cache@v1 + uses: actions/cache@v4 with: path: ~/.cargo/bin/ key: ${{ runner.os }}-cargo-plugins @@ -127,59 +159,63 @@ jobs: - name: Test run: | # all stable features: - cross test --no-fail-fast --target ${{ matrix.target }} --features=serde1,log,small_rng + cross test --no-fail-fast --target ${{ matrix.target }} --features=serde,log,small_rng cross test --no-fail-fast --target ${{ matrix.target }} --examples cross test --no-fail-fast --target ${{ matrix.target }} --manifest-path rand_core/Cargo.toml - cross test --no-fail-fast --target ${{ matrix.target }} --manifest-path rand_distr/Cargo.toml - cross test --no-fail-fast --target ${{ matrix.target }} --manifest-path rand_pcg/Cargo.toml --features=serde1 + cross test --no-fail-fast --target ${{ matrix.target }} --manifest-path rand_distr/Cargo.toml --features=serde + cross test --no-fail-fast --target ${{ matrix.target }} --manifest-path rand_pcg/Cargo.toml --features=serde cross test --no-fail-fast --target ${{ matrix.target }} --manifest-path rand_chacha/Cargo.toml - cross test --no-fail-fast --target ${{ matrix.target }} --manifest-path rand_hc/Cargo.toml test-miri: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Install toolchain run: | - MIRI_NIGHTLY=nightly-$(curl -s https://rust-lang.github.io/rustup-components-history/x86_64-unknown-linux-gnu/miri) - rustup default "$MIRI_NIGHTLY" - rustup component add miri + rustup toolchain install nightly --component miri + rustup override set nightly + cargo miri setup - name: Test rand run: | - cargo miri test --no-default-features + cargo miri test --no-default-features --lib --tests cargo miri test --features=log,small_rng cargo miri test --manifest-path rand_core/Cargo.toml - cargo miri test --manifest-path rand_core/Cargo.toml --features=serde1 + cargo miri test --manifest-path rand_core/Cargo.toml --features=serde cargo miri test --manifest-path rand_core/Cargo.toml --no-default-features #cargo miri test --manifest-path rand_distr/Cargo.toml # no unsafe and lots of slow tests - cargo miri test --manifest-path rand_pcg/Cargo.toml --features=serde1 + cargo miri test --manifest-path rand_pcg/Cargo.toml --features=serde cargo miri test --manifest-path rand_chacha/Cargo.toml --no-default-features - cargo miri test --manifest-path rand_hc/Cargo.toml test-no-std: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Install toolchain - uses: actions-rs/toolchain@v1 + uses: dtolnay/rust-toolchain@nightly with: - profile: minimal - toolchain: nightly target: thumbv6m-none-eabi - override: true - name: Build top-level only run: cargo build --target=thumbv6m-none-eabi --no-default-features + # Disabled due to lack of known working compiler versions (not older than our MSRV) + # test-avr: + # runs-on: ubuntu-latest + # steps: + # - uses: actions/checkout@v4 + # - name: Install toolchain + # uses: dtolnay/rust-toolchain@nightly + # with: + # components: rust-src + # - name: Build top-level only + # run: cargo build -Z build-std=core --target=avr-unknown-gnu-atmega328 --no-default-features + test-ios: runs-on: macos-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Install toolchain - uses: actions-rs/toolchain@v1 + uses: dtolnay/rust-toolchain@nightly with: - profile: minimal - toolchain: nightly target: aarch64-apple-ios - override: true - name: Build top-level only run: cargo build --target=aarch64-apple-ios diff --git a/CHANGELOG.md b/CHANGELOG.md index c4815bbb83c..fded9d79aca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,152 @@ A [separate changelog is kept for rand_core](rand_core/CHANGELOG.md). You may also find the [Upgrade Guide](https://rust-random.github.io/book/update.html) useful. +## [0.9.0] - 2025-01-27 +### Security and unsafe +- Policy: "rand is not a crypto library" (#1514) +- Remove fork-protection from `ReseedingRng` and `ThreadRng`. Instead, it is recommended to call `ThreadRng::reseed` on fork. (#1379) +- Use `zerocopy` to replace some `unsafe` code (#1349, #1393, #1446, #1502) + +### Dependencies +- Bump the MSRV to 1.63.0 (#1207, #1246, #1269, #1341, #1416, #1536); note that 1.60.0 may work for dependents when using `--ignore-rust-version` +- Update to `rand_core` v0.9.0 (#1558) + +### Features +- Support `std` feature without `getrandom` or `rand_chacha` (#1354) +- Enable feature `small_rng` by default (#1455) +- Remove implicit feature `rand_chacha`; use `std_rng` instead. (#1473) +- Rename feature `serde1` to `serde` (#1477) +- Rename feature `getrandom` to `os_rng` (#1537) +- Add feature `thread_rng` (#1547) + +### API changes: rand_core traits +- Add fn `RngCore::read_adapter` implementing `std::io::Read` (#1267) +- Add trait `CryptoBlockRng: BlockRngCore`; make `trait CryptoRng: RngCore` (#1273) +- Add traits `TryRngCore`, `TryCryptoRng` (#1424, #1499) +- Rename `fn SeedableRng::from_rng` -> `try_from_rng` and add infallible variant `fn from_rng` (#1424) +- Rename `fn SeedableRng::from_entropy` -> `from_os_rng` and add fallible variant `fn try_from_os_rng` (#1424) +- Add bounds `Clone` and `AsRef` to associated type `SeedableRng::Seed` (#1491) + +### API changes: Rng trait and top-level fns +- Rename fn `rand::thread_rng()` to `rand::rng()` and remove from the prelude (#1506) +- Remove fn `rand::random()` from the prelude (#1506) +- Add top-level fns `random_iter`, `random_range`, `random_bool`, `random_ratio`, `fill` (#1488) +- Re-introduce fn `Rng::gen_iter` as `random_iter` (#1305, #1500) +- Rename fn `Rng::gen` to `random` to avoid conflict with the new `gen` keyword in Rust 2024 (#1438) +- Rename fns `Rng::gen_range` to `random_range`, `gen_bool` to `random_bool`, `gen_ratio` to `random_ratio` (#1505) +- Annotate panicking methods with `#[track_caller]` (#1442, #1447) + +### API changes: RNGs +- Fix `::Seed` size to 256 bits (#1455) +- Remove first parameter (`rng`) of `ReseedingRng::new` (#1533) + +### API changes: Sequences +- Split trait `SliceRandom` into `IndexedRandom`, `IndexedMutRandom`, `SliceRandom` (#1382) +- Add `IndexedRandom::choose_multiple_array`, `index::sample_array` (#1453, #1469) + +### API changes: Distributions: renames +- Rename module `rand::distributions` to `rand::distr` (#1470) +- Rename distribution `Standard` to `StandardUniform` (#1526) +- Move `distr::Slice` -> `distr::slice::Choose`, `distr::EmptySlice` -> `distr::slice::Empty` (#1548) +- Rename trait `distr::DistString` -> `distr::SampleString` (#1548) +- Rename `distr::DistIter` -> `distr::Iter`, `distr::DistMap` -> `distr::Map` (#1548) + +### API changes: Distributions +- Relax `Sized` bound on `Distribution for &D` (#1278) +- Remove impl of `Distribution>` for `StandardUniform` (#1526) +- Let distribution `StandardUniform` support all `NonZero*` types (#1332) +- Fns `{Uniform, UniformSampler}::{new, new_inclusive}` return a `Result` (instead of potentially panicking) (#1229) +- Distribution `Uniform` implements `TryFrom` instead of `From` for ranges (#1229) +- Add `UniformUsize` (#1487) +- Remove support for generating `isize` and `usize` values with `StandardUniform`, `Uniform` (except via `UniformUsize`) and `Fill` and usage as a `WeightedAliasIndex` weight (#1487) +- Add impl `DistString` for distributions `Slice` and `Uniform` (#1315) +- Add fn `Slice::num_choices` (#1402) +- Add fn `p()` for distribution `Bernoulli` to access probability (#1481) + +### API changes: Weighted distributions +- Add `pub` module `rand::distr::weighted`, moving `WeightedIndex` there (#1548) +- Add trait `weighted::Weight`, allowing `WeightedIndex` to trap overflow (#1353) +- Add fns `weight, weights, total_weight` to distribution `WeightedIndex` (#1420) +- Rename enum `WeightedError` to `weighted::Error`, revising variants (#1382) and mark as `#[non_exhaustive]` (#1480) + +### API changes: SIMD +- Switch to `std::simd`, expand SIMD & docs (#1239) + +### Reproducibility-breaking changes +- Make `ReseedingRng::reseed` discard remaining data from the last block generated (#1379) +- Change fn `SmallRng::seed_from_u64` implementation (#1203) +- Allow `UniformFloat::new` samples and `UniformFloat::sample_single` to yield `high` (#1462) +- Fix portability of distribution `Slice` (#1469) +- Make `Uniform` for `usize` portable via `UniformUsize` (#1487) +- Fix `IndexdRandom::choose_multiple_weighted` for very small seeds and optimize for large input length / low memory (#1530) + +### Reproducibility-breaking optimisations +- Optimize fn `sample_floyd`, affecting output of `rand::seq::index::sample` and `rand::seq::SliceRandom::choose_multiple` (#1277) +- New, faster algorithms for `IteratorRandom::choose` and `choose_stable` (#1268) +- New, faster algorithms for `SliceRandom::shuffle` and `partial_shuffle` (#1272) +- Optimize distribution `Uniform`: use Canon's method (single sampling) / Lemire's method (distribution sampling) for faster sampling (breaks value stability; #1287) +- Optimize fn `sample_single_inclusive` for floats (+~20% perf) (#1289) + +### Other optimisations +- Improve `SmallRng` initialization performance (#1482) +- Optimise SIMD widening multiply (#1247) + +### Other +- Add `Cargo.lock.msrv` file (#1275) +- Reformat with `rustfmt` and enforce (#1448) +- Apply Clippy suggestions and enforce (#1448, #1474) +- Move all benchmarks to new `benches` crate (#1329, #1439) and migrate to Criterion (#1490) + +### Documentation +- Improve `ThreadRng` related docs (#1257) +- Docs: enable experimental `--generate-link-to-definition` feature (#1327) +- Better doc of crate features, use `doc_auto_cfg` (#1411, #1450) + +## [0.8.5] - 2021-08-20 +### Fixes +- Fix build on non-32/64-bit architectures (#1144) +- Fix "min_const_gen" feature for `no_std` (#1173) +- Check `libc::pthread_atfork` return value with panic on error (#1178) +- More robust reseeding in case `ReseedingRng` is used from a fork handler (#1178) +- Fix nightly: remove unused `slice_partition_at_index` feature (#1215) +- Fix nightly + `simd_support`: update `packed_simd` (#1216) + +### Rngs +- `StdRng`: Switch from HC128 to ChaCha12 on emscripten (#1142). + We now use ChaCha12 on all platforms. + +### Documentation +- Added docs about rand's use of const generics (#1150) +- Better random chars example (#1157) + + +## [0.8.4] - 2021-06-15 +### Additions +- Use const-generics to support arrays of all sizes (#1104) +- Implement `Clone` and `Copy` for `Alphanumeric` (#1126) +- Add `Distribution::map` to derive a distribution using a closure (#1129) +- Add `Slice` distribution (#1107) +- Add `DistString` trait with impls for `Standard` and `Alphanumeric` (#1133) + +### Other +- Reorder asserts in `Uniform` float distributions for easier debugging of non-finite arguments + (#1094, #1108) +- Add range overflow check in `Uniform` float distributions (#1108) +- Deprecate `rngs::adapter::ReadRng` (#1130) + +## [0.8.3] - 2021-01-25 +### Fixes +- Fix `no-std` + `alloc` build by gating `choose_multiple_weighted` on `std` (#1088) + +## [0.8.2] - 2021-01-12 +### Fixes +- Fix panic in `UniformInt::sample_single_inclusive` and `Rng::gen_range` when + providing a full integer range (eg `0..=MAX`) (#1087) + +## [0.8.1] - 2020-12-31 +### Other +- Enable all stable features in the playground (#1081) + ## [0.8.0] - 2020-12-18 ### Platform support - The minimum supported Rust version is now 1.36 (#1011) @@ -651,4 +797,4 @@ when updating from `rand 0.7.0` without also updating `rand_core`. ## [0.10-pre] - 2014-03-02 ### Added -- Seperate `rand` out of the standard library +- Separate `rand` out of the standard library diff --git a/Cargo.lock.msrv b/Cargo.lock.msrv new file mode 100644 index 00000000000..66921820c1e --- /dev/null +++ b/Cargo.lock.msrv @@ -0,0 +1,728 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + +[[package]] +name = "average" +version = "0.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a237a6822e1c3c98e700b6db5b293eb341b7524dcb8d227941245702b7431dc" +dependencies = [ + "easy-cast", + "float-ord", + "num-traits", +] + +[[package]] +name = "base64" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" + +[[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + +[[package]] +name = "bumpalo" +version = "3.15.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ff69b9dd49fd426c69a0db9fc04dd934cdb6645ff000864d98f7e2af8830eaa" + +[[package]] +name = "cc" +version = "1.0.90" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8cd6604a82acf3039f1144f54b8eb34e91ffba622051189e71b781822d5ee1f5" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "chrono" +version = "0.4.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8eaf5903dcbc0a39312feb77df2ff4c76387d591b9fc7b04a238dcf8bb62639a" +dependencies = [ + "android-tzdata", + "iana-time-zone", + "num-traits", + "serde", + "windows-targets", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" + +[[package]] +name = "crossbeam-channel" +version = "0.5.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab3db02a9c5b5121e1e42fbdb1aeb65f5e02624cc58c43f2884c6ccac0b82f95" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" + +[[package]] +name = "darling" +version = "0.20.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54e36fcd13ed84ffdfda6f5be89b31287cbb80c439841fe69e04841435464391" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.20.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c2cf1c23a687a1feeb728783b993c4e1ad83d99f351801977dd809b48d0a70f" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn", +] + +[[package]] +name = "darling_macro" +version = "0.20.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a668eda54683121533a393014d8692171709ff57a7d61f187b6e782719f8933f" +dependencies = [ + "darling_core", + "quote", + "syn", +] + +[[package]] +name = "deranged" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" +dependencies = [ + "powerfmt", + "serde", +] + +[[package]] +name = "easy-cast" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10936778145f3bea71fd9bf61332cce28c28e96a380714f7ab34838b80733fd6" +dependencies = [ + "libm", +] + +[[package]] +name = "either" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a" + +[[package]] +name = "fast_polynomial" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62eea6ee590b08a5f8b1139f4d6caee195b646d0c07e4b1808fbd5c4dea4829a" +dependencies = [ + "num-traits", +] + +[[package]] +name = "float-ord" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ce81f49ae8a0482e4c55ea62ebbd7e5a686af544c00b9d090bba3ff9be97b3d" + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "getrandom" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "hashbrown" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" + +[[package]] +name = "hermit-abi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" + +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + +[[package]] +name = "iana-time-zone" +version = "0.1.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + +[[package]] +name = "indexmap" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" +dependencies = [ + "autocfg", + "hashbrown", + "serde", +] + +[[package]] +name = "itoa" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" + +[[package]] +name = "js-sys" +version = "0.3.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" +dependencies = [ + "wasm-bindgen", +] + +[[package]] +name = "lambert_w" +version = "0.5.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd8852c2190439a46c77861aca230080cc9db4064be7f9de8ee81816d6c72c25" +dependencies = [ + "fast_polynomial", + "libm", +] + +[[package]] +name = "libc" +version = "0.2.153" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" + +[[package]] +name = "libm" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" + +[[package]] +name = "log" +version = "0.4.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" + +[[package]] +name = "num-conv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" + +[[package]] +name = "num-traits" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da0df0e5185db44f69b44f26786fe401b6c293d1907744beaa7fa62b2e5a517a" +dependencies = [ + "autocfg", + "libm", +] + +[[package]] +name = "num_cpus" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" +dependencies = [ + "hermit-abi", + "libc", +] + +[[package]] +name = "once_cell" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" + +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + +[[package]] +name = "proc-macro2" +version = "1.0.79" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e835ff2298f5721608eb1a980ecaee1aef2c132bf95ecc026a11b7bf3c01c02e" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rand" +version = "0.9.0-beta.0" +dependencies = [ + "bincode", + "log", + "rand_chacha", + "rand_core", + "rand_pcg", + "rayon", + "serde", + "zerocopy", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0-beta.0" +dependencies = [ + "ppv-lite86", + "rand_core", + "serde", + "serde_json", +] + +[[package]] +name = "rand_core" +version = "0.9.0-beta.0" +dependencies = [ + "getrandom", + "serde", + "zerocopy", +] + +[[package]] +name = "rand_distr" +version = "0.5.0-beta.0" +dependencies = [ + "average", + "num-traits", + "rand", + "rand_pcg", + "serde", + "serde_with", + "special", +] + +[[package]] +name = "rand_pcg" +version = "0.9.0-beta.0" +dependencies = [ + "bincode", + "rand_core", + "serde", +] + +[[package]] +name = "rayon" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d2df5196e37bcc87abebc0053e20787d73847bb33134a69841207dd0a47f03b" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b8f95bd6966f5c87776639160a66bd8ab9895d9d4ab01ddba9fc60661aebe8d" +dependencies = [ + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-utils", + "num_cpus", +] + +[[package]] +name = "ryu" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1" + +[[package]] +name = "serde" +version = "1.0.197" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fb1c873e1b9b056a4dc4c0c198b24c3ffa059243875552b2bd0933b1aee4ce2" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.197" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.114" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c5f09b1bd632ef549eaa9f60a1f8de742bdbc698e6cee2095fc84dde5f549ae0" +dependencies = [ + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "serde_with" +version = "3.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f02d8aa6e3c385bf084924f660ce2a3a6bd333ba55b35e8590b321f35d88513" +dependencies = [ + "base64", + "chrono", + "hex", + "indexmap", + "serde", + "serde_json", + "serde_with_macros", + "time", +] + +[[package]] +name = "serde_with_macros" +version = "3.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edc7d5d3932fb12ce722ee5e64dd38c504efba37567f0c402f6ca728c3b8b070" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "special" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "98d279079c3ddec4e7851337070c1055a18b8f606bba0b1aeb054bc059fc2e27" +dependencies = [ + "lambert_w", + "libm", +] + +[[package]] +name = "strsim" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" + +[[package]] +name = "syn" +version = "2.0.53" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7383cd0e49fff4b6b90ca5670bfd3e9d6a733b3f90c686605aa7eec8c4996032" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "time" +version = "0.3.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8248b6521bb14bc45b4067159b9b6ad792e2d6d754d6c41fb50e29fefe38749" +dependencies = [ + "deranged", + "itoa", + "num-conv", + "powerfmt", + "serde", + "time-core", + "time-macros", +] + +[[package]] +name = "time-core" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" + +[[package]] +name = "time-macros" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ba3a3ef41e6672a2f0f001392bb5dcd3ff0a9992d618ca761a11c3121547774" +dependencies = [ + "num-conv", + "time-core", +] + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "wasm-bindgen" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" +dependencies = [ + "cfg-if", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" +dependencies = [ + "bumpalo", + "log", + "once_cell", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" + +[[package]] +name = "windows-core" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.52.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7dd37b7e5ab9018759f893a1952c9420d060016fc19a472b4bb20d1bdd694d1b" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bcf46cf4c365c6f2d1cc93ce535f2c8b244591df96ceee75d8e83deb70a9cac9" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da9f259dd3bcf6990b55bffd094c4f7235817ba4ceebde8e6d11cd0c5633b675" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b474d8268f99e0995f25b9f095bc7434632601028cf86590aea5c8a5cb7801d3" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1515e9a29e5bed743cb4415a9ecf5dfca648ce85ee42e15873c3cd8610ff8e02" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5eee091590e89cc02ad514ffe3ead9eb6b660aedca2183455434b93546371a03" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77ca79f2451b49fa9e2af39f0747fe999fcda4f5e241b2898624dca97a1f2177" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8" + +[[package]] +name = "zerocopy" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a65238aacd5fb83fb03fcaf94823e71643e937000ec03c46e7da94234b10c870" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ca22c4ad176b37bd81a565f66635bde3d654fe6832730c3e52e1018ae1655ee" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] diff --git a/Cargo.toml b/Cargo.toml index aee917c1226..956f12741fc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rand" -version = "0.8.0" +version = "0.9.0" authors = ["The Rand Project Developers", "The Rust Project Developers"] license = "MIT OR Apache-2.0" readme = "README.md" @@ -13,75 +13,73 @@ Random number generators and other randomness functionality. keywords = ["random", "rng"] categories = ["algorithms", "no-std"] autobenches = true -edition = "2018" +edition = "2021" +rust-version = "1.63" include = ["src/", "LICENSE-*", "README.md", "CHANGELOG.md", "COPYRIGHT"] +[package.metadata.docs.rs] +# To build locally: +# RUSTDOCFLAGS="--cfg docsrs -Zunstable-options --generate-link-to-definition" cargo +nightly doc --all --all-features --no-deps --open +all-features = true +rustdoc-args = ["--generate-link-to-definition"] + +[package.metadata.playground] +features = ["small_rng", "serde"] + [features] # Meta-features: -default = ["std", "std_rng"] -nightly = [] # enables performance optimizations requiring nightly rust -serde1 = ["serde"] +default = ["std", "std_rng", "os_rng", "small_rng", "thread_rng"] +nightly = [] # some additions requiring nightly Rust +serde = ["dep:serde", "rand_core/serde"] # Option (enabled by default): without "std" rand uses libcore; this option # enables functionality expected to be available on a standard platform. -std = ["rand_core/std", "rand_chacha/std", "alloc", "getrandom", "libc"] +std = ["rand_core/std", "rand_chacha?/std", "alloc"] # Option: "alloc" enables support for Vec and Box when not using "std" -alloc = ["rand_core/alloc"] +alloc = [] -# Option: use getrandom package for seeding -getrandom = ["rand_core/getrandom"] +# Option: enable OsRng +os_rng = ["rand_core/os_rng"] -# Option (requires nightly): experimental SIMD support -simd_support = ["packed_simd"] +# Option (requires nightly Rust): experimental SIMD support +simd_support = ["zerocopy/simd-nightly"] # Option (enabled by default): enable StdRng -std_rng = ["rand_chacha", "rand_hc"] +std_rng = ["dep:rand_chacha"] # Option: enable SmallRng small_rng = [] +# Option: enable ThreadRng and rng() +thread_rng = ["std", "std_rng", "os_rng"] + +# Option: use unbiased sampling for algorithms supporting this option: Uniform distribution. +# By default, bias affecting no more than one in 2^48 samples is accepted. +# Note: enabling this option is expected to affect reproducibility of results. +unbiased = [] + +# Option: enable logging +log = ["dep:log"] + [workspace] members = [ "rand_core", "rand_distr", "rand_chacha", - "rand_hc", "rand_pcg", ] +exclude = ["benches", "distr_test"] [dependencies] -rand_core = { path = "rand_core", version = "0.6.0" } +rand_core = { path = "rand_core", version = "0.9.0", default-features = false } log = { version = "0.4.4", optional = true } serde = { version = "1.0.103", features = ["derive"], optional = true } - -[dependencies.packed_simd] -# NOTE: so far no version works reliably due to dependence on unstable features -package = "packed_simd_2" -version = "0.3.4" -optional = true -features = ["into_bits"] - -[target.'cfg(unix)'.dependencies] -# Used for fork protection (reseeding.rs) -libc = { version = "0.2.22", optional = true, default-features = false } - -# Emscripten does not support 128-bit integers, which are used by ChaCha code. -# We work around this by using a different RNG. -[target.'cfg(not(target_os = "emscripten"))'.dependencies] -rand_chacha = { path = "rand_chacha", version = "0.3.0", default-features = false, optional = true } -[target.'cfg(target_os = "emscripten")'.dependencies] -rand_hc = { path = "rand_hc", version = "0.3.0", optional = true } +rand_chacha = { path = "rand_chacha", version = "0.9.0", default-features = false, optional = true } +zerocopy = { version = "0.8.0", default-features = false, features = ["simd"] } [dev-dependencies] -rand_pcg = { path = "rand_pcg", version = "0.3.0" } -# Only for benches: -rand_hc = { path = "rand_hc", version = "0.3.0" } -# Only to test serde1 +rand_pcg = { path = "rand_pcg", version = "0.9.0" } +# Only to test serde bincode = "1.2.1" - -[package.metadata.docs.rs] -# To build locally: -# RUSTDOCFLAGS="--cfg doc_cfg" cargo +nightly doc --all-features --no-deps --open -all-features = true -rustdoc-args = ["--cfg", "doc_cfg"] +rayon = "1.7" diff --git a/LICENSE-APACHE b/LICENSE-APACHE index 17d74680f8c..494ad3bfdfe 100644 --- a/LICENSE-APACHE +++ b/LICENSE-APACHE @@ -174,28 +174,3 @@ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS - -APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - -Copyright [yyyy] [name of copyright owner] - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. diff --git a/README.md b/README.md index aaf6df1d584..740807a9669 100644 --- a/README.md +++ b/README.md @@ -1,42 +1,50 @@ # Rand -[![Test Status](https://github.com/rust-random/rand/workflows/Tests/badge.svg?event=push)](https://github.com/rust-random/rand/actions) +[![Test Status](https://github.com/rust-random/rand/actions/workflows/test.yml/badge.svg?event=push)](https://github.com/rust-random/rand/actions) [![Crate](https://img.shields.io/crates/v/rand.svg)](https://crates.io/crates/rand) [![Book](https://img.shields.io/badge/book-master-yellow.svg)](https://rust-random.github.io/book/) [![API](https://img.shields.io/badge/api-master-yellow.svg)](https://rust-random.github.io/rand/rand) [![API](https://docs.rs/rand/badge.svg)](https://docs.rs/rand) -[![Minimum rustc version](https://img.shields.io/badge/rustc-1.36+-lightgray.svg)](https://github.com/rust-random/rand#rust-version-requirements) - -A Rust library for random number generation, featuring: - -- Easy random value generation and usage via the [`Rng`](https://docs.rs/rand/*/rand/trait.Rng.html), - [`SliceRandom`](https://docs.rs/rand/*/rand/seq/trait.SliceRandom.html) and - [`IteratorRandom`](https://docs.rs/rand/*/rand/seq/trait.IteratorRandom.html) traits -- Secure seeding via the [`getrandom` crate](https://crates.io/crates/getrandom) - and fast, convenient generation via [`thread_rng`](https://docs.rs/rand/*/rand/fn.thread_rng.html) -- A modular design built over [`rand_core`](https://crates.io/crates/rand_core) - ([see the book](https://rust-random.github.io/book/crates.html)) -- Fast implementations of the best-in-class [cryptographic](https://rust-random.github.io/book/guide-rngs.html#cryptographically-secure-pseudo-random-number-generators-csprngs) and - [non-cryptographic](https://rust-random.github.io/book/guide-rngs.html#basic-pseudo-random-number-generators-prngs) generators -- A flexible [`distributions`](https://docs.rs/rand/*/rand/distributions/index.html) module -- Samplers for a large number of random number distributions via our own + +Rand is a set of crates supporting (pseudo-)random generators: + +- Built over a standard RNG trait: [`rand_core::RngCore`](https://docs.rs/rand_core/latest/rand_core/trait.RngCore.html) +- With fast implementations of both [strong](https://rust-random.github.io/book/guide-rngs.html#cryptographically-secure-pseudo-random-number-generators-csprngs) and + [small](https://rust-random.github.io/book/guide-rngs.html#basic-pseudo-random-number-generators-prngs) generators: [`rand::rngs`](https://docs.rs/rand/latest/rand/rngs/index.html), and more RNGs: [`rand_chacha`](https://docs.rs/rand_chacha), [`rand_xoshiro`](https://docs.rs/rand_xoshiro/), [`rand_pcg`](https://docs.rs/rand_pcg/), [rngs repo](https://github.com/rust-random/rngs/) +- [`rand::rng`](https://docs.rs/rand/latest/rand/fn.rng.html) is an asymptotically-fast, automatically-seeded and reasonably strong generator available on all `std` targets +- Direct support for seeding generators from the [getrandom] crate + +With broad support for random value generation and random processes: + +- [`StandardUniform`](https://docs.rs/rand/latest/rand/distributions/struct.StandardUniform.html) random value sampling, + [`Uniform`](https://docs.rs/rand/latest/rand/distributions/struct.Uniform.html)-ranged value sampling + and [more](https://docs.rs/rand/latest/rand/distr/index.html) +- Samplers for a large number of non-uniform random number distributions via our own [`rand_distr`](https://docs.rs/rand_distr) and via the [`statrs`](https://docs.rs/statrs/0.13.0/statrs/) +- Random processes (mostly choose and shuffle) via [`rand::seq`](https://docs.rs/rand/latest/rand/seq/index.html) traits + +All with: + - [Portably reproducible output](https://rust-random.github.io/book/portability.html) - `#[no_std]` compatibility (partial) -- *Many* performance optimisations +- *Many* performance optimisations thanks to contributions from the wide + user-base -It's also worth pointing out what `rand` *is not*: +Rand **is not**: -- Small. Most low-level crates are small, but the higher-level `rand` and - `rand_distr` each contain a lot of functionality. +- Small (LoC). Most low-level crates are small, but the higher-level `rand` + and `rand_distr` each contain a lot of functionality. - Simple (implementation). We have a strong focus on correctness, speed and flexibility, but not simplicity. If you prefer a small-and-simple library, there are alternatives including [fastrand](https://crates.io/crates/fastrand) and [oorandom](https://crates.io/crates/oorandom). -- Slow. We take performance seriously, with considerations also for set-up - time of new distributions, commonly-used parameters, and parameters of the - current sampler. +- A cryptography library. Rand provides functionality for generating + unpredictable random data (potentially applicable depending on requirements) + but does not provide high-level cryptography functionality. + +Rand is a community project and cannot provide legally-binding guarantees of +security. Documentation: @@ -45,67 +53,14 @@ Documentation: - [API reference (docs.rs)](https://docs.rs/rand) -## Usage - -Add this to your `Cargo.toml`: - -```toml -[dependencies] -rand = "0.8.0" -``` - -To get started using Rand, see [The Book](https://rust-random.github.io/book). - - ## Versions Rand is *mature* (suitable for general usage, with infrequent breaking releases -which minimise breakage) but not yet at 1.0. We maintain compatibility with -pinned versions of the Rust compiler (see below). - -Current Rand versions are: - -- Version 0.7 was released in June 2019, moving most non-uniform distributions - to an external crate, moving `from_entropy` to `SeedableRng`, and many small - changes and fixes. -- Version 0.8 was released in December 2020 with many small changes. +which minimise breakage) but not yet at 1.0. Current versions are: -A detailed [changelog](CHANGELOG.md) is available for releases. +- Version 0.9 was released in January 2025. -When upgrading to the next minor series (especially 0.4 → 0.5), we recommend -reading the [Upgrade Guide](https://rust-random.github.io/book/update.html). - -Rand has not yet reached 1.0 implying some breaking changes may arrive in the -future ([SemVer](https://semver.org/) allows each 0.x.0 release to include -breaking changes), but is considered *mature*: breaking changes are minimised -and breaking releases are infrequent. - -Rand libs have inter-dependencies and make use of the -[semver trick](https://github.com/dtolnay/semver-trick/) in order to make traits -compatible across crate versions. (This is especially important for `RngCore` -and `SeedableRng`.) A few crate releases are thus compatibility shims, -depending on the *next* lib version (e.g. `rand_core` versions `0.2.2` and -`0.3.1`). This means, for example, that `rand_core_0_4_0::SeedableRng` and -`rand_core_0_3_0::SeedableRng` are distinct, incompatible traits, which can -cause build errors. Usually, running `cargo update` is enough to fix any issues. - -### Yanked versions - -Some versions of Rand crates have been yanked ("unreleased"). Where this occurs, -the crate's CHANGELOG *should* be updated with a rationale, and a search on the -issue tracker with the keyword `yank` *should* uncover the motivation. - -### Rust version requirements - -Since version 0.8, Rand requires **Rustc version 1.36 or greater**. -Rand 0.7 requires Rustc 1.32 or greater while versions 0.5 require Rustc 1.22 or -greater, and 0.4 and 0.3 (since approx. June 2017) require Rustc version 1.15 or -greater. Subsets of the Rand code may work with older Rust versions, but this is -not supported. - -Continuous Integration (CI) will always test the minimum supported Rustc version -(the MSRV). The current policy is that this can be updated in any -Rand release if required, but the change must be noted in the changelog. +See the [CHANGELOG](CHANGELOG.md) or [Upgrade Guide](https://rust-random.github.io/book/update.html) for more details. ## Crate Features @@ -113,19 +68,17 @@ Rand is built with these features enabled by default: - `std` enables functionality dependent on the `std` lib - `alloc` (implied by `std`) enables functionality requiring an allocator -- `getrandom` (implied by `std`) is an optional dependency providing the code - behind `rngs::OsRng` -- `std_rng` enables inclusion of `StdRng`, `thread_rng` and `random` - (the latter two *also* require that `std` be enabled) +- `os_rng` (implied by `std`) enables `rngs::OsRng`, using the [getrandom] crate +- `std_rng` enables inclusion of `StdRng`, `ThreadRng` Optionally, the following dependencies can be enabled: -- `log` enables logging via the `log` crate` crate +- `log` enables logging via [log](https://crates.io/crates/log) Additionally, these features configure Rand: - `small_rng` enables inclusion of the `SmallRng` PRNG -- `nightly` enables some optimizations requiring nightly Rust +- `nightly` includes some additions requiring nightly Rust - `simd_support` (experimental) enables sampling of SIMD values (uniformly random SIMD integers and floats), requiring nightly Rust @@ -134,10 +87,26 @@ compiler versions will be compatible. This is especially true of Rand's experimental `simd_support` feature. Rand supports limited functionality in `no_std` mode (enabled via -`default-features = false`). In this case, `OsRng` and `from_entropy` are -unavailable (unless `getrandom` is enabled), large parts of `seq` are -unavailable (unless `alloc` is enabled), and `thread_rng` and `random` are -unavailable. +`default-features = false`). In this case, `OsRng` and `from_os_rng` are +unavailable (unless `os_rng` is enabled), large parts of `seq` are +unavailable (unless `alloc` is enabled), and `ThreadRng` is unavailable. + +## Portability and platform support + +Many (but not all) algorithms are intended to have reproducible output. Read more in the book: [Portability](https://rust-random.github.io/book/portability.html). + +The Rand library supports a variety of CPU architectures. Platform integration is outsourced to [getrandom]. + +### WASM support + +Seeding entropy from OS on WASM target `wasm32-unknown-unknown` is not +*automatically* supported by `rand` or `getrandom`. If you are fine with +seeding the generator manually, you can disable the `os_rng` feature +and use the methods on the `SeedableRng` trait. To enable seeding from OS, +either use a different target such as `wasm32-wasi` or add a direct +dependency on [getrandom] with the `js` feature (if the target supports +JavaScript). See +[getrandom#WebAssembly support](https://docs.rs/getrandom/latest/getrandom/#webassembly-support). # License @@ -146,3 +115,5 @@ Apache License (Version 2.0). See [LICENSE-APACHE](LICENSE-APACHE) and [LICENSE-MIT](LICENSE-MIT), and [COPYRIGHT](COPYRIGHT) for details. + +[getrandom]: https://crates.io/crates/getrandom diff --git a/SECURITY.md b/SECURITY.md index 0da2bf0fed6..26cf7c12fc5 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -1,34 +1,46 @@ # Security Policy -## No guarantees +## Disclaimer -Support is provided on a best-effort bases only. -No binding guarantees can be provided. +Rand is a community project and cannot provide legally-binding guarantees of +security. ## Security premises -Rand provides the trait `rand_core::CryptoRng` aka `rand::CryptoRng` as a marker -trait. Generators implementating `RngCore` *and* `CryptoRng`, and given the -additional constraints that: +### Marker traits + +Rand provides the marker traits `CryptoRng`, `TryCryptoRng` and +`CryptoBlockRng`. Generators implementing one of these traits and used in a way +which meets the following additional constraints: - Instances of seedable RNGs (those implementing `SeedableRng`) are constructed with cryptographically secure seed values -- The state (memory) of the RNG and its seed value are not be exposed +- The state (memory) of the RNG and its seed value are not exposed are expected to provide the following: -- An attacker can gain no advantage over chance (50% for each bit) in - predicting the RNG output, even with full knowledge of all prior outputs. +- An attacker cannot predict the output with more accuracy than what would be + expected through pure chance since each possible output value of any method + under the above traits which generates output bytes (including + `RngCore::next_u32`, `RngCore::next_u64`, `RngCore::fill_bytes`, + `TryRngCore::try_next_u32`, `TryRngCore::try_next_u64`, + `TryRngCore::try_fill_bytes` and `BlockRngCore::generate`) should be equally + likely +- Knowledge of prior outputs from the generator does not aid an attacker in + predicting future outputs + +### Specific generators + +`OsRng` is a stateless "generator" implemented via [getrandom]. As such, it has +no possible state to leak and cannot be improperly seeded. + +`ThreadRng` will periodically reseed itself, thus placing an upper bound on the +number of bits of output from an instance before any advantage an attacker may +have gained through state-compromising side-channel attacks is lost. -For some RNGs, notably `OsRng`, `ThreadRng` and those wrapped by `ReseedingRng`, -we provide limited mitigations against side-channel attacks: +[getrandom]: https://crates.io/crates/getrandom -- After a process fork on Unix, there is an upper-bound on the number of bits - output by the RNG before the processes diverge, after which outputs from - each process's RNG are uncorrelated -- After the state (memory) of an RNG is leaked, there is an upper-bound on the - number of bits of output by the RNG before prediction of output by an - observer again becomes computationally-infeasible +### Distributions Additionally, derivations from such an RNG (including the `Rng` trait, implementations of the `Distribution` trait, and `seq` algorithms) should not @@ -49,19 +61,18 @@ exceptions for theoretical issues without a known exploit: | `rand` | 0.4 | Jitter, ISAAC | | `rand_core` | 0.2 - 0.6 | | | `rand_chacha` | 0.1 - 0.3 | | -| `rand_hc` | 0.1 - 0.3 | | Explanation of exceptions: - Jitter: `JitterRng` is used as an entropy source when the primary source fails; this source may not be secure against side-channel attacks, see #699. - ISAAC: the [ISAAC](https://burtleburtle.net/bob/rand/isaacafa.html) RNG used - to implement `thread_rng` is difficult to analyse and thus cannot provide + to implement `ThreadRng` is difficult to analyse and thus cannot provide strong assertions of security. ## Known issues -In `rand` version 0.3 (0.3.18 and later), if `OsRng` fails, `thread_rng` is +In `rand` version 0.3 (0.3.18 and later), if `OsRng` fails, `ThreadRng` is seeded from the system time in an insecure manner. ## Reporting a Vulnerability diff --git a/benches/Cargo.toml b/benches/Cargo.toml new file mode 100644 index 00000000000..a143bff3c02 --- /dev/null +++ b/benches/Cargo.toml @@ -0,0 +1,55 @@ +[package] +name = "benches" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] + +[dev-dependencies] +rand = { path = "..", features = ["small_rng", "nightly"] } +rand_pcg = { path = "../rand_pcg" } +rand_chacha = { path = "../rand_chacha" } +rand_distr = { path = "../rand_distr" } +criterion = "0.5" +criterion-cycles-per-byte = "0.6" + +[[bench]] +name = "array" +harness = false + +[[bench]] +name = "bool" +harness = false + +[[bench]] +name = "distr" +harness = false + +[[bench]] +name = "generators" +harness = false + +[[bench]] +name = "seq_choose" +harness = false + +[[bench]] +name = "shuffle" +harness = false + +[[bench]] +name = "standard" +harness = false + +[[bench]] +name = "uniform" +harness = false + +[[bench]] +name = "uniform_float" +harness = false + +[[bench]] +name = "weighted" +harness = false diff --git a/benches/benches/array.rs b/benches/benches/array.rs new file mode 100644 index 00000000000..063516337bf --- /dev/null +++ b/benches/benches/array.rs @@ -0,0 +1,94 @@ +// Copyright 2018-2023 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Generating/filling arrays and iterators of output + +use criterion::{criterion_group, criterion_main, Criterion}; +use rand::distr::StandardUniform; +use rand::prelude::*; +use rand_pcg::Pcg64Mcg; + +criterion_group!( + name = benches; + config = Criterion::default(); + targets = bench +); +criterion_main!(benches); + +pub fn bench(c: &mut Criterion) { + let mut g = c.benchmark_group("random_1kb"); + g.throughput(criterion::Throughput::Bytes(1024)); + + g.bench_function("u16_iter_repeat", |b| { + use core::iter; + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + b.iter(|| { + let v: Vec = iter::repeat(()).map(|()| rng.random()).take(512).collect(); + v + }); + }); + + g.bench_function("u16_sample_iter", |b| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + b.iter(|| { + let v: Vec = StandardUniform.sample_iter(&mut rng).take(512).collect(); + v + }); + }); + + g.bench_function("u16_gen_array", |b| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + b.iter(|| { + let v: [u16; 512] = rng.random(); + v + }); + }); + + g.bench_function("u16_fill", |b| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + let mut buf = [0u16; 512]; + b.iter(|| { + rng.fill(&mut buf[..]); + buf + }); + }); + + g.bench_function("u64_iter_repeat", |b| { + use core::iter; + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + b.iter(|| { + let v: Vec = iter::repeat(()).map(|()| rng.random()).take(128).collect(); + v + }); + }); + + g.bench_function("u64_sample_iter", |b| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + b.iter(|| { + let v: Vec = StandardUniform.sample_iter(&mut rng).take(128).collect(); + v + }); + }); + + g.bench_function("u64_gen_array", |b| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + b.iter(|| { + let v: [u64; 128] = rng.random(); + v + }); + }); + + g.bench_function("u64_fill", |b| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + let mut buf = [0u64; 128]; + b.iter(|| { + rng.fill(&mut buf[..]); + buf + }); + }); +} diff --git a/benches/benches/bool.rs b/benches/benches/bool.rs new file mode 100644 index 00000000000..8ff8c676024 --- /dev/null +++ b/benches/benches/bool.rs @@ -0,0 +1,69 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Generating/filling arrays and iterators of output + +use criterion::{criterion_group, criterion_main, Criterion}; +use rand::distr::Bernoulli; +use rand::prelude::*; +use rand_pcg::Pcg32; + +criterion_group!( + name = benches; + config = Criterion::default(); + targets = bench +); +criterion_main!(benches); + +pub fn bench(c: &mut Criterion) { + let mut g = c.benchmark_group("random_bool"); + g.sample_size(1000); + g.warm_up_time(core::time::Duration::from_millis(500)); + g.measurement_time(core::time::Duration::from_millis(1000)); + + g.bench_function("standard", |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + b.iter(|| rng.sample::(rand::distr::StandardUniform)) + }); + + g.bench_function("const", |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + b.iter(|| rng.random_bool(0.18)) + }); + + g.bench_function("var", |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + let p = rng.random(); + b.iter(|| rng.random_bool(p)) + }); + + g.bench_function("ratio_const", |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + b.iter(|| rng.random_ratio(2, 3)) + }); + + g.bench_function("ratio_var", |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + let d = rng.random_range(1..=100); + let n = rng.random_range(0..=d); + b.iter(|| rng.random_ratio(n, d)); + }); + + g.bench_function("bernoulli_const", |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + let d = Bernoulli::new(0.18).unwrap(); + b.iter(|| rng.sample(d)) + }); + + g.bench_function("bernoulli_var", |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + let p = rng.random(); + let d = Bernoulli::new(p).unwrap(); + b.iter(|| rng.sample(d)) + }); +} diff --git a/benches/benches/distr.rs b/benches/benches/distr.rs new file mode 100644 index 00000000000..3a76211972d --- /dev/null +++ b/benches/benches/distr.rs @@ -0,0 +1,194 @@ +// Copyright 2018-2021 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use criterion::{criterion_group, criterion_main, Criterion, Throughput}; +use criterion_cycles_per_byte::CyclesPerByte; + +use rand::prelude::*; +use rand_distr::weighted::*; +use rand_distr::*; + +// At this time, distributions are optimised for 64-bit platforms. +use rand_pcg::Pcg64Mcg; + +const ITER_ELTS: u64 = 100; + +macro_rules! distr_int { + ($group:ident, $fnn:expr, $ty:ty, $distr:expr) => { + $group.bench_function($fnn, |c| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + let distr = $distr; + + c.iter(|| distr.sample(&mut rng)); + }); + }; +} + +macro_rules! distr_float { + ($group:ident, $fnn:expr, $ty:ty, $distr:expr) => { + $group.bench_function($fnn, |c| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + let distr = $distr; + + c.iter(|| Distribution::<$ty>::sample(&distr, &mut rng)); + }); + }; +} + +macro_rules! distr_arr { + ($group:ident, $fnn:expr, $ty:ty, $distr:expr) => { + $group.bench_function($fnn, |c| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + let distr = $distr; + + c.iter(|| Distribution::<$ty>::sample(&distr, &mut rng)); + }); + }; +} + +macro_rules! sample_binomial { + ($group:ident, $name:expr, $n:expr, $p:expr) => { + distr_int!($group, $name, u64, Binomial::new($n, $p).unwrap()) + }; +} + +fn bench(c: &mut Criterion) { + let mut g = c.benchmark_group("exp"); + distr_float!(g, "exp", f64, Exp::new(1.23 * 4.56).unwrap()); + distr_float!(g, "exp1_specialized", f64, Exp1); + distr_float!(g, "exp1_general", f64, Exp::new(1.).unwrap()); + g.finish(); + + let mut g = c.benchmark_group("normal"); + distr_float!(g, "normal", f64, Normal::new(-1.23, 4.56).unwrap()); + distr_float!(g, "standardnormal_specialized", f64, StandardNormal); + distr_float!(g, "standardnormal_general", f64, Normal::new(0., 1.).unwrap()); + distr_float!(g, "log_normal", f64, LogNormal::new(-1.23, 4.56).unwrap()); + g.throughput(Throughput::Elements(ITER_ELTS)); + g.bench_function("iter", |c| { + use core::f64::consts::{E, PI}; + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + let distr = Normal::new(-E, PI).unwrap(); + + c.iter(|| { + distr + .sample_iter(&mut rng) + .take(ITER_ELTS as usize) + .fold(0.0, |a, r| a + r) + }); + }); + g.finish(); + + let mut g = c.benchmark_group("skew_normal"); + distr_float!(g, "shape_zero", f64, SkewNormal::new(0.0, 1.0, 0.0).unwrap()); + distr_float!(g, "shape_positive", f64, SkewNormal::new(0.0, 1.0, 100.0).unwrap()); + distr_float!(g, "shape_negative", f64, SkewNormal::new(0.0, 1.0, -100.0).unwrap()); + g.finish(); + + let mut g = c.benchmark_group("gamma"); + distr_float!(g, "large_shape", f64, Gamma::new(10., 1.0).unwrap()); + distr_float!(g, "small_shape", f64, Gamma::new(0.1, 1.0).unwrap()); + g.finish(); + + let mut g = c.benchmark_group("beta"); + distr_float!(g, "small_param", f64, Beta::new(0.1, 0.1).unwrap()); + distr_float!(g, "large_param_similar", f64, Beta::new(101., 95.).unwrap()); + distr_float!(g, "large_param_different", f64, Beta::new(10., 1000.).unwrap()); + distr_float!(g, "mixed_param", f64, Beta::new(0.5, 100.).unwrap()); + g.finish(); + + let mut g = c.benchmark_group("cauchy"); + distr_float!(g, "cauchy", f64, Cauchy::new(4.2, 6.9).unwrap()); + g.finish(); + + let mut g = c.benchmark_group("triangular"); + distr_float!(g, "triangular", f64, Triangular::new(0., 1., 0.9).unwrap()); + g.finish(); + + let mut g = c.benchmark_group("geometric"); + distr_int!(g, "geometric", u64, Geometric::new(0.5).unwrap()); + distr_int!(g, "standard_geometric", u64, StandardGeometric); + g.finish(); + + let mut g = c.benchmark_group("weighted"); + distr_int!(g, "i8", usize, WeightedIndex::new([1i8, 2, 3, 4, 12, 0, 2, 1]).unwrap()); + distr_int!(g, "u32", usize, WeightedIndex::new([1u32, 2, 3, 4, 12, 0, 2, 1]).unwrap()); + distr_int!(g, "f64", usize, WeightedIndex::new([1.0f64, 0.001, 1.0 / 3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap()); + distr_int!(g, "large_set", usize, WeightedIndex::new((0..10000).rev().chain(1..10001)).unwrap()); + distr_int!(g, "alias_method_i8", usize, WeightedAliasIndex::new(vec![1i8, 2, 3, 4, 12, 0, 2, 1]).unwrap()); + distr_int!(g, "alias_method_u32", usize, WeightedAliasIndex::new(vec![1u32, 2, 3, 4, 12, 0, 2, 1]).unwrap()); + distr_int!( + g, + "alias_method_f64", + usize, + WeightedAliasIndex::new(vec![1.0f64, 0.001, 1.0 / 3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap() + ); + distr_int!( + g, + "alias_method_large_set", + usize, + WeightedAliasIndex::new((0..10000).rev().chain(1..10001).collect()).unwrap() + ); + g.finish(); + + let mut g = c.benchmark_group("binomial"); + sample_binomial!(g, "small", 1_000_000, 1e-30); + sample_binomial!(g, "1", 1, 0.9); + sample_binomial!(g, "10", 10, 0.9); + sample_binomial!(g, "100", 100, 0.99); + sample_binomial!(g, "1000", 1000, 0.01); + sample_binomial!(g, "1e12", 1_000_000_000_000, 0.2); + g.finish(); + + let mut g = c.benchmark_group("poisson"); + for lambda in [1f64, 4.0, 10.0, 100.0].into_iter() { + let name = format!("{lambda}"); + distr_float!(g, name, f64, Poisson::new(lambda).unwrap()); + } + g.throughput(Throughput::Elements(ITER_ELTS)); + g.bench_function("variable", |c| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + let ldistr = Uniform::new(0.1, 10.0).unwrap(); + + c.iter(|| { + let l = rng.sample(ldistr); + let distr = Poisson::new(l * l).unwrap(); + Distribution::::sample_iter(&distr, &mut rng) + .take(ITER_ELTS as usize) + .fold(0.0, |a, r| a + r) + }) + }); + g.finish(); + + let mut g = c.benchmark_group("zipf"); + distr_float!(g, "zipf", f64, Zipf::new(10.0, 1.5).unwrap()); + distr_float!(g, "zeta", f64, Zeta::new(1.5).unwrap()); + g.finish(); + + let mut g = c.benchmark_group("bernoulli"); + g.bench_function("bernoulli", |c| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + let distr = Bernoulli::new(0.18).unwrap(); + c.iter(|| distr.sample(&mut rng)) + }); + g.finish(); + + let mut g = c.benchmark_group("unit"); + distr_arr!(g, "circle", [f64; 2], UnitCircle); + distr_arr!(g, "sphere", [f64; 3], UnitSphere); + g.finish(); +} + +criterion_group!( + name = benches; + config = Criterion::default().with_measurement(CyclesPerByte) + .warm_up_time(core::time::Duration::from_secs(1)) + .measurement_time(core::time::Duration::from_secs(2)); + targets = bench +); +criterion_main!(benches); diff --git a/benches/benches/generators.rs b/benches/benches/generators.rs new file mode 100644 index 00000000000..64325ceb9ee --- /dev/null +++ b/benches/benches/generators.rs @@ -0,0 +1,221 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use core::time::Duration; +use criterion::measurement::WallTime; +use criterion::{black_box, criterion_group, criterion_main, BenchmarkGroup, Criterion}; +use rand::prelude::*; +use rand::rngs::ReseedingRng; +use rand::rngs::{mock::StepRng, OsRng}; +use rand_chacha::rand_core::UnwrapErr; +use rand_chacha::{ChaCha12Rng, ChaCha20Core, ChaCha20Rng, ChaCha8Rng}; +use rand_pcg::{Pcg32, Pcg64, Pcg64Dxsm, Pcg64Mcg}; + +criterion_group!( + name = benches; + config = Criterion::default(); + targets = random_bytes, random_u32, random_u64, init_gen, init_from_u64, init_from_seed, reseeding_bytes +); +criterion_main!(benches); + +pub fn random_bytes(c: &mut Criterion) { + let mut g = c.benchmark_group("random_bytes"); + g.warm_up_time(Duration::from_millis(500)); + g.measurement_time(Duration::from_millis(1000)); + g.throughput(criterion::Throughput::Bytes(1024)); + + fn bench(g: &mut BenchmarkGroup, name: &str, mut rng: impl Rng) { + g.bench_function(name, |b| { + let mut buf = [0u8; 1024]; + b.iter(|| { + rng.fill_bytes(&mut buf); + black_box(buf); + }); + }); + } + + bench(&mut g, "step", StepRng::new(0, 1)); + bench(&mut g, "pcg32", Pcg32::from_rng(&mut rand::rng())); + bench(&mut g, "pcg64", Pcg64::from_rng(&mut rand::rng())); + bench(&mut g, "pcg64mcg", Pcg64Mcg::from_rng(&mut rand::rng())); + bench(&mut g, "pcg64dxsm", Pcg64Dxsm::from_rng(&mut rand::rng())); + bench(&mut g, "chacha8", ChaCha8Rng::from_rng(&mut rand::rng())); + bench(&mut g, "chacha12", ChaCha12Rng::from_rng(&mut rand::rng())); + bench(&mut g, "chacha20", ChaCha20Rng::from_rng(&mut rand::rng())); + bench(&mut g, "std", StdRng::from_rng(&mut rand::rng())); + bench(&mut g, "small", SmallRng::from_rng(&mut rand::rng())); + bench(&mut g, "os", UnwrapErr(OsRng)); + bench(&mut g, "thread", rand::rng()); + + g.finish() +} + +pub fn random_u32(c: &mut Criterion) { + let mut g = c.benchmark_group("random_u32"); + g.sample_size(1000); + g.warm_up_time(Duration::from_millis(500)); + g.measurement_time(Duration::from_millis(1000)); + g.throughput(criterion::Throughput::Bytes(4)); + + fn bench(g: &mut BenchmarkGroup, name: &str, mut rng: impl Rng) { + g.bench_function(name, |b| { + b.iter(|| rng.random::()); + }); + } + + bench(&mut g, "step", StepRng::new(0, 1)); + bench(&mut g, "pcg32", Pcg32::from_rng(&mut rand::rng())); + bench(&mut g, "pcg64", Pcg64::from_rng(&mut rand::rng())); + bench(&mut g, "pcg64mcg", Pcg64Mcg::from_rng(&mut rand::rng())); + bench(&mut g, "pcg64dxsm", Pcg64Dxsm::from_rng(&mut rand::rng())); + bench(&mut g, "chacha8", ChaCha8Rng::from_rng(&mut rand::rng())); + bench(&mut g, "chacha12", ChaCha12Rng::from_rng(&mut rand::rng())); + bench(&mut g, "chacha20", ChaCha20Rng::from_rng(&mut rand::rng())); + bench(&mut g, "std", StdRng::from_rng(&mut rand::rng())); + bench(&mut g, "small", SmallRng::from_rng(&mut rand::rng())); + bench(&mut g, "os", UnwrapErr(OsRng)); + bench(&mut g, "thread", rand::rng()); + + g.finish() +} + +pub fn random_u64(c: &mut Criterion) { + let mut g = c.benchmark_group("random_u64"); + g.sample_size(1000); + g.warm_up_time(Duration::from_millis(500)); + g.measurement_time(Duration::from_millis(1000)); + g.throughput(criterion::Throughput::Bytes(8)); + + fn bench(g: &mut BenchmarkGroup, name: &str, mut rng: impl Rng) { + g.bench_function(name, |b| { + b.iter(|| rng.random::()); + }); + } + + bench(&mut g, "step", StepRng::new(0, 1)); + bench(&mut g, "pcg32", Pcg32::from_rng(&mut rand::rng())); + bench(&mut g, "pcg64", Pcg64::from_rng(&mut rand::rng())); + bench(&mut g, "pcg64mcg", Pcg64Mcg::from_rng(&mut rand::rng())); + bench(&mut g, "pcg64dxsm", Pcg64Dxsm::from_rng(&mut rand::rng())); + bench(&mut g, "chacha8", ChaCha8Rng::from_rng(&mut rand::rng())); + bench(&mut g, "chacha12", ChaCha12Rng::from_rng(&mut rand::rng())); + bench(&mut g, "chacha20", ChaCha20Rng::from_rng(&mut rand::rng())); + bench(&mut g, "std", StdRng::from_rng(&mut rand::rng())); + bench(&mut g, "small", SmallRng::from_rng(&mut rand::rng())); + bench(&mut g, "os", UnwrapErr(OsRng)); + bench(&mut g, "thread", rand::rng()); + + g.finish() +} + +pub fn init_gen(c: &mut Criterion) { + let mut g = c.benchmark_group("init_gen"); + g.warm_up_time(Duration::from_millis(500)); + g.measurement_time(Duration::from_millis(1000)); + + fn bench(g: &mut BenchmarkGroup, name: &str) { + g.bench_function(name, |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + b.iter(|| R::from_rng(&mut rng)); + }); + } + + bench::(&mut g, "pcg32"); + bench::(&mut g, "pcg64"); + bench::(&mut g, "pcg64mcg"); + bench::(&mut g, "pcg64dxsm"); + bench::(&mut g, "chacha8"); + bench::(&mut g, "chacha12"); + bench::(&mut g, "chacha20"); + bench::(&mut g, "std"); + bench::(&mut g, "small"); + + g.finish() +} + +pub fn init_from_u64(c: &mut Criterion) { + let mut g = c.benchmark_group("init_from_u64"); + g.warm_up_time(Duration::from_millis(500)); + g.measurement_time(Duration::from_millis(1000)); + + fn bench(g: &mut BenchmarkGroup, name: &str) { + g.bench_function(name, |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + let seed = rng.random(); + b.iter(|| R::seed_from_u64(black_box(seed))); + }); + } + + bench::(&mut g, "pcg32"); + bench::(&mut g, "pcg64"); + bench::(&mut g, "pcg64mcg"); + bench::(&mut g, "pcg64dxsm"); + bench::(&mut g, "chacha8"); + bench::(&mut g, "chacha12"); + bench::(&mut g, "chacha20"); + bench::(&mut g, "std"); + bench::(&mut g, "small"); + + g.finish() +} + +pub fn init_from_seed(c: &mut Criterion) { + let mut g = c.benchmark_group("init_from_seed"); + g.warm_up_time(Duration::from_millis(500)); + g.measurement_time(Duration::from_millis(1000)); + + fn bench(g: &mut BenchmarkGroup, name: &str) + where + rand::distr::StandardUniform: Distribution<::Seed>, + { + g.bench_function(name, |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + let seed = rng.random(); + b.iter(|| R::from_seed(black_box(seed.clone()))); + }); + } + + bench::(&mut g, "pcg32"); + bench::(&mut g, "pcg64"); + bench::(&mut g, "pcg64mcg"); + bench::(&mut g, "pcg64dxsm"); + bench::(&mut g, "chacha8"); + bench::(&mut g, "chacha12"); + bench::(&mut g, "chacha20"); + bench::(&mut g, "std"); + bench::(&mut g, "small"); + + g.finish() +} + +pub fn reseeding_bytes(c: &mut Criterion) { + let mut g = c.benchmark_group("reseeding_bytes"); + g.warm_up_time(Duration::from_millis(500)); + g.throughput(criterion::Throughput::Bytes(1024 * 1024)); + + fn bench(g: &mut BenchmarkGroup, thresh: u64) { + let name = format!("chacha20_{}k", thresh); + g.bench_function(name.as_str(), |b| { + let mut rng = ReseedingRng::::new(thresh * 1024, OsRng).unwrap(); + let mut buf = [0u8; 1024 * 1024]; + b.iter(|| { + rng.fill_bytes(&mut buf); + black_box(&buf); + }); + }); + } + + bench(&mut g, 4); + bench(&mut g, 16); + bench(&mut g, 32); + bench(&mut g, 64); + bench(&mut g, 256); + bench(&mut g, 1024); + + g.finish() +} diff --git a/benches/benches/seq_choose.rs b/benches/benches/seq_choose.rs new file mode 100644 index 00000000000..56223dd0a62 --- /dev/null +++ b/benches/benches/seq_choose.rs @@ -0,0 +1,180 @@ +// Copyright 2018-2023 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use rand::prelude::*; +use rand::SeedableRng; +use rand_pcg::Pcg32; + +criterion_group!( + name = benches; + config = Criterion::default(); + targets = bench +); +criterion_main!(benches); + +pub fn bench(c: &mut Criterion) { + c.bench_function("seq_slice_choose_1_of_100", |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + let mut buf = [0i32; 100]; + rng.fill(&mut buf); + let x = black_box(&mut buf); + + b.iter(|| x.choose(&mut rng).unwrap()); + }); + + let lens = [(1, 1000), (950, 1000), (10, 100), (90, 100)]; + for (amount, len) in lens { + let name = format!("seq_slice_choose_multiple_{}_of_{}", amount, len); + c.bench_function(name.as_str(), |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + let mut buf = [0i32; 1000]; + rng.fill(&mut buf); + let x = black_box(&buf[..len]); + + let mut results_buf = [0i32; 950]; + let y = black_box(&mut results_buf[..amount]); + let amount = black_box(amount); + + b.iter(|| { + // Collect full result to prevent unwanted shortcuts getting + // first element (in case sample_indices returns an iterator). + for (slot, sample) in y.iter_mut().zip(x.choose_multiple(&mut rng, amount)) { + *slot = *sample; + } + y[amount - 1] + }) + }); + } + + let lens = [(1, 1000), (950, 1000), (10, 100), (90, 100)]; + for (amount, len) in lens { + let name = format!("seq_slice_choose_multiple_weighted_{}_of_{}", amount, len); + c.bench_function(name.as_str(), |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + let mut buf = [0i32; 1000]; + rng.fill(&mut buf); + let x = black_box(&buf[..len]); + + let mut results_buf = [0i32; 950]; + let y = black_box(&mut results_buf[..amount]); + let amount = black_box(amount); + + b.iter(|| { + // Collect full result to prevent unwanted shortcuts getting + // first element (in case sample_indices returns an iterator). + let samples_iter = x.choose_multiple_weighted(&mut rng, amount, |_| 1.0).unwrap(); + for (slot, sample) in y.iter_mut().zip(samples_iter) { + *slot = *sample; + } + y[amount - 1] + }) + }); + } + + c.bench_function("seq_iter_choose_multiple_10_of_100", |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + let mut buf = [0i32; 100]; + rng.fill(&mut buf); + let x = black_box(&buf); + b.iter(|| x.iter().cloned().choose_multiple(&mut rng, 10)) + }); + + c.bench_function("seq_iter_choose_multiple_fill_10_of_100", |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + let mut buf = [0i32; 100]; + rng.fill(&mut buf); + let x = black_box(&buf); + let mut buf = [0; 10]; + b.iter(|| x.iter().cloned().choose_multiple_fill(&mut rng, &mut buf)) + }); + + bench_rng::(c, "ChaCha20"); + bench_rng::(c, "Pcg32"); + bench_rng::(c, "Pcg64"); +} + +fn bench_rng(c: &mut Criterion, rng_name: &'static str) { + for length in [1, 2, 3, 10, 100, 1000].map(black_box) { + let name = format!("choose_size-hinted_from_{length}_{rng_name}"); + c.bench_function(name.as_str(), |b| { + let mut rng = Rng::seed_from_u64(123); + b.iter(|| choose_size_hinted(length, &mut rng)) + }); + + let name = format!("choose_stable_from_{length}_{rng_name}"); + c.bench_function(name.as_str(), |b| { + let mut rng = Rng::seed_from_u64(123); + b.iter(|| choose_stable(length, &mut rng)) + }); + + let name = format!("choose_unhinted_from_{length}_{rng_name}"); + c.bench_function(name.as_str(), |b| { + let mut rng = Rng::seed_from_u64(123); + b.iter(|| choose_unhinted(length, &mut rng)) + }); + + let name = format!("choose_windowed_from_{length}_{rng_name}"); + c.bench_function(name.as_str(), |b| { + let mut rng = Rng::seed_from_u64(123); + b.iter(|| choose_windowed(length, 7, &mut rng)) + }); + } +} + +fn choose_size_hinted(max: usize, rng: &mut R) -> Option { + let iterator = 0..max; + iterator.choose(rng) +} + +fn choose_stable(max: usize, rng: &mut R) -> Option { + let iterator = 0..max; + iterator.choose_stable(rng) +} + +fn choose_unhinted(max: usize, rng: &mut R) -> Option { + let iterator = UnhintedIterator { iter: (0..max) }; + iterator.choose(rng) +} + +fn choose_windowed(max: usize, window_size: usize, rng: &mut R) -> Option { + let iterator = WindowHintedIterator { + iter: (0..max), + window_size, + }; + iterator.choose(rng) +} + +#[derive(Clone)] +struct UnhintedIterator { + iter: I, +} +impl Iterator for UnhintedIterator { + type Item = I::Item; + + fn next(&mut self) -> Option { + self.iter.next() + } +} + +#[derive(Clone)] +struct WindowHintedIterator { + iter: I, + window_size: usize, +} +impl Iterator for WindowHintedIterator { + type Item = I::Item; + + fn next(&mut self) -> Option { + self.iter.next() + } + + fn size_hint(&self) -> (usize, Option) { + (core::cmp::min(self.iter.len(), self.window_size), None) + } +} diff --git a/benches/benches/shuffle.rs b/benches/benches/shuffle.rs new file mode 100644 index 00000000000..c2f37daaeab --- /dev/null +++ b/benches/benches/shuffle.rs @@ -0,0 +1,61 @@ +// Copyright 2018-2023 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use rand::prelude::*; +use rand::SeedableRng; +use rand_pcg::Pcg32; + +criterion_group!( + name = benches; + config = Criterion::default(); + targets = bench +); +criterion_main!(benches); + +pub fn bench(c: &mut Criterion) { + c.bench_function("seq_shuffle_100", |b| { + let mut rng = Pcg32::from_rng(&mut rand::rng()); + let mut buf = [0i32; 100]; + rng.fill(&mut buf); + let x = black_box(&mut buf); + b.iter(|| { + x.shuffle(&mut rng); + x[0] + }) + }); + + bench_rng::(c, "ChaCha12"); + bench_rng::(c, "Pcg32"); + bench_rng::(c, "Pcg64"); +} + +fn bench_rng(c: &mut Criterion, rng_name: &'static str) { + for length in [1, 2, 3, 10, 100, 1000, 10000].map(black_box) { + c.bench_function(format!("shuffle_{length}_{rng_name}").as_str(), |b| { + let mut rng = Rng::seed_from_u64(123); + let mut vec: Vec = (0..length).collect(); + b.iter(|| { + vec.shuffle(&mut rng); + vec[0] + }) + }); + + if length >= 10 { + let name = format!("partial_shuffle_{length}_{rng_name}"); + c.bench_function(name.as_str(), |b| { + let mut rng = Rng::seed_from_u64(123); + let mut vec: Vec = (0..length).collect(); + b.iter(|| { + vec.partial_shuffle(&mut rng, length / 2); + vec[0] + }) + }); + } + } +} diff --git a/benches/benches/standard.rs b/benches/benches/standard.rs new file mode 100644 index 00000000000..ac38f0225f8 --- /dev/null +++ b/benches/benches/standard.rs @@ -0,0 +1,64 @@ +// Copyright 2019 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use core::time::Duration; +use criterion::measurement::WallTime; +use criterion::{criterion_group, criterion_main, BenchmarkGroup, Criterion}; +use rand::distr::{Alphanumeric, StandardUniform}; +use rand::prelude::*; +use rand_distr::{Open01, OpenClosed01}; +use rand_pcg::Pcg64Mcg; + +criterion_group!( + name = benches; + config = Criterion::default(); + targets = bench +); +criterion_main!(benches); + +fn bench_ty(g: &mut BenchmarkGroup, name: &str) +where + D: Distribution + Default, +{ + g.throughput(criterion::Throughput::Bytes(size_of::() as u64)); + g.bench_function(name, |b| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + + b.iter(|| rng.sample::(D::default())); + }); +} + +pub fn bench(c: &mut Criterion) { + let mut g = c.benchmark_group("StandardUniform"); + g.sample_size(1000); + g.warm_up_time(Duration::from_millis(500)); + g.measurement_time(Duration::from_millis(1000)); + + macro_rules! do_ty { + ($t:ty) => { + bench_ty::<$t, StandardUniform>(&mut g, stringify!($t)); + }; + ($t:ty, $($tt:ty),*) => { + do_ty!($t); + do_ty!($($tt),*); + }; + } + + do_ty!(i8, i16, i32, i64, i128); + do_ty!(f32, f64); + do_ty!(char); + + bench_ty::(&mut g, "Alphanumeric"); + + bench_ty::(&mut g, "Open01/f32"); + bench_ty::(&mut g, "Open01/f64"); + bench_ty::(&mut g, "OpenClosed01/f32"); + bench_ty::(&mut g, "OpenClosed01/f64"); + + g.finish(); +} diff --git a/benches/benches/uniform.rs b/benches/benches/uniform.rs new file mode 100644 index 00000000000..ab1b0ed4149 --- /dev/null +++ b/benches/benches/uniform.rs @@ -0,0 +1,78 @@ +// Copyright 2021 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Implement benchmarks for uniform distributions over integer types + +use core::time::Duration; +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use rand::distr::uniform::{SampleRange, Uniform}; +use rand::prelude::*; +use rand_chacha::ChaCha8Rng; +use rand_pcg::{Pcg32, Pcg64}; + +const WARM_UP_TIME: Duration = Duration::from_millis(1000); +const MEASUREMENT_TIME: Duration = Duration::from_secs(3); +const SAMPLE_SIZE: usize = 100_000; +const N_RESAMPLES: usize = 10_000; + +macro_rules! sample { + ($R:ty, $T:ty, $U:ty, $g:expr) => { + $g.bench_function(BenchmarkId::new(stringify!($R), "single"), |b| { + let mut rng = <$R>::from_rng(&mut rand::rng()); + let x = rng.random::<$U>(); + let bits = (<$T>::BITS / 2); + let mask = (1 as $U).wrapping_neg() >> bits; + let range = (x >> bits) * (x & mask); + let low = <$T>::MIN; + let high = low.wrapping_add(range as $T); + + b.iter(|| (low..=high).sample_single(&mut rng)); + }); + + $g.bench_function(BenchmarkId::new(stringify!($R), "distr"), |b| { + let mut rng = <$R>::from_rng(&mut rand::rng()); + let x = rng.random::<$U>(); + let bits = (<$T>::BITS / 2); + let mask = (1 as $U).wrapping_neg() >> bits; + let range = (x >> bits) * (x & mask); + let low = <$T>::MIN; + let high = low.wrapping_add(range as $T); + let dist = Uniform::<$T>::new_inclusive(<$T>::MIN, high).unwrap(); + + b.iter(|| dist.sample(&mut rng)); + }); + }; + + ($c:expr, $T:ty, $U:ty) => {{ + let mut g = $c.benchmark_group(concat!("sample", stringify!($T))); + g.sample_size(SAMPLE_SIZE); + g.warm_up_time(WARM_UP_TIME); + g.measurement_time(MEASUREMENT_TIME); + g.nresamples(N_RESAMPLES); + sample!(SmallRng, $T, $U, g); + sample!(ChaCha8Rng, $T, $U, g); + sample!(Pcg32, $T, $U, g); + sample!(Pcg64, $T, $U, g); + g.finish(); + }}; +} + +fn sample(c: &mut Criterion) { + sample!(c, i8, u8); + sample!(c, i16, u16); + sample!(c, i32, u32); + sample!(c, i64, u64); + sample!(c, i128, u128); +} + +criterion_group! { + name = benches; + config = Criterion::default(); + targets = sample +} +criterion_main!(benches); diff --git a/benches/benches/uniform_float.rs b/benches/benches/uniform_float.rs new file mode 100644 index 00000000000..03a434fc228 --- /dev/null +++ b/benches/benches/uniform_float.rs @@ -0,0 +1,103 @@ +// Copyright 2023 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Implement benchmarks for uniform distributions over FP types +//! +//! Sampling methods compared: +//! +//! - sample: current method: (x12 - 1.0) * (b - a) + a + +use core::time::Duration; +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use rand::distr::uniform::{SampleUniform, Uniform, UniformSampler}; +use rand::prelude::*; +use rand_chacha::ChaCha8Rng; +use rand_pcg::{Pcg32, Pcg64}; + +const WARM_UP_TIME: Duration = Duration::from_millis(1000); +const MEASUREMENT_TIME: Duration = Duration::from_secs(3); +const SAMPLE_SIZE: usize = 100_000; +const N_RESAMPLES: usize = 10_000; + +macro_rules! single_random { + ($R:ty, $T:ty, $g:expr) => { + $g.bench_function(BenchmarkId::new(stringify!($T), stringify!($R)), |b| { + let mut rng = <$R>::from_rng(&mut rand::rng()); + let (mut low, mut high); + loop { + low = <$T>::from_bits(rng.random()); + high = <$T>::from_bits(rng.random()); + if (low < high) && (high - low).is_normal() { + break; + } + } + + b.iter(|| <$T as SampleUniform>::Sampler::sample_single_inclusive(low, high, &mut rng)); + }); + }; + + ($c:expr, $T:ty) => {{ + let mut g = $c.benchmark_group("uniform_single"); + g.sample_size(SAMPLE_SIZE); + g.warm_up_time(WARM_UP_TIME); + g.measurement_time(MEASUREMENT_TIME); + g.nresamples(N_RESAMPLES); + single_random!(SmallRng, $T, g); + single_random!(ChaCha8Rng, $T, g); + single_random!(Pcg32, $T, g); + single_random!(Pcg64, $T, g); + g.finish(); + }}; +} + +fn single_random(c: &mut Criterion) { + single_random!(c, f32); + single_random!(c, f64); +} + +macro_rules! distr_random { + ($R:ty, $T:ty, $g:expr) => { + $g.bench_function(BenchmarkId::new(stringify!($T), stringify!($R)), |b| { + let mut rng = <$R>::from_rng(&mut rand::rng()); + let dist = loop { + let low = <$T>::from_bits(rng.random()); + let high = <$T>::from_bits(rng.random()); + if let Ok(dist) = Uniform::<$T>::new_inclusive(low, high) { + break dist; + } + }; + + b.iter(|| dist.sample(&mut rng)); + }); + }; + + ($c:expr, $T:ty) => {{ + let mut g = $c.benchmark_group("uniform_distribution"); + g.sample_size(SAMPLE_SIZE); + g.warm_up_time(WARM_UP_TIME); + g.measurement_time(MEASUREMENT_TIME); + g.nresamples(N_RESAMPLES); + distr_random!(SmallRng, $T, g); + distr_random!(ChaCha8Rng, $T, g); + distr_random!(Pcg32, $T, g); + distr_random!(Pcg64, $T, g); + g.finish(); + }}; +} + +fn distr_random(c: &mut Criterion) { + distr_random!(c, f32); + distr_random!(c, f64); +} + +criterion_group! { + name = benches; + config = Criterion::default(); + targets = single_random, distr_random +} +criterion_main!(benches); diff --git a/benches/benches/weighted.rs b/benches/benches/weighted.rs new file mode 100644 index 00000000000..69576b3608d --- /dev/null +++ b/benches/benches/weighted.rs @@ -0,0 +1,60 @@ +// Copyright 2019 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use rand::distr::weighted::WeightedIndex; +use rand::prelude::*; +use rand::seq::index::sample_weighted; + +criterion_group!( + name = benches; + config = Criterion::default(); + targets = bench +); +criterion_main!(benches); + +pub fn bench(c: &mut Criterion) { + c.bench_function("weighted_index_creation", |b| { + let mut rng = rand::rng(); + let weights = black_box([1u32, 2, 4, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 7]); + b.iter(|| { + let distr = WeightedIndex::new(weights.to_vec()).unwrap(); + rng.sample(distr) + }) + }); + + c.bench_function("weighted_index_modification", |b| { + let mut rng = rand::rng(); + let weights = black_box([1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]); + let mut distr = WeightedIndex::new(weights.to_vec()).unwrap(); + b.iter(|| { + distr.update_weights(&[(2, &4), (5, &1)]).unwrap(); + rng.sample(&distr) + }) + }); + + let lens = [ + (1, 1000, "1k"), + (10, 1000, "1k"), + (100, 1000, "1k"), + (100, 1_000_000, "1M"), + (200, 1_000_000, "1M"), + (400, 1_000_000, "1M"), + (600, 1_000_000, "1M"), + (1000, 1_000_000, "1M"), + ]; + for (amount, length, len_name) in lens { + let name = format!("weighted_sample_indices_{}_of_{}", amount, len_name); + c.bench_function(name.as_str(), |b| { + let length = black_box(length); + let amount = black_box(amount); + let mut rng = SmallRng::from_rng(&mut rand::rng()); + b.iter(|| sample_weighted(&mut rng, length, |idx| (1 + (idx % 100)) as u32, amount)) + }); + } +} diff --git a/benches/distributions.rs b/benches/distributions.rs deleted file mode 100644 index 7d8ac94c37b..00000000000 --- a/benches/distributions.rs +++ /dev/null @@ -1,440 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -#![feature(custom_inner_attributes)] -#![feature(test)] - -// Rustfmt splits macro invocations to shorten lines; in this case longer-lines are more readable -#![rustfmt::skip] - -extern crate test; - -const RAND_BENCH_N: u64 = 1000; - -use rand::distributions::{Alphanumeric, Open01, OpenClosed01, Standard, Uniform}; -use rand::distributions::uniform::{UniformInt, UniformSampler}; -use std::mem::size_of; -use std::num::{NonZeroU128, NonZeroU16, NonZeroU32, NonZeroU64, NonZeroU8}; -use std::time::Duration; -use test::{Bencher, black_box}; - -use rand::prelude::*; - -// At this time, distributions are optimised for 64-bit platforms. -use rand_pcg::Pcg64Mcg; - -macro_rules! distr_int { - ($fnn:ident, $ty:ty, $distr:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_entropy(); - let distr = $distr; - - b.iter(|| { - let mut accum = 0 as $ty; - for _ in 0..RAND_BENCH_N { - let x: $ty = distr.sample(&mut rng); - accum = accum.wrapping_add(x); - } - accum - }); - b.bytes = size_of::<$ty>() as u64 * RAND_BENCH_N; - } - }; -} - -macro_rules! distr_nz_int { - ($fnn:ident, $tynz:ty, $ty:ty, $distr:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_entropy(); - let distr = $distr; - - b.iter(|| { - let mut accum = 0 as $ty; - for _ in 0..RAND_BENCH_N { - let x: $tynz = distr.sample(&mut rng); - accum = accum.wrapping_add(x.get()); - } - accum - }); - b.bytes = size_of::<$ty>() as u64 * RAND_BENCH_N; - } - }; -} - -macro_rules! distr_float { - ($fnn:ident, $ty:ty, $distr:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_entropy(); - let distr = $distr; - - b.iter(|| { - let mut accum = 0.0; - for _ in 0..RAND_BENCH_N { - let x: $ty = distr.sample(&mut rng); - accum += x; - } - accum - }); - b.bytes = size_of::<$ty>() as u64 * RAND_BENCH_N; - } - }; -} - -macro_rules! distr_duration { - ($fnn:ident, $distr:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_entropy(); - let distr = $distr; - - b.iter(|| { - let mut accum = Duration::new(0, 0); - for _ in 0..RAND_BENCH_N { - let x: Duration = distr.sample(&mut rng); - accum = accum - .checked_add(x) - .unwrap_or(Duration::new(u64::max_value(), 999_999_999)); - } - accum - }); - b.bytes = size_of::() as u64 * RAND_BENCH_N; - } - }; -} - -macro_rules! distr { - ($fnn:ident, $ty:ty, $distr:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_entropy(); - let distr = $distr; - - b.iter(|| { - let mut accum = 0u32; - for _ in 0..RAND_BENCH_N { - let x: $ty = distr.sample(&mut rng); - accum = accum.wrapping_add(x as u32); - } - accum - }); - b.bytes = size_of::<$ty>() as u64 * RAND_BENCH_N; - } - }; -} - -// uniform -distr_int!(distr_uniform_i8, i8, Uniform::new(20i8, 100)); -distr_int!(distr_uniform_i16, i16, Uniform::new(-500i16, 2000)); -distr_int!(distr_uniform_i32, i32, Uniform::new(-200_000_000i32, 800_000_000)); -distr_int!(distr_uniform_i64, i64, Uniform::new(3i64, 123_456_789_123)); -distr_int!(distr_uniform_i128, i128, Uniform::new(-123_456_789_123i128, 123_456_789_123_456_789)); -distr_int!(distr_uniform_usize16, usize, Uniform::new(0usize, 0xb9d7)); -distr_int!(distr_uniform_usize32, usize, Uniform::new(0usize, 0x548c0f43)); -#[cfg(target_pointer_width = "64")] -distr_int!(distr_uniform_usize64, usize, Uniform::new(0usize, 0x3a42714f2bf927a8)); -distr_int!(distr_uniform_isize, isize, Uniform::new(-1060478432isize, 1858574057)); - -distr_float!(distr_uniform_f32, f32, Uniform::new(2.26f32, 2.319)); -distr_float!(distr_uniform_f64, f64, Uniform::new(2.26f64, 2.319)); - -const LARGE_SEC: u64 = u64::max_value() / 1000; - -distr_duration!(distr_uniform_duration_largest, - Uniform::new_inclusive(Duration::new(0, 0), Duration::new(u64::max_value(), 999_999_999)) -); -distr_duration!(distr_uniform_duration_large, - Uniform::new(Duration::new(0, 0), Duration::new(LARGE_SEC, 1_000_000_000 / 2)) -); -distr_duration!(distr_uniform_duration_one, - Uniform::new(Duration::new(0, 0), Duration::new(1, 0)) -); -distr_duration!(distr_uniform_duration_variety, - Uniform::new(Duration::new(10000, 423423), Duration::new(200000, 6969954)) -); -distr_duration!(distr_uniform_duration_edge, - Uniform::new_inclusive(Duration::new(LARGE_SEC, 999_999_999), Duration::new(LARGE_SEC + 1, 1)) -); - -// standard -distr_int!(distr_standard_i8, i8, Standard); -distr_int!(distr_standard_i16, i16, Standard); -distr_int!(distr_standard_i32, i32, Standard); -distr_int!(distr_standard_i64, i64, Standard); -distr_int!(distr_standard_i128, i128, Standard); -distr_nz_int!(distr_standard_nz8, NonZeroU8, u8, Standard); -distr_nz_int!(distr_standard_nz16, NonZeroU16, u16, Standard); -distr_nz_int!(distr_standard_nz32, NonZeroU32, u32, Standard); -distr_nz_int!(distr_standard_nz64, NonZeroU64, u64, Standard); -distr_nz_int!(distr_standard_nz128, NonZeroU128, u128, Standard); - -distr!(distr_standard_bool, bool, Standard); -distr!(distr_standard_alphanumeric, u8, Alphanumeric); -distr!(distr_standard_codepoint, char, Standard); - -distr_float!(distr_standard_f32, f32, Standard); -distr_float!(distr_standard_f64, f64, Standard); -distr_float!(distr_open01_f32, f32, Open01); -distr_float!(distr_open01_f64, f64, Open01); -distr_float!(distr_openclosed01_f32, f32, OpenClosed01); -distr_float!(distr_openclosed01_f64, f64, OpenClosed01); - -// construct and sample from a range -macro_rules! gen_range_int { - ($fnn:ident, $ty:ident, $low:expr, $high:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_entropy(); - - b.iter(|| { - let mut high = $high; - let mut accum: $ty = 0; - for _ in 0..RAND_BENCH_N { - accum = accum.wrapping_add(rng.gen_range($low..high)); - // force recalculation of range each time - high = high.wrapping_add(1) & std::$ty::MAX; - } - accum - }); - b.bytes = size_of::<$ty>() as u64 * RAND_BENCH_N; - } - }; -} - -// Algorithms such as Fisher–Yates shuffle often require uniform values from an -// incrementing range 0..n. We use -1..n here to prevent wrapping in the test -// from generating a 0-sized range. -gen_range_int!(gen_range_i8_low, i8, -1i8, 0); -gen_range_int!(gen_range_i16_low, i16, -1i16, 0); -gen_range_int!(gen_range_i32_low, i32, -1i32, 0); -gen_range_int!(gen_range_i64_low, i64, -1i64, 0); -gen_range_int!(gen_range_i128_low, i128, -1i128, 0); - -// These were the initially tested ranges. They are likely to see fewer -// rejections than the low tests. -gen_range_int!(gen_range_i8_high, i8, -20i8, 100); -gen_range_int!(gen_range_i16_high, i16, -500i16, 2000); -gen_range_int!(gen_range_i32_high, i32, -200_000_000i32, 800_000_000); -gen_range_int!(gen_range_i64_high, i64, 3i64, 123_456_789_123); -gen_range_int!(gen_range_i128_high, i128, -12345678901234i128, 123_456_789_123_456_789); - -// construct and sample from a floating-point range -macro_rules! gen_range_float { - ($fnn:ident, $ty:ident, $low:expr, $high:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_entropy(); - - b.iter(|| { - let mut high = $high; - let mut low = $low; - let mut accum: $ty = 0.0; - for _ in 0..RAND_BENCH_N { - accum += rng.gen_range(low..high); - // force recalculation of range each time - low += 0.9; - high += 1.1; - } - accum - }); - b.bytes = size_of::<$ty>() as u64 * RAND_BENCH_N; - } - }; -} - -gen_range_float!(gen_range_f32, f32, -20000.0f32, 100000.0); -gen_range_float!(gen_range_f64, f64, 123.456f64, 7890.12); - - -// In src/distributions/uniform.rs, we say: -// Implementation of [`uniform_single`] is optional, and is only useful when -// the implementation can be faster than `Self::new(low, high).sample(rng)`. - -// `UniformSampler::uniform_single` compromises on the rejection range to be -// faster. This benchmark demonstrates both the speed gain of doing this, and -// the worst case behavior. - -/// Sample random values from a pre-existing distribution. This uses the -/// half open `new` to be equivalent to the behavior of `uniform_single`. -macro_rules! uniform_sample { - ($fnn:ident, $type:ident, $low:expr, $high:expr, $count:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_entropy(); - let low = black_box($low); - let high = black_box($high); - b.iter(|| { - for _ in 0..10 { - let dist = UniformInt::<$type>::new(low, high); - for _ in 0..$count { - black_box(dist.sample(&mut rng)); - } - } - }); - } - }; -} - -macro_rules! uniform_inclusive { - ($fnn:ident, $type:ident, $low:expr, $high:expr, $count:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_entropy(); - let low = black_box($low); - let high = black_box($high); - b.iter(|| { - for _ in 0..10 { - let dist = UniformInt::<$type>::new_inclusive(low, high); - for _ in 0..$count { - black_box(dist.sample(&mut rng)); - } - } - }); - } - }; -} - -/// Use `uniform_single` to create a one-off random value -macro_rules! uniform_single { - ($fnn:ident, $type:ident, $low:expr, $high:expr, $count:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_entropy(); - let low = black_box($low); - let high = black_box($high); - b.iter(|| { - for _ in 0..(10 * $count) { - black_box(UniformInt::<$type>::sample_single(low, high, &mut rng)); - } - }); - } - }; -} - - -// Benchmark: -// n: can use the full generated range -// (n-1): only the max value is rejected: expect this to be fast -// n/2+1: almost half of the values are rejected, and we can do no better -// n/2: approximation rejects half the values but powers of 2 could have no rejection -// n/2-1: only a few values are rejected: expect this to be fast -// 6: approximation rejects 25% of values but could be faster. However modulo by -// low numbers is typically more expensive - -// With the use of u32 as the minimum generated width, the worst-case u16 range -// (32769) will only reject 32769 / 4294967296 samples. -const HALF_16_BIT_UNSIGNED: u16 = 1 << 15; - -uniform_sample!(uniform_u16x1_allm1_new, u16, 0, u16::max_value(), 1); -uniform_sample!(uniform_u16x1_halfp1_new, u16, 0, HALF_16_BIT_UNSIGNED + 1, 1); -uniform_sample!(uniform_u16x1_half_new, u16, 0, HALF_16_BIT_UNSIGNED, 1); -uniform_sample!(uniform_u16x1_halfm1_new, u16, 0, HALF_16_BIT_UNSIGNED - 1, 1); -uniform_sample!(uniform_u16x1_6_new, u16, 0, 6u16, 1); - -uniform_single!(uniform_u16x1_allm1_single, u16, 0, u16::max_value(), 1); -uniform_single!(uniform_u16x1_halfp1_single, u16, 0, HALF_16_BIT_UNSIGNED + 1, 1); -uniform_single!(uniform_u16x1_half_single, u16, 0, HALF_16_BIT_UNSIGNED, 1); -uniform_single!(uniform_u16x1_halfm1_single, u16, 0, HALF_16_BIT_UNSIGNED - 1, 1); -uniform_single!(uniform_u16x1_6_single, u16, 0, 6u16, 1); - -uniform_inclusive!(uniform_u16x10_all_new_inclusive, u16, 0, u16::max_value(), 10); -uniform_sample!(uniform_u16x10_allm1_new, u16, 0, u16::max_value(), 10); -uniform_sample!(uniform_u16x10_halfp1_new, u16, 0, HALF_16_BIT_UNSIGNED + 1, 10); -uniform_sample!(uniform_u16x10_half_new, u16, 0, HALF_16_BIT_UNSIGNED, 10); -uniform_sample!(uniform_u16x10_halfm1_new, u16, 0, HALF_16_BIT_UNSIGNED - 1, 10); -uniform_sample!(uniform_u16x10_6_new, u16, 0, 6u16, 10); - -uniform_single!(uniform_u16x10_allm1_single, u16, 0, u16::max_value(), 10); -uniform_single!(uniform_u16x10_halfp1_single, u16, 0, HALF_16_BIT_UNSIGNED + 1, 10); -uniform_single!(uniform_u16x10_half_single, u16, 0, HALF_16_BIT_UNSIGNED, 10); -uniform_single!(uniform_u16x10_halfm1_single, u16, 0, HALF_16_BIT_UNSIGNED - 1, 10); -uniform_single!(uniform_u16x10_6_single, u16, 0, 6u16, 10); - - -const HALF_32_BIT_UNSIGNED: u32 = 1 << 31; - -uniform_sample!(uniform_u32x1_allm1_new, u32, 0, u32::max_value(), 1); -uniform_sample!(uniform_u32x1_halfp1_new, u32, 0, HALF_32_BIT_UNSIGNED + 1, 1); -uniform_sample!(uniform_u32x1_half_new, u32, 0, HALF_32_BIT_UNSIGNED, 1); -uniform_sample!(uniform_u32x1_halfm1_new, u32, 0, HALF_32_BIT_UNSIGNED - 1, 1); -uniform_sample!(uniform_u32x1_6_new, u32, 0, 6u32, 1); - -uniform_single!(uniform_u32x1_allm1_single, u32, 0, u32::max_value(), 1); -uniform_single!(uniform_u32x1_halfp1_single, u32, 0, HALF_32_BIT_UNSIGNED + 1, 1); -uniform_single!(uniform_u32x1_half_single, u32, 0, HALF_32_BIT_UNSIGNED, 1); -uniform_single!(uniform_u32x1_halfm1_single, u32, 0, HALF_32_BIT_UNSIGNED - 1, 1); -uniform_single!(uniform_u32x1_6_single, u32, 0, 6u32, 1); - -uniform_inclusive!(uniform_u32x10_all_new_inclusive, u32, 0, u32::max_value(), 10); -uniform_sample!(uniform_u32x10_allm1_new, u32, 0, u32::max_value(), 10); -uniform_sample!(uniform_u32x10_halfp1_new, u32, 0, HALF_32_BIT_UNSIGNED + 1, 10); -uniform_sample!(uniform_u32x10_half_new, u32, 0, HALF_32_BIT_UNSIGNED, 10); -uniform_sample!(uniform_u32x10_halfm1_new, u32, 0, HALF_32_BIT_UNSIGNED - 1, 10); -uniform_sample!(uniform_u32x10_6_new, u32, 0, 6u32, 10); - -uniform_single!(uniform_u32x10_allm1_single, u32, 0, u32::max_value(), 10); -uniform_single!(uniform_u32x10_halfp1_single, u32, 0, HALF_32_BIT_UNSIGNED + 1, 10); -uniform_single!(uniform_u32x10_half_single, u32, 0, HALF_32_BIT_UNSIGNED, 10); -uniform_single!(uniform_u32x10_halfm1_single, u32, 0, HALF_32_BIT_UNSIGNED - 1, 10); -uniform_single!(uniform_u32x10_6_single, u32, 0, 6u32, 10); - -const HALF_64_BIT_UNSIGNED: u64 = 1 << 63; - -uniform_sample!(uniform_u64x1_allm1_new, u64, 0, u64::max_value(), 1); -uniform_sample!(uniform_u64x1_halfp1_new, u64, 0, HALF_64_BIT_UNSIGNED + 1, 1); -uniform_sample!(uniform_u64x1_half_new, u64, 0, HALF_64_BIT_UNSIGNED, 1); -uniform_sample!(uniform_u64x1_halfm1_new, u64, 0, HALF_64_BIT_UNSIGNED - 1, 1); -uniform_sample!(uniform_u64x1_6_new, u64, 0, 6u64, 1); - -uniform_single!(uniform_u64x1_allm1_single, u64, 0, u64::max_value(), 1); -uniform_single!(uniform_u64x1_halfp1_single, u64, 0, HALF_64_BIT_UNSIGNED + 1, 1); -uniform_single!(uniform_u64x1_half_single, u64, 0, HALF_64_BIT_UNSIGNED, 1); -uniform_single!(uniform_u64x1_halfm1_single, u64, 0, HALF_64_BIT_UNSIGNED - 1, 1); -uniform_single!(uniform_u64x1_6_single, u64, 0, 6u64, 1); - -uniform_inclusive!(uniform_u64x10_all_new_inclusive, u64, 0, u64::max_value(), 10); -uniform_sample!(uniform_u64x10_allm1_new, u64, 0, u64::max_value(), 10); -uniform_sample!(uniform_u64x10_halfp1_new, u64, 0, HALF_64_BIT_UNSIGNED + 1, 10); -uniform_sample!(uniform_u64x10_half_new, u64, 0, HALF_64_BIT_UNSIGNED, 10); -uniform_sample!(uniform_u64x10_halfm1_new, u64, 0, HALF_64_BIT_UNSIGNED - 1, 10); -uniform_sample!(uniform_u64x10_6_new, u64, 0, 6u64, 10); - -uniform_single!(uniform_u64x10_allm1_single, u64, 0, u64::max_value(), 10); -uniform_single!(uniform_u64x10_halfp1_single, u64, 0, HALF_64_BIT_UNSIGNED + 1, 10); -uniform_single!(uniform_u64x10_half_single, u64, 0, HALF_64_BIT_UNSIGNED, 10); -uniform_single!(uniform_u64x10_halfm1_single, u64, 0, HALF_64_BIT_UNSIGNED - 1, 10); -uniform_single!(uniform_u64x10_6_single, u64, 0, 6u64, 10); - -const HALF_128_BIT_UNSIGNED: u128 = 1 << 127; - -uniform_sample!(uniform_u128x1_allm1_new, u128, 0, u128::max_value(), 1); -uniform_sample!(uniform_u128x1_halfp1_new, u128, 0, HALF_128_BIT_UNSIGNED + 1, 1); -uniform_sample!(uniform_u128x1_half_new, u128, 0, HALF_128_BIT_UNSIGNED, 1); -uniform_sample!(uniform_u128x1_halfm1_new, u128, 0, HALF_128_BIT_UNSIGNED - 1, 1); -uniform_sample!(uniform_u128x1_6_new, u128, 0, 6u128, 1); - -uniform_single!(uniform_u128x1_allm1_single, u128, 0, u128::max_value(), 1); -uniform_single!(uniform_u128x1_halfp1_single, u128, 0, HALF_128_BIT_UNSIGNED + 1, 1); -uniform_single!(uniform_u128x1_half_single, u128, 0, HALF_128_BIT_UNSIGNED, 1); -uniform_single!(uniform_u128x1_halfm1_single, u128, 0, HALF_128_BIT_UNSIGNED - 1, 1); -uniform_single!(uniform_u128x1_6_single, u128, 0, 6u128, 1); - -uniform_inclusive!(uniform_u128x10_all_new_inclusive, u128, 0, u128::max_value(), 10); -uniform_sample!(uniform_u128x10_allm1_new, u128, 0, u128::max_value(), 10); -uniform_sample!(uniform_u128x10_halfp1_new, u128, 0, HALF_128_BIT_UNSIGNED + 1, 10); -uniform_sample!(uniform_u128x10_half_new, u128, 0, HALF_128_BIT_UNSIGNED, 10); -uniform_sample!(uniform_u128x10_halfm1_new, u128, 0, HALF_128_BIT_UNSIGNED - 1, 10); -uniform_sample!(uniform_u128x10_6_new, u128, 0, 6u128, 10); - -uniform_single!(uniform_u128x10_allm1_single, u128, 0, u128::max_value(), 10); -uniform_single!(uniform_u128x10_halfp1_single, u128, 0, HALF_128_BIT_UNSIGNED + 1, 10); -uniform_single!(uniform_u128x10_half_single, u128, 0, HALF_128_BIT_UNSIGNED, 10); -uniform_single!(uniform_u128x10_halfm1_single, u128, 0, HALF_128_BIT_UNSIGNED - 1, 10); -uniform_single!(uniform_u128x10_6_single, u128, 0, 6u128, 10); diff --git a/benches/generators.rs b/benches/generators.rs deleted file mode 100644 index 3e264083d7d..00000000000 --- a/benches/generators.rs +++ /dev/null @@ -1,165 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -#![feature(test)] -#![allow(non_snake_case)] - -extern crate test; - -const RAND_BENCH_N: u64 = 1000; -const BYTES_LEN: usize = 1024; - -use std::mem::size_of; -use test::{black_box, Bencher}; - -use rand::prelude::*; -use rand::rngs::adapter::ReseedingRng; -use rand::rngs::{mock::StepRng, OsRng}; -use rand_chacha::{ChaCha12Rng, ChaCha20Core, ChaCha20Rng, ChaCha8Rng}; -use rand_hc::Hc128Rng; -use rand_pcg::{Pcg32, Pcg64, Pcg64Mcg}; - -macro_rules! gen_bytes { - ($fnn:ident, $gen:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = $gen; - let mut buf = [0u8; BYTES_LEN]; - b.iter(|| { - for _ in 0..RAND_BENCH_N { - rng.fill_bytes(&mut buf); - black_box(buf); - } - }); - b.bytes = BYTES_LEN as u64 * RAND_BENCH_N; - } - }; -} - -gen_bytes!(gen_bytes_step, StepRng::new(0, 1)); -gen_bytes!(gen_bytes_pcg32, Pcg32::from_entropy()); -gen_bytes!(gen_bytes_pcg64, Pcg64::from_entropy()); -gen_bytes!(gen_bytes_pcg64mcg, Pcg64Mcg::from_entropy()); -gen_bytes!(gen_bytes_chacha8, ChaCha8Rng::from_entropy()); -gen_bytes!(gen_bytes_chacha12, ChaCha12Rng::from_entropy()); -gen_bytes!(gen_bytes_chacha20, ChaCha20Rng::from_entropy()); -gen_bytes!(gen_bytes_hc128, Hc128Rng::from_entropy()); -gen_bytes!(gen_bytes_std, StdRng::from_entropy()); -#[cfg(feature = "small_rng")] -gen_bytes!(gen_bytes_small, SmallRng::from_entropy()); -gen_bytes!(gen_bytes_os, OsRng); - -macro_rules! gen_uint { - ($fnn:ident, $ty:ty, $gen:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = $gen; - b.iter(|| { - let mut accum: $ty = 0; - for _ in 0..RAND_BENCH_N { - accum = accum.wrapping_add(rng.gen::<$ty>()); - } - accum - }); - b.bytes = size_of::<$ty>() as u64 * RAND_BENCH_N; - } - }; -} - -gen_uint!(gen_u32_step, u32, StepRng::new(0, 1)); -gen_uint!(gen_u32_pcg32, u32, Pcg32::from_entropy()); -gen_uint!(gen_u32_pcg64, u32, Pcg64::from_entropy()); -gen_uint!(gen_u32_pcg64mcg, u32, Pcg64Mcg::from_entropy()); -gen_uint!(gen_u32_chacha8, u32, ChaCha8Rng::from_entropy()); -gen_uint!(gen_u32_chacha12, u32, ChaCha12Rng::from_entropy()); -gen_uint!(gen_u32_chacha20, u32, ChaCha20Rng::from_entropy()); -gen_uint!(gen_u32_hc128, u32, Hc128Rng::from_entropy()); -gen_uint!(gen_u32_std, u32, StdRng::from_entropy()); -#[cfg(feature = "small_rng")] -gen_uint!(gen_u32_small, u32, SmallRng::from_entropy()); -gen_uint!(gen_u32_os, u32, OsRng); - -gen_uint!(gen_u64_step, u64, StepRng::new(0, 1)); -gen_uint!(gen_u64_pcg32, u64, Pcg32::from_entropy()); -gen_uint!(gen_u64_pcg64, u64, Pcg64::from_entropy()); -gen_uint!(gen_u64_pcg64mcg, u64, Pcg64Mcg::from_entropy()); -gen_uint!(gen_u64_chacha8, u64, ChaCha8Rng::from_entropy()); -gen_uint!(gen_u64_chacha12, u64, ChaCha12Rng::from_entropy()); -gen_uint!(gen_u64_chacha20, u64, ChaCha20Rng::from_entropy()); -gen_uint!(gen_u64_hc128, u64, Hc128Rng::from_entropy()); -gen_uint!(gen_u64_std, u64, StdRng::from_entropy()); -#[cfg(feature = "small_rng")] -gen_uint!(gen_u64_small, u64, SmallRng::from_entropy()); -gen_uint!(gen_u64_os, u64, OsRng); - -macro_rules! init_gen { - ($fnn:ident, $gen:ident) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = Pcg32::from_entropy(); - b.iter(|| { - let r2 = $gen::from_rng(&mut rng).unwrap(); - r2 - }); - } - }; -} - -init_gen!(init_pcg32, Pcg32); -init_gen!(init_pcg64, Pcg64); -init_gen!(init_pcg64mcg, Pcg64Mcg); -init_gen!(init_hc128, Hc128Rng); -init_gen!(init_chacha, ChaCha20Rng); - -const RESEEDING_BYTES_LEN: usize = 1024 * 1024; -const RESEEDING_BENCH_N: u64 = 16; - -macro_rules! reseeding_bytes { - ($fnn:ident, $thresh:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = ReseedingRng::new(ChaCha20Core::from_entropy(), $thresh * 1024, OsRng); - let mut buf = [0u8; RESEEDING_BYTES_LEN]; - b.iter(|| { - for _ in 0..RESEEDING_BENCH_N { - rng.fill_bytes(&mut buf); - black_box(&buf); - } - }); - b.bytes = RESEEDING_BYTES_LEN as u64 * RESEEDING_BENCH_N; - } - }; -} - -reseeding_bytes!(reseeding_chacha20_4k, 4); -reseeding_bytes!(reseeding_chacha20_16k, 16); -reseeding_bytes!(reseeding_chacha20_32k, 32); -reseeding_bytes!(reseeding_chacha20_64k, 64); -reseeding_bytes!(reseeding_chacha20_256k, 256); -reseeding_bytes!(reseeding_chacha20_1M, 1024); - - -macro_rules! threadrng_uint { - ($fnn:ident, $ty:ty) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = thread_rng(); - b.iter(|| { - let mut accum: $ty = 0; - for _ in 0..RAND_BENCH_N { - accum = accum.wrapping_add(rng.gen::<$ty>()); - } - accum - }); - b.bytes = size_of::<$ty>() as u64 * RAND_BENCH_N; - } - }; -} - -threadrng_uint!(thread_rng_u32, u32); -threadrng_uint!(thread_rng_u64, u64); diff --git a/benches/misc.rs b/benches/misc.rs deleted file mode 100644 index 11d12eb24ad..00000000000 --- a/benches/misc.rs +++ /dev/null @@ -1,183 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -#![feature(test)] - -extern crate test; - -const RAND_BENCH_N: u64 = 1000; - -use test::Bencher; - -use rand::distributions::{Bernoulli, Distribution, Standard}; -use rand::prelude::*; -use rand_pcg::{Pcg32, Pcg64Mcg}; - -#[bench] -fn misc_gen_bool_const(b: &mut Bencher) { - let mut rng = Pcg32::from_rng(&mut thread_rng()).unwrap(); - b.iter(|| { - let mut accum = true; - for _ in 0..crate::RAND_BENCH_N { - accum ^= rng.gen_bool(0.18); - } - accum - }) -} - -#[bench] -fn misc_gen_bool_var(b: &mut Bencher) { - let mut rng = Pcg32::from_rng(&mut thread_rng()).unwrap(); - b.iter(|| { - let mut accum = true; - let mut p = 0.18; - for _ in 0..crate::RAND_BENCH_N { - accum ^= rng.gen_bool(p); - p += 0.0001; - } - accum - }) -} - -#[bench] -fn misc_gen_ratio_const(b: &mut Bencher) { - let mut rng = Pcg32::from_rng(&mut thread_rng()).unwrap(); - b.iter(|| { - let mut accum = true; - for _ in 0..crate::RAND_BENCH_N { - accum ^= rng.gen_ratio(2, 3); - } - accum - }) -} - -#[bench] -fn misc_gen_ratio_var(b: &mut Bencher) { - let mut rng = Pcg32::from_rng(&mut thread_rng()).unwrap(); - b.iter(|| { - let mut accum = true; - for i in 2..(crate::RAND_BENCH_N as u32 + 2) { - accum ^= rng.gen_ratio(i, i + 1); - } - accum - }) -} - -#[bench] -fn misc_bernoulli_const(b: &mut Bencher) { - let mut rng = Pcg32::from_rng(&mut thread_rng()).unwrap(); - b.iter(|| { - let d = rand::distributions::Bernoulli::new(0.18).unwrap(); - let mut accum = true; - for _ in 0..crate::RAND_BENCH_N { - accum ^= rng.sample(d); - } - accum - }) -} - -#[bench] -fn misc_bernoulli_var(b: &mut Bencher) { - let mut rng = Pcg32::from_rng(&mut thread_rng()).unwrap(); - b.iter(|| { - let mut accum = true; - let mut p = 0.18; - for _ in 0..crate::RAND_BENCH_N { - let d = Bernoulli::new(p).unwrap(); - accum ^= rng.sample(d); - p += 0.0001; - } - accum - }) -} - -#[bench] -fn gen_1kb_u16_iter_repeat(b: &mut Bencher) { - use std::iter; - let mut rng = Pcg64Mcg::from_rng(&mut thread_rng()).unwrap(); - b.iter(|| { - let v: Vec = iter::repeat(()).map(|()| rng.gen()).take(512).collect(); - v - }); - b.bytes = 1024; -} - -#[bench] -fn gen_1kb_u16_sample_iter(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_rng(&mut thread_rng()).unwrap(); - b.iter(|| { - let v: Vec = Standard.sample_iter(&mut rng).take(512).collect(); - v - }); - b.bytes = 1024; -} - -#[bench] -fn gen_1kb_u16_gen_array(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_rng(&mut thread_rng()).unwrap(); - b.iter(|| { - // max supported array length is 32! - let v: [[u16; 32]; 16] = rng.gen(); - v - }); - b.bytes = 1024; -} - -#[bench] -fn gen_1kb_u16_fill(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_rng(&mut thread_rng()).unwrap(); - let mut buf = [0u16; 512]; - b.iter(|| { - rng.fill(&mut buf[..]); - buf - }); - b.bytes = 1024; -} - -#[bench] -fn gen_1kb_u64_iter_repeat(b: &mut Bencher) { - use std::iter; - let mut rng = Pcg64Mcg::from_rng(&mut thread_rng()).unwrap(); - b.iter(|| { - let v: Vec = iter::repeat(()).map(|()| rng.gen()).take(128).collect(); - v - }); - b.bytes = 1024; -} - -#[bench] -fn gen_1kb_u64_sample_iter(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_rng(&mut thread_rng()).unwrap(); - b.iter(|| { - let v: Vec = Standard.sample_iter(&mut rng).take(128).collect(); - v - }); - b.bytes = 1024; -} - -#[bench] -fn gen_1kb_u64_gen_array(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_rng(&mut thread_rng()).unwrap(); - b.iter(|| { - // max supported array length is 32! - let v: [[u64; 32]; 4] = rng.gen(); - v - }); - b.bytes = 1024; -} - -#[bench] -fn gen_1kb_u64_fill(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_rng(&mut thread_rng()).unwrap(); - let mut buf = [0u64; 128]; - b.iter(|| { - rng.fill(&mut buf[..]); - buf - }); - b.bytes = 1024; -} diff --git a/benches/rustfmt.toml b/benches/rustfmt.toml new file mode 100644 index 00000000000..b64fd7ad0e6 --- /dev/null +++ b/benches/rustfmt.toml @@ -0,0 +1,2 @@ +max_width = 120 +fn_call_width = 108 diff --git a/benches/seq.rs b/benches/seq.rs deleted file mode 100644 index 5b6fccf51ee..00000000000 --- a/benches/seq.rs +++ /dev/null @@ -1,200 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -#![feature(test)] -#![allow(non_snake_case)] - -extern crate test; - -use test::Bencher; - -use rand::prelude::*; -use rand::seq::*; -use std::mem::size_of; - -// We force use of 32-bit RNG since seq code is optimised for use with 32-bit -// generators on all platforms. -use rand_pcg::Pcg32 as SmallRng; - -const RAND_BENCH_N: u64 = 1000; - -#[bench] -fn seq_shuffle_100(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - let x: &mut [usize] = &mut [1; 100]; - b.iter(|| { - x.shuffle(&mut rng); - x[0] - }) -} - -#[bench] -fn seq_slice_choose_1_of_1000(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - let x: &mut [usize] = &mut [1; 1000]; - for i in 0..1000 { - x[i] = i; - } - b.iter(|| { - let mut s = 0; - for _ in 0..RAND_BENCH_N { - s += x.choose(&mut rng).unwrap(); - } - s - }); - b.bytes = size_of::() as u64 * crate::RAND_BENCH_N; -} - -macro_rules! seq_slice_choose_multiple { - ($name:ident, $amount:expr, $length:expr) => { - #[bench] - fn $name(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - let x: &[i32] = &[$amount; $length]; - let mut result = [0i32; $amount]; - b.iter(|| { - // Collect full result to prevent unwanted shortcuts getting - // first element (in case sample_indices returns an iterator). - for (slot, sample) in result.iter_mut().zip(x.choose_multiple(&mut rng, $amount)) { - *slot = *sample; - } - result[$amount - 1] - }) - } - }; -} - -seq_slice_choose_multiple!(seq_slice_choose_multiple_1_of_1000, 1, 1000); -seq_slice_choose_multiple!(seq_slice_choose_multiple_950_of_1000, 950, 1000); -seq_slice_choose_multiple!(seq_slice_choose_multiple_10_of_100, 10, 100); -seq_slice_choose_multiple!(seq_slice_choose_multiple_90_of_100, 90, 100); - -#[bench] -fn seq_iter_choose_from_1000(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - let x: &mut [usize] = &mut [1; 1000]; - for i in 0..1000 { - x[i] = i; - } - b.iter(|| { - let mut s = 0; - for _ in 0..RAND_BENCH_N { - s += x.iter().choose(&mut rng).unwrap(); - } - s - }); - b.bytes = size_of::() as u64 * crate::RAND_BENCH_N; -} - -#[derive(Clone)] -struct UnhintedIterator { - iter: I, -} -impl Iterator for UnhintedIterator { - type Item = I::Item; - - fn next(&mut self) -> Option { - self.iter.next() - } -} - -#[derive(Clone)] -struct WindowHintedIterator { - iter: I, - window_size: usize, -} -impl Iterator for WindowHintedIterator { - type Item = I::Item; - - fn next(&mut self) -> Option { - self.iter.next() - } - - fn size_hint(&self) -> (usize, Option) { - (std::cmp::min(self.iter.len(), self.window_size), None) - } -} - -#[bench] -fn seq_iter_unhinted_choose_from_1000(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - let x: &[usize] = &[1; 1000]; - b.iter(|| { - UnhintedIterator { iter: x.iter() } - .choose(&mut rng) - .unwrap() - }) -} - -#[bench] -fn seq_iter_window_hinted_choose_from_1000(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - let x: &[usize] = &[1; 1000]; - b.iter(|| { - WindowHintedIterator { - iter: x.iter(), - window_size: 7, - } - .choose(&mut rng) - }) -} - -#[bench] -fn seq_iter_choose_multiple_10_of_100(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - let x: &[usize] = &[1; 100]; - b.iter(|| x.iter().cloned().choose_multiple(&mut rng, 10)) -} - -#[bench] -fn seq_iter_choose_multiple_fill_10_of_100(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - let x: &[usize] = &[1; 100]; - let mut buf = [0; 10]; - b.iter(|| x.iter().cloned().choose_multiple_fill(&mut rng, &mut buf)) -} - -macro_rules! sample_indices { - ($name:ident, $fn:ident, $amount:expr, $length:expr) => { - #[bench] - fn $name(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - b.iter(|| index::$fn(&mut rng, $length, $amount)) - } - }; -} - -sample_indices!(misc_sample_indices_1_of_1k, sample, 1, 1000); -sample_indices!(misc_sample_indices_10_of_1k, sample, 10, 1000); -sample_indices!(misc_sample_indices_100_of_1k, sample, 100, 1000); -sample_indices!(misc_sample_indices_100_of_1M, sample, 100, 1000_000); -sample_indices!(misc_sample_indices_100_of_1G, sample, 100, 1000_000_000); -sample_indices!(misc_sample_indices_200_of_1G, sample, 200, 1000_000_000); -sample_indices!(misc_sample_indices_400_of_1G, sample, 400, 1000_000_000); -sample_indices!(misc_sample_indices_600_of_1G, sample, 600, 1000_000_000); - -macro_rules! sample_indices_rand_weights { - ($name:ident, $amount:expr, $length:expr) => { - #[bench] - fn $name(b: &mut Bencher) { - let mut rng = SmallRng::from_rng(thread_rng()).unwrap(); - b.iter(|| { - index::sample_weighted(&mut rng, $length, |idx| (1 + (idx % 100)) as u32, $amount) - }) - } - }; -} - -sample_indices_rand_weights!(misc_sample_weighted_indices_1_of_1k, 1, 1000); -sample_indices_rand_weights!(misc_sample_weighted_indices_10_of_1k, 10, 1000); -sample_indices_rand_weights!(misc_sample_weighted_indices_100_of_1k, 100, 1000); -sample_indices_rand_weights!(misc_sample_weighted_indices_100_of_1M, 100, 1000_000); -sample_indices_rand_weights!(misc_sample_weighted_indices_200_of_1M, 200, 1000_000); -sample_indices_rand_weights!(misc_sample_weighted_indices_400_of_1M, 400, 1000_000); -sample_indices_rand_weights!(misc_sample_weighted_indices_600_of_1M, 600, 1000_000); -sample_indices_rand_weights!(misc_sample_weighted_indices_1k_of_1M, 1000, 1000_000); diff --git a/benches/weighted.rs b/benches/weighted.rs deleted file mode 100644 index 68722908a9e..00000000000 --- a/benches/weighted.rs +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2019 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -#![feature(test)] - -extern crate test; - -use rand::distributions::WeightedIndex; -use rand::Rng; -use test::Bencher; - -#[bench] -fn weighted_index_creation(b: &mut Bencher) { - let mut rng = rand::thread_rng(); - let weights = [1u32, 2, 4, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 7]; - b.iter(|| { - let distr = WeightedIndex::new(weights.to_vec()).unwrap(); - rng.sample(distr) - }) -} - -#[bench] -fn weighted_index_modification(b: &mut Bencher) { - let mut rng = rand::thread_rng(); - let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; - let mut distr = WeightedIndex::new(weights.to_vec()).unwrap(); - b.iter(|| { - distr.update_weights(&[(2, &4), (5, &1)]).unwrap(); - rng.sample(&distr) - }) -} diff --git a/clippy.toml b/clippy.toml new file mode 100644 index 00000000000..14793c52048 --- /dev/null +++ b/clippy.toml @@ -0,0 +1,2 @@ +# Don't warn about these identifiers when using clippy::doc_markdown. +doc-valid-idents = ["ChaCha", "ChaCha12", "SplitMix64", "ZiB", ".."] diff --git a/distr_test/Cargo.toml b/distr_test/Cargo.toml new file mode 100644 index 00000000000..d9d7fe2c274 --- /dev/null +++ b/distr_test/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "distr_test" +version = "0.1.0" +edition = "2021" +publish = false + +[dev-dependencies] +rand_distr = { path = "../rand_distr", version = "0.5.0", default-features = false, features = ["alloc"] } +rand = { path = "..", version = "0.9.0", features = ["small_rng"] } +num-traits = "0.2.19" +# Special functions for testing distributions +special = "0.11.0" +spfunc = "0.1.0" +# Cdf implementation +statrs = "0.17.1" diff --git a/distr_test/tests/cdf.rs b/distr_test/tests/cdf.rs new file mode 100644 index 00000000000..f417c630ae2 --- /dev/null +++ b/distr_test/tests/cdf.rs @@ -0,0 +1,454 @@ +// Copyright 2021 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use core::f64; + +use special::{Beta, Gamma, Primitive}; +use statrs::distribution::ContinuousCDF; +use statrs::distribution::DiscreteCDF; + +mod ks; +use ks::test_continuous; +use ks::test_discrete; + +#[test] +fn normal() { + let parameters = [ + (0.0, 1.0), + (0.0, 0.1), + (1.0, 10.0), + (1.0, 100.0), + (-1.0, 0.00001), + (-1.0, 0.0000001), + ]; + + for (seed, (mean, std_dev)) in parameters.into_iter().enumerate() { + test_continuous( + seed as u64, + rand_distr::Normal::new(mean, std_dev).unwrap(), + |x| { + statrs::distribution::Normal::new(mean, std_dev) + .unwrap() + .cdf(x) + }, + ); + } +} + +#[test] +fn cauchy() { + let parameters = [ + (0.0, 1.0), + (0.0, 0.1), + (1.0, 10.0), + (1.0, 100.0), + (-1.0, 0.00001), + (-1.0, 0.0000001), + ]; + + for (seed, (median, scale)) in parameters.into_iter().enumerate() { + let dist = rand_distr::Cauchy::new(median, scale).unwrap(); + test_continuous(seed as u64, dist, |x| { + statrs::distribution::Cauchy::new(median, scale) + .unwrap() + .cdf(x) + }); + } +} + +#[test] +fn uniform() { + fn cdf(x: f64, a: f64, b: f64) -> f64 { + if x < a { + 0.0 + } else if x < b { + (x - a) / (b - a) + } else { + 1.0 + } + } + + let parameters = [(0.0, 1.0), (-1.0, 1.0), (0.0, 100.0), (-100.0, 100.0)]; + + for (seed, (a, b)) in parameters.into_iter().enumerate() { + let dist = rand_distr::Uniform::new(a, b).unwrap(); + test_continuous(seed as u64, dist, |x| cdf(x, a, b)); + } +} + +#[test] +fn log_normal() { + let parameters = [ + (0.0, 1.0), + (0.0, 0.1), + (0.5, 0.7), + (1.0, 10.0), + (1.0, 100.0), + ]; + + for (seed, (mean, std_dev)) in parameters.into_iter().enumerate() { + let dist = rand_distr::LogNormal::new(mean, std_dev).unwrap(); + test_continuous(seed as u64, dist, |x| { + statrs::distribution::LogNormal::new(mean, std_dev) + .unwrap() + .cdf(x) + }); + } +} + +#[test] +fn pareto() { + let parameters = [ + (1.0, 1.0), + (1.0, 0.1), + (1.0, 10.0), + (1.0, 100.0), + (0.1, 1.0), + (10.0, 1.0), + (100.0, 1.0), + ]; + + for (seed, (scale, alpha)) in parameters.into_iter().enumerate() { + let dist = rand_distr::Pareto::new(scale, alpha).unwrap(); + test_continuous(seed as u64, dist, |x| { + statrs::distribution::Pareto::new(scale, alpha) + .unwrap() + .cdf(x) + }); + } +} + +#[test] +fn exp() { + fn cdf(x: f64, lambda: f64) -> f64 { + 1.0 - (-lambda * x).exp() + } + + let parameters = [0.5, 1.0, 7.5, 32.0, 100.0]; + + for (seed, lambda) in parameters.into_iter().enumerate() { + let dist = rand_distr::Exp::new(lambda).unwrap(); + test_continuous(seed as u64, dist, |x| cdf(x, lambda)); + } +} + +#[test] +fn weibull() { + fn cdf(x: f64, lambda: f64, k: f64) -> f64 { + if x < 0.0 { + return 0.0; + } + + 1.0 - (-(x / lambda).powf(k)).exp() + } + + let parameters = [ + (0.5, 1.0), + (1.0, 1.0), + (10.0, 0.1), + (0.1, 10.0), + (15.0, 20.0), + (1000.0, 0.01), + ]; + + for (seed, (lambda, k)) in parameters.into_iter().enumerate() { + let dist = rand_distr::Weibull::new(lambda, k).unwrap(); + test_continuous(seed as u64, dist, |x| cdf(x, lambda, k)); + } +} + +#[test] +fn gumbel() { + fn cdf(x: f64, mu: f64, beta: f64) -> f64 { + (-(-(x - mu) / beta).exp()).exp() + } + + let parameters = [ + (0.0, 1.0), + (1.0, 2.0), + (-1.0, 0.5), + (10.0, 0.1), + (100.0, 0.0001), + ]; + + for (seed, (mu, beta)) in parameters.into_iter().enumerate() { + let dist = rand_distr::Gumbel::new(mu, beta).unwrap(); + test_continuous(seed as u64, dist, |x| cdf(x, mu, beta)); + } +} + +#[test] +fn frechet() { + fn cdf(x: f64, alpha: f64, s: f64, m: f64) -> f64 { + if x < m { + return 0.0; + } + + (-((x - m) / s).powf(-alpha)).exp() + } + + let parameters = [ + (0.5, 2.0, 1.0), + (1.0, 1.0, 1.0), + (10.0, 0.1, 1.0), + (100.0, 0.0001, 1.0), + (0.9999, 2.0, 1.0), + ]; + + for (seed, (alpha, s, m)) in parameters.into_iter().enumerate() { + let dist = rand_distr::Frechet::new(m, s, alpha).unwrap(); + test_continuous(seed as u64, dist, |x| cdf(x, alpha, s, m)); + } +} + +#[test] +fn gamma() { + fn cdf(x: f64, shape: f64, scale: f64) -> f64 { + if x < 0.0 { + return 0.0; + } + + (x / scale).inc_gamma(shape) + } + + let parameters = [ + (0.5, 2.0), + (1.0, 1.0), + (10.0, 0.1), + (100.0, 0.0001), + (0.9999, 2.0), + ]; + + for (seed, (shape, scale)) in parameters.into_iter().enumerate() { + let dist = rand_distr::Gamma::new(shape, scale).unwrap(); + test_continuous(seed as u64, dist, |x| cdf(x, shape, scale)); + } +} + +#[test] +fn chi_squared() { + fn cdf(x: f64, k: f64) -> f64 { + if x < 0.0 { + return 0.0; + } + + (x / 2.0).inc_gamma(k / 2.0) + } + + let parameters = [0.1, 1.0, 2.0, 10.0, 100.0, 1000.0]; + + for (seed, k) in parameters.into_iter().enumerate() { + let dist = rand_distr::ChiSquared::new(k).unwrap(); + test_continuous(seed as u64, dist, |x| cdf(x, k)); + } +} +#[test] +fn studend_t() { + fn cdf(x: f64, df: f64) -> f64 { + let h = df / (df + x.powi(2)); + let ib = 0.5 * h.inc_beta(df / 2.0, 0.5, 0.5.ln_beta(df / 2.0)); + if x < 0.0 { + ib + } else { + 1.0 - ib + } + } + + let parameters = [1.0, 10.0, 50.0]; + + for (seed, df) in parameters.into_iter().enumerate() { + let dist = rand_distr::StudentT::new(df).unwrap(); + test_continuous(seed as u64, dist, |x| cdf(x, df)); + } +} + +#[test] +fn fisher_f() { + fn cdf(x: f64, m: f64, n: f64) -> f64 { + if (m == 1.0 && x <= 0.0) || x < 0.0 { + 0.0 + } else { + let k = m * x / (m * x + n); + let d1 = m / 2.0; + let d2 = n / 2.0; + k.inc_beta(d1, d2, d1.ln_beta(d2)) + } + } + + let parameters = [(1.0, 1.0), (1.0, 2.0), (2.0, 1.0), (50.0, 1.0)]; + + for (seed, (m, n)) in parameters.into_iter().enumerate() { + let dist = rand_distr::FisherF::new(m, n).unwrap(); + test_continuous(seed as u64, dist, |x| cdf(x, m, n)); + } +} + +#[test] +fn beta() { + fn cdf(x: f64, alpha: f64, beta: f64) -> f64 { + if x < 0.0 { + return 0.0; + } + if x > 1.0 { + return 1.0; + } + let ln_beta_ab = alpha.ln_beta(beta); + x.inc_beta(alpha, beta, ln_beta_ab) + } + + let parameters = [(0.5, 0.5), (2.0, 3.5), (10.0, 1.0), (100.0, 50.0)]; + + for (seed, (alpha, beta)) in parameters.into_iter().enumerate() { + let dist = rand_distr::Beta::new(alpha, beta).unwrap(); + test_continuous(seed as u64, dist, |x| cdf(x, alpha, beta)); + } +} + +#[test] +fn triangular() { + fn cdf(x: f64, a: f64, b: f64, c: f64) -> f64 { + if x <= a { + 0.0 + } else if a < x && x <= c { + (x - a).powi(2) / ((b - a) * (c - a)) + } else if c < x && x < b { + 1.0 - (b - x).powi(2) / ((b - a) * (b - c)) + } else { + 1.0 + } + } + + let parameters = [ + (0.0, 1.0, 0.0001), + (0.0, 1.0, 0.9999), + (0.0, 1.0, 0.5), + (0.0, 100.0, 50.0), + (-100.0, 100.0, 0.0), + ]; + + for (seed, (a, b, c)) in parameters.into_iter().enumerate() { + let dist = rand_distr::Triangular::new(a, b, c).unwrap(); + test_continuous(seed as u64, dist, |x| cdf(x, a, b, c)); + } +} + +fn binomial_cdf(k: i64, p: f64, n: u64) -> f64 { + if k < 0 { + return 0.0; + } + let k = k as u64; + if k >= n { + return 1.0; + } + + let a = (n - k) as f64; + let b = k as f64 + 1.0; + + let q = 1.0 - p; + + let ln_beta_ab = a.ln_beta(b); + + q.inc_beta(a, b, ln_beta_ab) +} + +#[test] +fn binomial() { + let parameters = [ + (0.5, 10), + (0.5, 100), + (0.1, 10), + (0.0000001, 1000000), + (0.0000001, 10), + (0.9999, 2), + ]; + + for (seed, (p, n)) in parameters.into_iter().enumerate() { + test_discrete(seed as u64, rand_distr::Binomial::new(n, p).unwrap(), |k| { + binomial_cdf(k, p, n) + }); + } +} + +#[test] +fn geometric() { + fn cdf(k: i64, p: f64) -> f64 { + if k < 0 { + 0.0 + } else { + 1.0 - (1.0 - p).powi(1 + k as i32) + } + } + + let parameters = [0.3, 0.5, 0.7, 0.0000001, 0.9999]; + + for (seed, p) in parameters.into_iter().enumerate() { + let dist = rand_distr::Geometric::new(p).unwrap(); + test_discrete(seed as u64, dist, |k| cdf(k, p)); + } +} + +#[test] +fn hypergeometric() { + fn cdf(x: i64, n: u64, k: u64, n_: u64) -> f64 { + let min = if n_ + k > n { n_ + k - n } else { 0 }; + let max = k.min(n_); + if x < min as i64 { + return 0.0; + } else if x >= max as i64 { + return 1.0; + } + + (min..x as u64 + 1).fold(0.0, |acc, k_| { + acc + (ln_binomial(k, k_) + ln_binomial(n - k, n_ - k_) - ln_binomial(n, n_)).exp() + }) + } + + let parameters = [ + (15, 13, 10), + (25, 15, 5), + (60, 10, 7), + (70, 20, 50), + (100, 50, 10), + (100, 50, 49), + ]; + + for (seed, (n, k, n_)) in parameters.into_iter().enumerate() { + let dist = rand_distr::Hypergeometric::new(n, k, n_).unwrap(); + test_discrete(seed as u64, dist, |x| cdf(x, n, k, n_)); + } +} + +#[test] +fn poisson() { + use rand_distr::Poisson; + let parameters = [ + 0.1, 1.0, 7.5, + 45.0, // 1e9, passed case but too slow + // 1.844E+19, // fail case + ]; + + for (seed, lambda) in parameters.into_iter().enumerate() { + let dist = Poisson::new(lambda).unwrap(); + let analytic = statrs::distribution::Poisson::new(lambda).unwrap(); + test_discrete::, _>(seed as u64, dist, |k| { + if k < 0 { + 0.0 + } else { + analytic.cdf(k as u64) + } + }); + } +} + +fn ln_factorial(n: u64) -> f64 { + (n as f64 + 1.0).lgamma().0 +} + +fn ln_binomial(n: u64, k: u64) -> f64 { + ln_factorial(n) - ln_factorial(k) - ln_factorial(n - k) +} diff --git a/distr_test/tests/ks/mod.rs b/distr_test/tests/ks/mod.rs new file mode 100644 index 00000000000..ab94db6e1f4 --- /dev/null +++ b/distr_test/tests/ks/mod.rs @@ -0,0 +1,137 @@ +// Copyright 2021 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +// [1] Nonparametric Goodness-of-Fit Tests for Discrete Null Distributions +// by Taylor B. Arnold and John W. Emerson +// http://www.stat.yale.edu/~jay/EmersonMaterials/DiscreteGOF.pdf + +#![allow(dead_code)] + +use num_traits::AsPrimitive; +use rand::SeedableRng; +use rand_distr::Distribution; + +/// Empirical Cumulative Distribution Function (ECDF) +struct Ecdf { + sorted_samples: Vec, +} + +impl Ecdf { + fn new(mut samples: Vec) -> Self { + samples.sort_by(|a, b| a.partial_cmp(b).unwrap()); + Self { + sorted_samples: samples, + } + } + + /// Returns the step points of the ECDF + /// The ECDF is a step function that increases by 1/n at each sample point + /// The function is continuous from the right, so we give the bigger value at the step points + /// First point is (-inf, 0.0), last point is (max(samples), 1.0) + fn step_points(&self) -> Vec<(f64, f64)> { + let mut points = Vec::with_capacity(self.sorted_samples.len() + 1); + let mut last = f64::NEG_INFINITY; + let mut count = 0; + let n = self.sorted_samples.len() as f64; + for &x in &self.sorted_samples { + if x != last { + points.push((last, count as f64 / n)); + last = x; + } + count += 1; + } + points.push((last, count as f64 / n)); + points + } +} + +fn kolmogorov_smirnov_statistic_continuous(ecdf: Ecdf, cdf: impl Fn(f64) -> f64) -> f64 { + // We implement equation (3) from [1] + + let mut max_diff: f64 = 0.; + + let step_points = ecdf.step_points(); // x_i in the paper + for i in 1..step_points.len() { + let (x_i, f_i) = step_points[i]; + let (_, f_i_1) = step_points[i - 1]; + let cdf_i = cdf(x_i); + let max_1 = (cdf_i - f_i).abs(); + let max_2 = (cdf_i - f_i_1).abs(); + + max_diff = max_diff.max(max_1).max(max_2); + } + max_diff +} + +fn kolmogorov_smirnov_statistic_discrete(ecdf: Ecdf, cdf: impl Fn(i64) -> f64) -> f64 { + // We implement equation (4) from [1] + + let mut max_diff: f64 = 0.; + + let step_points = ecdf.step_points(); // x_i in the paper + for i in 1..step_points.len() { + let (x_i, f_i) = step_points[i]; + let (_, f_i_1) = step_points[i - 1]; + let max_1 = (cdf(x_i as i64) - f_i).abs(); + let max_2 = (cdf(x_i as i64 - 1) - f_i_1).abs(); // -1 is the same as -epsilon, because we have integer support + + max_diff = max_diff.max(max_1).max(max_2); + } + max_diff +} + +const SAMPLE_SIZE: u64 = 1_000_000; + +fn critical_value() -> f64 { + // If the sampler is correct, we expect less than 0.001 false positives (alpha = 0.001). + // Passing this does not prove that the sampler is correct but is a good indication. + 1.95 / (SAMPLE_SIZE as f64).sqrt() +} + +fn sample_ecdf(seed: u64, dist: impl Distribution) -> Ecdf +where + T: AsPrimitive, +{ + let mut rng = rand::rngs::SmallRng::seed_from_u64(seed); + let samples = (0..SAMPLE_SIZE) + .map(|_| dist.sample(&mut rng).as_()) + .collect(); + Ecdf::new(samples) +} + +/// Tests a distribution against an analytical CDF. +/// The CDF has to be continuous. +pub fn test_continuous(seed: u64, dist: impl Distribution, cdf: impl Fn(f64) -> f64) { + let ecdf = sample_ecdf(seed, dist); + let ks_statistic = kolmogorov_smirnov_statistic_continuous(ecdf, cdf); + + let critical_value = critical_value(); + + println!("KS statistic: {}", ks_statistic); + println!("Critical value: {}", critical_value); + assert!(ks_statistic < critical_value); +} + +/// Tests a distribution over integers against an analytical CDF. +/// The analytical CDF must not have jump points which are not integers. +pub fn test_discrete(seed: u64, dist: D, cdf: F) +where + I: AsPrimitive, + D: Distribution, + F: Fn(i64) -> f64, +{ + let ecdf = sample_ecdf(seed, dist); + let ks_statistic = kolmogorov_smirnov_statistic_discrete(ecdf, cdf); + + // This critical value is bigger than it could be for discrete distributions, but because of large sample sizes this should not matter too much + let critical_value = critical_value(); + + println!("KS statistic: {}", ks_statistic); + println!("Critical value: {}", critical_value); + assert!(ks_statistic < critical_value); +} diff --git a/distr_test/tests/skew_normal.rs b/distr_test/tests/skew_normal.rs new file mode 100644 index 00000000000..0e6b7b3a028 --- /dev/null +++ b/distr_test/tests/skew_normal.rs @@ -0,0 +1,266 @@ +// Copyright 2021 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +mod ks; +use ks::test_continuous; +use special::Primitive; + +#[test] +fn skew_normal() { + fn cdf(x: f64, location: f64, scale: f64, shape: f64) -> f64 { + let norm = (x - location) / scale; + phi(norm) - 2.0 * owen_t(norm, shape) + } + + let parameters = [(0.0, 1.0, 5.0), (1.0, 10.0, -5.0), (-1.0, 0.00001, 0.0)]; + + for (seed, (location, scale, shape)) in parameters.into_iter().enumerate() { + let dist = rand_distr::SkewNormal::new(location, scale, shape).unwrap(); + test_continuous(seed as u64, dist, |x| cdf(x, location, scale, shape)); + } +} + +/// [1] Patefield, M. (2000). Fast and Accurate Calculation of Owen’s T Function. +/// Journal of Statistical Software, 5(5), 1–25. +/// https://doi.org/10.18637/jss.v005.i05 +/// +/// This function is ported to Rust from the Fortran code provided in the paper +fn owen_t(h: f64, a: f64) -> f64 { + let absh = h.abs(); + let absa = a.abs(); + let ah = absa * absh; + + let mut t; + if absa <= 1.0 { + t = tf(absh, absa, ah); + } else if absh <= 0.67 { + t = 0.25 - znorm1(absh) * znorm1(ah) - tf(ah, 1.0 / absa, absh); + } else { + let normh = znorm2(absh); + let normah = znorm2(ah); + t = 0.5 * (normh + normah) - normh * normah - tf(ah, 1.0 / absa, absh); + } + + if a < 0.0 { + t = -t; + } + + fn tf(h: f64, a: f64, ah: f64) -> f64 { + let rtwopi = 0.159_154_943_091_895_35; + let rrtpi = 0.398_942_280_401_432_7; + + let c2 = [ + 0.999_999_999_999_999_9, + -0.999_999_999_999_888, + 0.999_999_999_982_907_5, + -0.999_999_998_962_825, + 0.999_999_966_604_593_7, + -0.999_999_339_862_724_7, + 0.999_991_256_111_369_6, + -0.999_917_776_244_633_8, + 0.999_428_355_558_701_4, + -0.996_973_117_207_23, + 0.987_514_480_372_753, + -0.959_158_579_805_728_8, + 0.892_463_055_110_067_1, + -0.768_934_259_904_64, + 0.588_935_284_684_846_9, + -0.383_803_451_604_402_55, + 0.203_176_017_010_453, + -8.281_363_160_700_499e-2, + 2.416_798_473_575_957_8e-2, + -4.467_656_666_397_183e-3, + 3.914_116_940_237_383_6e-4, + ]; + + let pts = [ + 3.508_203_967_645_171_6e-3, + 3.127_904_233_803_075_6e-2, + 8.526_682_628_321_945e-2, + 0.162_450_717_308_122_77, + 0.258_511_960_491_254_36, + 0.368_075_538_406_975_3, + 0.485_010_929_056_047, + 0.602_775_141_526_185_7, + 0.714_778_842_177_532_3, + 0.814_755_109_887_601, + 0.897_110_297_559_489_7, + 0.957_238_080_859_442_6, + 0.991_788_329_746_297, + ]; + + let wts = [ + 1.883_143_811_532_350_3e-2, + 1.856_708_624_397_765e-2, + 1.804_209_346_122_338_5e-2, + 1.726_382_960_639_875_2e-2, + 1.624_321_997_598_985_8e-2, + 1.499_459_203_411_670_5e-2, + 1.353_547_446_966_209e-2, + 1.188_635_160_582_016_5e-2, + 1.007_037_724_277_743_2e-2, + 8.113_054_574_229_958e-3, + 6.041_900_952_847_024e-3, + 3.886_221_701_074_205_7e-3, + 1.679_303_108_454_609e-3, + ]; + + let hrange = [ + 0.02, 0.06, 0.09, 0.125, 0.26, 0.4, 0.6, 1.6, 1.7, 2.33, 2.4, 3.36, 3.4, 4.8, + ]; + let arange = [0.025, 0.09, 0.15, 0.36, 0.5, 0.9, 0.99999]; + + let select = [ + [1, 1, 2, 13, 13, 13, 13, 13, 13, 13, 13, 16, 16, 16, 9], + [1, 2, 2, 3, 3, 5, 5, 14, 14, 15, 15, 16, 16, 16, 9], + [2, 2, 3, 3, 3, 5, 5, 15, 15, 15, 15, 16, 16, 16, 10], + [2, 2, 3, 5, 5, 5, 5, 7, 7, 16, 16, 16, 16, 16, 10], + [2, 3, 3, 5, 5, 6, 6, 8, 8, 17, 17, 17, 12, 12, 11], + [2, 3, 5, 5, 5, 6, 6, 8, 8, 17, 17, 17, 12, 12, 12], + [2, 3, 4, 4, 6, 6, 8, 8, 17, 17, 17, 17, 17, 12, 12], + [2, 3, 4, 4, 6, 6, 18, 18, 18, 18, 17, 17, 17, 12, 12], + ]; + + let ihint = hrange.iter().position(|&r| h < r).unwrap_or(14); + + let iaint = arange.iter().position(|&r| a < r).unwrap_or(7); + + let icode = select[iaint][ihint]; + let m = [ + 2, 3, 4, 5, 7, 10, 12, 18, 10, 20, 30, 20, 4, 7, 8, 20, 13, 0, + ][icode - 1]; + let method = [1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 3, 4, 4, 4, 4, 5, 6][icode - 1]; + + match method { + 1 => { + let hs = -0.5 * h * h; + let dhs = hs.exp(); + let as_ = a * a; + let mut j = 1; + let mut jj = 1; + let mut aj = rtwopi * a; + let mut tf = rtwopi * a.atan(); + let mut dj = dhs - 1.0; + let mut gj = hs * dhs; + loop { + tf += dj * aj / (jj as f64); + if j >= m { + return tf; + } + j += 1; + jj += 2; + aj *= as_; + dj = gj - dj; + gj *= hs / (j as f64); + } + } + 2 => { + let maxii = m + m + 1; + let mut ii = 1; + let mut tf = 0.0; + let hs = h * h; + let as_ = -a * a; + let mut vi = rrtpi * a * (-0.5 * ah * ah).exp(); + let mut z = znorm1(ah) / h; + let y = 1.0 / hs; + loop { + tf += z; + if ii >= maxii { + tf *= rrtpi * (-0.5 * hs).exp(); + return tf; + } + z = y * (vi - (ii as f64) * z); + vi *= as_; + ii += 2; + } + } + 3 => { + let mut i = 1; + let mut ii = 1; + let mut tf = 0.0; + let hs = h * h; + let as_ = a * a; + let mut vi = rrtpi * a * (-0.5 * ah * ah).exp(); + let mut zi = znorm1(ah) / h; + let y = 1.0 / hs; + loop { + tf += zi * c2[i - 1]; + if i > m { + tf *= rrtpi * (-0.5 * hs).exp(); + return tf; + } + zi = y * ((ii as f64) * zi - vi); + vi *= as_; + i += 1; + ii += 2; + } + } + 4 => { + let maxii = m + m + 1; + let mut ii = 1; + let mut tf = 0.0; + let hs = h * h; + let as_ = -a * a; + let mut ai = rtwopi * a * (-0.5 * hs * (1.0 - as_)).exp(); + let mut yi = 1.0; + loop { + tf += ai * yi; + if ii >= maxii { + return tf; + } + ii += 2; + yi = (1.0 - hs * yi) / (ii as f64); + ai *= as_; + } + } + 5 => { + let mut tf = 0.0; + let as_ = a * a; + let hs = -0.5 * h * h; + for i in 0..m { + let r = 1.0 + as_ * pts[i]; + tf += wts[i] * (hs * r).exp() / r; + } + tf *= a; + tf + } + 6 => { + let normh = znorm2(h); + let mut tf = 0.5 * normh * (1.0 - normh); + let y = 1.0 - a; + let r = (y / (1.0 + a)).atan(); + if r != 0.0 { + tf -= rtwopi * r * (-0.5 * y * h * h / r).exp(); + } + tf + } + _ => 0.0, + } + } + + // P(0 ≤ Z ≤ x) + fn znorm1(x: f64) -> f64 { + phi(x) - 0.5 + } + + // P(x ≤ Z < ∞) + fn znorm2(x: f64) -> f64 { + 1.0 - phi(x) + } + + t +} + +fn normal_cdf(x: f64, mean: f64, std_dev: f64) -> f64 { + 0.5 * ((mean - x) / (std_dev * core::f64::consts::SQRT_2)).erfc() +} + +/// standard normal cdf +fn phi(x: f64) -> f64 { + normal_cdf(x, 0.0, 1.0) +} diff --git a/distr_test/tests/weighted.rs b/distr_test/tests/weighted.rs new file mode 100644 index 00000000000..73df7beb9bc --- /dev/null +++ b/distr_test/tests/weighted.rs @@ -0,0 +1,235 @@ +// Copyright 2024 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +mod ks; +use ks::test_discrete; +use rand::distr::Distribution; +use rand::seq::{IndexedRandom, IteratorRandom}; +use rand_distr::weighted::*; + +/// Takes the unnormalized pdf and creates the cdf of a discrete distribution +fn make_cdf(num: usize, f: impl Fn(i64) -> f64) -> impl Fn(i64) -> f64 { + let mut cdf = Vec::with_capacity(num); + let mut ac = 0.0; + for i in 0..num { + ac += f(i as i64); + cdf.push(ac); + } + + let frac = 1.0 / ac; + for x in &mut cdf { + *x *= frac; + } + + move |i| { + if i < 0 { + 0.0 + } else { + cdf[i as usize] + } + } +} + +#[test] +fn weighted_index() { + fn test_weights(num: usize, weight: impl Fn(i64) -> f64) { + let distr = WeightedIndex::new((0..num).map(|i| weight(i as i64))).unwrap(); + test_discrete(0, distr, make_cdf(num, weight)); + } + + test_weights(100, |_| 1.0); + test_weights(100, |i| ((i + 1) as f64).ln()); + test_weights(100, |i| i as f64); + test_weights(100, |i| (i as f64).powi(3)); + test_weights(100, |i| 1.0 / ((i + 1) as f64)); +} + +#[test] +fn weighted_alias_index() { + fn test_weights(num: usize, weight: impl Fn(i64) -> f64) { + let weights = (0..num).map(|i| weight(i as i64)).collect(); + let distr = WeightedAliasIndex::new(weights).unwrap(); + test_discrete(0, distr, make_cdf(num, weight)); + } + + test_weights(100, |_| 1.0); + test_weights(100, |i| ((i + 1) as f64).ln()); + test_weights(100, |i| i as f64); + test_weights(100, |i| (i as f64).powi(3)); + test_weights(100, |i| 1.0 / ((i + 1) as f64)); +} + +#[test] +fn weighted_tree_index() { + fn test_weights(num: usize, weight: impl Fn(i64) -> f64) { + let distr = WeightedTreeIndex::new((0..num).map(|i| weight(i as i64))).unwrap(); + test_discrete(0, distr, make_cdf(num, weight)); + } + + test_weights(100, |_| 1.0); + test_weights(100, |i| ((i + 1) as f64).ln()); + test_weights(100, |i| i as f64); + test_weights(100, |i| (i as f64).powi(3)); + test_weights(100, |i| 1.0 / ((i + 1) as f64)); +} + +#[test] +fn choose_weighted_indexed() { + struct Adapter f64>(Vec, F); + impl f64> Distribution for Adapter { + fn sample(&self, rng: &mut R) -> i64 { + *IndexedRandom::choose_weighted(&self.0[..], rng, |i| (self.1)(*i)).unwrap() + } + } + + fn test_weights(num: usize, weight: impl Fn(i64) -> f64) { + let distr = Adapter((0..num).map(|i| i as i64).collect(), &weight); + test_discrete(0, distr, make_cdf(num, &weight)); + } + + test_weights(100, |_| 1.0); + test_weights(100, |i| ((i + 1) as f64).ln()); + test_weights(100, |i| i as f64); + test_weights(100, |i| (i as f64).powi(3)); + test_weights(100, |i| 1.0 / ((i + 1) as f64)); +} + +#[test] +fn choose_one_weighted_indexed() { + struct Adapter f64>(Vec, F); + impl f64> Distribution for Adapter { + fn sample(&self, rng: &mut R) -> i64 { + *IndexedRandom::choose_multiple_weighted(&self.0[..], rng, 1, |i| (self.1)(*i)) + .unwrap() + .next() + .unwrap() + } + } + + fn test_weights(num: usize, weight: impl Fn(i64) -> f64) { + let distr = Adapter((0..num).map(|i| i as i64).collect(), &weight); + test_discrete(0, distr, make_cdf(num, &weight)); + } + + test_weights(100, |_| 1.0); + test_weights(100, |i| ((i + 1) as f64).ln()); + test_weights(100, |i| i as f64); + test_weights(100, |i| (i as f64).powi(3)); + test_weights(100, |i| 1.0 / ((i + 1) as f64)); +} + +#[test] +fn choose_two_weighted_indexed() { + struct Adapter f64>(Vec, F); + impl f64> Distribution for Adapter { + fn sample(&self, rng: &mut R) -> i64 { + let mut iter = + IndexedRandom::choose_multiple_weighted(&self.0[..], rng, 2, |i| (self.1)(*i)) + .unwrap(); + let mut a = *iter.next().unwrap(); + let mut b = *iter.next().unwrap(); + assert!(iter.next().is_none()); + if b < a { + std::mem::swap(&mut a, &mut b); + } + a * self.0.len() as i64 + b + } + } + + fn test_weights(num: usize, weight: impl Fn(i64) -> f64) { + let distr = Adapter((0..num).map(|i| i as i64).collect(), &weight); + + let pmf1 = (0..num).map(|i| weight(i as i64)).collect::>(); + let sum: f64 = pmf1.iter().sum(); + let frac = 1.0 / sum; + + let mut ac = 0.0; + let mut cdf = Vec::with_capacity(num * num); + for a in 0..num { + for b in 0..num { + if a < b { + let pa = pmf1[a] * frac; + let pab = pa * pmf1[b] / (sum - pmf1[a]); + + let pb = pmf1[b] * frac; + let pba = pb * pmf1[a] / (sum - pmf1[b]); + + ac += pab + pba; + } + cdf.push(ac); + } + } + assert!((cdf.last().unwrap() - 1.0).abs() < 1e-9); + + let cdf = |i| { + if i < 0 { + 0.0 + } else { + cdf[i as usize] + } + }; + + test_discrete(0, distr, cdf); + } + + test_weights(100, |_| 1.0); + test_weights(100, |i| ((i + 1) as f64).ln()); + test_weights(100, |i| i as f64); + test_weights(100, |i| (i as f64).powi(3)); + test_weights(100, |i| 1.0 / ((i + 1) as f64)); + test_weights(10, |i| ((i + 1) as f64).powi(-8)); +} + +#[test] +fn choose_iterator() { + struct Adapter(I); + impl> Distribution for Adapter { + fn sample(&self, rng: &mut R) -> i64 { + IteratorRandom::choose(self.0.clone(), rng).unwrap() + } + } + + let distr = Adapter((0..100).map(|i| i as i64)); + test_discrete(0, distr, make_cdf(100, |_| 1.0)); +} + +#[test] +fn choose_stable_iterator() { + struct Adapter(I); + impl> Distribution for Adapter { + fn sample(&self, rng: &mut R) -> i64 { + IteratorRandom::choose_stable(self.0.clone(), rng).unwrap() + } + } + + let distr = Adapter((0..100).map(|i| i as i64)); + test_discrete(0, distr, make_cdf(100, |_| 1.0)); +} + +#[test] +fn choose_two_iterator() { + struct Adapter(I); + impl> Distribution for Adapter { + fn sample(&self, rng: &mut R) -> i64 { + let mut buf = [0; 2]; + IteratorRandom::choose_multiple_fill(self.0.clone(), rng, &mut buf); + buf.sort_unstable(); + assert!(buf[0] < 99 && buf[1] >= 1); + let a = buf[0]; + 4950 - (99 - a) * (100 - a) / 2 + buf[1] - a - 1 + } + } + + let distr = Adapter((0..100).map(|i| i as i64)); + + test_discrete( + 0, + distr, + |i| if i < 0 { 0.0 } else { (i + 1) as f64 / 4950.0 }, + ); +} diff --git a/distr_test/tests/zeta.rs b/distr_test/tests/zeta.rs new file mode 100644 index 00000000000..6e5ab1f594e --- /dev/null +++ b/distr_test/tests/zeta.rs @@ -0,0 +1,56 @@ +// Copyright 2021 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +mod ks; +use ks::test_discrete; + +#[test] +fn zeta() { + fn cdf(k: i64, s: f64) -> f64 { + use spfunc::zeta::zeta as zeta_func; + if k < 1 { + return 0.0; + } + + gen_harmonic(k as u64, s) / zeta_func(s) + } + + let parameters = [2.0, 3.7, 5.0, 100.0]; + + for (seed, s) in parameters.into_iter().enumerate() { + let dist = rand_distr::Zeta::new(s).unwrap(); + test_discrete(seed as u64, dist, |k| cdf(k, s)); + } +} + +#[test] +fn zipf() { + fn cdf(k: i64, n: u64, s: f64) -> f64 { + if k < 1 { + return 0.0; + } + if k > n as i64 { + return 1.0; + } + gen_harmonic(k as u64, s) / gen_harmonic(n, s) + } + + let parameters = [(1000, 1.0), (500, 2.0), (1000, 0.5)]; + + for (seed, (n, x)) in parameters.into_iter().enumerate() { + let dist = rand_distr::Zipf::new(n as f64, x).unwrap(); + test_discrete(seed as u64, dist, |k| cdf(k, n, x)); + } +} + +fn gen_harmonic(n: u64, m: f64) -> f64 { + match n { + 0 => 1.0, + _ => (0..n).fold(0.0, |acc, x| acc + (x as f64 + 1.0).powf(-m)), + } +} diff --git a/examples/monte-carlo.rs b/examples/monte-carlo.rs index 70560d0fab9..d5b898f17f0 100644 --- a/examples/monte-carlo.rs +++ b/examples/monte-carlo.rs @@ -23,14 +23,11 @@ //! We can use the above fact to estimate the value of π: pick many points in //! the square at random, calculate the fraction that fall within the circle, //! and multiply this fraction by 4. - -#![cfg(feature = "std")] - -use rand::distributions::{Distribution, Uniform}; +use rand::distr::{Distribution, Uniform}; fn main() { - let range = Uniform::new(-1.0f64, 1.0); - let mut rng = rand::thread_rng(); + let range = Uniform::new(-1.0f64, 1.0).unwrap(); + let mut rng = rand::rng(); let total = 1_000_000; let mut in_circle = 0; diff --git a/examples/monty-hall.rs b/examples/monty-hall.rs index 30e2f44d154..0a6d033739c 100644 --- a/examples/monty-hall.rs +++ b/examples/monty-hall.rs @@ -26,9 +26,7 @@ //! //! [Monty Hall Problem]: https://en.wikipedia.org/wiki/Monty_Hall_problem -#![cfg(feature = "std")] - -use rand::distributions::{Distribution, Uniform}; +use rand::distr::{Distribution, Uniform}; use rand::Rng; struct SimulationResult { @@ -47,7 +45,7 @@ fn simulate(random_door: &Uniform, rng: &mut R) -> SimulationResult let open = game_host_open(car, choice, rng); // Shall we switch? - let switch = rng.gen(); + let switch = rng.random(); if switch { choice = switch_door(choice, open); } @@ -61,7 +59,7 @@ fn simulate(random_door: &Uniform, rng: &mut R) -> SimulationResult // Returns the door the game host opens given our choice and knowledge of // where the car is. The game host will never open the door with the car. fn game_host_open(car: u32, choice: u32, rng: &mut R) -> u32 { - use rand::seq::SliceRandom; + use rand::seq::IndexedRandom; *free_doors(&[car, choice]).choose(rng).unwrap() } @@ -79,8 +77,8 @@ fn main() { // The estimation will be more accurate with more simulations let num_simulations = 10000; - let mut rng = rand::thread_rng(); - let random_door = Uniform::new(0u32, 3); + let mut rng = rand::rng(); + let random_door = Uniform::new(0u32, 3).unwrap(); let (mut switch_wins, mut switch_losses) = (0, 0); let (mut keep_wins, mut keep_losses) = (0, 0); diff --git a/examples/rayon-monte-carlo.rs b/examples/rayon-monte-carlo.rs new file mode 100644 index 00000000000..31d8e681067 --- /dev/null +++ b/examples/rayon-monte-carlo.rs @@ -0,0 +1,80 @@ +// Copyright 2018 Developers of the Rand project. +// Copyright 2013-2018 The Rust Project Developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! # Monte Carlo estimation of π with a chosen seed and rayon for parallelism +//! +//! Imagine that we have a square with sides of length 2 and a unit circle +//! (radius = 1), both centered at the origin. The areas are: +//! +//! ```text +//! area of circle = πr² = π * r * r = π +//! area of square = 2² = 4 +//! ``` +//! +//! The circle is entirely within the square, so if we sample many points +//! randomly from the square, roughly π / 4 of them should be inside the circle. +//! +//! We can use the above fact to estimate the value of π: pick many points in +//! the square at random, calculate the fraction that fall within the circle, +//! and multiply this fraction by 4. +//! +//! Note on determinism: +//! It's slightly tricky to build a parallel simulation using Rayon +//! which is both efficient *and* reproducible. +//! +//! Rayon's ParallelIterator api does not guarantee that the work will be +//! batched into identical batches on every run, so we can't simply use +//! map_init to construct one RNG per Rayon batch. +//! +//! Instead, we do our own batching, so that a Rayon work item becomes a +//! batch. Then we can fix our rng stream to the batched work item. +//! Batching amortizes the cost of constructing the Rng from a fixed seed +//! over BATCH_SIZE trials. Manually batching also turns out to be faster +//! for the nondeterministic version of this program as well. + +use rand::distr::{Distribution, Uniform}; +use rand_chacha::{rand_core::SeedableRng, ChaCha8Rng}; +use rayon::prelude::*; + +static SEED: u64 = 0; +static BATCH_SIZE: u64 = 10_000; +static BATCHES: u64 = 1000; + +fn main() { + let range = Uniform::new(-1.0f64, 1.0).unwrap(); + + let in_circle = (0..BATCHES) + .into_par_iter() + .map(|i| { + let mut rng = ChaCha8Rng::seed_from_u64(SEED); + // We chose ChaCha because it's fast, has suitable statistical properties for simulation, + // and because it supports this set_stream() api, which lets us choose a different stream + // per work item. ChaCha supports 2^64 independent streams. + rng.set_stream(i); + let mut count = 0; + for _ in 0..BATCH_SIZE { + let a = range.sample(&mut rng); + let b = range.sample(&mut rng); + if a * a + b * b <= 1.0 { + count += 1; + } + } + count + }) + .sum::(); + + // assert this is deterministic + assert_eq!(in_circle, 7852263); + + // prints something close to 3.14159... + println!( + "π is approximately {}", + 4. * (in_circle as f64) / ((BATCH_SIZE * BATCHES) as f64) + ); +} diff --git a/rand_chacha/CHANGELOG.md b/rand_chacha/CHANGELOG.md index 8a073900765..7965cf7640e 100644 --- a/rand_chacha/CHANGELOG.md +++ b/rand_chacha/CHANGELOG.md @@ -4,6 +4,22 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.9.0] - 2025-01-27 +### Dependencies and features +- Update to `rand_core` v0.9.0 (#1558) +- Feature `std` now implies feature `rand_core/std` (#1153) +- Rename feature `serde1` to `serde` (#1477) +- Rename feature `getrandom` to `os_rng` (#1537) + +### Other changes +- Remove usage of `unsafe` in `fn generate` (#1181) then optimise for AVX2 (~4-7%) (#1192) +- Revise crate docs (#1454) + +## [0.3.1] - 2021-06-09 +- add getters corresponding to existing setters: `get_seed`, `get_stream` (#1124) +- add serde support, gated by the `serde1` feature (#1124) +- ensure expected layout via `repr(transparent)` (#1120) + ## [0.3.0] - 2020-12-08 - Bump `rand_core` version to 0.6.0 - Bump MSRV to 1.36 (#1011) diff --git a/rand_chacha/Cargo.toml b/rand_chacha/Cargo.toml index 0a653113d85..7052dd48e4b 100644 --- a/rand_chacha/Cargo.toml +++ b/rand_chacha/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rand_chacha" -version = "0.3.0" +version = "0.9.0" authors = ["The Rand Project Developers", "The Rust Project Developers", "The CryptoCorrosion Contributors"] license = "MIT OR Apache-2.0" readme = "README.md" @@ -12,13 +12,25 @@ ChaCha random number generator """ keywords = ["random", "rng", "chacha"] categories = ["algorithms", "no-std"] -edition = "2018" +edition = "2021" +rust-version = "1.63" + +[package.metadata.docs.rs] +all-features = true +rustdoc-args = ["--generate-link-to-definition"] [dependencies] -rand_core = { path = "../rand_core", version = "0.6.0" } -ppv-lite86 = { version = "0.2.8", default-features = false, features = ["simd"] } +rand_core = { path = "../rand_core", version = "0.9.0" } +ppv-lite86 = { version = "0.2.14", default-features = false, features = ["simd"] } +serde = { version = "1.0", features = ["derive"], optional = true } + +[dev-dependencies] +# Only to test serde +serde_json = "1.0" +rand_core = { path = "../rand_core", version = "0.9.0", features = ["os_rng"] } [features] default = ["std"] -std = ["ppv-lite86/std"] -simd = [] # deprecated +os_rng = ["rand_core/os_rng"] +std = ["ppv-lite86/std", "rand_core/std"] +serde = ["dep:serde"] diff --git a/rand_chacha/LICENSE-APACHE b/rand_chacha/LICENSE-APACHE index 17d74680f8c..494ad3bfdfe 100644 --- a/rand_chacha/LICENSE-APACHE +++ b/rand_chacha/LICENSE-APACHE @@ -174,28 +174,3 @@ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS - -APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - -Copyright [yyyy] [name of copyright owner] - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. diff --git a/rand_chacha/README.md b/rand_chacha/README.md index edd754d791e..167417f85c8 100644 --- a/rand_chacha/README.md +++ b/rand_chacha/README.md @@ -1,11 +1,10 @@ # rand_chacha -[![Test Status](https://github.com/rust-random/rand/workflows/Tests/badge.svg?event=push)](https://github.com/rust-random/rand/actions) +[![Test Status](https://github.com/rust-random/rand/actions/workflows/test.yml/badge.svg?event=push)](https://github.com/rust-random/rand/actions) [![Latest version](https://img.shields.io/crates/v/rand_chacha.svg)](https://crates.io/crates/rand_chacha) [![Book](https://img.shields.io/badge/book-master-yellow.svg)](https://rust-random.github.io/book/) [![API](https://img.shields.io/badge/api-master-yellow.svg)](https://rust-random.github.io/rand/rand_chacha) [![API](https://docs.rs/rand_chacha/badge.svg)](https://docs.rs/rand_chacha) -[![Minimum rustc version](https://img.shields.io/badge/rustc-1.36+-lightgray.svg)](https://github.com/rust-random/rand#rust-version-requirements) A cryptographically secure random number generator that uses the ChaCha algorithm. @@ -36,7 +35,8 @@ Links: `rand_chacha` is `no_std` compatible when disabling default features; the `std` feature can be explicitly required to re-enable `std` support. Using `std` -allows detection of CPU features and thus better optimisation. +allows detection of CPU features and thus better optimisation. Using `std` +also enables `os_rng` functionality, such as `ChaCha20Rng::from_os_rng()`. # License diff --git a/rand_chacha/src/chacha.rs b/rand_chacha/src/chacha.rs index 17bcc5528d6..91d3cd628d2 100644 --- a/rand_chacha/src/chacha.rs +++ b/rand_chacha/src/chacha.rs @@ -8,25 +8,24 @@ //! The ChaCha random number generator. -#[cfg(not(feature = "std"))] use core; -#[cfg(feature = "std")] use std as core; - -use self::core::fmt; use crate::guts::ChaCha; -use rand_core::block::{BlockRng, BlockRngCore}; -use rand_core::{CryptoRng, Error, RngCore, SeedableRng}; +use core::fmt; +use rand_core::block::{BlockRng, BlockRngCore, CryptoBlockRng}; +use rand_core::{CryptoRng, RngCore, SeedableRng}; -const STREAM_PARAM_NONCE: u32 = 1; -const STREAM_PARAM_BLOCK: u32 = 0; +#[cfg(feature = "serde")] +use serde::{Deserialize, Deserializer, Serialize, Serializer}; // NB. this must remain consistent with some currently hard-coded numbers in this module const BUF_BLOCKS: u8 = 4; // number of 32-bit words per ChaCha block (fixed by algorithm definition) const BLOCK_WORDS: u8 = 16; +#[repr(transparent)] pub struct Array64([T; 64]); impl Default for Array64 -where T: Default +where + T: Default, { #[rustfmt::skip] fn default() -> Self { @@ -53,7 +52,8 @@ impl AsMut<[T]> for Array64 { } } impl Clone for Array64 -where T: Copy + Default +where + T: Copy + Default, { fn clone(&self) -> Self { let mut new = Self::default(); @@ -68,7 +68,7 @@ impl fmt::Debug for Array64 { } macro_rules! chacha_impl { - ($ChaChaXCore:ident, $ChaChaXRng:ident, $rounds:expr, $doc:expr) => { + ($ChaChaXCore:ident, $ChaChaXRng:ident, $rounds:expr, $doc:expr, $abst:ident,) => { #[doc=$doc] #[derive(Clone, PartialEq, Eq)] pub struct $ChaChaXCore { @@ -85,27 +85,25 @@ macro_rules! chacha_impl { impl BlockRngCore for $ChaChaXCore { type Item = u32; type Results = Array64; + #[inline] fn generate(&mut self, r: &mut Self::Results) { - // Fill slice of words by writing to equivalent slice of bytes, then fixing endianness. - self.state.refill4($rounds, unsafe { - &mut *(&mut *r as *mut Array64 as *mut [u8; 256]) - }); - for x in r.as_mut() { - *x = x.to_le(); - } + self.state.refill4($rounds, &mut r.0); } } impl SeedableRng for $ChaChaXCore { type Seed = [u8; 32]; + #[inline] fn from_seed(seed: Self::Seed) -> Self { - $ChaChaXCore { state: ChaCha::new(&seed, &[0u8; 8]) } + $ChaChaXCore { + state: ChaCha::new(&seed, &[0u8; 8]), + } } } - impl CryptoRng for $ChaChaXCore {} + impl CryptoBlockRng for $ChaChaXCore {} /// A cryptographically secure random number generator that uses the ChaCha algorithm. /// @@ -152,6 +150,7 @@ macro_rules! chacha_impl { impl SeedableRng for $ChaChaXRng { type Seed = [u8; 32]; + #[inline] fn from_seed(seed: Self::Seed) -> Self { let core = $ChaChaXCore::from_seed(seed); @@ -166,23 +165,21 @@ macro_rules! chacha_impl { fn next_u32(&mut self) -> u32 { self.rng.next_u32() } + #[inline] fn next_u64(&mut self) -> u64 { self.rng.next_u64() } + #[inline] fn fill_bytes(&mut self, bytes: &mut [u8]) { self.rng.fill_bytes(bytes) } - #[inline] - fn try_fill_bytes(&mut self, bytes: &mut [u8]) -> Result<(), Error> { - self.rng.try_fill_bytes(bytes) - } } impl $ChaChaXRng { // The buffer is a 4-block window, i.e. it is always at a block-aligned position in the - // stream but if the stream has been seeked it may not be self-aligned. + // stream but if the stream has been sought it may not be self-aligned. /// Get the offset from the start of the stream, in 32-bit words. /// @@ -193,7 +190,7 @@ macro_rules! chacha_impl { #[inline] pub fn get_word_pos(&self) -> u128 { let buf_start_block = { - let buf_end_block = self.rng.core.state.get_stream_param(STREAM_PARAM_BLOCK); + let buf_end_block = self.rng.core.state.get_block_pos(); u64::wrapping_sub(buf_end_block, BUF_BLOCKS.into()) }; let (buf_offset_blocks, block_offset_words) = { @@ -215,11 +212,9 @@ macro_rules! chacha_impl { #[inline] pub fn set_word_pos(&mut self, word_offset: u128) { let block = (word_offset / u128::from(BLOCK_WORDS)) as u64; + self.rng.core.state.set_block_pos(block); self.rng - .core - .state - .set_stream_param(STREAM_PARAM_BLOCK, block); - self.rng.generate_and_set((word_offset % u128::from(BLOCK_WORDS)) as usize); + .generate_and_set((word_offset % u128::from(BLOCK_WORDS)) as usize); } /// Set the stream number. @@ -235,15 +230,24 @@ macro_rules! chacha_impl { /// indirectly via `set_word_pos`), but this is not directly supported. #[inline] pub fn set_stream(&mut self, stream: u64) { - self.rng - .core - .state - .set_stream_param(STREAM_PARAM_NONCE, stream); + self.rng.core.state.set_nonce(stream); if self.rng.index() != 64 { let wp = self.get_word_pos(); self.set_word_pos(wp); } } + + /// Get the stream number. + #[inline] + pub fn get_stream(&self) -> u64 { + self.rng.core.state.get_nonce() + } + + /// Get the seed. + #[inline] + pub fn get_seed(&self) -> [u8; 32] { + self.rng.core.state.get_seed() + } } impl CryptoRng for $ChaChaXRng {} @@ -258,24 +262,151 @@ macro_rules! chacha_impl { impl PartialEq<$ChaChaXRng> for $ChaChaXRng { fn eq(&self, rhs: &$ChaChaXRng) -> bool { - self.rng.core.state.stream64_eq(&rhs.rng.core.state) - && self.get_word_pos() == rhs.get_word_pos() + let a: $abst::$ChaChaXRng = self.into(); + let b: $abst::$ChaChaXRng = rhs.into(); + a == b } } impl Eq for $ChaChaXRng {} - } + + #[cfg(feature = "serde")] + impl Serialize for $ChaChaXRng { + fn serialize(&self, s: S) -> Result + where + S: Serializer, + { + $abst::$ChaChaXRng::from(self).serialize(s) + } + } + #[cfg(feature = "serde")] + impl<'de> Deserialize<'de> for $ChaChaXRng { + fn deserialize(d: D) -> Result + where + D: Deserializer<'de>, + { + $abst::$ChaChaXRng::deserialize(d).map(|x| Self::from(&x)) + } + } + + mod $abst { + #[cfg(feature = "serde")] + use serde::{Deserialize, Serialize}; + + // The abstract state of a ChaCha stream, independent of implementation choices. The + // comparison and serialization of this object is considered a semver-covered part of + // the API. + #[derive(Debug, PartialEq, Eq)] + #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] + pub(crate) struct $ChaChaXRng { + seed: [u8; 32], + stream: u64, + word_pos: u128, + } + + impl From<&super::$ChaChaXRng> for $ChaChaXRng { + // Forget all information about the input except what is necessary to determine the + // outputs of any sequence of pub API calls. + fn from(r: &super::$ChaChaXRng) -> Self { + Self { + seed: r.get_seed(), + stream: r.get_stream(), + word_pos: r.get_word_pos(), + } + } + } + + impl From<&$ChaChaXRng> for super::$ChaChaXRng { + // Construct one of the possible concrete RNGs realizing an abstract state. + fn from(a: &$ChaChaXRng) -> Self { + use rand_core::SeedableRng; + let mut r = Self::from_seed(a.seed); + r.set_stream(a.stream); + r.set_word_pos(a.word_pos); + r + } + } + } + }; } -chacha_impl!(ChaCha20Core, ChaCha20Rng, 10, "ChaCha with 20 rounds"); -chacha_impl!(ChaCha12Core, ChaCha12Rng, 6, "ChaCha with 12 rounds"); -chacha_impl!(ChaCha8Core, ChaCha8Rng, 4, "ChaCha with 8 rounds"); +chacha_impl!( + ChaCha20Core, + ChaCha20Rng, + 10, + "ChaCha with 20 rounds", + abstract20, +); +chacha_impl!( + ChaCha12Core, + ChaCha12Rng, + 6, + "ChaCha with 12 rounds", + abstract12, +); +chacha_impl!( + ChaCha8Core, + ChaCha8Rng, + 4, + "ChaCha with 8 rounds", + abstract8, +); #[cfg(test)] mod test { use rand_core::{RngCore, SeedableRng}; + #[cfg(feature = "serde")] + use super::{ChaCha12Rng, ChaCha20Rng, ChaCha8Rng}; + type ChaChaRng = super::ChaCha20Rng; + #[cfg(feature = "serde")] + #[test] + fn test_chacha_serde_roundtrip() { + let seed = [ + 1, 0, 52, 0, 0, 0, 0, 0, 1, 0, 10, 0, 22, 32, 0, 0, 2, 0, 55, 49, 0, 11, 0, 0, 3, 0, 0, + 0, 0, 0, 2, 92, + ]; + let mut rng1 = ChaCha20Rng::from_seed(seed); + let mut rng2 = ChaCha12Rng::from_seed(seed); + let mut rng3 = ChaCha8Rng::from_seed(seed); + + let encoded1 = serde_json::to_string(&rng1).unwrap(); + let encoded2 = serde_json::to_string(&rng2).unwrap(); + let encoded3 = serde_json::to_string(&rng3).unwrap(); + + let mut decoded1: ChaCha20Rng = serde_json::from_str(&encoded1).unwrap(); + let mut decoded2: ChaCha12Rng = serde_json::from_str(&encoded2).unwrap(); + let mut decoded3: ChaCha8Rng = serde_json::from_str(&encoded3).unwrap(); + + assert_eq!(rng1, decoded1); + assert_eq!(rng2, decoded2); + assert_eq!(rng3, decoded3); + + assert_eq!(rng1.next_u32(), decoded1.next_u32()); + assert_eq!(rng2.next_u32(), decoded2.next_u32()); + assert_eq!(rng3.next_u32(), decoded3.next_u32()); + } + + // This test validates that: + // 1. a hard-coded serialization demonstrating the format at time of initial release can still + // be deserialized to a ChaChaRng + // 2. re-serializing the resultant object produces exactly the original string + // + // Condition 2 is stronger than necessary: an equivalent serialization (e.g. with field order + // permuted, or whitespace differences) would also be admissible, but would fail this test. + // However testing for equivalence of serialized data is difficult, and there shouldn't be any + // reason we need to violate the stronger-than-needed condition, e.g. by changing the field + // definition order. + #[cfg(feature = "serde")] + #[test] + fn test_chacha_serde_format_stability() { + let j = r#"{"seed":[4,8,15,16,23,42,4,8,15,16,23,42,4,8,15,16,23,42,4,8,15,16,23,42,4,8,15,16,23,42,4,8],"stream":27182818284,"word_pos":314159265359}"#; + let r: ChaChaRng = serde_json::from_str(j).unwrap(); + let j1 = serde_json::to_string(&r).unwrap(); + assert_eq!(j, j1); + } + #[test] fn test_chacha_construction() { let seed = [ @@ -285,7 +416,7 @@ mod test { let mut rng1 = ChaChaRng::from_seed(seed); assert_eq!(rng1.next_u32(), 137206642); - let mut rng2 = ChaChaRng::from_rng(rng1).unwrap(); + let mut rng2 = ChaChaRng::from_rng(&mut rng1); assert_eq!(rng2.next_u32(), 1325750369); } @@ -481,7 +612,7 @@ mod test { #[test] fn test_chacha_word_pos_wrap_exact() { - use super::{BUF_BLOCKS, BLOCK_WORDS}; + use super::{BLOCK_WORDS, BUF_BLOCKS}; let mut rng = ChaChaRng::from_seed(Default::default()); // refilling the buffer in set_word_pos will wrap the block counter to 0 let last_block = (1 << 68) - u128::from(BUF_BLOCKS * BLOCK_WORDS); @@ -506,4 +637,15 @@ mod test { rng.set_word_pos(0); assert_eq!(rng.get_word_pos(), 0); } + + #[test] + fn test_trait_objects() { + use rand_core::CryptoRng; + + let mut rng1 = ChaChaRng::from_seed(Default::default()); + let rng2 = &mut rng1.clone() as &mut dyn CryptoRng; + for _ in 0..1000 { + assert_eq!(rng1.next_u64(), rng2.next_u64()); + } + } } diff --git a/rand_chacha/src/guts.rs b/rand_chacha/src/guts.rs index 27ff957a92c..d077225c625 100644 --- a/rand_chacha/src/guts.rs +++ b/rand_chacha/src/guts.rs @@ -12,15 +12,20 @@ use ppv_lite86::{dispatch, dispatch_light128}; pub use ppv_lite86::Machine; -use ppv_lite86::{vec128_storage, ArithOps, BitOps32, LaneWords4, MultiLane, StoreBytes, Vec4}; +use ppv_lite86::{ + vec128_storage, ArithOps, BitOps32, LaneWords4, MultiLane, StoreBytes, Vec4, Vec4Ext, Vector, +}; -pub(crate) const BLOCK: usize = 64; +pub(crate) const BLOCK: usize = 16; pub(crate) const BLOCK64: u64 = BLOCK as u64; const LOG2_BUFBLOCKS: u64 = 2; const BUFBLOCKS: u64 = 1 << LOG2_BUFBLOCKS; pub(crate) const BUFSZ64: u64 = BLOCK64 * BUFBLOCKS; pub(crate) const BUFSZ: usize = BUFSZ64 as usize; +const STREAM_PARAM_NONCE: u32 = 1; +const STREAM_PARAM_BLOCK: u32 = 0; + #[derive(Clone, PartialEq, Eq)] pub struct ChaCha { pub(crate) b: vec128_storage, @@ -70,96 +75,115 @@ impl ChaCha { init_chacha(key, nonce) } + /// Produce 4 blocks of output, advancing the state #[inline(always)] - fn pos64(&self, m: M) -> u64 { - let d: M::u32x4 = m.unpack(self.d); - ((d.extract(1) as u64) << 32) | d.extract(0) as u64 + pub fn refill4(&mut self, drounds: u32, out: &mut [u32; BUFSZ]) { + refill_wide(self, drounds, out) } - /// Produce 4 blocks of output, advancing the state #[inline(always)] - pub fn refill4(&mut self, drounds: u32, out: &mut [u8; BUFSZ]) { - refill_wide(self, drounds, out) + pub fn set_block_pos(&mut self, value: u64) { + set_stream_param(self, STREAM_PARAM_BLOCK, value) } #[inline(always)] - pub fn set_stream_param(&mut self, param: u32, value: u64) { - set_stream_param(self, param, value) + pub fn get_block_pos(&self) -> u64 { + get_stream_param(self, STREAM_PARAM_BLOCK) } #[inline(always)] - pub fn get_stream_param(&self, param: u32) -> u64 { - get_stream_param(self, param) + pub fn set_nonce(&mut self, value: u64) { + set_stream_param(self, STREAM_PARAM_NONCE, value) } - /// Return whether rhs is equal in all parameters except current 64-bit position. - #[inline] - pub fn stream64_eq(&self, rhs: &Self) -> bool { - let self_d: [u32; 4] = self.d.into(); - let rhs_d: [u32; 4] = rhs.d.into(); - self.b == rhs.b && self.c == rhs.c && self_d[3] == rhs_d[3] && self_d[2] == rhs_d[2] + #[inline(always)] + pub fn get_nonce(&self) -> u64 { + get_stream_param(self, STREAM_PARAM_NONCE) + } + + #[inline(always)] + pub fn get_seed(&self) -> [u8; 32] { + get_seed(self) } } -#[allow(clippy::many_single_char_names)] +// This implementation is platform-independent. #[inline(always)] -fn refill_wide_impl( - m: Mach, state: &mut ChaCha, drounds: u32, out: &mut [u8; BUFSZ], -) { - let k = m.vec([0x6170_7865, 0x3320_646e, 0x7962_2d32, 0x6b20_6574]); - let mut pos = state.pos64(m); - let d0: Mach::u32x4 = m.unpack(state.d); +#[cfg(target_endian = "big")] +fn add_pos(_m: Mach, d0: Mach::u32x4, i: u64) -> Mach::u32x4 { + let pos0 = ((d0.extract(1) as u64) << 32) | d0.extract(0) as u64; + let pos = pos0.wrapping_add(i); + d0.insert((pos >> 32) as u32, 1).insert(pos as u32, 0) +} +#[inline(always)] +#[cfg(target_endian = "big")] +fn d0123(m: Mach, d: vec128_storage) -> Mach::u32x4x4 { + let d0: Mach::u32x4 = m.unpack(d); + let mut pos = ((d0.extract(1) as u64) << 32) | d0.extract(0) as u64; pos = pos.wrapping_add(1); let d1 = d0.insert((pos >> 32) as u32, 1).insert(pos as u32, 0); pos = pos.wrapping_add(1); let d2 = d0.insert((pos >> 32) as u32, 1).insert(pos as u32, 0); pos = pos.wrapping_add(1); let d3 = d0.insert((pos >> 32) as u32, 1).insert(pos as u32, 0); + Mach::u32x4x4::from_lanes([d0, d1, d2, d3]) +} + +// Pos is packed into the state vectors as a little-endian u64, +// so on LE platforms we can use native vector ops to increment it. +#[inline(always)] +#[cfg(target_endian = "little")] +fn add_pos(m: Mach, d: Mach::u32x4, i: u64) -> Mach::u32x4 { + let d0: Mach::u64x2 = m.unpack(d.into()); + let incr = m.vec([i, 0]); + m.unpack((d0 + incr).into()) +} +#[inline(always)] +#[cfg(target_endian = "little")] +fn d0123(m: Mach, d: vec128_storage) -> Mach::u32x4x4 { + let d0: Mach::u64x2 = m.unpack(d); + let incr = + Mach::u64x2x4::from_lanes([m.vec([0, 0]), m.vec([1, 0]), m.vec([2, 0]), m.vec([3, 0])]); + m.unpack((Mach::u64x2x4::from_lanes([d0, d0, d0, d0]) + incr).into()) +} +#[allow(clippy::many_single_char_names)] +#[inline(always)] +fn refill_wide_impl( + m: Mach, + state: &mut ChaCha, + drounds: u32, + out: &mut [u32; BUFSZ], +) { + let k = m.vec([0x6170_7865, 0x3320_646e, 0x7962_2d32, 0x6b20_6574]); let b = m.unpack(state.b); let c = m.unpack(state.c); let mut x = State { a: Mach::u32x4x4::from_lanes([k, k, k, k]), b: Mach::u32x4x4::from_lanes([b, b, b, b]), c: Mach::u32x4x4::from_lanes([c, c, c, c]), - d: m.unpack(Mach::u32x4x4::from_lanes([d0, d1, d2, d3]).into()), + d: d0123(m, state.d), }; for _ in 0..drounds { x = round(x); x = undiagonalize(round(diagonalize(x))); } - let mut pos = state.pos64(m); - let d0: Mach::u32x4 = m.unpack(state.d); - pos = pos.wrapping_add(1); - let d1 = d0.insert((pos >> 32) as u32, 1).insert(pos as u32, 0); - pos = pos.wrapping_add(1); - let d2 = d0.insert((pos >> 32) as u32, 1).insert(pos as u32, 0); - pos = pos.wrapping_add(1); - let d3 = d0.insert((pos >> 32) as u32, 1).insert(pos as u32, 0); - pos = pos.wrapping_add(1); - let d4 = d0.insert((pos >> 32) as u32, 1).insert(pos as u32, 0); - - let (a, b, c, d) = ( - x.a.to_lanes(), - x.b.to_lanes(), - x.c.to_lanes(), - x.d.to_lanes(), - ); + let kk = Mach::u32x4x4::from_lanes([k, k, k, k]); let sb = m.unpack(state.b); + let sb = Mach::u32x4x4::from_lanes([sb, sb, sb, sb]); let sc = m.unpack(state.c); - let sd = [m.unpack(state.d), d1, d2, d3]; - state.d = d4.into(); - let mut words = out.chunks_exact_mut(16); - for ((((&a, &b), &c), &d), &sd) in a.iter().zip(&b).zip(&c).zip(&d).zip(&sd) { - (a + k).write_le(words.next().unwrap()); - (b + sb).write_le(words.next().unwrap()); - (c + sc).write_le(words.next().unwrap()); - (d + sd).write_le(words.next().unwrap()); - } + let sc = Mach::u32x4x4::from_lanes([sc, sc, sc, sc]); + let sd = d0123(m, state.d); + let results = Mach::u32x4x4::transpose4(x.a + kk, x.b + sb, x.c + sc, x.d + sd); + out[0..16].copy_from_slice(&results.0.to_scalars()); + out[16..32].copy_from_slice(&results.1.to_scalars()); + out[32..48].copy_from_slice(&results.2.to_scalars()); + out[48..64].copy_from_slice(&results.3.to_scalars()); + state.d = add_pos(m, sd.to_lanes()[0], 4).into(); } dispatch!(m, Mach, { - fn refill_wide(state: &mut ChaCha, drounds: u32, out: &mut [u8; BUFSZ]) { + fn refill_wide(state: &mut ChaCha, drounds: u32, out: &mut [u32; BUFSZ]) { refill_wide_impl(m, state, drounds, out); } }); @@ -205,6 +229,17 @@ dispatch_light128!(m, Mach, { } }); +dispatch_light128!(m, Mach, { + fn get_seed(state: &ChaCha) -> [u8; 32] { + let b: Mach::u32x4 = m.unpack(state.b); + let c: Mach::u32x4 = m.unpack(state.c); + let mut key = [0u8; 32]; + b.write_le(&mut key[..16]); + c.write_le(&mut key[16..]); + key + } +}); + fn read_u32le(xs: &[u8]) -> u32 { assert_eq!(xs.len(), 4); u32::from(xs[0]) | (u32::from(xs[1]) << 8) | (u32::from(xs[2]) << 16) | (u32::from(xs[3]) << 24) diff --git a/rand_chacha/src/lib.rs b/rand_chacha/src/lib.rs index 24125b45e10..24ddd601d27 100644 --- a/rand_chacha/src/lib.rs +++ b/rand_chacha/src/lib.rs @@ -6,13 +6,86 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! The ChaCha random number generator. +//! The ChaCha random number generators. +//! +//! These are native Rust implementations of RNGs derived from the +//! [ChaCha stream ciphers] by D J Bernstein. +//! +//! ## Generators +//! +//! This crate provides 8-, 12- and 20-round variants of generators via a "core" +//! implementation (of [`BlockRngCore`]), each with an associated "RNG" type +//! (implementing [`RngCore`]). +//! +//! These generators are all deterministic and portable (see [Reproducibility] +//! in the book), with testing against reference vectors. +//! +//! ## Cryptographic (secure) usage +//! +//! Where secure unpredictable generators are required, it is suggested to use +//! [`ChaCha12Rng`] or [`ChaCha20Rng`] and to seed via +//! [`SeedableRng::from_os_rng`]. +//! +//! See also the [Security] chapter in the rand book. The crate is provided +//! "as is", without any form of guarantee, and without a security audit. +//! +//! ## Seeding (construction) +//! +//! Generators implement the [`SeedableRng`] trait. Any method may be used, +//! but note that `seed_from_u64` is not suitable for usage where security is +//! important. Some suggestions: +//! +//! 1. With a fresh seed, **direct from the OS** (implies a syscall): +//! ``` +//! # use {rand_core::SeedableRng, rand_chacha::ChaCha12Rng}; +//! let rng = ChaCha12Rng::from_os_rng(); +//! # let _: ChaCha12Rng = rng; +//! ``` +//! 2. **From a master generator.** This could be [`rand::rng`] +//! (effectively a fresh seed without the need for a syscall on each usage) +//! or a deterministic generator such as [`ChaCha20Rng`]. +//! Beware that should a weak master generator be used, correlations may be +//! detectable between the outputs of its child generators. +//! ```ignore +//! let rng = ChaCha12Rng::from_rng(&mut rand::rng()); +//! ``` +//! +//! See also [Seeding RNGs] in the book. +//! +//! ## Generation +//! +//! Generators implement [`RngCore`], whose methods may be used directly to +//! generate unbounded integer or byte values. +//! ``` +//! use rand_core::{SeedableRng, RngCore}; +//! use rand_chacha::ChaCha12Rng; +//! +//! let mut rng = ChaCha12Rng::from_seed(Default::default()); +//! let x = rng.next_u64(); +//! assert_eq!(x, 0x53f955076a9af49b); +//! ``` +//! +//! It is often more convenient to use the [`rand::Rng`] trait, which provides +//! further functionality. See also the [Random Values] chapter in the book. +//! +//! [ChaCha stream ciphers]: https://cr.yp.to/chacha.html +//! [Reproducibility]: https://rust-random.github.io/book/crate-reprod.html +//! [Seeding RNGs]: https://rust-random.github.io/book/guide-seeding.html +//! [Security]: https://rust-random.github.io/book/guide-rngs.html#security +//! [Random Values]: https://rust-random.github.io/book/guide-values.html +//! [`BlockRngCore`]: rand_core::block::BlockRngCore +//! [`RngCore`]: rand_core::RngCore +//! [`SeedableRng`]: rand_core::SeedableRng +//! [`SeedableRng::from_os_rng`]: rand_core::SeedableRng::from_os_rng +//! [`rand::rng`]: https://docs.rs/rand/latest/rand/fn.rng.html +//! [`rand::Rng`]: https://docs.rs/rand/latest/rand/trait.Rng.html #![doc( html_logo_url = "https://www.rust-lang.org/logos/rust-logo-128x128-blk.png", html_favicon_url = "https://www.rust-lang.org/favicon.ico", html_root_url = "https://rust-random.github.io/rand/" )] +#![forbid(unsafe_code)] #![deny(missing_docs)] #![deny(missing_debug_implementations)] #![doc(test(attr(allow(unused_variables), deny(warnings))))] diff --git a/rand_core/CHANGELOG.md b/rand_core/CHANGELOG.md index 63a8bb8f694..3b3064db71b 100644 --- a/rand_core/CHANGELOG.md +++ b/rand_core/CHANGELOG.md @@ -4,6 +4,45 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.9.0] - 2025-01-27 +### Dependencies and features +- Bump the MSRV to 1.63.0 (#1207, #1246, #1269, #1341, #1416, #1536); note that 1.60.0 may work for dependents when using `--ignore-rust-version` +- Update to `getrandom` v0.3.0 (#1558) +- Use `zerocopy` to replace some `unsafe` code (#1349, #1393, #1446, #1502) +- Rename feature `serde1` to `serde` (#1477) +- Rename feature `getrandom` to `os_rng` (#1537) + +### API changes +- Allow `rand_core::impls::fill_via_u*_chunks` to mutate source (#1182) +- Add fn `RngCore::read_adapter` implementing `std::io::Read` (#1267) +- Add trait `CryptoBlockRng: BlockRngCore`; make `trait CryptoRng: RngCore` (#1273) +- Add traits `TryRngCore`, `TryCryptoRng` (#1424, #1499) +- Rename `fn SeedableRng::from_rng` -> `try_from_rng` and add infallible variant `fn from_rng` (#1424) +- Rename `fn SeedableRng::from_entropy` -> `from_os_rng` and add fallible variant `fn try_from_os_rng` (#1424) +- Add bounds `Clone` and `AsRef` to associated type `SeedableRng::Seed` (#1491) + +## [0.6.4] - 2022-09-15 +- Fix unsoundness in `::next_u32` (#1160) +- Reduce use of `unsafe` and improve gen_bytes performance (#1180) +- Add `CryptoRngCore` trait (#1187, #1230) + +## [0.6.3] - 2021-06-15 +### Changed +- Improved bound for `serde` impls on `BlockRng` (#1130) +- Minor doc additions (#1118) + +## [0.6.2] - 2021-02-12 +### Fixed +- Fixed assertions in `le::read_u32_into` and `le::read_u64_into` which could + have allowed buffers not to be fully populated (#1096) + +## [0.6.1] - 2021-01-03 +### Fixed +- Avoid panic when using `RngCore::seed_from_u64` with a seed which is not a + multiple of four (#1082) +### Other +- Enable all stable features in the playground (#1081) + ## [0.6.0] - 2020-12-08 ### Breaking changes - Bump MSRV to 1.36, various code improvements (#1011) diff --git a/rand_core/Cargo.toml b/rand_core/Cargo.toml index a53a1215f46..d1d9e66d7fa 100644 --- a/rand_core/Cargo.toml +++ b/rand_core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rand_core" -version = "0.6.0" +version = "0.9.0" authors = ["The Rand Project Developers", "The Rust Project Developers"] license = "MIT OR Apache-2.0" readme = "README.md" @@ -12,19 +12,24 @@ Core random number generator traits and tools for implementation. """ keywords = ["random", "rng"] categories = ["algorithms", "no-std"] -edition = "2018" +edition = "2021" +rust-version = "1.63" + +[package.metadata.docs.rs] +# To build locally: +# RUSTDOCFLAGS="--cfg docsrs" cargo +nightly doc --all-features --no-deps --open +all-features = true +rustdoc-args = ["--generate-link-to-definition"] + +[package.metadata.playground] +all-features = true [features] -std = ["alloc", "getrandom", "getrandom/std"] # use std library; should be default but for above bug -alloc = [] # enables Vec and Box support without std -serde1 = ["serde"] # enables serde for BlockRng wrapper +std = ["getrandom?/std"] +os_rng = ["dep:getrandom"] +serde = ["dep:serde"] # enables serde for BlockRng wrapper [dependencies] serde = { version = "1", features = ["derive"], optional = true } -getrandom = { version = "0.2", optional = true } - -[package.metadata.docs.rs] -# To build locally: -# RUSTDOCFLAGS="--cfg doc_cfg" cargo +nightly doc --all-features --no-deps --open -all-features = true -rustdoc-args = ["--cfg", "doc_cfg"] +getrandom = { version = "0.3.0", optional = true } +zerocopy = { version = "0.8.0", default-features = false } diff --git a/rand_core/LICENSE-APACHE b/rand_core/LICENSE-APACHE index 17d74680f8c..455787c2334 100644 --- a/rand_core/LICENSE-APACHE +++ b/rand_core/LICENSE-APACHE @@ -185,17 +185,3 @@ APPENDIX: How to apply the Apache License to your work. file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. - -Copyright [yyyy] [name of copyright owner] - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. diff --git a/rand_core/README.md b/rand_core/README.md index d32dd6853d0..b95287c4e70 100644 --- a/rand_core/README.md +++ b/rand_core/README.md @@ -1,11 +1,10 @@ # rand_core -[![Test Status](https://github.com/rust-random/rand/workflows/Tests/badge.svg?event=push)](https://github.com/rust-random/rand/actions) +[![Test Status](https://github.com/rust-random/rand/actions/workflows/test.yml/badge.svg?event=push)](https://github.com/rust-random/rand/actions) [![Latest version](https://img.shields.io/crates/v/rand_core.svg)](https://crates.io/crates/rand_core) [![Book](https://img.shields.io/badge/book-master-yellow.svg)](https://rust-random.github.io/book/) [![API](https://img.shields.io/badge/api-master-yellow.svg)](https://rust-random.github.io/rand/rand_core) [![API](https://docs.rs/rand_core/badge.svg)](https://docs.rs/rand_core) -[![Minimum rustc version](https://img.shields.io/badge/rustc-1.36+-lightgray.svg)](https://github.com/rust-random/rand#rust-version-requirements) Core traits and error types of the [rand] library, plus tools for implementing RNGs. @@ -43,34 +42,9 @@ The traits and error types are also available via `rand`. The current version is: ``` -rand_core = "0.6.0" +rand_core = "=0.9.0-beta.1" ``` -Rand libs have inter-dependencies and make use of the -[semver trick](https://github.com/dtolnay/semver-trick/) in order to make traits -compatible across crate versions. (This is especially important for `RngCore` -and `SeedableRng`.) A few crate releases are thus compatibility shims, -depending on the *next* lib version (e.g. `rand_core` versions `0.2.2` and -`0.3.1`). This means, for example, that `rand_core_0_4_0::SeedableRng` and -`rand_core_0_3_0::SeedableRng` are distinct, incompatible traits, which can -cause build errors. Usually, running `cargo update` is enough to fix any issues. - -## Crate Features - -`rand_core` supports `no_std` and `alloc`-only configurations, as well as full -`std` functionality. The differences between `no_std` and full `std` are small, -comprising `RngCore` support for `Box` types where `R: RngCore`, -`std::io::Read` support for types supporting `RngCore`, and -extensions to the `Error` type's functionality. - -The `std` feature is *not enabled by default*. This is primarily to avoid build -problems where one crate implicitly requires `rand_core` with `std` support and -another crate requires `rand` *without* `std` support. However, the `rand` crate -continues to enable `std` support by default, both for itself and `rand_core`. - -The `serde1` feature can be used to derive `Serialize` and `Deserialize` for RNG -implementations that use the `BlockRng` or `BlockRng64` wrappers. - # License diff --git a/rand_core/src/block.rs b/rand_core/src/block.rs index 005d071fbb6..aa2252e6da2 100644 --- a/rand_core/src/block.rs +++ b/rand_core/src/block.rs @@ -43,7 +43,7 @@ //! } //! } //! -//! // optionally, also implement CryptoRng for MyRngCore +//! // optionally, also implement CryptoBlockRng for MyRngCore //! //! // Final RNG. //! let mut rng = BlockRng::::seed_from_u64(0); @@ -54,10 +54,9 @@ //! [`fill_bytes`]: RngCore::fill_bytes use crate::impls::{fill_via_u32_chunks, fill_via_u64_chunks}; -use crate::{CryptoRng, Error, RngCore, SeedableRng}; -use core::convert::AsRef; +use crate::{CryptoRng, RngCore, SeedableRng, TryRngCore}; use core::fmt; -#[cfg(feature = "serde1")] +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; /// A trait for RNGs which do not generate random numbers individually, but in @@ -77,6 +76,12 @@ pub trait BlockRngCore { fn generate(&mut self, results: &mut Self::Results); } +/// A marker trait used to indicate that an [`RngCore`] implementation is +/// supposed to be cryptographically secure. +/// +/// See [`CryptoRng`] docs for more information. +pub trait CryptoBlockRng: BlockRngCore {} + /// A wrapper type implementing [`RngCore`] for some type implementing /// [`BlockRngCore`] with `u32` array buffer; i.e. this can be used to implement /// a full RNG from just a `generate` function. @@ -92,16 +97,15 @@ pub trait BlockRngCore { /// `BlockRng` has heavily optimized implementations of the [`RngCore`] methods /// reading values from the results buffer, as well as /// calling [`BlockRngCore::generate`] directly on the output array when -/// [`fill_bytes`] / [`try_fill_bytes`] is called on a large array. These methods -/// also handle the bookkeeping of when to generate a new batch of values. +/// [`fill_bytes`] is called on a large array. These methods also handle +/// the bookkeeping of when to generate a new batch of values. /// -/// No whole generated `u32` values are thown away and all values are consumed +/// No whole generated `u32` values are thrown away and all values are consumed /// in-order. [`next_u32`] simply takes the next available `u32` value. /// [`next_u64`] is implemented by combining two `u32` values, least -/// significant first. [`fill_bytes`] and [`try_fill_bytes`] consume a whole -/// number of `u32` values, converting each `u32` to a byte slice in -/// little-endian order. If the requested byte length is not a multiple of 4, -/// some bytes will be discarded. +/// significant first. [`fill_bytes`] consume a whole number of `u32` values, +/// converting each `u32` to a byte slice in little-endian order. If the requested byte +/// length is not a multiple of 4, some bytes will be discarded. /// /// See also [`BlockRng64`] which uses `u64` array buffers. Currently there is /// no direct support for other buffer types. @@ -111,10 +115,15 @@ pub trait BlockRngCore { /// [`next_u32`]: RngCore::next_u32 /// [`next_u64`]: RngCore::next_u64 /// [`fill_bytes`]: RngCore::fill_bytes -/// [`try_fill_bytes`]: RngCore::try_fill_bytes #[derive(Clone)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub struct BlockRng { +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr( + feature = "serde", + serde( + bound = "for<'x> R: Serialize + Deserialize<'x>, for<'x> R::Results: Serialize + Deserialize<'x>" + ) +)] +pub struct BlockRng { results: R::Results, index: usize, /// The *core* part of the RNG, implementing the `generate` function. @@ -172,10 +181,7 @@ impl BlockRng { } } -impl> RngCore for BlockRng -where - ::Results: AsRef<[u32]> + AsMut<[u32]>, -{ +impl> RngCore for BlockRng { #[inline] fn next_u32(&mut self) -> u32 { if self.index >= self.results.as_ref().len() { @@ -219,19 +225,15 @@ where if self.index >= self.results.as_ref().len() { self.generate_and_set(0); } - let (consumed_u32, filled_u8) = - fill_via_u32_chunks(&self.results.as_ref()[self.index..], &mut dest[read_len..]); + let (consumed_u32, filled_u8) = fill_via_u32_chunks( + &mut self.results.as_mut()[self.index..], + &mut dest[read_len..], + ); self.index += consumed_u32; read_len += filled_u8; } } - - #[inline(always)] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.fill_bytes(dest); - Ok(()) - } } impl SeedableRng for BlockRng { @@ -248,11 +250,18 @@ impl SeedableRng for BlockRng { } #[inline(always)] - fn from_rng(rng: S) -> Result { - Ok(Self::new(R::from_rng(rng)?)) + fn from_rng(rng: &mut impl RngCore) -> Self { + Self::new(R::from_rng(rng)) + } + + #[inline(always)] + fn try_from_rng(rng: &mut S) -> Result { + R::try_from_rng(rng).map(Self::new) } } +impl> CryptoRng for BlockRng {} + /// A wrapper type implementing [`RngCore`] for some type implementing /// [`BlockRngCore`] with `u64` array buffer; i.e. this can be used to implement /// a full RNG from just a `generate` function. @@ -267,16 +276,14 @@ impl SeedableRng for BlockRng { /// then the other half is then consumed, however both [`next_u64`] and /// [`fill_bytes`] discard the rest of any half-consumed `u64`s when called. /// -/// [`fill_bytes`] and [`try_fill_bytes`] consume a whole number of `u64` -/// values. If the requested length is not a multiple of 8, some bytes will be -/// discarded. +/// [`fill_bytes`] consumes a whole number of `u64` values. If the requested length +/// is not a multiple of 8, some bytes will be discarded. /// /// [`next_u32`]: RngCore::next_u32 /// [`next_u64`]: RngCore::next_u64 /// [`fill_bytes`]: RngCore::fill_bytes -/// [`try_fill_bytes`]: RngCore::try_fill_bytes #[derive(Clone)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct BlockRng64 { results: R::Results, index: usize, @@ -340,33 +347,24 @@ impl BlockRng64 { } } -impl> RngCore for BlockRng64 -where - ::Results: AsRef<[u64]> + AsMut<[u64]>, -{ +impl> RngCore for BlockRng64 { #[inline] fn next_u32(&mut self) -> u32 { - let mut index = self.index * 2 - self.half_used as usize; - if index >= self.results.as_ref().len() * 2 { + let mut index = self.index - self.half_used as usize; + if index >= self.results.as_ref().len() { self.core.generate(&mut self.results); self.index = 0; + index = 0; // `self.half_used` is by definition `false` self.half_used = false; - index = 0; } + let shift = 32 * (self.half_used as usize); + self.half_used = !self.half_used; self.index += self.half_used as usize; - // Index as if this is a u32 slice. - unsafe { - let results = &*(self.results.as_ref() as *const [u64] as *const [u32]); - if cfg!(target_endian = "little") { - *results.get_unchecked(index) - } else { - *results.get_unchecked(index ^ 1) - } - } + (self.results.as_ref()[index] >> shift) as u32 } #[inline] @@ -387,13 +385,13 @@ where let mut read_len = 0; self.half_used = false; while read_len < dest.len() { - if self.index as usize >= self.results.as_ref().len() { + if self.index >= self.results.as_ref().len() { self.core.generate(&mut self.results); self.index = 0; } let (consumed_u64, filled_u8) = fill_via_u64_chunks( - &self.results.as_ref()[self.index as usize..], + &mut self.results.as_mut()[self.index..], &mut dest[read_len..], ); @@ -401,12 +399,6 @@ where read_len += filled_u8; } } - - #[inline(always)] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.fill_bytes(dest); - Ok(()) - } } impl SeedableRng for BlockRng64 { @@ -423,9 +415,124 @@ impl SeedableRng for BlockRng64 { } #[inline(always)] - fn from_rng(rng: S) -> Result { - Ok(Self::new(R::from_rng(rng)?)) + fn from_rng(rng: &mut impl RngCore) -> Self { + Self::new(R::from_rng(rng)) + } + + #[inline(always)] + fn try_from_rng(rng: &mut S) -> Result { + R::try_from_rng(rng).map(Self::new) } } -impl CryptoRng for BlockRng {} +impl> CryptoRng for BlockRng64 {} + +#[cfg(test)] +mod test { + use crate::block::{BlockRng, BlockRng64, BlockRngCore}; + use crate::{RngCore, SeedableRng}; + + #[derive(Debug, Clone)] + struct DummyRng { + counter: u32, + } + + impl BlockRngCore for DummyRng { + type Item = u32; + type Results = [u32; 16]; + + fn generate(&mut self, results: &mut Self::Results) { + for r in results { + *r = self.counter; + self.counter = self.counter.wrapping_add(3511615421); + } + } + } + + impl SeedableRng for DummyRng { + type Seed = [u8; 4]; + + fn from_seed(seed: Self::Seed) -> Self { + DummyRng { + counter: u32::from_le_bytes(seed), + } + } + } + + #[test] + fn blockrng_next_u32_vs_next_u64() { + let mut rng1 = BlockRng::::from_seed([1, 2, 3, 4]); + let mut rng2 = rng1.clone(); + let mut rng3 = rng1.clone(); + + let mut a = [0; 16]; + a[..4].copy_from_slice(&rng1.next_u32().to_le_bytes()); + a[4..12].copy_from_slice(&rng1.next_u64().to_le_bytes()); + a[12..].copy_from_slice(&rng1.next_u32().to_le_bytes()); + + let mut b = [0; 16]; + b[..4].copy_from_slice(&rng2.next_u32().to_le_bytes()); + b[4..8].copy_from_slice(&rng2.next_u32().to_le_bytes()); + b[8..].copy_from_slice(&rng2.next_u64().to_le_bytes()); + assert_eq!(a, b); + + let mut c = [0; 16]; + c[..8].copy_from_slice(&rng3.next_u64().to_le_bytes()); + c[8..12].copy_from_slice(&rng3.next_u32().to_le_bytes()); + c[12..].copy_from_slice(&rng3.next_u32().to_le_bytes()); + assert_eq!(a, c); + } + + #[derive(Debug, Clone)] + struct DummyRng64 { + counter: u64, + } + + impl BlockRngCore for DummyRng64 { + type Item = u64; + type Results = [u64; 8]; + + fn generate(&mut self, results: &mut Self::Results) { + for r in results { + *r = self.counter; + self.counter = self.counter.wrapping_add(2781463553396133981); + } + } + } + + impl SeedableRng for DummyRng64 { + type Seed = [u8; 8]; + + fn from_seed(seed: Self::Seed) -> Self { + DummyRng64 { + counter: u64::from_le_bytes(seed), + } + } + } + + #[test] + fn blockrng64_next_u32_vs_next_u64() { + let mut rng1 = BlockRng64::::from_seed([1, 2, 3, 4, 5, 6, 7, 8]); + let mut rng2 = rng1.clone(); + let mut rng3 = rng1.clone(); + + let mut a = [0; 16]; + a[..4].copy_from_slice(&rng1.next_u32().to_le_bytes()); + a[4..12].copy_from_slice(&rng1.next_u64().to_le_bytes()); + a[12..].copy_from_slice(&rng1.next_u32().to_le_bytes()); + + let mut b = [0; 16]; + b[..4].copy_from_slice(&rng2.next_u32().to_le_bytes()); + b[4..8].copy_from_slice(&rng2.next_u32().to_le_bytes()); + b[8..].copy_from_slice(&rng2.next_u64().to_le_bytes()); + assert_ne!(a, b); + assert_eq!(&a[..4], &b[..4]); + assert_eq!(&a[4..12], &b[8..]); + + let mut c = [0; 16]; + c[..8].copy_from_slice(&rng3.next_u64().to_le_bytes()); + c[8..12].copy_from_slice(&rng3.next_u32().to_le_bytes()); + c[12..].copy_from_slice(&rng3.next_u32().to_le_bytes()); + assert_eq!(b, c); + } +} diff --git a/rand_core/src/error.rs b/rand_core/src/error.rs deleted file mode 100644 index a64c430da8b..00000000000 --- a/rand_core/src/error.rs +++ /dev/null @@ -1,228 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! Error types - -use core::fmt; -use core::num::NonZeroU32; - -#[cfg(feature = "std")] use std::boxed::Box; - -/// Error type of random number generators -/// -/// In order to be compatible with `std` and `no_std`, this type has two -/// possible implementations: with `std` a boxed `Error` trait object is stored, -/// while with `no_std` we merely store an error code. -pub struct Error { - #[cfg(feature = "std")] - inner: Box, - #[cfg(not(feature = "std"))] - code: NonZeroU32, -} - -impl Error { - /// Codes at or above this point can be used by users to define their own - /// custom errors. - /// - /// This has a fixed value of `(1 << 31) + (1 << 30) = 0xC000_0000`, - /// therefore the number of values available for custom codes is `1 << 30`. - /// - /// This is identical to [`getrandom::Error::CUSTOM_START`](https://docs.rs/getrandom/latest/getrandom/struct.Error.html#associatedconstant.CUSTOM_START). - pub const CUSTOM_START: u32 = (1 << 31) + (1 << 30); - /// Codes below this point represent OS Errors (i.e. positive i32 values). - /// Codes at or above this point, but below [`Error::CUSTOM_START`] are - /// reserved for use by the `rand` and `getrandom` crates. - /// - /// This is identical to [`getrandom::Error::INTERNAL_START`](https://docs.rs/getrandom/latest/getrandom/struct.Error.html#associatedconstant.INTERNAL_START). - pub const INTERNAL_START: u32 = 1 << 31; - - /// Construct from any type supporting `std::error::Error` - /// - /// Available only when configured with `std`. - /// - /// See also `From`, which is available with and without `std`. - #[cfg(feature = "std")] - #[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] - #[inline] - pub fn new(err: E) -> Self - where - E: Into>, - { - Error { inner: err.into() } - } - - /// Reference the inner error (`std` only) - /// - /// When configured with `std`, this is a trivial operation and never - /// panics. Without `std`, this method is simply unavailable. - #[cfg(feature = "std")] - #[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] - #[inline] - pub fn inner(&self) -> &(dyn std::error::Error + Send + Sync + 'static) { - &*self.inner - } - - /// Unwrap the inner error (`std` only) - /// - /// When configured with `std`, this is a trivial operation and never - /// panics. Without `std`, this method is simply unavailable. - #[cfg(feature = "std")] - #[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] - #[inline] - pub fn take_inner(self) -> Box { - self.inner - } - - /// Extract the raw OS error code (if this error came from the OS) - /// - /// This method is identical to `std::io::Error::raw_os_error()`, except - /// that it works in `no_std` contexts. If this method returns `None`, the - /// error value can still be formatted via the `Diplay` implementation. - #[inline] - pub fn raw_os_error(&self) -> Option { - #[cfg(feature = "std")] - { - if let Some(e) = self.inner.downcast_ref::() { - return e.raw_os_error(); - } - } - match self.code() { - Some(code) if u32::from(code) < Self::INTERNAL_START => Some(u32::from(code) as i32), - _ => None, - } - } - - /// Retrieve the error code, if any. - /// - /// If this `Error` was constructed via `From`, then this method - /// will return this `NonZeroU32` code (for `no_std` this is always the - /// case). Otherwise, this method will return `None`. - #[inline] - pub fn code(&self) -> Option { - #[cfg(feature = "std")] - { - self.inner.downcast_ref::().map(|c| c.0) - } - #[cfg(not(feature = "std"))] - { - Some(self.code) - } - } -} - -impl fmt::Debug for Error { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - #[cfg(feature = "std")] - { - write!(f, "Error {{ inner: {:?} }}", self.inner) - } - #[cfg(all(feature = "getrandom", not(feature = "std")))] - { - getrandom::Error::from(self.code).fmt(f) - } - #[cfg(not(feature = "getrandom"))] - { - write!(f, "Error {{ code: {} }}", self.code) - } - } -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - #[cfg(feature = "std")] - { - write!(f, "{}", self.inner) - } - #[cfg(all(feature = "getrandom", not(feature = "std")))] - { - getrandom::Error::from(self.code).fmt(f) - } - #[cfg(not(feature = "getrandom"))] - { - write!(f, "error code {}", self.code) - } - } -} - -impl From for Error { - #[inline] - fn from(code: NonZeroU32) -> Self { - #[cfg(feature = "std")] - { - Error { - inner: Box::new(ErrorCode(code)), - } - } - #[cfg(not(feature = "std"))] - { - Error { code } - } - } -} - -#[cfg(feature = "getrandom")] -impl From for Error { - #[inline] - fn from(error: getrandom::Error) -> Self { - #[cfg(feature = "std")] - { - Error { - inner: Box::new(error), - } - } - #[cfg(not(feature = "std"))] - { - Error { code: error.code() } - } - } -} - -#[cfg(feature = "std")] -impl std::error::Error for Error { - #[inline] - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - self.inner.source() - } -} - -#[cfg(feature = "std")] -impl From for std::io::Error { - #[inline] - fn from(error: Error) -> Self { - if let Some(code) = error.raw_os_error() { - std::io::Error::from_raw_os_error(code) - } else { - std::io::Error::new(std::io::ErrorKind::Other, error) - } - } -} - -#[cfg(feature = "std")] -#[derive(Debug, Copy, Clone)] -struct ErrorCode(NonZeroU32); - -#[cfg(feature = "std")] -impl fmt::Display for ErrorCode { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "error code {}", self.0) - } -} - -#[cfg(feature = "std")] -impl std::error::Error for ErrorCode {} - -#[cfg(test)] -mod test { - #[cfg(feature = "getrandom")] - #[test] - fn test_error_codes() { - // Make sure the values are the same as in `getrandom`. - assert_eq!(super::Error::CUSTOM_START, getrandom::Error::CUSTOM_START); - assert_eq!(super::Error::INTERNAL_START, getrandom::Error::INTERNAL_START); - } -} diff --git a/rand_core/src/impls.rs b/rand_core/src/impls.rs index 2588a72ea3f..584a4c16f10 100644 --- a/rand_core/src/impls.rs +++ b/rand_core/src/impls.rs @@ -19,6 +19,7 @@ use crate::RngCore; use core::cmp::min; +use zerocopy::{Immutable, IntoBytes}; /// Implement `next_u64` via `next_u32`, little-endian order. pub fn next_u64_via_u32(rng: &mut R) -> u64 { @@ -52,36 +53,42 @@ pub fn fill_bytes_via_next(rng: &mut R, dest: &mut [u8]) { } } -macro_rules! fill_via_chunks { - ($src:expr, $dst:expr, $ty:ty) => {{ - const SIZE: usize = core::mem::size_of::<$ty>(); - let chunk_size_u8 = min($src.len() * SIZE, $dst.len()); - let chunk_size = (chunk_size_u8 + SIZE - 1) / SIZE; - - // The following can be replaced with safe code, but unfortunately it's - // ca. 8% slower. - if cfg!(target_endian = "little") { - unsafe { - core::ptr::copy_nonoverlapping( - $src.as_ptr() as *const u8, - $dst.as_mut_ptr(), - chunk_size_u8); - } - } else { - for (&n, chunk) in $src.iter().zip($dst.chunks_mut(SIZE)) { - let tmp = n.to_le(); - let src_ptr = &tmp as *const $ty as *const u8; - unsafe { - core::ptr::copy_nonoverlapping( - src_ptr, - chunk.as_mut_ptr(), - chunk.len()); - } - } +trait Observable: IntoBytes + Immutable + Copy { + fn to_le(self) -> Self; +} +impl Observable for u32 { + fn to_le(self) -> Self { + self.to_le() + } +} +impl Observable for u64 { + fn to_le(self) -> Self { + self.to_le() + } +} + +/// Fill dest from src +/// +/// Returns `(n, byte_len)`. `src[..n]` is consumed (and possibly mutated), +/// `dest[..byte_len]` is filled. `src[n..]` and `dest[byte_len..]` are left +/// unaltered. +fn fill_via_chunks(src: &mut [T], dest: &mut [u8]) -> (usize, usize) { + let size = core::mem::size_of::(); + let byte_len = min(core::mem::size_of_val(src), dest.len()); + let num_chunks = (byte_len + size - 1) / size; + + // Byte-swap for portability of results. This must happen before copying + // since the size of dest is not guaranteed to be a multiple of T or to be + // sufficiently aligned. + if cfg!(target_endian = "big") { + for x in &mut src[..num_chunks] { + *x = x.to_le(); } + } + + dest[..byte_len].copy_from_slice(&<[T]>::as_bytes(&src[..num_chunks])[..byte_len]); - (chunk_size, chunk_size_u8) - }}; + (num_chunks, byte_len) } /// Implement `fill_bytes` by reading chunks from the output buffer of a block @@ -89,6 +96,9 @@ macro_rules! fill_via_chunks { /// /// The return values are `(consumed_u32, filled_u8)`. /// +/// On big-endian systems, endianness of `src[..consumed_u32]` values is +/// swapped. No other adjustments to `src` are made. +/// /// `filled_u8` is the number of filled bytes in `dest`, which may be less than /// the length of `dest`. /// `consumed_u32` is the number of words consumed from `src`, which is the same @@ -114,22 +124,26 @@ macro_rules! fill_via_chunks { /// } /// } /// ``` -pub fn fill_via_u32_chunks(src: &[u32], dest: &mut [u8]) -> (usize, usize) { - fill_via_chunks!(src, dest, u32) +pub fn fill_via_u32_chunks(src: &mut [u32], dest: &mut [u8]) -> (usize, usize) { + fill_via_chunks(src, dest) } /// Implement `fill_bytes` by reading chunks from the output buffer of a block /// based RNG. /// /// The return values are `(consumed_u64, filled_u8)`. +/// +/// On big-endian systems, endianness of `src[..consumed_u64]` values is +/// swapped. No other adjustments to `src` are made. +/// /// `filled_u8` is the number of filled bytes in `dest`, which may be less than /// the length of `dest`. /// `consumed_u64` is the number of words consumed from `src`, which is the same /// as `filled_u8 / 8` rounded up. /// /// See `fill_via_u32_chunks` for an example. -pub fn fill_via_u64_chunks(src: &[u64], dest: &mut [u8]) -> (usize, usize) { - fill_via_chunks!(src, dest, u64) +pub fn fill_via_u64_chunks(src: &mut [u64], dest: &mut [u8]) -> (usize, usize) { + fill_via_chunks(src, dest) } /// Implement `next_u32` via `fill_bytes`, little-endian order. @@ -152,33 +166,41 @@ mod test { #[test] fn test_fill_via_u32_chunks() { - let src = [1, 2, 3]; + let src_orig = [1, 2, 3]; + + let mut src = src_orig; let mut dst = [0u8; 11]; - assert_eq!(fill_via_u32_chunks(&src, &mut dst), (3, 11)); + assert_eq!(fill_via_u32_chunks(&mut src, &mut dst), (3, 11)); assert_eq!(dst, [1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0]); + let mut src = src_orig; let mut dst = [0u8; 13]; - assert_eq!(fill_via_u32_chunks(&src, &mut dst), (3, 12)); + assert_eq!(fill_via_u32_chunks(&mut src, &mut dst), (3, 12)); assert_eq!(dst, [1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 0]); + let mut src = src_orig; let mut dst = [0u8; 5]; - assert_eq!(fill_via_u32_chunks(&src, &mut dst), (2, 5)); + assert_eq!(fill_via_u32_chunks(&mut src, &mut dst), (2, 5)); assert_eq!(dst, [1, 0, 0, 0, 2]); } #[test] fn test_fill_via_u64_chunks() { - let src = [1, 2]; + let src_orig = [1, 2]; + + let mut src = src_orig; let mut dst = [0u8; 11]; - assert_eq!(fill_via_u64_chunks(&src, &mut dst), (2, 11)); + assert_eq!(fill_via_u64_chunks(&mut src, &mut dst), (2, 11)); assert_eq!(dst, [1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0]); + let mut src = src_orig; let mut dst = [0u8; 17]; - assert_eq!(fill_via_u64_chunks(&src, &mut dst), (2, 16)); + assert_eq!(fill_via_u64_chunks(&mut src, &mut dst), (2, 16)); assert_eq!(dst, [1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0]); + let mut src = src_orig; let mut dst = [0u8; 5]; - assert_eq!(fill_via_u64_chunks(&src, &mut dst), (1, 5)); + assert_eq!(fill_via_u64_chunks(&mut src, &mut dst), (1, 5)); assert_eq!(dst, [1, 0, 0, 0, 0]); } } diff --git a/rand_core/src/le.rs b/rand_core/src/le.rs index fa338928403..cee84c2f327 100644 --- a/rand_core/src/le.rs +++ b/rand_core/src/le.rs @@ -11,21 +11,29 @@ //! Little-Endian order has been chosen for internal usage; this makes some //! useful functions available. -use core::convert::TryInto; - /// Reads unsigned 32 bit integers from `src` into `dst`. +/// +/// # Panics +/// +/// If `dst` has insufficient space (`4*dst.len() < src.len()`). #[inline] +#[track_caller] pub fn read_u32_into(src: &[u8], dst: &mut [u32]) { - assert!(4 * src.len() >= dst.len()); + assert!(src.len() >= 4 * dst.len()); for (out, chunk) in dst.iter_mut().zip(src.chunks_exact(4)) { *out = u32::from_le_bytes(chunk.try_into().unwrap()); } } /// Reads unsigned 64 bit integers from `src` into `dst`. +/// +/// # Panics +/// +/// If `dst` has insufficient space (`8*dst.len() < src.len()`). #[inline] +#[track_caller] pub fn read_u64_into(src: &[u8], dst: &mut [u64]) { - assert!(8 * src.len() >= dst.len()); + assert!(src.len() >= 8 * dst.len()); for (out, chunk) in dst.iter_mut().zip(src.chunks_exact(8)) { *out = u64::from_le_bytes(chunk.try_into().unwrap()); } diff --git a/rand_core/src/lib.rs b/rand_core/src/lib.rs index ff553a335ae..9faff9c752f 100644 --- a/rand_core/src/lib.rs +++ b/rand_core/src/lib.rs @@ -19,9 +19,6 @@ //! [`SeedableRng`] is an extension trait for construction from fixed seeds and //! other random number generators. //! -//! [`Error`] is provided for error-handling. It is safe to use in `no_std` -//! environments. -//! //! The [`impls`] and [`le`] sub-modules include a few small functions to assist //! implementation of [`RngCore`]. //! @@ -35,32 +32,28 @@ #![deny(missing_docs)] #![deny(missing_debug_implementations)] #![doc(test(attr(allow(unused_variables), deny(warnings))))] -#![cfg_attr(doc_cfg, feature(doc_cfg))] +#![cfg_attr(docsrs, feature(doc_auto_cfg))] #![no_std] -use core::convert::AsMut; -use core::default::Default; - -#[cfg(feature = "std")] extern crate std; -#[cfg(feature = "alloc")] extern crate alloc; -#[cfg(feature = "alloc")] use alloc::boxed::Box; - -pub use error::Error; -#[cfg(feature = "getrandom")] pub use os::OsRng; +#[cfg(feature = "std")] +extern crate std; +use core::{fmt, ops::DerefMut}; pub mod block; -mod error; pub mod impls; pub mod le; -#[cfg(feature = "getrandom")] mod os; +#[cfg(feature = "os_rng")] +mod os; +#[cfg(feature = "os_rng")] +pub use os::{OsError, OsRng}; -/// The core of a random number generator. +/// Implementation-level interface for RNGs /// /// This trait encapsulates the low-level functionality common to all /// generators, and is the "back end", to be implemented by generators. -/// End users should normally use the `Rng` trait from the [`rand`] crate, +/// End users should normally use the [`rand::Rng`] trait /// which is automatically implemented for every type implementing `RngCore`. /// /// Three different methods for generating random data are provided since the @@ -71,20 +64,24 @@ pub mod le; /// [`next_u32`] and [`next_u64`] methods, implementations may discard some /// random bits for efficiency. /// -/// The [`try_fill_bytes`] method is a variant of [`fill_bytes`] allowing error -/// handling; it is not deemed sufficiently useful to add equivalents for -/// [`next_u32`] or [`next_u64`] since the latter methods are almost always used -/// with algorithmic generators (PRNGs), which are normally infallible. +/// Implementers should produce bits uniformly. Pathological RNGs (e.g. always +/// returning the same value, or never setting certain bits) can break rejection +/// sampling used by random distributions, and also break other RNGs when +/// seeding them via [`SeedableRng::from_rng`]. /// /// Algorithmic generators implementing [`SeedableRng`] should normally have /// *portable, reproducible* output, i.e. fix Endianness when converting values /// to avoid platform differences, and avoid making any changes which affect /// output (except by communicating that the release has breaking changes). /// -/// Typically implementators will implement only one of the methods available +/// Typically an RNG will implement only one of the methods available /// in this trait directly, then use the helper functions from the /// [`impls`] module to implement the other methods. /// +/// Note that implementors of [`RngCore`] also automatically implement +/// the [`TryRngCore`] trait with the `Error` associated type being +/// equal to [`Infallible`]. +/// /// It is recommended that implementations also implement: /// /// - `Debug` with a custom implementation which *does not* print any internal @@ -105,7 +102,7 @@ pub mod le; /// /// ``` /// #![allow(dead_code)] -/// use rand_core::{RngCore, Error, impls}; +/// use rand_core::{RngCore, impls}; /// /// struct CountingRng(u64); /// @@ -119,21 +116,17 @@ pub mod le; /// self.0 /// } /// -/// fn fill_bytes(&mut self, dest: &mut [u8]) { -/// impls::fill_bytes_via_next(self, dest) -/// } -/// -/// fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { -/// Ok(self.fill_bytes(dest)) +/// fn fill_bytes(&mut self, dst: &mut [u8]) { +/// impls::fill_bytes_via_next(self, dst) /// } /// } /// ``` /// -/// [`rand`]: https://docs.rs/rand -/// [`try_fill_bytes`]: RngCore::try_fill_bytes +/// [`rand::Rng`]: https://docs.rs/rand/latest/rand/trait.Rng.html /// [`fill_bytes`]: RngCore::fill_bytes /// [`next_u32`]: RngCore::next_u32 /// [`next_u64`]: RngCore::next_u64 +/// [`Infallible`]: core::convert::Infallible pub trait RngCore { /// Return the next random `u32`. /// @@ -153,34 +146,37 @@ pub trait RngCore { /// /// RNGs must implement at least one method from this trait directly. In /// the case this method is not implemented directly, it can be implemented - /// via [`impls::fill_bytes_via_next`] or - /// via [`RngCore::try_fill_bytes`]; if this generator can - /// fail the implementation must choose how best to handle errors here - /// (e.g. panic with a descriptive message or log a warning and retry a few - /// times). + /// via [`impls::fill_bytes_via_next`]. /// /// This method should guarantee that `dest` is entirely filled /// with new data, and may panic if this is impossible /// (e.g. reading past the end of a file that is being used as the /// source of randomness). - fn fill_bytes(&mut self, dest: &mut [u8]); + fn fill_bytes(&mut self, dst: &mut [u8]); +} - /// Fill `dest` entirely with random data. - /// - /// This is the only method which allows an RNG to report errors while - /// generating random data thus making this the primary method implemented - /// by external (true) RNGs (e.g. `OsRng`) which can fail. It may be used - /// directly to generate keys and to seed (infallible) PRNGs. - /// - /// Other than error handling, this method is identical to [`RngCore::fill_bytes`]; - /// thus this may be implemented using `Ok(self.fill_bytes(dest))` or - /// `fill_bytes` may be implemented with - /// `self.try_fill_bytes(dest).unwrap()` or more specific error handling. - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error>; +impl RngCore for T +where + T::Target: RngCore, +{ + #[inline] + fn next_u32(&mut self) -> u32 { + self.deref_mut().next_u32() + } + + #[inline] + fn next_u64(&mut self) -> u64 { + self.deref_mut().next_u64() + } + + #[inline] + fn fill_bytes(&mut self, dst: &mut [u8]) { + self.deref_mut().fill_bytes(dst); + } } -/// A marker trait used to indicate that an [`RngCore`] or [`BlockRngCore`] -/// implementation is supposed to be cryptographically secure. +/// A marker trait used to indicate that an [`RngCore`] implementation is +/// supposed to be cryptographically secure. /// /// *Cryptographically secure generators*, also known as *CSPRNGs*, should /// satisfy an additional properties over other generators: given the first @@ -191,7 +187,7 @@ pub trait RngCore { /// Some generators may satisfy an additional property, however this is not /// required by this trait: if the CSPRNG's state is revealed, it should not be /// computationally-feasible to reconstruct output prior to this. Some other -/// generators allow backwards-computation and are consided *reversible*. +/// generators allow backwards-computation and are considered *reversible*. /// /// Note that this trait is provided for guidance only and cannot guarantee /// suitability for cryptographic applications. In general it should only be @@ -200,8 +196,110 @@ pub trait RngCore { /// Note also that use of a `CryptoRng` does not protect against other /// weaknesses such as seeding from a weak entropy source or leaking state. /// +/// Note that implementors of [`CryptoRng`] also automatically implement +/// the [`TryCryptoRng`] trait. +/// /// [`BlockRngCore`]: block::BlockRngCore -pub trait CryptoRng {} +/// [`Infallible`]: core::convert::Infallible +pub trait CryptoRng: RngCore {} + +impl CryptoRng for T where T::Target: CryptoRng {} + +/// A potentially fallible variant of [`RngCore`] +/// +/// This trait is a generalization of [`RngCore`] to support potentially- +/// fallible IO-based generators such as [`OsRng`]. +/// +/// All implementations of [`RngCore`] automatically support this `TryRngCore` +/// trait, using [`Infallible`][core::convert::Infallible] as the associated +/// `Error` type. +/// +/// An implementation of this trait may be made compatible with code requiring +/// an [`RngCore`] through [`TryRngCore::unwrap_err`]. The resulting RNG will +/// panic in case the underlying fallible RNG yields an error. +pub trait TryRngCore { + /// The type returned in the event of a RNG error. + type Error: fmt::Debug + fmt::Display; + + /// Return the next random `u32`. + fn try_next_u32(&mut self) -> Result; + /// Return the next random `u64`. + fn try_next_u64(&mut self) -> Result; + /// Fill `dest` entirely with random data. + fn try_fill_bytes(&mut self, dst: &mut [u8]) -> Result<(), Self::Error>; + + /// Wrap RNG with the [`UnwrapErr`] wrapper. + fn unwrap_err(self) -> UnwrapErr + where + Self: Sized, + { + UnwrapErr(self) + } + + /// Convert an [`RngCore`] to a [`RngReadAdapter`]. + #[cfg(feature = "std")] + fn read_adapter(&mut self) -> RngReadAdapter<'_, Self> + where + Self: Sized, + { + RngReadAdapter { inner: self } + } +} + +// Note that, unfortunately, this blanket impl prevents us from implementing +// `TryRngCore` for types which can be dereferenced to `TryRngCore`, i.e. `TryRngCore` +// will not be automatically implemented for `&mut R`, `Box`, etc. +impl TryRngCore for R { + type Error = core::convert::Infallible; + + #[inline] + fn try_next_u32(&mut self) -> Result { + Ok(self.next_u32()) + } + + #[inline] + fn try_next_u64(&mut self) -> Result { + Ok(self.next_u64()) + } + + #[inline] + fn try_fill_bytes(&mut self, dst: &mut [u8]) -> Result<(), Self::Error> { + self.fill_bytes(dst); + Ok(()) + } +} + +/// A marker trait used to indicate that a [`TryRngCore`] implementation is +/// supposed to be cryptographically secure. +/// +/// See [`CryptoRng`] docs for more information about cryptographically secure generators. +pub trait TryCryptoRng: TryRngCore {} + +impl TryCryptoRng for R {} + +/// Wrapper around [`TryRngCore`] implementation which implements [`RngCore`] +/// by panicking on potential errors. +#[derive(Debug, Default, Clone, Copy, Eq, PartialEq, Hash)] +pub struct UnwrapErr(pub R); + +impl RngCore for UnwrapErr { + #[inline] + fn next_u32(&mut self) -> u32 { + self.0.try_next_u32().unwrap() + } + + #[inline] + fn next_u64(&mut self) -> u64 { + self.0.try_next_u64().unwrap() + } + + #[inline] + fn fill_bytes(&mut self, dst: &mut [u8]) { + self.0.try_fill_bytes(dst).unwrap() + } +} + +impl CryptoRng for UnwrapErr {} /// A random number generator that can be explicitly seeded. /// @@ -210,7 +308,7 @@ pub trait CryptoRng {} /// /// [`rand`]: https://docs.rs/rand pub trait SeedableRng: Sized { - /// Seed type, which is restricted to types mutably-dereferencable as `u8` + /// Seed type, which is restricted to types mutably-dereferenceable as `u8` /// arrays (we recommend `[u8; N]` for some `N`). /// /// It is recommended to seed PRNGs with a seed of at least circa 100 bits, @@ -222,17 +320,17 @@ pub trait SeedableRng: Sized { /// /// # Implementing `SeedableRng` for RNGs with large seeds /// - /// Note that the required traits `core::default::Default` and - /// `core::convert::AsMut` are not implemented for large arrays - /// `[u8; N]` with `N` > 32. To be able to implement the traits required by - /// `SeedableRng` for RNGs with such large seeds, the newtype pattern can be - /// used: + /// Note that [`Default`] is not implemented for large arrays `[u8; N]` with + /// `N` > 32. To be able to implement the traits required by `SeedableRng` + /// for RNGs with such large seeds, the newtype pattern can be used: /// /// ``` /// use rand_core::SeedableRng; /// /// const N: usize = 64; + /// #[derive(Clone)] /// pub struct MyRngSeed(pub [u8; N]); + /// # #[allow(dead_code)] /// pub struct MyRng(MyRngSeed); /// /// impl Default for MyRngSeed { @@ -241,6 +339,12 @@ pub trait SeedableRng: Sized { /// } /// } /// + /// impl AsRef<[u8]> for MyRngSeed { + /// fn as_ref(&self) -> &[u8] { + /// &self.0 + /// } + /// } + /// /// impl AsMut<[u8]> for MyRngSeed { /// fn as_mut(&mut self) -> &mut [u8] { /// &mut self.0 @@ -255,7 +359,7 @@ pub trait SeedableRng: Sized { /// } /// } /// ``` - type Seed: Sized + Default + AsMut<[u8]>; + type Seed: Clone + Default + AsRef<[u8]> + AsMut<[u8]>; /// Create a new PRNG using the given seed. /// @@ -300,26 +404,36 @@ pub trait SeedableRng: Sized { /// considered a value-breaking change. fn seed_from_u64(mut state: u64) -> Self { // We use PCG32 to generate a u32 sequence, and copy to the seed - const MUL: u64 = 6364136223846793005; - const INC: u64 = 11634580027462260723; + fn pcg32(state: &mut u64) -> [u8; 4] { + const MUL: u64 = 6364136223846793005; + const INC: u64 = 11634580027462260723; - let mut seed = Self::Seed::default(); - for chunk in seed.as_mut().chunks_mut(4) { // We advance the state first (to get away from the input value, // in case it has low Hamming Weight). - state = state.wrapping_mul(MUL).wrapping_add(INC); + *state = state.wrapping_mul(MUL).wrapping_add(INC); + let state = *state; // Use PCG output function with to_le to generate x: let xorshifted = (((state >> 18) ^ state) >> 27) as u32; let rot = (state >> 59) as u32; let x = xorshifted.rotate_right(rot); - chunk.copy_from_slice(&x.to_le_bytes()); + x.to_le_bytes() + } + + let mut seed = Self::Seed::default(); + let mut iter = seed.as_mut().chunks_exact_mut(4); + for chunk in &mut iter { + chunk.copy_from_slice(&pcg32(&mut state)); + } + let rem = iter.into_remainder(); + if !rem.is_empty() { + rem.copy_from_slice(&pcg32(&mut state)[..rem.len()]); } Self::from_seed(seed) } - /// Create a new PRNG seeded from another `Rng`. + /// Create a new PRNG seeded from an infallible `Rng`. /// /// This may be useful when needing to rapidly seed many PRNGs from a master /// PRNG, and to allow forking of PRNGs. It may be considered deterministic. @@ -343,7 +457,16 @@ pub trait SeedableRng: Sized { /// (in prior versions this was not required). /// /// [`rand`]: https://docs.rs/rand - fn from_rng(mut rng: R) -> Result { + fn from_rng(rng: &mut impl RngCore) -> Self { + let mut seed = Self::Seed::default(); + rng.fill_bytes(seed.as_mut()); + Self::from_seed(seed) + } + + /// Create a new PRNG seeded from a potentially fallible `Rng`. + /// + /// See [`from_rng`][SeedableRng::from_rng] docs for more information. + fn try_from_rng(rng: &mut R) -> Result { let mut seed = Self::Seed::default(); rng.try_fill_bytes(seed.as_mut())?; Ok(Self::from_seed(seed)) @@ -354,91 +477,77 @@ pub trait SeedableRng: Sized { /// This method is the recommended way to construct non-deterministic PRNGs /// since it is convenient and secure. /// + /// Note that this method may panic on (extremely unlikely) [`getrandom`] errors. + /// If it's not desirable, use the [`try_from_os_rng`] method instead. + /// /// In case the overhead of using [`getrandom`] to seed *many* PRNGs is an /// issue, one may prefer to seed from a local PRNG, e.g. - /// `from_rng(thread_rng()).unwrap()`. + /// `from_rng(rand::rng()).unwrap()`. /// /// # Panics /// /// If [`getrandom`] is unable to provide secure entropy this method will panic. /// /// [`getrandom`]: https://docs.rs/getrandom - #[cfg(feature = "getrandom")] - #[cfg_attr(doc_cfg, doc(cfg(feature = "getrandom")))] - fn from_entropy() -> Self { - let mut seed = Self::Seed::default(); - if let Err(err) = getrandom::getrandom(seed.as_mut()) { - panic!("from_entropy failed: {}", err); + /// [`try_from_os_rng`]: SeedableRng::try_from_os_rng + #[cfg(feature = "os_rng")] + fn from_os_rng() -> Self { + match Self::try_from_os_rng() { + Ok(res) => res, + Err(err) => panic!("from_os_rng failed: {}", err), } - Self::from_seed(seed) - } -} - -// Implement `RngCore` for references to an `RngCore`. -// Force inlining all functions, so that it is up to the `RngCore` -// implementation and the optimizer to decide on inlining. -impl<'a, R: RngCore + ?Sized> RngCore for &'a mut R { - #[inline(always)] - fn next_u32(&mut self) -> u32 { - (**self).next_u32() - } - - #[inline(always)] - fn next_u64(&mut self) -> u64 { - (**self).next_u64() } - #[inline(always)] - fn fill_bytes(&mut self, dest: &mut [u8]) { - (**self).fill_bytes(dest) - } - - #[inline(always)] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - (**self).try_fill_bytes(dest) + /// Creates a new instance of the RNG seeded via [`getrandom`] without unwrapping + /// potential [`getrandom`] errors. + /// + /// In case the overhead of using [`getrandom`] to seed *many* PRNGs is an + /// issue, one may prefer to seed from a local PRNG, e.g. + /// `from_rng(&mut rand::rng()).unwrap()`. + /// + /// [`getrandom`]: https://docs.rs/getrandom + #[cfg(feature = "os_rng")] + fn try_from_os_rng() -> Result { + let mut seed = Self::Seed::default(); + getrandom::fill(seed.as_mut())?; + let res = Self::from_seed(seed); + Ok(res) } } -// Implement `RngCore` for boxed references to an `RngCore`. -// Force inlining all functions, so that it is up to the `RngCore` -// implementation and the optimizer to decide on inlining. -#[cfg(feature = "alloc")] -impl RngCore for Box { - #[inline(always)] - fn next_u32(&mut self) -> u32 { - (**self).next_u32() - } - - #[inline(always)] - fn next_u64(&mut self) -> u64 { - (**self).next_u64() - } - - #[inline(always)] - fn fill_bytes(&mut self, dest: &mut [u8]) { - (**self).fill_bytes(dest) - } - - #[inline(always)] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - (**self).try_fill_bytes(dest) - } +/// Adapter that enables reading through a [`io::Read`](std::io::Read) from a [`RngCore`]. +/// +/// # Examples +/// +/// ```no_run +/// # use std::{io, io::Read}; +/// # use std::fs::File; +/// # use rand_core::{OsRng, TryRngCore}; +/// +/// io::copy(&mut OsRng.read_adapter().take(100), &mut File::create("/tmp/random.bytes").unwrap()).unwrap(); +/// ``` +#[cfg(feature = "std")] +pub struct RngReadAdapter<'a, R: TryRngCore + ?Sized> { + inner: &'a mut R, } #[cfg(feature = "std")] -impl std::io::Read for dyn RngCore { +impl std::io::Read for RngReadAdapter<'_, R> { + #[inline] fn read(&mut self, buf: &mut [u8]) -> Result { - self.try_fill_bytes(buf)?; + self.inner.try_fill_bytes(buf).map_err(|err| { + std::io::Error::new(std::io::ErrorKind::Other, std::format!("RNG error: {err}")) + })?; Ok(buf.len()) } } -// Implement `CryptoRng` for references to an `CryptoRng`. -impl<'a, R: CryptoRng + ?Sized> CryptoRng for &'a mut R {} - -// Implement `CryptoRng` for boxed references to an `CryptoRng`. -#[cfg(feature = "alloc")] -impl CryptoRng for Box {} +#[cfg(feature = "std")] +impl std::fmt::Debug for RngReadAdapter<'_, R> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ReadAdapter").finish() + } +} #[cfg(test)] mod test { @@ -470,7 +579,7 @@ mod test { // This is the binomial distribution B(64, 0.5), so chance of // weight < 20 is binocdf(19, 64, 0.5) = 7.8e-4, and same for // weight > 44. - assert!(weight >= 20 && weight <= 44); + assert!((20..=44).contains(&weight)); for (i2, r2) in results.iter().enumerate() { if i1 == i2 { diff --git a/rand_core/src/os.rs b/rand_core/src/os.rs index 6cd1b9cf5de..49111632d9f 100644 --- a/rand_core/src/os.rs +++ b/rand_core/src/os.rs @@ -8,19 +8,18 @@ //! Interface to the random number generator of the operating system. -use crate::{impls, CryptoRng, Error, RngCore}; -use getrandom::getrandom; +use crate::{TryCryptoRng, TryRngCore}; -/// A random number generator that retrieves randomness from the -/// operating system. +/// An interface over the operating-system's random data source /// -/// This is a zero-sized struct. It can be freely constructed with `OsRng`. +/// This is a zero-sized struct. It can be freely constructed with just `OsRng`. /// /// The implementation is provided by the [getrandom] crate. Refer to /// [getrandom] documentation for details. /// -/// This struct is only available when specifying the crate feature `getrandom` -/// or `std`. When using the `rand` lib, it is also available as `rand::rngs::OsRng`. +/// This struct is available as `rand_core::OsRng` and as `rand::rngs::OsRng`. +/// In both cases, this requires the crate feature `os_rng` or `std` +/// (enabled by default in `rand` but not in `rand_core`). /// /// # Blocking and error handling /// @@ -31,55 +30,86 @@ use getrandom::getrandom; /// /// After the first successful call, it is highly unlikely that failures or /// significant delays will occur (although performance should be expected to -/// be much slower than a user-space PRNG). +/// be much slower than a user-space +/// [PRNG](https://rust-random.github.io/book/guide-gen.html#pseudo-random-number-generators)). /// /// # Usage example /// ``` -/// use rand_core::{RngCore, OsRng}; +/// use rand_core::{TryRngCore, OsRng}; /// /// let mut key = [0u8; 16]; -/// OsRng.fill_bytes(&mut key); -/// let random_u64 = OsRng.next_u64(); +/// OsRng.try_fill_bytes(&mut key).unwrap(); +/// let random_u64 = OsRng.try_next_u64().unwrap(); /// ``` /// /// [getrandom]: https://crates.io/crates/getrandom -#[cfg_attr(doc_cfg, doc(cfg(feature = "getrandom")))] #[derive(Clone, Copy, Debug, Default)] pub struct OsRng; -impl CryptoRng for OsRng {} +/// Error type of [`OsRng`] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct OsError(getrandom::Error); -impl RngCore for OsRng { - fn next_u32(&mut self) -> u32 { - impls::next_u32_via_fill(self) +impl core::fmt::Display for OsError { + #[inline] + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + self.0.fmt(f) } +} - fn next_u64(&mut self) -> u64 { - impls::next_u64_via_fill(self) +// NOTE: this can use core::error::Error from rustc 1.81.0 +#[cfg(feature = "std")] +impl std::error::Error for OsError { + #[inline] + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + std::error::Error::source(&self.0) } +} - fn fill_bytes(&mut self, dest: &mut [u8]) { - if let Err(e) = self.try_fill_bytes(dest) { - panic!("Error: {}", e); - } +impl OsError { + /// Extract the raw OS error code (if this error came from the OS) + /// + /// This method is identical to [`std::io::Error::raw_os_error()`][1], except + /// that it works in `no_std` contexts. If this method returns `None`, the + /// error value can still be formatted via the `Display` implementation. + /// + /// [1]: https://doc.rust-lang.org/std/io/struct.Error.html#method.raw_os_error + #[inline] + pub fn raw_os_error(self) -> Option { + self.0.raw_os_error() } +} - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - getrandom(dest)?; - Ok(()) +impl TryRngCore for OsRng { + type Error = OsError; + + #[inline] + fn try_next_u32(&mut self) -> Result { + getrandom::u32().map_err(OsError) + } + + #[inline] + fn try_next_u64(&mut self) -> Result { + getrandom::u64().map_err(OsError) + } + + #[inline] + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Self::Error> { + getrandom::fill(dest).map_err(OsError) } } +impl TryCryptoRng for OsRng {} + #[test] fn test_os_rng() { - let x = OsRng.next_u64(); - let y = OsRng.next_u64(); + let x = OsRng.try_next_u64().unwrap(); + let y = OsRng.try_next_u64().unwrap(); assert!(x != 0); assert!(x != y); } #[test] fn test_construction() { - let mut rng = OsRng::default(); - assert!(rng.next_u64() != 0); + assert!(OsRng.try_next_u64().unwrap() != 0); } diff --git a/rand_distr/CHANGELOG.md b/rand_distr/CHANGELOG.md index 63cd1b9d761..81fa3a3c4bc 100644 --- a/rand_distr/CHANGELOG.md +++ b/rand_distr/CHANGELOG.md @@ -4,6 +4,69 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.5.0] - 2025-01-27 + +### Dependencies and features +- Bump the MSRV to 1.61.0 (#1207, #1246, #1269, #1341, #1416); note that 1.60.0 may work for dependents when using `--ignore-rust-version` +- Update to `rand` v0.9.0 (#1558) +- Rename feature `serde1` to `serde` (#1477) + +### API changes +- Make distributions comparable with `PartialEq` (#1218) +- `Dirichlet` now uses `const` generics, which means that its size is required at compile time (#1292) +- The `Dirichlet::new_with_size` constructor was removed (#1292) +- Add `WeightedIndexTree` (#1372, #1444) +- Add `PertBuilder` to allow specification of `mean` or `mode` (#1452) +- Rename `Zeta`'s parameter `a` to `s` (#1466) +- Mark `WeightError`, `PoissonError`, `BinomialError` as `#[non_exhaustive]` (#1480) +- Remove support for usage of `isize` as a `WeightedAliasIndex` weight (#1487) +- Change parameter type of `Zipf::new`: `n` is now floating-point (#1518) + +### API changes: renames +- Move `Slice` -> `slice::Choose`, `EmptySlice` -> `slice::Empty` (#1548) +- Rename trait `DistString` -> `SampleString` (#1548) +- Rename `DistIter` -> `Iter`, `DistMap` -> `Map` (#1548) +- Move `{Weight, WeightError, WeightedIndex}` -> `weighted::{Weight, Error, WeightedIndex}` (#1548) +- Move `weighted_alias::{AliasableWeight, WeightedAliasIndex}` -> `weighted::{..}` (#1548) +- Move `weighted_tree::WeightedTreeIndex` -> `weighted::WeightedTreeIndex` (#1548) + +### Testing +- Add Kolmogorov Smirnov tests for distributions (#1494, #1504, #1525, #1530) + +### Fixes +- Fix Knuth's method so `Poisson` doesn't return -1.0 for small lambda (#1284) +- Fix `Poisson` distribution instantiation so it return an error if lambda is infinite (#1291) +- Fix Dirichlet sample for small alpha values to avoid NaN samples (#1209) +- Fix infinite loop in `Binomial` distribution (#1325) +- Fix `Pert` distribution where `mode` is close to `(min + max) / 2` (#1452) +- Fix panic in Binomial (#1484) +- Limit the maximal acceptable lambda for `Poisson` to solve (#1312) (#1498) +- Fix bug in `Hypergeometric`, this is a Value-breaking change (#1510) + +### Other changes +- Remove unused fields from `Gamma`, `NormalInverseGaussian` and `Zipf` distributions (#1184) + This breaks serialization compatibility with older versions. +- Add plots for `rand_distr` distributions to documentation (#1434) +- Move some of the computations in Binomial from `sample` to `new` (#1484) + +## [0.4.3] - 2021-12-30 +- Fix `no_std` build (#1208) + +## [0.4.2] - 2021-09-18 +- New `Zeta` and `Zipf` distributions (#1136) +- New `SkewNormal` distribution (#1149) +- New `Gumbel` and `Frechet` distributions (#1168, #1171) + +## [0.4.1] - 2021-06-15 +- Empirically test PDF of normal distribution (#1121) +- Correctly document `no_std` support (#1100) +- Add `std_math` feature to prefer `std` over `libm` for floating point math (#1100) +- Add mean and std_dev accessors to Normal (#1114) +- Make sure all distributions and their error types implement `Error`, `Display`, `Clone`, + `Copy`, `PartialEq` and `Eq` as appropriate (#1126) +- Port benchmarks to use Criterion crate (#1116) +- Support serde for distributions (#1141) + ## [0.4.0] - 2020-12-18 - Bump `rand` to v0.8.0 - New `Geometric`, `StandardGeometric` and `Hypergeometric` distributions (#1062) diff --git a/rand_distr/Cargo.toml b/rand_distr/Cargo.toml index 027d7753adf..dd55673777c 100644 --- a/rand_distr/Cargo.toml +++ b/rand_distr/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rand_distr" -version = "0.4.0" +version = "0.5.0" authors = ["The Rand Project Developers"] license = "MIT OR Apache-2.0" readme = "README.md" @@ -11,22 +11,38 @@ description = """ Sampling from random number distributions """ keywords = ["random", "rng", "distribution", "probability"] -categories = ["algorithms"] -edition = "2018" -include = ["src/", "LICENSE-*", "README.md", "CHANGELOG.md", "COPYRIGHT"] +categories = ["algorithms", "no-std"] +edition = "2021" +rust-version = "1.63" +include = ["/src", "LICENSE-*", "README.md", "CHANGELOG.md", "COPYRIGHT"] -[dependencies] -rand = { path = "..", version = "0.8.0", default-features = false } -num-traits = { version = "0.2", default-features = false, features = ["libm"] } +[package.metadata.docs.rs] +features = ["serde"] +rustdoc-args = ["--generate-link-to-definition"] [features] default = ["std"] std = ["alloc", "rand/std"] alloc = ["rand/alloc"] +# Use std's floating-point arithmetic instead of libm. +# Note that any other crate depending on `num-traits`'s `std` +# feature (default-enabled) will have the same effect. +std_math = ["num-traits/std"] + +serde = ["dep:serde", "dep:serde_with", "rand/serde"] + +[dependencies] +rand = { path = "..", version = "0.9.0", default-features = false } +num-traits = { version = "0.2", default-features = false, features = ["libm"] } +serde = { version = "1.0.103", features = ["derive"], optional = true } +serde_with = { version = ">= 3.0, <= 3.11", optional = true } + [dev-dependencies] -rand_pcg = { version = "0.3.0", path = "../rand_pcg" } +rand_pcg = { version = "0.9.0", path = "../rand_pcg" } # For inline examples -rand = { path = "..", version = "0.8.0", default-features = false, features = ["std_rng", "std"] } +rand = { path = "..", version = "0.9.0", features = ["small_rng"] } # Histogram implementation for testing uniformity -average = "0.10.3" +average = { version = "0.15", features = [ "std" ] } +# Special functions for testing distributions +special = "0.11.0" diff --git a/rand_distr/LICENSE-APACHE b/rand_distr/LICENSE-APACHE index 17d74680f8c..455787c2334 100644 --- a/rand_distr/LICENSE-APACHE +++ b/rand_distr/LICENSE-APACHE @@ -185,17 +185,3 @@ APPENDIX: How to apply the Apache License to your work. file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. - -Copyright [yyyy] [name of copyright owner] - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. diff --git a/rand_distr/README.md b/rand_distr/README.md index 29b5fe0853e..193d54123d1 100644 --- a/rand_distr/README.md +++ b/rand_distr/README.md @@ -1,27 +1,43 @@ # rand_distr -[![Test Status](https://github.com/rust-random/rand/workflows/Tests/badge.svg?event=push)](https://github.com/rust-random/rand/actions) +[![Test Status](https://github.com/rust-random/rand/actions/workflows/test.yml/badge.svg?event=push)](https://github.com/rust-random/rand/actions) [![Latest version](https://img.shields.io/crates/v/rand_distr.svg)](https://crates.io/crates/rand_distr) -[[![Book](https://img.shields.io/badge/book-master-yellow.svg)](https://rust-random.github.io/book/) +[![Book](https://img.shields.io/badge/book-master-yellow.svg)](https://rust-random.github.io/book/) [![API](https://img.shields.io/badge/api-master-yellow.svg)](https://rust-random.github.io/rand/rand_distr) [![API](https://docs.rs/rand_distr/badge.svg)](https://docs.rs/rand_distr) -[![Minimum rustc version](https://img.shields.io/badge/rustc-1.36+-lightgray.svg)](https://github.com/rust-random/rand#rust-version-requirements) -Implements a full suite of random number distributions sampling routines. +Implements a full suite of random number distribution sampling routines. -This crate is a super-set of the [rand::distributions] module, including support -for sampling from Beta, Binomial, Cauchy, ChiSquared, Dirichlet, exponential, -Fisher F, Gamma, Log-normal, Normal, Pareto, Poisson, StudentT, Triangular and -Weibull distributions, as well as sampling points from the unit circle and unit -sphere surface. +This crate is a superset of the [rand::distr] module, including support +for sampling from Beta, Binomial, Cauchy, ChiSquared, Dirichlet, Exponential, +FisherF, Gamma, Geometric, Hypergeometric, InverseGaussian, LogNormal, Normal, +Pareto, PERT, Poisson, StudentT, Triangular and Weibull distributions. Sampling +from the unit ball, unit circle, unit disc and unit sphere surfaces is also +supported. It is worth mentioning the [statrs] crate which provides similar functionality along with various support functions, including PDF and CDF computation. In -contrast, this `rand_distr` crate focusses on sampling from distributions. +contrast, this `rand_distr` crate focuses on sampling from distributions. -Unlike most Rand crates, `rand_distr` does not currently support `no_std`. +## Portability and libm -Links: +The floating point functions from `num_traits` and `libm` are used to support +`no_std` environments and ensure reproducibility. If the floating point +functions from `std` are preferred, which may provide better accuracy and +performance but may produce different random values, the `std_math` feature +can be enabled. (Note that any other crate depending on `num-traits` with the +`std` feature (default-enabled) will have the same effect.) + +## Crate features + +- `std` (enabled by default): `rand_distr` implements the `Error` trait for + its error types. Implies `alloc` and `rand/std`. +- `alloc` (enabled by default): required for some distributions when not using + `std` (in particular, `Dirichlet` and `WeightedAliasIndex`). +- `std_math`: see above on portability and libm +- `serde`: implement (de)seriaialization using `serde` + +## Links - [API documentation (master)](https://rust-random.github.io/rand/rand_distr) - [API documentation (docs.rs)](https://docs.rs/rand_distr) @@ -30,7 +46,7 @@ Links: [statrs]: https://github.com/boxtown/statrs -[rand::distributions]: https://rust-random.github.io/rand/rand/distributions/index.html +[rand::distr]: https://rust-random.github.io/rand/rand/distr/index.html ## License diff --git a/rand_distr/benches/distributions.rs b/rand_distr/benches/distributions.rs deleted file mode 100644 index 31e7e2fc716..00000000000 --- a/rand_distr/benches/distributions.rs +++ /dev/null @@ -1,180 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -#![feature(custom_inner_attributes)] -#![feature(test)] - -// Rustfmt splits macro invocations to shorten lines; in this case longer-lines are more readable -#![rustfmt::skip] - -extern crate test; - -const RAND_BENCH_N: u64 = 1000; - -use std::mem::size_of; -use test::Bencher; - -use rand::prelude::*; -use rand_distr::*; - -// At this time, distributions are optimised for 64-bit platforms. -use rand_pcg::Pcg64Mcg; - -macro_rules! distr_int { - ($fnn:ident, $ty:ty, $distr:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_entropy(); - let distr = $distr; - - b.iter(|| { - let mut accum = 0 as $ty; - for _ in 0..RAND_BENCH_N { - let x: $ty = distr.sample(&mut rng); - accum = accum.wrapping_add(x); - } - accum - }); - b.bytes = size_of::<$ty>() as u64 * RAND_BENCH_N; - } - }; -} - -macro_rules! distr_float { - ($fnn:ident, $ty:ty, $distr:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_entropy(); - let distr = $distr; - - b.iter(|| { - let mut accum = 0.0; - for _ in 0..RAND_BENCH_N { - let x: $ty = distr.sample(&mut rng); - accum += x; - } - accum - }); - b.bytes = size_of::<$ty>() as u64 * RAND_BENCH_N; - } - }; -} - -macro_rules! distr { - ($fnn:ident, $ty:ty, $distr:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_entropy(); - let distr = $distr; - - b.iter(|| { - let mut accum = 0u32; - for _ in 0..RAND_BENCH_N { - let x: $ty = distr.sample(&mut rng); - accum = accum.wrapping_add(x as u32); - } - accum - }); - b.bytes = size_of::<$ty>() as u64 * RAND_BENCH_N; - } - }; -} - -macro_rules! distr_arr { - ($fnn:ident, $ty:ty, $distr:expr) => { - #[bench] - fn $fnn(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_entropy(); - let distr = $distr; - - b.iter(|| { - let mut accum = 0u32; - for _ in 0..RAND_BENCH_N { - let x: $ty = distr.sample(&mut rng); - accum = accum.wrapping_add(x[0] as u32); - } - accum - }); - b.bytes = size_of::<$ty>() as u64 * RAND_BENCH_N; - } - }; -} - - -// distributions -distr_float!(distr_exp, f64, Exp::new(1.23 * 4.56).unwrap()); -distr_float!(distr_exp1_specialized, f64, Exp1); -distr_float!(distr_exp1_general, f64, Exp::new(1.).unwrap()); -distr_float!(distr_normal, f64, Normal::new(-1.23, 4.56).unwrap()); -distr_float!(distr_standardnormal_specialized, f64, StandardNormal); -distr_float!(distr_standardnormal_general, f64, Normal::new(0., 1.).unwrap()); -distr_float!(distr_log_normal, f64, LogNormal::new(-1.23, 4.56).unwrap()); -distr_float!(distr_gamma_large_shape, f64, Gamma::new(10., 1.0).unwrap()); -distr_float!(distr_gamma_small_shape, f64, Gamma::new(0.1, 1.0).unwrap()); -distr_float!(distr_beta_small_param, f64, Beta::new(0.1, 0.1).unwrap()); -distr_float!(distr_beta_large_param_similar, f64, Beta::new(101., 95.).unwrap()); -distr_float!(distr_beta_large_param_different, f64, Beta::new(10., 1000.).unwrap()); -distr_float!(distr_beta_mixed_param, f64, Beta::new(0.5, 100.).unwrap()); -distr_float!(distr_cauchy, f64, Cauchy::new(4.2, 6.9).unwrap()); -distr_float!(distr_triangular, f64, Triangular::new(0., 1., 0.9).unwrap()); -distr_int!(distr_binomial, u64, Binomial::new(20, 0.7).unwrap()); -distr_int!(distr_binomial_small, u64, Binomial::new(1000000, 1e-30).unwrap()); -distr_float!(distr_poisson, f64, Poisson::new(4.0).unwrap()); -distr!(distr_bernoulli, bool, Bernoulli::new(0.18).unwrap()); -distr_arr!(distr_circle, [f64; 2], UnitCircle); -distr_arr!(distr_sphere, [f64; 3], UnitSphere); - -// Weighted -distr_int!(distr_weighted_i8, usize, WeightedIndex::new(&[1i8, 2, 3, 4, 12, 0, 2, 1]).unwrap()); -distr_int!(distr_weighted_u32, usize, WeightedIndex::new(&[1u32, 2, 3, 4, 12, 0, 2, 1]).unwrap()); -distr_int!(distr_weighted_f64, usize, WeightedIndex::new(&[1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap()); -distr_int!(distr_weighted_large_set, usize, WeightedIndex::new((0..10000).rev().chain(1..10001)).unwrap()); - -distr_int!(distr_weighted_alias_method_i8, usize, WeightedAliasIndex::new(vec![1i8, 2, 3, 4, 12, 0, 2, 1]).unwrap()); -distr_int!(distr_weighted_alias_method_u32, usize, WeightedAliasIndex::new(vec![1u32, 2, 3, 4, 12, 0, 2, 1]).unwrap()); -distr_int!(distr_weighted_alias_method_f64, usize, WeightedAliasIndex::new(vec![1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap()); -distr_int!(distr_weighted_alias_method_large_set, usize, WeightedAliasIndex::new((0..10000).rev().chain(1..10001).collect()).unwrap()); - -distr_int!(distr_geometric, u64, Geometric::new(0.5).unwrap()); -distr_int!(distr_standard_geometric, u64, StandardGeometric); - -#[bench] -fn dist_iter(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_entropy(); - let distr = Normal::new(-2.71828, 3.14159).unwrap(); - let mut iter = distr.sample_iter(&mut rng); - - b.iter(|| { - let mut accum = 0.0; - for _ in 0..RAND_BENCH_N { - accum += iter.next().unwrap(); - } - accum - }); - b.bytes = size_of::() as u64 * RAND_BENCH_N; -} - -macro_rules! sample_binomial { - ($name:ident, $n:expr, $p:expr) => { - #[bench] - fn $name(b: &mut Bencher) { - let mut rng = Pcg64Mcg::from_rng(&mut thread_rng()).unwrap(); - let (n, p) = ($n, $p); - b.iter(|| { - let d = Binomial::new(n, p).unwrap(); - rng.sample(d) - }) - } - }; -} - -sample_binomial!(misc_binomial_1, 1, 0.9); -sample_binomial!(misc_binomial_10, 10, 0.9); -sample_binomial!(misc_binomial_100, 100, 0.99); -sample_binomial!(misc_binomial_1000, 1000, 0.01); -sample_binomial!(misc_binomial_1e12, 1000_000_000_000, 0.2); diff --git a/rand_distr/src/beta.rs b/rand_distr/src/beta.rs new file mode 100644 index 00000000000..4dc297cfd50 --- /dev/null +++ b/rand_distr/src/beta.rs @@ -0,0 +1,298 @@ +// Copyright 2018 Developers of the Rand project. +// Copyright 2013 The Rust Project Developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! The Beta distribution. + +use crate::{Distribution, Open01}; +use core::fmt; +use num_traits::Float; +use rand::Rng; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +/// The algorithm used for sampling the Beta distribution. +/// +/// Reference: +/// +/// R. C. H. Cheng (1978). +/// Generating beta variates with nonintegral shape parameters. +/// Communications of the ACM 21, 317-322. +/// https://doi.org/10.1145/359460.359482 +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +enum BetaAlgorithm { + BB(BB), + BC(BC), +} + +/// Algorithm BB for `min(alpha, beta) > 1`. +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +struct BB { + alpha: N, + beta: N, + gamma: N, +} + +/// Algorithm BC for `min(alpha, beta) <= 1`. +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +struct BC { + alpha: N, + beta: N, + kappa1: N, + kappa2: N, +} + +/// The [Beta distribution](https://en.wikipedia.org/wiki/Beta_distribution) `Beta(α, β)`. +/// +/// The Beta distribution is a continuous probability distribution +/// defined on the interval `[0, 1]`. It is the conjugate prior for the +/// parameter `p` of the [`Binomial`][crate::Binomial] distribution. +/// +/// It has two shape parameters `α` (alpha) and `β` (beta) which control +/// the shape of the distribution. Both `a` and `β` must be greater than zero. +/// The distribution is symmetric when `α = β`. +/// +/// # Plot +/// +/// The plot shows the Beta distribution with various combinations +/// of `α` and `β`. +/// +/// ![Beta distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/beta.svg) +/// +/// # Example +/// +/// ``` +/// use rand_distr::{Distribution, Beta}; +/// +/// let beta = Beta::new(2.0, 5.0).unwrap(); +/// let v = beta.sample(&mut rand::rng()); +/// println!("{} is from a Beta(2, 5) distribution", v); +/// ``` +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct Beta +where + F: Float, + Open01: Distribution, +{ + a: F, + b: F, + switched_params: bool, + algorithm: BetaAlgorithm, +} + +/// Error type returned from [`Beta::new`]. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum Error { + /// `alpha <= 0` or `nan`. + AlphaTooSmall, + /// `beta <= 0` or `nan`. + BetaTooSmall, +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + Error::AlphaTooSmall => "alpha is not positive in beta distribution", + Error::BetaTooSmall => "beta is not positive in beta distribution", + }) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for Error {} + +impl Beta +where + F: Float, + Open01: Distribution, +{ + /// Construct an object representing the `Beta(alpha, beta)` + /// distribution. + pub fn new(alpha: F, beta: F) -> Result, Error> { + if !(alpha > F::zero()) { + return Err(Error::AlphaTooSmall); + } + if !(beta > F::zero()) { + return Err(Error::BetaTooSmall); + } + // From now on, we use the notation from the reference, + // i.e. `alpha` and `beta` are renamed to `a0` and `b0`. + let (a0, b0) = (alpha, beta); + let (a, b, switched_params) = if a0 < b0 { + (a0, b0, false) + } else { + (b0, a0, true) + }; + if a > F::one() { + // Algorithm BB + let alpha = a + b; + + let two = F::from(2.).unwrap(); + let beta_numer = alpha - two; + let beta_denom = two * a * b - alpha; + let beta = (beta_numer / beta_denom).sqrt(); + + let gamma = a + F::one() / beta; + + Ok(Beta { + a, + b, + switched_params, + algorithm: BetaAlgorithm::BB(BB { alpha, beta, gamma }), + }) + } else { + // Algorithm BC + // + // Here `a` is the maximum instead of the minimum. + let (a, b, switched_params) = (b, a, !switched_params); + let alpha = a + b; + let beta = F::one() / b; + let delta = F::one() + a - b; + let kappa1 = delta + * (F::from(1. / 18. / 4.).unwrap() + F::from(3. / 18. / 4.).unwrap() * b) + / (a * beta - F::from(14. / 18.).unwrap()); + let kappa2 = F::from(0.25).unwrap() + + (F::from(0.5).unwrap() + F::from(0.25).unwrap() / delta) * b; + + Ok(Beta { + a, + b, + switched_params, + algorithm: BetaAlgorithm::BC(BC { + alpha, + beta, + kappa1, + kappa2, + }), + }) + } + } +} + +impl Distribution for Beta +where + F: Float, + Open01: Distribution, +{ + fn sample(&self, rng: &mut R) -> F { + let mut w; + match self.algorithm { + BetaAlgorithm::BB(algo) => { + loop { + // 1. + let u1 = rng.sample(Open01); + let u2 = rng.sample(Open01); + let v = algo.beta * (u1 / (F::one() - u1)).ln(); + w = self.a * v.exp(); + let z = u1 * u1 * u2; + let r = algo.gamma * v - F::from(4.).unwrap().ln(); + let s = self.a + r - w; + // 2. + if s + F::one() + F::from(5.).unwrap().ln() >= F::from(5.).unwrap() * z { + break; + } + // 3. + let t = z.ln(); + if s >= t { + break; + } + // 4. + if !(r + algo.alpha * (algo.alpha / (self.b + w)).ln() < t) { + break; + } + } + } + BetaAlgorithm::BC(algo) => { + loop { + let z; + // 1. + let u1 = rng.sample(Open01); + let u2 = rng.sample(Open01); + if u1 < F::from(0.5).unwrap() { + // 2. + let y = u1 * u2; + z = u1 * y; + if F::from(0.25).unwrap() * u2 + z - y >= algo.kappa1 { + continue; + } + } else { + // 3. + z = u1 * u1 * u2; + if z <= F::from(0.25).unwrap() { + let v = algo.beta * (u1 / (F::one() - u1)).ln(); + w = self.a * v.exp(); + break; + } + // 4. + if z >= algo.kappa2 { + continue; + } + } + // 5. + let v = algo.beta * (u1 / (F::one() - u1)).ln(); + w = self.a * v.exp(); + if !(algo.alpha * ((algo.alpha / (self.b + w)).ln() + v) + - F::from(4.).unwrap().ln() + < z.ln()) + { + break; + }; + } + } + }; + // 5. for BB, 6. for BC + if !self.switched_params { + if w == F::infinity() { + // Assuming `b` is finite, for large `w`: + return F::one(); + } + w / (self.b + w) + } else { + self.b / (self.b + w) + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_beta() { + let beta = Beta::new(1.0, 2.0).unwrap(); + let mut rng = crate::test::rng(201); + for _ in 0..1000 { + beta.sample(&mut rng); + } + } + + #[test] + #[should_panic] + fn test_beta_invalid_dof() { + Beta::new(0., 0.).unwrap(); + } + + #[test] + fn test_beta_small_param() { + let beta = Beta::::new(1e-3, 1e-3).unwrap(); + let mut rng = crate::test::rng(206); + for i in 0..1000 { + assert!(!beta.sample(&mut rng).is_nan(), "failed at i={}", i); + } + } + + #[test] + fn beta_distributions_can_be_compared() { + assert_eq!(Beta::new(1.0, 2.0), Beta::new(1.0, 2.0)); + } +} diff --git a/rand_distr/src/binomial.rs b/rand_distr/src/binomial.rs index 476ae64f559..d6dfceae777 100644 --- a/rand_distr/src/binomial.rs +++ b/rand_distr/src/binomial.rs @@ -7,37 +7,79 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! The binomial distribution. +//! The binomial distribution `Binomial(n, p)`. use crate::{Distribution, Uniform}; -use rand::Rng; -use core::fmt; use core::cmp::Ordering; +use core::fmt; +#[allow(unused_imports)] +use num_traits::Float; +use rand::Rng; -/// The binomial distribution `Binomial(n, p)`. +/// The [binomial distribution](https://en.wikipedia.org/wiki/Binomial_distribution) `Binomial(n, p)`. +/// +/// The binomial distribution is a discrete probability distribution +/// which describes the probability of seeing `k` successes in `n` +/// independent trials, each of which has success probability `p`. +/// +/// # Density function /// -/// This distribution has density function: /// `f(k) = n!/(k! (n-k)!) p^k (1-p)^(n-k)` for `k >= 0`. /// +/// # Plot +/// +/// The following plot of the binomial distribution illustrates the +/// probability of `k` successes out of `n = 10` trials with `p = 0.2` +/// and `p = 0.6` for `0 <= k <= n`. +/// +/// ![Binomial distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/binomial.svg) +/// /// # Example /// /// ``` /// use rand_distr::{Binomial, Distribution}; /// /// let bin = Binomial::new(20, 0.3).unwrap(); -/// let v = bin.sample(&mut rand::thread_rng()); +/// let v = bin.sample(&mut rand::rng()); /// println!("{} is from a binomial distribution", v); /// ``` -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct Binomial { - /// Number of trials. + method: Method, +} + +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +enum Method { + Binv(Binv, bool), + Btpe(Btpe, bool), + Poisson(crate::poisson::KnuthMethod), + Constant(u64), +} + +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +struct Binv { + r: f64, + s: f64, + a: f64, + n: u64, +} + +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +struct Btpe { n: u64, - /// Probability of success. p: f64, + m: i64, + p1: f64, } -/// Error type returned from `Binomial::new`. +/// Error type returned from [`Binomial::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] +// Marked non_exhaustive to allow a new error code in the solution to #1378. +#[non_exhaustive] pub enum Error { /// `p < 0` or `nan`. ProbabilityTooSmall, @@ -67,34 +109,22 @@ impl Binomial { if !(p <= 1.0) { return Err(Error::ProbabilityTooLarge); } - Ok(Binomial { n, p }) - } -} - -/// Convert a `f64` to an `i64`, panicing on overflow. -// In the future (Rust 1.34), this might be replaced with `TryFrom`. -fn f64_to_i64(x: f64) -> i64 { - assert!(x < (core::i64::MAX as f64)); - x as i64 -} -impl Distribution for Binomial { - #[allow(clippy::many_single_char_names)] // Same names as in the reference. - fn sample(&self, rng: &mut R) -> u64 { - // Handle these values directly. - if self.p == 0.0 { - return 0; - } else if self.p == 1.0 { - return self.n; + if p == 0.0 { + return Ok(Binomial { + method: Method::Constant(0), + }); } - // The binomial distribution is symmetrical with respect to p -> 1-p, - // k -> n-k switch p so that it is less than 0.5 - this allows for lower - // expected values we will just invert the result at the end - let p = if self.p <= 0.5 { self.p } else { 1.0 - self.p }; + if p == 1.0 { + return Ok(Binomial { + method: Method::Constant(n), + }); + } - let result; - let q = 1. - p; + // The binomial distribution is symmetrical with respect to p -> 1-p + let flipped = p > 0.5; + let p = if flipped { 1.0 - p } else { p }; // For small n * min(p, 1 - p), the BINV algorithm based on the inverse // transformation of the binomial distribution is efficient. Otherwise, @@ -104,195 +134,257 @@ impl Distribution for Binomial { // random variate generation. Commun. ACM 31, 2 (February 1988), // 216-222. http://dx.doi.org/10.1145/42372.42381 - // Threshold for prefering the BINV algorithm. The paper suggests 10, + // Threshold for preferring the BINV algorithm. The paper suggests 10, // Ranlib uses 30, and GSL uses 14. const BINV_THRESHOLD: f64 = 10.; - if (self.n as f64) * p < BINV_THRESHOLD && self.n <= (core::i32::MAX as u64) { - // Use the BINV algorithm. - let s = p / q; - let a = ((self.n + 1) as f64) * s; - let mut r = q.powi(self.n as i32); - let mut u: f64 = rng.gen(); - let mut x = 0; - while u > r as f64 { - u -= r; - x += 1; - r *= a / (x as f64) - s; + let np = n as f64 * p; + let method = if np < BINV_THRESHOLD { + let q = 1.0 - p; + if q == 1.0 { + // p is so small that this is extremely close to a Poisson distribution. + // The flipped case cannot occur here. + Method::Poisson(crate::poisson::KnuthMethod::new(np)) + } else { + let s = p / q; + Method::Binv( + Binv { + r: q.powf(n as f64), + s, + a: (n as f64 + 1.0) * s, + n, + }, + flipped, + ) } - result = x; } else { - // Use the BTPE algorithm. - - // Threshold for using the squeeze algorithm. This can be freely - // chosen based on performance. Ranlib and GSL use 20. - const SQUEEZE_THRESHOLD: i64 = 20; - - // Step 0: Calculate constants as functions of `n` and `p`. - let n = self.n as f64; - let np = n * p; + let q = 1.0 - p; let npq = np * q; + let p1 = (2.195 * npq.sqrt() - 4.6 * q).floor() + 0.5; let f_m = np + p; let m = f64_to_i64(f_m); - // radius of triangle region, since height=1 also area of region - let p1 = (2.195 * npq.sqrt() - 4.6 * q).floor() + 0.5; - // tip of triangle - let x_m = (m as f64) + 0.5; - // left edge of triangle - let x_l = x_m - p1; - // right edge of triangle - let x_r = x_m + p1; - let c = 0.134 + 20.5 / (15.3 + (m as f64)); - // p1 + area of parallelogram region - let p2 = p1 * (1. + 2. * c); - - fn lambda(a: f64) -> f64 { - a * (1. + 0.5 * a) + Method::Btpe(Btpe { n, p, m, p1 }, flipped) + }; + Ok(Binomial { method }) + } +} + +/// Convert a `f64` to an `i64`, panicking on overflow. +fn f64_to_i64(x: f64) -> i64 { + assert!(x < (i64::MAX as f64)); + x as i64 +} + +fn binv(binv: Binv, flipped: bool, rng: &mut R) -> u64 { + // Same value as in GSL. + // It is possible for BINV to get stuck, so we break if x > BINV_MAX_X and try again. + // It would be safer to set BINV_MAX_X to self.n, but it is extremely unlikely to be relevant. + // When n*p < 10, so is n*p*q which is the variance, so a result > 110 would be 100 / sqrt(10) = 31 standard deviations away. + const BINV_MAX_X: u64 = 110; + + let sample = 'outer: loop { + let mut r = binv.r; + let mut u: f64 = rng.random(); + let mut x = 0; + + while u > r { + u -= r; + x += 1; + if x > BINV_MAX_X { + continue 'outer; } + r *= binv.a / (x as f64) - binv.s; + } + break x; + }; - let lambda_l = lambda((f_m - x_l) / (f_m - x_l * p)); - let lambda_r = lambda((x_r - f_m) / (x_r * q)); - // p1 + area of left tail - let p3 = p2 + c / lambda_l; - // p1 + area of right tail - let p4 = p3 + c / lambda_r; - - // return value - let mut y: i64; - - let gen_u = Uniform::new(0., p4); - let gen_v = Uniform::new(0., 1.); - - loop { - // Step 1: Generate `u` for selecting the region. If region 1 is - // selected, generate a triangularly distributed variate. - let u = gen_u.sample(rng); - let mut v = gen_v.sample(rng); - if !(u > p1) { - y = f64_to_i64(x_m - p1 * v + u); - break; - } + if flipped { + binv.n - sample + } else { + sample + } +} - if !(u > p2) { - // Step 2: Region 2, parallelograms. Check if region 2 is - // used. If so, generate `y`. - let x = x_l + (u - p1) / c; - v = v * c + 1.0 - (x - x_m).abs() / p1; - if v > 1. { - continue; - } else { - y = f64_to_i64(x); - } - } else if !(u > p3) { - // Step 3: Region 3, left exponential tail. - y = f64_to_i64(x_l + v.ln() / lambda_l); - if y < 0 { - continue; - } else { - v *= (u - p2) * lambda_l; - } - } else { - // Step 4: Region 4, right exponential tail. - y = f64_to_i64(x_r - v.ln() / lambda_r); - if y > 0 && (y as u64) > self.n { - continue; - } else { - v *= (u - p3) * lambda_r; - } - } +#[allow(clippy::many_single_char_names)] // Same names as in the reference. +fn btpe(btpe: Btpe, flipped: bool, rng: &mut R) -> u64 { + // Threshold for using the squeeze algorithm. This can be freely + // chosen based on performance. Ranlib and GSL use 20. + const SQUEEZE_THRESHOLD: i64 = 20; + + // Step 0: Calculate constants as functions of `n` and `p`. + let n = btpe.n as f64; + let np = n * btpe.p; + let q = 1. - btpe.p; + let npq = np * q; + let f_m = np + btpe.p; + let m = btpe.m; + // radius of triangle region, since height=1 also area of region + let p1 = btpe.p1; + // tip of triangle + let x_m = (m as f64) + 0.5; + // left edge of triangle + let x_l = x_m - p1; + // right edge of triangle + let x_r = x_m + p1; + let c = 0.134 + 20.5 / (15.3 + (m as f64)); + // p1 + area of parallelogram region + let p2 = p1 * (1. + 2. * c); + + fn lambda(a: f64) -> f64 { + a * (1. + 0.5 * a) + } - // Step 5: Acceptance/rejection comparison. - - // Step 5.0: Test for appropriate method of evaluating f(y). - let k = (y - m).abs(); - if !(k > SQUEEZE_THRESHOLD && (k as f64) < 0.5 * npq - 1.) { - // Step 5.1: Evaluate f(y) via the recursive relationship. Start the - // search from the mode. - let s = p / q; - let a = s * (n + 1.); - let mut f = 1.0; - match m.cmp(&y) { - Ordering::Less => { - let mut i = m; - loop { - i += 1; - f *= a / (i as f64) - s; - if i == y { - break; - } - } - }, - Ordering::Greater => { - let mut i = y; - loop { - i += 1; - f /= a / (i as f64) - s; - if i == m { - break; - } - } - }, - Ordering::Equal => {}, - } - if v > f { - continue; - } else { - break; - } - } + let lambda_l = lambda((f_m - x_l) / (f_m - x_l * btpe.p)); + let lambda_r = lambda((x_r - f_m) / (x_r * q)); - // Step 5.2: Squeezing. Check the value of ln(v) againts upper and - // lower bound of ln(f(y)). - let k = k as f64; - let rho = (k / npq) * ((k * (k / 3. + 0.625) + 1. / 6.) / npq + 0.5); - let t = -0.5 * k * k / npq; - let alpha = v.ln(); - if alpha < t - rho { - break; - } - if alpha > t + rho { - continue; - } + let p3 = p2 + c / lambda_l; - // Step 5.3: Final acceptance/rejection test. - let x1 = (y + 1) as f64; - let f1 = (m + 1) as f64; - let z = (f64_to_i64(n) + 1 - m) as f64; - let w = (f64_to_i64(n) - y + 1) as f64; + let p4 = p3 + c / lambda_r; - fn stirling(a: f64) -> f64 { - let a2 = a * a; - (13860. - (462. - (132. - (99. - 140. / a2) / a2) / a2) / a2) / a / 166320. - } + // return value + let mut y: i64; - if alpha - > x_m * (f1 / x1).ln() - + (n - (m as f64) + 0.5) * (z / w).ln() - + ((y - m) as f64) * (w * p / (x1 * q)).ln() - // We use the signs from the GSL implementation, which are - // different than the ones in the reference. According to - // the GSL authors, the new signs were verified to be - // correct by one of the original designers of the - // algorithm. - + stirling(f1) - + stirling(z) - - stirling(x1) - - stirling(w) - { - continue; - } + let gen_u = Uniform::new(0., p4).unwrap(); + let gen_v = Uniform::new(0., 1.).unwrap(); + + loop { + // Step 1: Generate `u` for selecting the region. If region 1 is + // selected, generate a triangularly distributed variate. + let u = gen_u.sample(rng); + let mut v = gen_v.sample(rng); + if !(u > p1) { + y = f64_to_i64(x_m - p1 * v + u); + break; + } + + if !(u > p2) { + // Step 2: Region 2, parallelograms. Check if region 2 is + // used. If so, generate `y`. + let x = x_l + (u - p1) / c; + v = v * c + 1.0 - (x - x_m).abs() / p1; + if v > 1. { + continue; + } else { + y = f64_to_i64(x); + } + } else if !(u > p3) { + // Step 3: Region 3, left exponential tail. + y = f64_to_i64(x_l + v.ln() / lambda_l); + if y < 0 { + continue; + } else { + v *= (u - p2) * lambda_l; + } + } else { + // Step 4: Region 4, right exponential tail. + y = f64_to_i64(x_r - v.ln() / lambda_r); + if y > 0 && (y as u64) > btpe.n { + continue; + } else { + v *= (u - p3) * lambda_r; + } + } + // Step 5: Acceptance/rejection comparison. + + // Step 5.0: Test for appropriate method of evaluating f(y). + let k = (y - m).abs(); + if !(k > SQUEEZE_THRESHOLD && (k as f64) < 0.5 * npq - 1.) { + // Step 5.1: Evaluate f(y) via the recursive relationship. Start the + // search from the mode. + let s = btpe.p / q; + let a = s * (n + 1.); + let mut f = 1.0; + match m.cmp(&y) { + Ordering::Less => { + let mut i = m; + loop { + i += 1; + f *= a / (i as f64) - s; + if i == y { + break; + } + } + } + Ordering::Greater => { + let mut i = y; + loop { + i += 1; + f /= a / (i as f64) - s; + if i == m { + break; + } + } + } + Ordering::Equal => {} + } + if v > f { + continue; + } else { break; } - assert!(y >= 0); - result = y as u64; } - // Invert the result for p < 0.5. - if p != self.p { - self.n - result - } else { - result + // Step 5.2: Squeezing. Check the value of ln(v) against upper and + // lower bound of ln(f(y)). + let k = k as f64; + let rho = (k / npq) * ((k * (k / 3. + 0.625) + 1. / 6.) / npq + 0.5); + let t = -0.5 * k * k / npq; + let alpha = v.ln(); + if alpha < t - rho { + break; + } + if alpha > t + rho { + continue; + } + + // Step 5.3: Final acceptance/rejection test. + let x1 = (y + 1) as f64; + let f1 = (m + 1) as f64; + let z = (f64_to_i64(n) + 1 - m) as f64; + let w = (f64_to_i64(n) - y + 1) as f64; + + fn stirling(a: f64) -> f64 { + let a2 = a * a; + (13860. - (462. - (132. - (99. - 140. / a2) / a2) / a2) / a2) / a / 166320. + } + + if alpha + > x_m * (f1 / x1).ln() + + (n - (m as f64) + 0.5) * (z / w).ln() + + ((y - m) as f64) * (w * btpe.p / (x1 * q)).ln() + // We use the signs from the GSL implementation, which are + // different than the ones in the reference. According to + // the GSL authors, the new signs were verified to be + // correct by one of the original designers of the + // algorithm. + + stirling(f1) + + stirling(z) + - stirling(x1) + - stirling(w) + { + continue; + } + + break; + } + assert!(y >= 0); + let y = y as u64; + + if flipped { + btpe.n - y + } else { + y + } +} + +impl Distribution for Binomial { + fn sample(&self, rng: &mut R) -> u64 { + match self.method { + Method::Binv(binv_para, flipped) => binv(binv_para, flipped, rng), + Method::Btpe(btpe_para, flipped) => btpe(btpe_para, flipped, rng), + Method::Poisson(poisson) => poisson.sample(rng) as u64, + Method::Constant(c) => c, } } } @@ -315,7 +407,7 @@ mod test { } let mean = results.iter().sum::() / results.len() as f64; - assert!((mean as f64 - expected_mean).abs() < expected_mean / 50.0); + assert!((mean - expected_mean).abs() < expected_mean / 50.0); let variance = results.iter().map(|x| (x - mean) * (x - mean)).sum::() / results.len() as f64; @@ -330,6 +422,8 @@ mod test { test_binomial_mean_and_variance(40, 0.5, &mut rng); test_binomial_mean_and_variance(20, 0.7, &mut rng); test_binomial_mean_and_variance(20, 0.5, &mut rng); + test_binomial_mean_and_variance(1 << 61, 1e-17, &mut rng); + test_binomial_mean_and_variance(u64::MAX, 1e-19, &mut rng); } #[test] @@ -344,4 +438,20 @@ mod test { fn test_binomial_invalid_lambda_neg() { Binomial::new(20, -10.0).unwrap(); } + + #[test] + fn binomial_distributions_can_be_compared() { + assert_eq!(Binomial::new(1, 1.0), Binomial::new(1, 1.0)); + } + + #[test] + fn binomial_avoid_infinite_loop() { + let dist = Binomial::new(16000000, 3.1444753148558566e-10).unwrap(); + let mut sum: u64 = 0; + let mut rng = crate::test::rng(742); + for _ in 0..100_000 { + sum = sum.wrapping_add(dist.sample(&mut rng)); + } + assert_ne!(sum, 0); + } } diff --git a/rand_distr/src/cauchy.rs b/rand_distr/src/cauchy.rs index 952b7f29dc9..8f0faad3863 100644 --- a/rand_distr/src/cauchy.rs +++ b/rand_distr/src/cauchy.rs @@ -7,20 +7,37 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! The Cauchy distribution. +//! The Cauchy distribution `Cauchy(x₀, γ)`. +use crate::{Distribution, StandardUniform}; +use core::fmt; use num_traits::{Float, FloatConst}; -use crate::{Distribution, Standard}; use rand::Rng; -use core::fmt; -/// The Cauchy distribution `Cauchy(median, scale)`. +/// The [Cauchy distribution](https://en.wikipedia.org/wiki/Cauchy_distribution) `Cauchy(x₀, γ)`. /// -/// This distribution has a density function: -/// `f(x) = 1 / (pi * scale * (1 + ((x - median) / scale)^2))` +/// The Cauchy distribution is a continuous probability distribution with +/// parameters `x₀` (median) and `γ` (scale). +/// It describes the distribution of the ratio of two independent +/// normally distributed random variables with means `x₀` and scales `γ`. +/// In other words, if `X` and `Y` are independent normally distributed +/// random variables with means `x₀` and scales `γ`, respectively, then +/// `X / Y` is `Cauchy(x₀, γ)` distributed. /// -/// Note that at least for `f32`, results are not fully portable due to minor -/// differences in the target system's *tan* implementation, `tanf`. +/// # Density function +/// +/// `f(x) = 1 / (π * γ * (1 + ((x - x₀) / γ)²))` +/// +/// # Plot +/// +/// The plot illustrates the Cauchy distribution with various values of `x₀` and `γ`. +/// Note how the median parameter `x₀` shifts the distribution along the x-axis, +/// and how the scale `γ` changes the density around the median. +/// +/// The standard Cauchy distribution is the special case with `x₀ = 0` and `γ = 1`, +/// which corresponds to the ratio of two [`StandardNormal`](crate::StandardNormal) distributions. +/// +/// ![Cauchy distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/cauchy.svg) /// /// # Example /// @@ -28,18 +45,26 @@ use core::fmt; /// use rand_distr::{Cauchy, Distribution}; /// /// let cau = Cauchy::new(2.0, 5.0).unwrap(); -/// let v = cau.sample(&mut rand::thread_rng()); +/// let v = cau.sample(&mut rand::rng()); /// println!("{} is from a Cauchy(2, 5) distribution", v); /// ``` -#[derive(Clone, Copy, Debug)] +/// +/// # Notes +/// +/// Note that at least for `f32`, results are not fully portable due to minor +/// differences in the target system's *tan* implementation, `tanf`. +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct Cauchy -where F: Float + FloatConst, Standard: Distribution +where + F: Float + FloatConst, + StandardUniform: Distribution, { median: F, scale: F, } -/// Error type returned from `Cauchy::new`. +/// Error type returned from [`Cauchy::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Error { /// `scale <= 0` or `nan`. @@ -58,7 +83,9 @@ impl fmt::Display for Error { impl std::error::Error for Error {} impl Cauchy -where F: Float + FloatConst, Standard: Distribution +where + F: Float + FloatConst, + StandardUniform: Distribution, { /// Construct a new `Cauchy` with the given shape parameters /// `median` the peak location and `scale` the scale factor. @@ -71,11 +98,13 @@ where F: Float + FloatConst, Standard: Distribution } impl Distribution for Cauchy -where F: Float + FloatConst, Standard: Distribution +where + F: Float + FloatConst, + StandardUniform: Distribution, { fn sample(&self, rng: &mut R) -> F { // sample from [0, 1) - let x = Standard.sample(rng); + let x = StandardUniform.sample(rng); // get standard cauchy random number // note that π/2 is not exactly representable, even if x=0.5 the result is finite let comp_dev = (F::PI() * x).tan(); @@ -88,8 +117,8 @@ where F: Float + FloatConst, Standard: Distribution mod test { use super::*; - fn median(mut numbers: &mut [f64]) -> f64 { - sort(&mut numbers); + fn median(numbers: &mut [f64]) -> f64 { + sort(numbers); let mid = numbers.len() / 2; numbers[mid] } @@ -106,9 +135,9 @@ mod test { let mut rng = crate::test::rng(123); let mut numbers: [f64; 1000] = [0.0; 1000]; let mut sum = 0.0; - for i in 0..1000 { - numbers[i] = cauchy.sample(&mut rng); - sum += numbers[i]; + for number in &mut numbers[..] { + *number = cauchy.sample(&mut rng); + sum += *number; } let median = median(&mut numbers); #[cfg(feature = "std")] @@ -135,23 +164,28 @@ mod test { #[test] fn value_stability() { - fn gen_samples(m: F, s: F, buf: &mut [F]) - where Standard: Distribution { + fn gen_samples(m: F, s: F, buf: &mut [F]) + where + StandardUniform: Distribution, + { let distr = Cauchy::new(m, s).unwrap(); let mut rng = crate::test::rng(353); for x in buf { - *x = rng.sample(&distr); + *x = rng.sample(distr); } } let mut buf = [0.0; 4]; gen_samples(100f64, 10.0, &mut buf); - assert_eq!(&buf, &[ - 77.93369152808678, - 90.1606912098641, - 125.31516221323625, - 86.10217834773925 - ]); + assert_eq!( + &buf, + &[ + 77.93369152808678, + 90.1606912098641, + 125.31516221323625, + 86.10217834773925 + ] + ); // Unfortunately this test is not fully portable due to reliance on the // system's implementation of tanf (see doc on Cauchy struct). @@ -162,4 +196,9 @@ mod test { assert_almost_eq!(*a, *b, 1e-5); } } + + #[test] + fn cauchy_distributions_can_be_compared() { + assert_eq!(Cauchy::new(1.0, 2.0), Cauchy::new(1.0, 2.0)); + } } diff --git a/rand_distr/src/chi_squared.rs b/rand_distr/src/chi_squared.rs new file mode 100644 index 00000000000..409a78bb311 --- /dev/null +++ b/rand_distr/src/chi_squared.rs @@ -0,0 +1,179 @@ +// Copyright 2018 Developers of the Rand project. +// Copyright 2013 The Rust Project Developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! The Chi-squared distribution. + +use self::ChiSquaredRepr::*; + +use crate::{Distribution, Exp1, Gamma, Open01, StandardNormal}; +use core::fmt; +use num_traits::Float; +use rand::Rng; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +/// The [chi-squared distribution](https://en.wikipedia.org/wiki/Chi-squared_distribution) `χ²(k)`. +/// +/// The chi-squared distribution is a continuous probability +/// distribution with parameter `k > 0` degrees of freedom. +/// +/// For `k > 0` integral, this distribution is the sum of the squares +/// of `k` independent standard normal random variables. For other +/// `k`, this uses the equivalent characterisation +/// `χ²(k) = Gamma(k/2, 2)`. +/// +/// # Plot +/// +/// The plot shows the chi-squared distribution with various degrees +/// of freedom. +/// +/// ![Chi-squared distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/chi_squared.svg) +/// +/// # Example +/// +/// ``` +/// use rand_distr::{ChiSquared, Distribution}; +/// +/// let chi = ChiSquared::new(11.0).unwrap(); +/// let v = chi.sample(&mut rand::rng()); +/// println!("{} is from a χ²(11) distribution", v) +/// ``` +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct ChiSquared +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + repr: ChiSquaredRepr, +} + +/// Error type returned from [`ChiSquared::new`] and [`StudentT::new`](crate::StudentT::new). +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum Error { + /// `0.5 * k <= 0` or `nan`. + DoFTooSmall, +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + Error::DoFTooSmall => { + "degrees-of-freedom k is not positive in chi-squared distribution" + } + }) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for Error {} + +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +enum ChiSquaredRepr +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + // k == 1, Gamma(alpha, ..) is particularly slow for alpha < 1, + // e.g. when alpha = 1/2 as it would be for this case, so special- + // casing and using the definition of N(0,1)^2 is faster. + DoFExactlyOne, + DoFAnythingElse(Gamma), +} + +impl ChiSquared +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + /// Create a new chi-squared distribution with degrees-of-freedom + /// `k`. + pub fn new(k: F) -> Result, Error> { + let repr = if k == F::one() { + DoFExactlyOne + } else { + if !(F::from(0.5).unwrap() * k > F::zero()) { + return Err(Error::DoFTooSmall); + } + DoFAnythingElse(Gamma::new(F::from(0.5).unwrap() * k, F::from(2.0).unwrap()).unwrap()) + }; + Ok(ChiSquared { repr }) + } +} +impl Distribution for ChiSquared +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + fn sample(&self, rng: &mut R) -> F { + match self.repr { + DoFExactlyOne => { + // k == 1 => N(0,1)^2 + let norm: F = rng.sample(StandardNormal); + norm * norm + } + DoFAnythingElse(ref g) => g.sample(rng), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_chi_squared_one() { + let chi = ChiSquared::new(1.0).unwrap(); + let mut rng = crate::test::rng(201); + for _ in 0..1000 { + chi.sample(&mut rng); + } + } + #[test] + fn test_chi_squared_small() { + let chi = ChiSquared::new(0.5).unwrap(); + let mut rng = crate::test::rng(202); + for _ in 0..1000 { + chi.sample(&mut rng); + } + } + #[test] + fn test_chi_squared_large() { + let chi = ChiSquared::new(30.0).unwrap(); + let mut rng = crate::test::rng(203); + for _ in 0..1000 { + chi.sample(&mut rng); + } + } + #[test] + #[should_panic] + fn test_chi_squared_invalid_dof() { + ChiSquared::new(-1.0).unwrap(); + } + + #[test] + fn gamma_distributions_can_be_compared() { + assert_eq!(Gamma::new(1.0, 2.0), Gamma::new(1.0, 2.0)); + } + + #[test] + fn chi_squared_distributions_can_be_compared() { + assert_eq!(ChiSquared::new(1.0), ChiSquared::new(1.0)); + } +} diff --git a/rand_distr/src/dirichlet.rs b/rand_distr/src/dirichlet.rs index 391e428d351..ac17fa2e298 100644 --- a/rand_distr/src/dirichlet.rs +++ b/rand_distr/src/dirichlet.rs @@ -7,19 +7,202 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! The dirichlet distribution. +//! The dirichlet distribution `Dirichlet(α₁, α₂, ..., αₙ)`. + #![cfg(feature = "alloc")] -use num_traits::Float; -use crate::{Distribution, Exp1, Gamma, Open01, StandardNormal}; -use rand::Rng; +use crate::{Beta, Distribution, Exp1, Gamma, Open01, StandardNormal}; use core::fmt; +use num_traits::{Float, NumCast}; +use rand::Rng; +#[cfg(feature = "serde")] +use serde_with::serde_as; + use alloc::{boxed::Box, vec, vec::Vec}; -/// The Dirichlet distribution `Dirichlet(alpha)`. +#[derive(Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serde", serde_as)] +struct DirichletFromGamma +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + samplers: [Gamma; N], +} + +/// Error type returned from [`DirchletFromGamma::new`]. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum DirichletFromGammaError { + /// Gamma::new(a, 1) failed. + GammmaNewFailed, + + /// gamma_dists.try_into() failed (in theory, this should not happen). + GammaArrayCreationFailed, +} + +impl DirichletFromGamma +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + /// Construct a new `DirichletFromGamma` with the given parameters `alpha`. + /// + /// This function is part of a private implementation detail. + /// It assumes that the input is correct, so no validation of alpha is done. + #[inline] + fn new(alpha: [F; N]) -> Result, DirichletFromGammaError> { + let mut gamma_dists = Vec::new(); + for a in alpha { + let dist = + Gamma::new(a, F::one()).map_err(|_| DirichletFromGammaError::GammmaNewFailed)?; + gamma_dists.push(dist); + } + Ok(DirichletFromGamma { + samplers: gamma_dists + .try_into() + .map_err(|_| DirichletFromGammaError::GammaArrayCreationFailed)?, + }) + } +} + +impl Distribution<[F; N]> for DirichletFromGamma +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + fn sample(&self, rng: &mut R) -> [F; N] { + let mut samples = [F::zero(); N]; + let mut sum = F::zero(); + + for (s, g) in samples.iter_mut().zip(self.samplers.iter()) { + *s = g.sample(rng); + sum = sum + *s; + } + let invacc = F::one() / sum; + for s in samples.iter_mut() { + *s = *s * invacc; + } + samples + } +} + +#[derive(Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +struct DirichletFromBeta +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + samplers: Box<[Beta]>, +} + +/// Error type returned from [`DirchletFromBeta::new`]. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum DirichletFromBetaError { + /// Beta::new(a, b) failed. + BetaNewFailed, +} + +impl DirichletFromBeta +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + /// Construct a new `DirichletFromBeta` with the given parameters `alpha`. + /// + /// This function is part of a private implementation detail. + /// It assumes that the input is correct, so no validation of alpha is done. + #[inline] + fn new(alpha: [F; N]) -> Result, DirichletFromBetaError> { + // `alpha_rev_csum` is the reverse of the cumulative sum of the + // reverse of `alpha[1..]`. E.g. if `alpha = [a0, a1, a2, a3]`, then + // `alpha_rev_csum` is `[a1 + a2 + a3, a2 + a3, a3]`. + // Note that instances of DirichletFromBeta will always have N >= 2, + // so the subtractions of 1, 2 and 3 from N in the following are safe. + let mut alpha_rev_csum = vec![alpha[N - 1]; N - 1]; + for k in 0..(N - 2) { + alpha_rev_csum[N - 3 - k] = alpha_rev_csum[N - 2 - k] + alpha[N - 2 - k]; + } + + // Zip `alpha[..(N-1)]` and `alpha_rev_csum`; for the example + // `alpha = [a0, a1, a2, a3]`, the zip result holds the tuples + // `[(a0, a1+a2+a3), (a1, a2+a3), (a2, a3)]`. + // Then pass each tuple to `Beta::new()` to create the `Beta` + // instances. + let mut beta_dists = Vec::new(); + for (&a, &b) in alpha[..(N - 1)].iter().zip(alpha_rev_csum.iter()) { + let dist = Beta::new(a, b).map_err(|_| DirichletFromBetaError::BetaNewFailed)?; + beta_dists.push(dist); + } + Ok(DirichletFromBeta { + samplers: beta_dists.into_boxed_slice(), + }) + } +} + +impl Distribution<[F; N]> for DirichletFromBeta +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + fn sample(&self, rng: &mut R) -> [F; N] { + let mut samples = [F::zero(); N]; + let mut acc = F::one(); + + for (s, beta) in samples.iter_mut().zip(self.samplers.iter()) { + let beta_sample = beta.sample(rng); + *s = acc * beta_sample; + acc = acc * (F::one() - beta_sample); + } + samples[N - 1] = acc; + samples + } +} + +#[derive(Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serde", serde_as)] +enum DirichletRepr +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + /// Dirichlet distribution that generates samples using the Gamma distribution. + FromGamma(DirichletFromGamma), + + /// Dirichlet distribution that generates samples using the Beta distribution. + FromBeta(DirichletFromBeta), +} + +/// The [Dirichlet distribution](https://en.wikipedia.org/wiki/Dirichlet_distribution) `Dirichlet(α₁, α₂, ..., αₖ)`. /// /// The Dirichlet distribution is a family of continuous multivariate -/// probability distributions parameterized by a vector alpha of positive reals. -/// It is a multivariate generalization of the beta distribution. +/// probability distributions parameterized by a vector of positive +/// real numbers `α₁, α₂, ..., αₖ`, where `k` is the number of dimensions +/// of the distribution. The distribution is supported on the `k-1`-dimensional +/// simplex, which is the set of points `x = [x₁, x₂, ..., xₖ]` such that +/// `0 ≤ xᵢ ≤ 1` and `∑ xᵢ = 1`. +/// It is a multivariate generalization of the [`Beta`](crate::Beta) distribution. +/// The distribution is symmetric when all `αᵢ` are equal. +/// +/// # Plot +/// +/// The following plot illustrates the 2-dimensional simplices for various +/// 3-dimensional Dirichlet distributions. +/// +/// ![Dirichlet distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/dirichlet.png) /// /// # Example /// @@ -27,31 +210,38 @@ use alloc::{boxed::Box, vec, vec::Vec}; /// use rand::prelude::*; /// use rand_distr::Dirichlet; /// -/// let dirichlet = Dirichlet::new(&[1.0, 2.0, 3.0]).unwrap(); -/// let samples = dirichlet.sample(&mut rand::thread_rng()); +/// let dirichlet = Dirichlet::new([1.0, 2.0, 3.0]).unwrap(); +/// let samples = dirichlet.sample(&mut rand::rng()); /// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples); /// ``` -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] -#[derive(Clone, Debug)] -pub struct Dirichlet +#[cfg_attr(feature = "serde", serde_as)] +#[derive(Clone, Debug, PartialEq)] +pub struct Dirichlet where F: Float, StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution, { - /// Concentration parameters (alpha) - alpha: Box<[F]>, + repr: DirichletRepr, } -/// Error type returned from `Dirchlet::new`. -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] +/// Error type returned from [`Dirichlet::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Error { /// `alpha.len() < 2`. AlphaTooShort, /// `alpha <= 0.0` or `nan`. AlphaTooSmall, + /// `alpha` is subnormal. + /// Variate generation methods are not reliable with subnormal inputs. + AlphaSubnormal, + /// `alpha` is infinite. + AlphaInfinite, + /// Failed to create required Gamma distribution(s). + FailedToCreateGamma, + /// Failed to create required Beta distribition(s). + FailedToCreateBeta, /// `size < 2`. SizeTooSmall, } @@ -63,6 +253,14 @@ impl fmt::Display for Error { "less than 2 dimensions in Dirichlet distribution" } Error::AlphaTooSmall => "alpha is not positive in Dirichlet distribution", + Error::AlphaSubnormal => "alpha contains a subnormal value in Dirichlet distribution", + Error::AlphaInfinite => "alpha contains an infinite value in Dirichlet distribution", + Error::FailedToCreateGamma => { + "failed to create required Gamma distribution for Dirichlet distribution" + } + Error::FailedToCreateBeta => { + "failed to create required Beta distribition for Dirichlet distribution" + } }) } } @@ -70,7 +268,7 @@ impl fmt::Display for Error { #[cfg(feature = "std")] impl std::error::Error for Error {} -impl Dirichlet +impl Dirichlet where F: Float, StandardNormal: Distribution, @@ -79,60 +277,56 @@ where { /// Construct a new `Dirichlet` with the given alpha parameter `alpha`. /// - /// Requires `alpha.len() >= 2`. + /// Requires `alpha.len() >= 2`, and each value in `alpha` must be positive, + /// finite and not subnormal. #[inline] - pub fn new(alpha: &[F]) -> Result, Error> { - if alpha.len() < 2 { + pub fn new(alpha: [F; N]) -> Result, Error> { + if N < 2 { return Err(Error::AlphaTooShort); } for &ai in alpha.iter() { if !(ai > F::zero()) { + // This also catches nan. return Err(Error::AlphaTooSmall); } + if ai.is_infinite() { + return Err(Error::AlphaInfinite); + } + if !ai.is_normal() { + return Err(Error::AlphaSubnormal); + } } - Ok(Dirichlet { alpha: alpha.to_vec().into_boxed_slice() }) - } - - /// Construct a new `Dirichlet` with the given shape parameter `alpha` and `size`. - /// - /// Requires `size >= 2`. - #[inline] - pub fn new_with_size(alpha: F, size: usize) -> Result, Error> { - if !(alpha > F::zero()) { - return Err(Error::AlphaTooSmall); - } - if size < 2 { - return Err(Error::SizeTooSmall); + if alpha.iter().all(|&x| x <= NumCast::from(0.1).unwrap()) { + // Use the Beta method when all the alphas are less than 0.1 This + // threshold provides a reasonable compromise between using the faster + // Gamma method for as wide a range as possible while ensuring that + // the probability of generating nans is negligibly small. + let dist = DirichletFromBeta::new(alpha).map_err(|_| Error::FailedToCreateBeta)?; + Ok(Dirichlet { + repr: DirichletRepr::FromBeta(dist), + }) + } else { + let dist = DirichletFromGamma::new(alpha).map_err(|_| Error::FailedToCreateGamma)?; + Ok(Dirichlet { + repr: DirichletRepr::FromGamma(dist), + }) } - Ok(Dirichlet { - alpha: vec![alpha; size].into_boxed_slice(), - }) } } -impl Distribution> for Dirichlet +impl Distribution<[F; N]> for Dirichlet where F: Float, StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution, { - fn sample(&self, rng: &mut R) -> Vec { - let n = self.alpha.len(); - let mut samples = vec![F::zero(); n]; - let mut sum = F::zero(); - - for (s, &a) in samples.iter_mut().zip(self.alpha.iter()) { - let g = Gamma::new(a, F::one()).unwrap(); - *s = g.sample(rng); - sum = sum + (*s); - } - let invacc = F::one() / sum; - for s in samples.iter_mut() { - *s = (*s)*invacc; + fn sample(&self, rng: &mut R) -> [F; N] { + match &self.repr { + DirichletRepr::FromGamma(dirichlet) => dirichlet.sample(rng), + DirichletRepr::FromBeta(dirichlet) => dirichlet.sample(rng), } - samples } } @@ -142,43 +336,111 @@ mod test { #[test] fn test_dirichlet() { - let d = Dirichlet::new(&[1.0, 2.0, 3.0]).unwrap(); + let d = Dirichlet::new([1.0, 2.0, 3.0]).unwrap(); let mut rng = crate::test::rng(221); let samples = d.sample(&mut rng); - let _: Vec = samples - .into_iter() - .map(|x| { - assert!(x > 0.0); - x - }) - .collect(); + assert!(samples.into_iter().all(|x: f64| x > 0.0)); } #[test] - fn test_dirichlet_with_param() { - let alpha = 0.5f64; - let size = 2; - let d = Dirichlet::new_with_size(alpha, size).unwrap(); - let mut rng = crate::test::rng(221); - let samples = d.sample(&mut rng); - let _: Vec = samples - .into_iter() - .map(|x| { - assert!(x > 0.0); - x - }) - .collect(); + #[should_panic] + fn test_dirichlet_invalid_length() { + Dirichlet::new([0.5]).unwrap(); } #[test] #[should_panic] - fn test_dirichlet_invalid_length() { - Dirichlet::new_with_size(0.5f64, 1).unwrap(); + fn test_dirichlet_alpha_zero() { + Dirichlet::new([0.1, 0.0, 0.3]).unwrap(); + } + + #[test] + #[should_panic] + fn test_dirichlet_alpha_negative() { + Dirichlet::new([0.1, -1.5, 0.3]).unwrap(); + } + + #[test] + #[should_panic] + fn test_dirichlet_alpha_nan() { + Dirichlet::new([0.5, f64::NAN, 0.25]).unwrap(); + } + + #[test] + #[should_panic] + fn test_dirichlet_alpha_subnormal() { + Dirichlet::new([0.5, 1.5e-321, 0.25]).unwrap(); } #[test] #[should_panic] - fn test_dirichlet_invalid_alpha() { - Dirichlet::new_with_size(0.0f64, 2).unwrap(); + fn test_dirichlet_alpha_inf() { + Dirichlet::new([0.5, f64::INFINITY, 0.25]).unwrap(); + } + + #[test] + fn dirichlet_distributions_can_be_compared() { + assert_eq!(Dirichlet::new([1.0, 2.0]), Dirichlet::new([1.0, 2.0])); + } + + /// Check that the means of the components of n samples from + /// the Dirichlet distribution agree with the expected means + /// with a relative tolerance of rtol. + /// + /// This is a crude statistical test, but it will catch egregious + /// mistakes. It will also also fail if any samples contain nan. + fn check_dirichlet_means(alpha: [f64; N], n: i32, rtol: f64, seed: u64) { + let d = Dirichlet::new(alpha).unwrap(); + let mut rng = crate::test::rng(seed); + let mut sums = [0.0; N]; + for _ in 0..n { + let samples = d.sample(&mut rng); + for i in 0..N { + sums[i] += samples[i]; + } + } + let sample_mean = sums.map(|x| x / n as f64); + let alpha_sum: f64 = alpha.iter().sum(); + let expected_mean = alpha.map(|x| x / alpha_sum); + for i in 0..N { + assert_almost_eq!(sample_mean[i], expected_mean[i], rtol); + } + } + + #[test] + fn test_dirichlet_means() { + // Check the means of 20000 samples for several different alphas. + let n = 20000; + let rtol = 2e-2; + let seed = 1317624576693539401; + check_dirichlet_means([0.5, 0.25], n, rtol, seed); + check_dirichlet_means([123.0, 75.0], n, rtol, seed); + check_dirichlet_means([2.0, 2.5, 5.0, 7.0], n, rtol, seed); + check_dirichlet_means([0.1, 8.0, 1.0, 2.0, 2.0, 0.85, 0.05, 12.5], n, rtol, seed); + } + + #[test] + fn test_dirichlet_means_very_small_alpha() { + // With values of alpha that are all 0.001, check that the means of the + // components of 10000 samples are within 1% of the expected means. + // With the sampling method based on gamma variates, this test would + // fail, with about 10% of the samples containing nan. + let alpha = [0.001; 3]; + let n = 10000; + let rtol = 1e-2; + let seed = 1317624576693539401; + check_dirichlet_means(alpha, n, rtol, seed); + } + + #[test] + fn test_dirichlet_means_small_alpha() { + // With values of alpha that are all less than 0.1, check that the + // means of the components of 150000 samples are within 0.1% of the + // expected means. + let alpha = [0.05, 0.025, 0.075, 0.05]; + let n = 150000; + let rtol = 1e-3; + let seed = 1317624576693539401; + check_dirichlet_means(alpha, n, rtol, seed); } } diff --git a/rand_distr/src/exponential.rs b/rand_distr/src/exponential.rs index fb9818974ad..6d61015a8c1 100644 --- a/rand_distr/src/exponential.rs +++ b/rand_distr/src/exponential.rs @@ -7,38 +7,49 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! The exponential distribution. +//! The exponential distribution `Exp(λ)`. use crate::utils::ziggurat; -use num_traits::Float; use crate::{ziggurat_tables, Distribution}; -use rand::Rng; use core::fmt; +use num_traits::Float; +use rand::Rng; -/// Samples floating-point numbers according to the exponential distribution, -/// with rate parameter `λ = 1`. This is equivalent to `Exp::new(1.0)` or -/// sampling with `-rng.gen::().ln()`, but faster. +/// The standard exponential distribution `Exp(1)`. /// -/// See `Exp` for the general exponential distribution. +/// This is equivalent to `Exp::new(1.0)` or sampling with +/// `-rng.gen::().ln()`, but faster. /// -/// Implemented via the ZIGNOR variant[^1] of the Ziggurat method. The exact -/// description in the paper was adjusted to use tables for the exponential -/// distribution rather than normal. +/// See [`Exp`](crate::Exp) for the general exponential distribution. /// -/// [^1]: Jurgen A. Doornik (2005). [*An Improved Ziggurat Method to -/// Generate Normal Random Samples*]( -/// https://www.doornik.com/research/ziggurat.pdf). -/// Nuffield College, Oxford +/// # Plot +/// +/// The following plot illustrates the exponential distribution with `λ = 1`. +/// +/// ![Exponential distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/exponential_exp1.svg) /// /// # Example +/// /// ``` /// use rand::prelude::*; /// use rand_distr::Exp1; /// -/// let val: f64 = thread_rng().sample(Exp1); +/// let val: f64 = rand::rng().sample(Exp1); /// println!("{}", val); /// ``` +/// +/// # Notes +/// +/// Implemented via the ZIGNOR variant[^1] of the Ziggurat method. The exact +/// description in the paper was adjusted to use tables for the exponential +/// distribution rather than normal. +/// +/// [^1]: Jurgen A. Doornik (2005). [*An Improved Ziggurat Method to +/// Generate Normal Random Samples*]( +/// https://www.doornik.com/research/ziggurat.pdf). +/// Nuffield College, Oxford #[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct Exp1; impl Distribution for Exp1 { @@ -60,7 +71,7 @@ impl Distribution for Exp1 { } #[inline] fn zero_case(rng: &mut R, _u: f64) -> f64 { - ziggurat_tables::ZIG_EXP_R - rng.gen::().ln() + ziggurat_tables::ZIG_EXP_R - rng.random::().ln() } ziggurat( @@ -74,12 +85,30 @@ impl Distribution for Exp1 { } } -/// The exponential distribution `Exp(lambda)`. +/// The [exponential distribution](https://en.wikipedia.org/wiki/Exponential_distribution) `Exp(λ)`. +/// +/// The exponential distribution is a continuous probability distribution +/// with rate parameter `λ` (`lambda`). It describes the time between events +/// in a [`Poisson`](crate::Poisson) process, i.e. a process in which +/// events occur continuously and independently at a constant average rate. +/// +/// See [`Exp1`](crate::Exp1) for an optimised implementation for `λ = 1`. /// -/// This distribution has density function: `f(x) = lambda * exp(-lambda * x)` -/// for `x > 0`, when `lambda > 0`. For `lambda = 0`, all samples yield infinity. +/// # Density function /// -/// Note that [`Exp1`](crate::Exp1) is an optimised implementation for `lambda = 1`. +/// `f(x) = λ * exp(-λ * x)` for `x > 0`, when `λ > 0`. +/// +/// For `λ = 0`, all samples yield infinity (because a Poisson process +/// with rate 0 has no events). +/// +/// # Plot +/// +/// The following plot illustrates the exponential distribution with +/// various values of `λ`. +/// The `λ` parameter controls the rate of decay as `x` approaches infinity, +/// and the mean of the distribution is `1/λ`. +/// +/// ![Exponential distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/exponential.svg) /// /// # Example /// @@ -87,18 +116,21 @@ impl Distribution for Exp1 { /// use rand_distr::{Exp, Distribution}; /// /// let exp = Exp::new(2.0).unwrap(); -/// let v = exp.sample(&mut rand::thread_rng()); +/// let v = exp.sample(&mut rand::rng()); /// println!("{} is from a Exp(2) distribution", v); /// ``` -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct Exp -where F: Float, Exp1: Distribution +where + F: Float, + Exp1: Distribution, { /// `lambda` stored as `1/lambda`, since this is what we scale by. lambda_inverse: F, } -/// Error type returned from `Exp::new`. +/// Error type returned from [`Exp::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Error { /// `lambda < 0` or `nan`. @@ -117,16 +149,18 @@ impl fmt::Display for Error { impl std::error::Error for Error {} impl Exp -where F: Float, Exp1: Distribution +where + F: Float, + Exp1: Distribution, { /// Construct a new `Exp` with the given shape parameter /// `lambda`. - /// + /// /// # Remarks - /// + /// /// For custom types `N` implementing the [`Float`] trait, /// the case `lambda = 0` is handled as follows: each sample corresponds - /// to a sample from an `Exp1` multiplied by `1 / 0`. Primitive types + /// to a sample from an `Exp1` multiplied by `1 / 0`. Primitive types /// yield infinity, since `1 / 0 = infinity`. #[inline] pub fn new(lambda: F) -> Result, Error> { @@ -140,7 +174,9 @@ where F: Float, Exp1: Distribution } impl Distribution for Exp -where F: Float, Exp1: Distribution +where + F: Float, + Exp1: Distribution, { fn sample(&self, rng: &mut R) -> F { rng.sample(Exp1) * self.lambda_inverse @@ -175,4 +211,9 @@ mod test { fn test_exp_invalid_lambda_nan() { Exp::new(f64::nan()).unwrap(); } + + #[test] + fn exponential_distributions_can_be_compared() { + assert_eq!(Exp::new(1.0), Exp::new(1.0)); + } } diff --git a/rand_distr/src/fisher_f.rs b/rand_distr/src/fisher_f.rs new file mode 100644 index 00000000000..9c2c13cf64f --- /dev/null +++ b/rand_distr/src/fisher_f.rs @@ -0,0 +1,131 @@ +// Copyright 2018 Developers of the Rand project. +// Copyright 2013 The Rust Project Developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! The Fisher F-distribution. + +use crate::{ChiSquared, Distribution, Exp1, Open01, StandardNormal}; +use core::fmt; +use num_traits::Float; +use rand::Rng; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +/// The [Fisher F-distribution](https://en.wikipedia.org/wiki/F-distribution) `F(m, n)`. +/// +/// This distribution is equivalent to the ratio of two normalised +/// chi-squared distributions, that is, `F(m,n) = (χ²(m)/m) / +/// (χ²(n)/n)`. +/// +/// # Plot +/// +/// The plot shows the F-distribution with various values of `m` and `n`. +/// +/// ![F-distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/fisher_f.svg) +/// +/// # Example +/// +/// ``` +/// use rand_distr::{FisherF, Distribution}; +/// +/// let f = FisherF::new(2.0, 32.0).unwrap(); +/// let v = f.sample(&mut rand::rng()); +/// println!("{} is from an F(2, 32) distribution", v) +/// ``` +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct FisherF +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + numer: ChiSquared, + denom: ChiSquared, + // denom_dof / numer_dof so that this can just be a straight + // multiplication, rather than a division. + dof_ratio: F, +} + +/// Error type returned from [`FisherF::new`]. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum Error { + /// `m <= 0` or `nan`. + MTooSmall, + /// `n <= 0` or `nan`. + NTooSmall, +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + Error::MTooSmall => "m is not positive in Fisher F distribution", + Error::NTooSmall => "n is not positive in Fisher F distribution", + }) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for Error {} + +impl FisherF +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + /// Create a new `FisherF` distribution, with the given parameter. + pub fn new(m: F, n: F) -> Result, Error> { + let zero = F::zero(); + if !(m > zero) { + return Err(Error::MTooSmall); + } + if !(n > zero) { + return Err(Error::NTooSmall); + } + + Ok(FisherF { + numer: ChiSquared::new(m).unwrap(), + denom: ChiSquared::new(n).unwrap(), + dof_ratio: n / m, + }) + } +} +impl Distribution for FisherF +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + fn sample(&self, rng: &mut R) -> F { + self.numer.sample(rng) / self.denom.sample(rng) * self.dof_ratio + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_f() { + let f = FisherF::new(2.0, 32.0).unwrap(); + let mut rng = crate::test::rng(204); + for _ in 0..1000 { + f.sample(&mut rng); + } + } + + #[test] + fn fisher_f_distributions_can_be_compared() { + assert_eq!(FisherF::new(1.0, 2.0), FisherF::new(1.0, 2.0)); + } +} diff --git a/rand_distr/src/frechet.rs b/rand_distr/src/frechet.rs new file mode 100644 index 00000000000..feecd603fb5 --- /dev/null +++ b/rand_distr/src/frechet.rs @@ -0,0 +1,205 @@ +// Copyright 2021 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! The Fréchet distribution `Fréchet(μ, σ, α)`. + +use crate::{Distribution, OpenClosed01}; +use core::fmt; +use num_traits::Float; +use rand::Rng; + +/// The [Fréchet distribution](https://en.wikipedia.org/wiki/Fr%C3%A9chet_distribution) `Fréchet(α, μ, σ)`. +/// +/// The Fréchet distribution is a continuous probability distribution +/// with location parameter `μ` (`mu`), scale parameter `σ` (`sigma`), +/// and shape parameter `α` (`alpha`). It describes the distribution +/// of the maximum (or minimum) of a number of random variables. +/// It is also known as the Type II extreme value distribution. +/// +/// # Density function +/// +/// `f(x) = [(x - μ) / σ]^(-1 - α) exp[-(x - μ) / σ]^(-α) α / σ` +/// +/// # Plot +/// +/// The plot shows the Fréchet distribution with various values of `μ`, `σ`, and `α`. +/// Note how the location parameter `μ` shifts the distribution along the x-axis, +/// the scale parameter `σ` stretches or compresses the distribution along the x-axis, +/// and the shape parameter `α` changes the tail behavior. +/// +/// ![Fréchet distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/frechet.svg) +/// +/// # Example +/// +/// ``` +/// use rand::prelude::*; +/// use rand_distr::Frechet; +/// +/// let val: f64 = rand::rng().sample(Frechet::new(0.0, 1.0, 1.0).unwrap()); +/// println!("{}", val); +/// ``` +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct Frechet +where + F: Float, + OpenClosed01: Distribution, +{ + location: F, + scale: F, + shape: F, +} + +/// Error type returned from [`Frechet::new`]. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Error { + /// location is infinite or NaN + LocationNotFinite, + /// scale is not finite positive number + ScaleNotPositive, + /// shape is not finite positive number + ShapeNotPositive, +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + Error::LocationNotFinite => "location is not finite in Frechet distribution", + Error::ScaleNotPositive => "scale is not positive and finite in Frechet distribution", + Error::ShapeNotPositive => "shape is not positive and finite in Frechet distribution", + }) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for Error {} + +impl Frechet +where + F: Float, + OpenClosed01: Distribution, +{ + /// Construct a new `Frechet` distribution with given `location`, `scale`, and `shape`. + pub fn new(location: F, scale: F, shape: F) -> Result, Error> { + if scale <= F::zero() || scale.is_infinite() || scale.is_nan() { + return Err(Error::ScaleNotPositive); + } + if shape <= F::zero() || shape.is_infinite() || shape.is_nan() { + return Err(Error::ShapeNotPositive); + } + if location.is_infinite() || location.is_nan() { + return Err(Error::LocationNotFinite); + } + Ok(Frechet { + location, + scale, + shape, + }) + } +} + +impl Distribution for Frechet +where + F: Float, + OpenClosed01: Distribution, +{ + fn sample(&self, rng: &mut R) -> F { + let x: F = rng.sample(OpenClosed01); + self.location + self.scale * (-x.ln()).powf(-self.shape.recip()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + #[should_panic] + fn test_zero_scale() { + Frechet::new(0.0, 0.0, 1.0).unwrap(); + } + + #[test] + #[should_panic] + fn test_infinite_scale() { + Frechet::new(0.0, f64::INFINITY, 1.0).unwrap(); + } + + #[test] + #[should_panic] + fn test_nan_scale() { + Frechet::new(0.0, f64::NAN, 1.0).unwrap(); + } + + #[test] + #[should_panic] + fn test_zero_shape() { + Frechet::new(0.0, 1.0, 0.0).unwrap(); + } + + #[test] + #[should_panic] + fn test_infinite_shape() { + Frechet::new(0.0, 1.0, f64::INFINITY).unwrap(); + } + + #[test] + #[should_panic] + fn test_nan_shape() { + Frechet::new(0.0, 1.0, f64::NAN).unwrap(); + } + + #[test] + #[should_panic] + fn test_infinite_location() { + Frechet::new(f64::INFINITY, 1.0, 1.0).unwrap(); + } + + #[test] + #[should_panic] + fn test_nan_location() { + Frechet::new(f64::NAN, 1.0, 1.0).unwrap(); + } + + #[test] + fn test_sample_against_cdf() { + fn quantile_function(x: f64) -> f64 { + (-x.ln()).recip() + } + let location = 0.0; + let scale = 1.0; + let shape = 1.0; + let iterations = 100_000; + let increment = 1.0 / iterations as f64; + let probabilities = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]; + let mut quantiles = [0.0; 9]; + for (i, p) in probabilities.iter().enumerate() { + quantiles[i] = quantile_function(*p); + } + let mut proportions = [0.0; 9]; + let d = Frechet::new(location, scale, shape).unwrap(); + let mut rng = crate::test::rng(1); + for _ in 0..iterations { + let replicate = d.sample(&mut rng); + for (i, q) in quantiles.iter().enumerate() { + if replicate < *q { + proportions[i] += increment; + } + } + } + assert!(proportions + .iter() + .zip(&probabilities) + .all(|(p_hat, p)| (p_hat - p).abs() < 0.003)) + } + + #[test] + fn frechet_distributions_can_be_compared() { + assert_eq!(Frechet::new(1.0, 2.0, 3.0), Frechet::new(1.0, 2.0, 3.0)); + } +} diff --git a/rand_distr/src/gamma.rs b/rand_distr/src/gamma.rs index 5e98dbdfcfc..0fc6b756df3 100644 --- a/rand_distr/src/gamma.rs +++ b/rand_distr/src/gamma.rs @@ -7,32 +7,39 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! The Gamma and derived distributions. +//! The Gamma distribution. -use self::ChiSquaredRepr::*; use self::GammaRepr::*; -use crate::normal::StandardNormal; +use crate::{Distribution, Exp, Exp1, Open01, StandardNormal}; +use core::fmt; use num_traits::Float; -use crate::{Distribution, Exp, Exp1, Open01}; use rand::Rng; -use core::fmt; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; -/// The Gamma distribution `Gamma(shape, scale)` distribution. +/// The [Gamma distribution](https://en.wikipedia.org/wiki/Gamma_distribution) `Gamma(k, θ)`. /// -/// The density function of this distribution is +/// The Gamma distribution is a continuous probability distribution +/// with shape parameter `k > 0` (number of events) and +/// scale parameter `θ > 0` (mean waiting time between events). +/// It describes the time until `k` events occur in a Poisson +/// process with rate `1/θ`. It is the generalization of the +/// [`Exponential`](crate::Exp) distribution. /// -/// ```text -/// f(x) = x^(k - 1) * exp(-x / θ) / (Γ(k) * θ^k) -/// ``` +/// # Density function /// -/// where `Γ` is the Gamma function, `k` is the shape and `θ` is the -/// scale and both `k` and `θ` are strictly positive. +/// `f(x) = x^(k - 1) * exp(-x / θ) / (Γ(k) * θ^k)` for `x > 0`, +/// where `Γ` is the [gamma function](https://en.wikipedia.org/wiki/Gamma_function). /// -/// The algorithm used is that described by Marsaglia & Tsang 2000[^1], -/// falling back to directly sampling from an Exponential for `shape -/// == 1`, and using the boosting technique described in that paper for -/// `shape < 1`. +/// # Plot +/// +/// The following plot illustrates the Gamma distribution with +/// various values of `k` and `θ`. +/// Curves with `θ = 1` are more saturated, while corresponding +/// curves with `θ = 2` have a lighter color. +/// +/// ![Gamma distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/gamma.svg) /// /// # Example /// @@ -40,15 +47,23 @@ use core::fmt; /// use rand_distr::{Distribution, Gamma}; /// /// let gamma = Gamma::new(2.0, 5.0).unwrap(); -/// let v = gamma.sample(&mut rand::thread_rng()); +/// let v = gamma.sample(&mut rand::rng()); /// println!("{} is from a Gamma(2, 5) distribution", v); /// ``` /// +/// # Notes +/// +/// The algorithm used is that described by Marsaglia & Tsang 2000[^1], +/// falling back to directly sampling from an Exponential for `shape +/// == 1`, and using the boosting technique described in that paper for +/// `shape < 1`. +/// /// [^1]: George Marsaglia and Wai Wan Tsang. 2000. "A Simple Method for /// Generating Gamma Variables" *ACM Trans. Math. Softw.* 26, 3 /// (September 2000), 363-372. /// DOI:[10.1145/358407.358414](https://doi.acm.org/10.1145/358407.358414) -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Gamma where F: Float, @@ -59,7 +74,7 @@ where repr: GammaRepr, } -/// Error type returned from `Gamma::new`. +/// Error type returned from [`Gamma::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Error { /// `shape <= 0` or `nan`. @@ -83,7 +98,8 @@ impl fmt::Display for Error { #[cfg(feature = "std")] impl std::error::Error for Error {} -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] enum GammaRepr where F: Float, @@ -110,7 +126,8 @@ where /// /// See `Gamma` for sampling from a Gamma distribution with general /// shape parameters. -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] struct GammaSmallShape where F: Float, @@ -125,7 +142,8 @@ where /// /// See `Gamma` for sampling from a Gamma distribution with general /// shape parameters. -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] struct GammaLargeShape where F: Float, @@ -252,538 +270,12 @@ where } } -/// The chi-squared distribution `χ²(k)`, where `k` is the degrees of -/// freedom. -/// -/// For `k > 0` integral, this distribution is the sum of the squares -/// of `k` independent standard normal random variables. For other -/// `k`, this uses the equivalent characterisation -/// `χ²(k) = Gamma(k/2, 2)`. -/// -/// # Example -/// -/// ``` -/// use rand_distr::{ChiSquared, Distribution}; -/// -/// let chi = ChiSquared::new(11.0).unwrap(); -/// let v = chi.sample(&mut rand::thread_rng()); -/// println!("{} is from a χ²(11) distribution", v) -/// ``` -#[derive(Clone, Copy, Debug)] -pub struct ChiSquared -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - repr: ChiSquaredRepr, -} - -/// Error type returned from `ChiSquared::new` and `StudentT::new`. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum ChiSquaredError { - /// `0.5 * k <= 0` or `nan`. - DoFTooSmall, -} - -impl fmt::Display for ChiSquaredError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - ChiSquaredError::DoFTooSmall => { - "degrees-of-freedom k is not positive in chi-squared distribution" - } - }) - } -} - -#[cfg(feature = "std")] -impl std::error::Error for ChiSquaredError {} - -#[derive(Clone, Copy, Debug)] -enum ChiSquaredRepr -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - // k == 1, Gamma(alpha, ..) is particularly slow for alpha < 1, - // e.g. when alpha = 1/2 as it would be for this case, so special- - // casing and using the definition of N(0,1)^2 is faster. - DoFExactlyOne, - DoFAnythingElse(Gamma), -} - -impl ChiSquared -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - /// Create a new chi-squared distribution with degrees-of-freedom - /// `k`. - pub fn new(k: F) -> Result, ChiSquaredError> { - let repr = if k == F::one() { - DoFExactlyOne - } else { - if !(F::from(0.5).unwrap() * k > F::zero()) { - return Err(ChiSquaredError::DoFTooSmall); - } - DoFAnythingElse(Gamma::new(F::from(0.5).unwrap() * k, F::from(2.0).unwrap()).unwrap()) - }; - Ok(ChiSquared { repr }) - } -} -impl Distribution for ChiSquared -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - fn sample(&self, rng: &mut R) -> F { - match self.repr { - DoFExactlyOne => { - // k == 1 => N(0,1)^2 - let norm: F = rng.sample(StandardNormal); - norm * norm - } - DoFAnythingElse(ref g) => g.sample(rng), - } - } -} - -/// The Fisher F distribution `F(m, n)`. -/// -/// This distribution is equivalent to the ratio of two normalised -/// chi-squared distributions, that is, `F(m,n) = (χ²(m)/m) / -/// (χ²(n)/n)`. -/// -/// # Example -/// -/// ``` -/// use rand_distr::{FisherF, Distribution}; -/// -/// let f = FisherF::new(2.0, 32.0).unwrap(); -/// let v = f.sample(&mut rand::thread_rng()); -/// println!("{} is from an F(2, 32) distribution", v) -/// ``` -#[derive(Clone, Copy, Debug)] -pub struct FisherF -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - numer: ChiSquared, - denom: ChiSquared, - // denom_dof / numer_dof so that this can just be a straight - // multiplication, rather than a division. - dof_ratio: F, -} - -/// Error type returned from `FisherF::new`. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum FisherFError { - /// `m <= 0` or `nan`. - MTooSmall, - /// `n <= 0` or `nan`. - NTooSmall, -} - -impl fmt::Display for FisherFError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - FisherFError::MTooSmall => "m is not positive in Fisher F distribution", - FisherFError::NTooSmall => "n is not positive in Fisher F distribution", - }) - } -} - -#[cfg(feature = "std")] -impl std::error::Error for FisherFError {} - -impl FisherF -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - /// Create a new `FisherF` distribution, with the given parameter. - pub fn new(m: F, n: F) -> Result, FisherFError> { - let zero = F::zero(); - if !(m > zero) { - return Err(FisherFError::MTooSmall); - } - if !(n > zero) { - return Err(FisherFError::NTooSmall); - } - - Ok(FisherF { - numer: ChiSquared::new(m).unwrap(), - denom: ChiSquared::new(n).unwrap(), - dof_ratio: n / m, - }) - } -} -impl Distribution for FisherF -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - fn sample(&self, rng: &mut R) -> F { - self.numer.sample(rng) / self.denom.sample(rng) * self.dof_ratio - } -} - -/// The Student t distribution, `t(nu)`, where `nu` is the degrees of -/// freedom. -/// -/// # Example -/// -/// ``` -/// use rand_distr::{StudentT, Distribution}; -/// -/// let t = StudentT::new(11.0).unwrap(); -/// let v = t.sample(&mut rand::thread_rng()); -/// println!("{} is from a t(11) distribution", v) -/// ``` -#[derive(Clone, Copy, Debug)] -pub struct StudentT -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - chi: ChiSquared, - dof: F, -} - -impl StudentT -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - /// Create a new Student t distribution with `n` degrees of - /// freedom. - pub fn new(n: F) -> Result, ChiSquaredError> { - Ok(StudentT { - chi: ChiSquared::new(n)?, - dof: n, - }) - } -} -impl Distribution for StudentT -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - fn sample(&self, rng: &mut R) -> F { - let norm: F = rng.sample(StandardNormal); - norm * (self.dof / self.chi.sample(rng)).sqrt() - } -} - -/// The algorithm used for sampling the Beta distribution. -/// -/// Reference: -/// -/// R. C. H. Cheng (1978). -/// Generating beta variates with nonintegral shape parameters. -/// Communications of the ACM 21, 317-322. -/// https://doi.org/10.1145/359460.359482 -#[derive(Clone, Copy, Debug)] -enum BetaAlgorithm { - BB(BB), - BC(BC), -} - -/// Algorithm BB for `min(alpha, beta) > 1`. -#[derive(Clone, Copy, Debug)] -struct BB { - alpha: N, - beta: N, - gamma: N, -} - -/// Algorithm BC for `min(alpha, beta) <= 1`. -#[derive(Clone, Copy, Debug)] -struct BC { - alpha: N, - beta: N, - delta: N, - kappa1: N, - kappa2: N, -} - -/// The Beta distribution with shape parameters `alpha` and `beta`. -/// -/// # Example -/// -/// ``` -/// use rand_distr::{Distribution, Beta}; -/// -/// let beta = Beta::new(2.0, 5.0).unwrap(); -/// let v = beta.sample(&mut rand::thread_rng()); -/// println!("{} is from a Beta(2, 5) distribution", v); -/// ``` -#[derive(Clone, Copy, Debug)] -pub struct Beta -where - F: Float, - Open01: Distribution, -{ - a: F, b: F, switched_params: bool, - algorithm: BetaAlgorithm, -} - -/// Error type returned from `Beta::new`. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum BetaError { - /// `alpha <= 0` or `nan`. - AlphaTooSmall, - /// `beta <= 0` or `nan`. - BetaTooSmall, -} - -impl fmt::Display for BetaError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - BetaError::AlphaTooSmall => "alpha is not positive in beta distribution", - BetaError::BetaTooSmall => "beta is not positive in beta distribution", - }) - } -} - -#[cfg(feature = "std")] -impl std::error::Error for BetaError {} - -impl Beta -where - F: Float, - Open01: Distribution, -{ - /// Construct an object representing the `Beta(alpha, beta)` - /// distribution. - pub fn new(alpha: F, beta: F) -> Result, BetaError> { - if !(alpha > F::zero()) { - return Err(BetaError::AlphaTooSmall); - } - if !(beta > F::zero()) { - return Err(BetaError::BetaTooSmall); - } - // From now on, we use the notation from the reference, - // i.e. `alpha` and `beta` are renamed to `a0` and `b0`. - let (a0, b0) = (alpha, beta); - let (a, b, switched_params) = if a0 < b0 { - (a0, b0, false) - } else { - (b0, a0, true) - }; - if a > F::one() { - // Algorithm BB - let alpha = a + b; - let beta = ((alpha - F::from(2.).unwrap()) - / (F::from(2.).unwrap()*a*b - alpha)).sqrt(); - let gamma = a + F::one() / beta; - - Ok(Beta { - a, b, switched_params, - algorithm: BetaAlgorithm::BB(BB { - alpha, beta, gamma, - }) - }) - } else { - // Algorithm BC - // - // Here `a` is the maximum instead of the minimum. - let (a, b, switched_params) = (b, a, !switched_params); - let alpha = a + b; - let beta = F::one() / b; - let delta = F::one() + a - b; - let kappa1 = delta - * (F::from(1. / 18. / 4.).unwrap() + F::from(3. / 18. / 4.).unwrap()*b) - / (a*beta - F::from(14. / 18.).unwrap()); - let kappa2 = F::from(0.25).unwrap() - + (F::from(0.5).unwrap() + F::from(0.25).unwrap()/delta)*b; - - Ok(Beta { - a, b, switched_params, - algorithm: BetaAlgorithm::BC(BC { - alpha, beta, delta, kappa1, kappa2, - }) - }) - } - } -} - -impl Distribution for Beta -where - F: Float, - Open01: Distribution, -{ - fn sample(&self, rng: &mut R) -> F { - let mut w; - match self.algorithm { - BetaAlgorithm::BB(algo) => { - loop { - // 1. - let u1 = rng.sample(Open01); - let u2 = rng.sample(Open01); - let v = algo.beta * (u1 / (F::one() - u1)).ln(); - w = self.a * v.exp(); - let z = u1*u1 * u2; - let r = algo.gamma * v - F::from(4.).unwrap().ln(); - let s = self.a + r - w; - // 2. - if s + F::one() + F::from(5.).unwrap().ln() - >= F::from(5.).unwrap() * z { - break; - } - // 3. - let t = z.ln(); - if s >= t { - break; - } - // 4. - if !(r + algo.alpha * (algo.alpha / (self.b + w)).ln() < t) { - break; - } - } - }, - BetaAlgorithm::BC(algo) => { - loop { - let z; - // 1. - let u1 = rng.sample(Open01); - let u2 = rng.sample(Open01); - if u1 < F::from(0.5).unwrap() { - // 2. - let y = u1 * u2; - z = u1 * y; - if F::from(0.25).unwrap() * u2 + z - y >= algo.kappa1 { - continue; - } - } else { - // 3. - z = u1 * u1 * u2; - if z <= F::from(0.25).unwrap() { - let v = algo.beta * (u1 / (F::one() - u1)).ln(); - w = self.a * v.exp(); - break; - } - // 4. - if z >= algo.kappa2 { - continue; - } - } - // 5. - let v = algo.beta * (u1 / (F::one() - u1)).ln(); - w = self.a * v.exp(); - if !(algo.alpha * ((algo.alpha / (self.b + w)).ln() + v) - - F::from(4.).unwrap().ln() < z.ln()) { - break; - }; - } - }, - }; - // 5. for BB, 6. for BC - if !self.switched_params { - if w == F::infinity() { - // Assuming `b` is finite, for large `w`: - return F::one(); - } - w / (self.b + w) - } else { - self.b / (self.b + w) - } - } -} - #[cfg(test)] mod test { use super::*; #[test] - fn test_chi_squared_one() { - let chi = ChiSquared::new(1.0).unwrap(); - let mut rng = crate::test::rng(201); - for _ in 0..1000 { - chi.sample(&mut rng); - } - } - #[test] - fn test_chi_squared_small() { - let chi = ChiSquared::new(0.5).unwrap(); - let mut rng = crate::test::rng(202); - for _ in 0..1000 { - chi.sample(&mut rng); - } - } - #[test] - fn test_chi_squared_large() { - let chi = ChiSquared::new(30.0).unwrap(); - let mut rng = crate::test::rng(203); - for _ in 0..1000 { - chi.sample(&mut rng); - } - } - #[test] - #[should_panic] - fn test_chi_squared_invalid_dof() { - ChiSquared::new(-1.0).unwrap(); - } - - #[test] - fn test_f() { - let f = FisherF::new(2.0, 32.0).unwrap(); - let mut rng = crate::test::rng(204); - for _ in 0..1000 { - f.sample(&mut rng); - } - } - - #[test] - fn test_t() { - let t = StudentT::new(11.0).unwrap(); - let mut rng = crate::test::rng(205); - for _ in 0..1000 { - t.sample(&mut rng); - } - } - - #[test] - fn test_beta() { - let beta = Beta::new(1.0, 2.0).unwrap(); - let mut rng = crate::test::rng(201); - for _ in 0..1000 { - beta.sample(&mut rng); - } - } - - #[test] - #[should_panic] - fn test_beta_invalid_dof() { - Beta::new(0., 0.).unwrap(); - } - - #[test] - fn test_beta_small_param() { - let beta = Beta::::new(1e-3, 1e-3).unwrap(); - let mut rng = crate::test::rng(206); - for i in 0..1000 { - assert!(!beta.sample(&mut rng).is_nan(), "failed at i={}", i); - } + fn gamma_distributions_can_be_compared() { + assert_eq!(Gamma::new(1.0, 2.0), Gamma::new(1.0, 2.0)); } } diff --git a/rand_distr/src/geometric.rs b/rand_distr/src/geometric.rs index c51168af0c2..74d30a4459a 100644 --- a/rand_distr/src/geometric.rs +++ b/rand_distr/src/geometric.rs @@ -1,39 +1,52 @@ -//! The geometric distribution. +//! The geometric distribution `Geometric(p)`. use crate::Distribution; -use rand::Rng; use core::fmt; +#[allow(unused_imports)] +use num_traits::Float; +use rand::Rng; -/// The geometric distribution `Geometric(p)` bounded to `[0, u64::MAX]`. -/// -/// This is the probability distribution of the number of failures before the -/// first success in a series of Bernoulli trials. It has the density function -/// `f(k) = (1 - p)^k p` for `k >= 0`, where `p` is the probability of success -/// on each trial. -/// +/// The [geometric distribution](https://en.wikipedia.org/wiki/Geometric_distribution) `Geometric(p)`. +/// +/// This is the probability distribution of the number of failures +/// (bounded to `[0, u64::MAX]`) before the first success in a +/// series of [`Bernoulli`](crate::Bernoulli) trials, where the +/// probability of success on each trial is `p`. +/// /// This is the discrete analogue of the [exponential distribution](crate::Exp). -/// -/// Note that [`StandardGeometric`](crate::StandardGeometric) is an optimised +/// +/// See [`StandardGeometric`](crate::StandardGeometric) for an optimised /// implementation for `p = 0.5`. /// -/// # Example +/// # Density function +/// +/// `f(k) = (1 - p)^k p` for `k >= 0`. +/// +/// # Plot /// +/// The following plot illustrates the geometric distribution for various +/// values of `p`. Note how higher `p` values shift the distribution to +/// the left, and the mean of the distribution is `1/p`. +/// +/// ![Geometric distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/geometric.svg) +/// +/// # Example /// ``` /// use rand_distr::{Geometric, Distribution}; /// /// let geo = Geometric::new(0.25).unwrap(); -/// let v = geo.sample(&mut rand::thread_rng()); +/// let v = geo.sample(&mut rand::rng()); /// println!("{} is from a Geometric(0.25) distribution", v); /// ``` -#[derive(Copy, Clone, Debug)] -pub struct Geometric -{ +#[derive(Copy, Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct Geometric { p: f64, pi: f64, - k: u64 + k: u64, } -/// Error type returned from `Geometric::new`. +/// Error type returned from [`Geometric::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Error { /// `p < 0 || p > 1` or `nan` @@ -43,7 +56,9 @@ pub enum Error { impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str(match self { - Error::InvalidProbability => "p is NaN or outside the interval [0, 1] in geometric distribution", + Error::InvalidProbability => { + "p is NaN or outside the interval [0, 1] in geometric distribution" + } }) } } @@ -53,9 +68,9 @@ impl std::error::Error for Error {} impl Geometric { /// Construct a new `Geometric` with the given shape parameter `p` - /// (probablity of success on each trial). + /// (probability of success on each trial). pub fn new(p: f64) -> Result { - if !p.is_finite() || p < 0.0 || p > 1.0 { + if !p.is_finite() || !(0.0..=1.0).contains(&p) { Err(Error::InvalidProbability) } else if p == 0.0 || p >= 2.0 / 3.0 { Ok(Geometric { p, pi: p, k: 0 }) @@ -76,21 +91,24 @@ impl Geometric { } } -impl Distribution for Geometric -{ +impl Distribution for Geometric { fn sample(&self, rng: &mut R) -> u64 { if self.p >= 2.0 / 3.0 { // use the trivial algorithm: let mut failures = 0; loop { - let u = rng.gen::(); - if u <= self.p { break; } + let u = rng.random::(); + if u <= self.p { + break; + } failures += 1; } return failures; } - - if self.p == 0.0 { return core::u64::MAX; } + + if self.p == 0.0 { + return u64::MAX; + } let Geometric { p, pi, k } = *self; @@ -104,7 +122,7 @@ impl Distribution for Geometric // Use the trivial algorithm to sample D from Geo(pi) = Geo(p) / 2^k: let d = { let mut failures = 0; - while rng.gen::() < pi { + while rng.random::() < pi { failures += 1; } failures @@ -112,18 +130,18 @@ impl Distribution for Geometric // Use rejection sampling for the remainder M from Geo(p) % 2^k: // choose M uniformly from [0, 2^k), but reject with probability (1 - p)^M - // NOTE: The paper suggests using bitwise sampling here, which is + // NOTE: The paper suggests using bitwise sampling here, which is // currently unsupported, but should improve performance by requiring // fewer iterations on average. ~ October 28, 2020 let m = loop { - let m = rng.gen::() & ((1 << k) - 1); - let p_reject = if m <= core::i32::MAX as u64 { + let m = rng.random::() & ((1 << k) - 1); + let p_reject = if m <= i32::MAX as u64 { (1.0 - p).powi(m as i32) } else { (1.0 - p).powf(m as f64) }; - - let u = rng.gen::(); + + let u = rng.random::(); if u < p_reject { break m; } @@ -133,32 +151,43 @@ impl Distribution for Geometric } } -/// Samples integers according to the geometric distribution with success -/// probability `p = 0.5`. This is equivalent to `Geometeric::new(0.5)`, -/// but faster. -/// +/// The standard geometric distribution `Geometric(0.5)`. +/// +/// This is equivalent to `Geometric::new(0.5)`, but faster. +/// /// See [`Geometric`](crate::Geometric) for the general geometric distribution. -/// -/// Implemented via iterated [Rng::gen::().leading_zeros()]. -/// +/// +/// # Plot +/// +/// The following plot illustrates the standard geometric distribution. +/// +/// ![Standard Geometric distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/standard_geometric.svg) +/// /// # Example /// ``` /// use rand::prelude::*; /// use rand_distr::StandardGeometric; -/// -/// let v = StandardGeometric.sample(&mut thread_rng()); +/// +/// let v = StandardGeometric.sample(&mut rand::rng()); /// println!("{} is from a Geometric(0.5) distribution", v); /// ``` +/// +/// # Notes +/// Implemented via iterated +/// [`Rng::gen::().leading_zeros()`](Rng::gen::().leading_zeros()). #[derive(Copy, Clone, Debug)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct StandardGeometric; impl Distribution for StandardGeometric { fn sample(&self, rng: &mut R) -> u64 { let mut result = 0; loop { - let x = rng.gen::().leading_zeros() as u64; + let x = rng.random::().leading_zeros() as u64; result += x; - if x < 64 { break; } + if x < 64 { + break; + } } result } @@ -170,9 +199,9 @@ mod test { #[test] fn test_geo_invalid_p() { - assert!(Geometric::new(core::f64::NAN).is_err()); - assert!(Geometric::new(core::f64::INFINITY).is_err()); - assert!(Geometric::new(core::f64::NEG_INFINITY).is_err()); + assert!(Geometric::new(f64::NAN).is_err()); + assert!(Geometric::new(f64::INFINITY).is_err()); + assert!(Geometric::new(f64::NEG_INFINITY).is_err()); assert!(Geometric::new(-0.5).is_err()); assert!(Geometric::new(0.0).is_ok()); @@ -192,7 +221,7 @@ mod test { } let mean = results.iter().sum::() / results.len() as f64; - assert!((mean as f64 - expected_mean).abs() < expected_mean / 40.0); + assert!((mean - expected_mean).abs() < expected_mean / 40.0); let variance = results.iter().map(|x| (x - mean) * (x - mean)).sum::() / results.len() as f64; @@ -224,10 +253,15 @@ mod test { } let mean = results.iter().sum::() / results.len() as f64; - assert!((mean as f64 - expected_mean).abs() < expected_mean / 50.0); + assert!((mean - expected_mean).abs() < expected_mean / 50.0); let variance = results.iter().map(|x| (x - mean) * (x - mean)).sum::() / results.len() as f64; assert!((variance - expected_variance).abs() < expected_variance / 10.0); } -} \ No newline at end of file + + #[test] + fn geometric_distributions_can_be_compared() { + assert_eq!(Geometric::new(1.0), Geometric::new(1.0)); + } +} diff --git a/rand_distr/src/gumbel.rs b/rand_distr/src/gumbel.rs new file mode 100644 index 00000000000..f420a52df84 --- /dev/null +++ b/rand_distr/src/gumbel.rs @@ -0,0 +1,173 @@ +// Copyright 2021 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! The Gumbel distribution `Gumbel(μ, β)`. + +use crate::{Distribution, OpenClosed01}; +use core::fmt; +use num_traits::Float; +use rand::Rng; + +/// The [Gumbel distribution](https://en.wikipedia.org/wiki/Gumbel_distribution) `Gumbel(μ, β)`. +/// +/// The Gumbel distribution is a continuous probability distribution +/// with location parameter `μ` (`mu`) and scale parameter `β` (`beta`). +/// It is used to model the distribution of the maximum (or minimum) +/// of a number of samples of various distributions. +/// +/// # Density function +/// +/// `f(x) = exp(-(z + exp(-z))) / β`, where `z = (x - μ) / β`. +/// +/// # Plot +/// +/// The following plot illustrates the Gumbel distribution with various values of `μ` and `β`. +/// Note how the location parameter `μ` shifts the distribution along the x-axis, +/// and the scale parameter `β` changes the density around `μ`. +/// Note also the asymptotic behavior of the distribution towards the right. +/// +/// ![Gumbel distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/gumbel.svg) +/// +/// # Example +/// ``` +/// use rand::prelude::*; +/// use rand_distr::Gumbel; +/// +/// let val: f64 = rand::rng().sample(Gumbel::new(0.0, 1.0).unwrap()); +/// println!("{}", val); +/// ``` +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct Gumbel +where + F: Float, + OpenClosed01: Distribution, +{ + location: F, + scale: F, +} + +/// Error type returned from [`Gumbel::new`]. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Error { + /// location is infinite or NaN + LocationNotFinite, + /// scale is not finite positive number + ScaleNotPositive, +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + Error::ScaleNotPositive => "scale is not positive and finite in Gumbel distribution", + Error::LocationNotFinite => "location is not finite in Gumbel distribution", + }) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for Error {} + +impl Gumbel +where + F: Float, + OpenClosed01: Distribution, +{ + /// Construct a new `Gumbel` distribution with given `location` and `scale`. + pub fn new(location: F, scale: F) -> Result, Error> { + if scale <= F::zero() || scale.is_infinite() || scale.is_nan() { + return Err(Error::ScaleNotPositive); + } + if location.is_infinite() || location.is_nan() { + return Err(Error::LocationNotFinite); + } + Ok(Gumbel { location, scale }) + } +} + +impl Distribution for Gumbel +where + F: Float, + OpenClosed01: Distribution, +{ + fn sample(&self, rng: &mut R) -> F { + let x: F = rng.sample(OpenClosed01); + self.location - self.scale * (-x.ln()).ln() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + #[should_panic] + fn test_zero_scale() { + Gumbel::new(0.0, 0.0).unwrap(); + } + + #[test] + #[should_panic] + fn test_infinite_scale() { + Gumbel::new(0.0, f64::INFINITY).unwrap(); + } + + #[test] + #[should_panic] + fn test_nan_scale() { + Gumbel::new(0.0, f64::NAN).unwrap(); + } + + #[test] + #[should_panic] + fn test_infinite_location() { + Gumbel::new(f64::INFINITY, 1.0).unwrap(); + } + + #[test] + #[should_panic] + fn test_nan_location() { + Gumbel::new(f64::NAN, 1.0).unwrap(); + } + + #[test] + fn test_sample_against_cdf() { + fn neg_log_log(x: f64) -> f64 { + -(-x.ln()).ln() + } + let location = 0.0; + let scale = 1.0; + let iterations = 100_000; + let increment = 1.0 / iterations as f64; + let probabilities = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]; + let mut quantiles = [0.0; 9]; + for (i, p) in probabilities.iter().enumerate() { + quantiles[i] = neg_log_log(*p); + } + let mut proportions = [0.0; 9]; + let d = Gumbel::new(location, scale).unwrap(); + let mut rng = crate::test::rng(1); + for _ in 0..iterations { + let replicate = d.sample(&mut rng); + for (i, q) in quantiles.iter().enumerate() { + if replicate < *q { + proportions[i] += increment; + } + } + } + assert!(proportions + .iter() + .zip(&probabilities) + .all(|(p_hat, p)| (p_hat - p).abs() < 0.003)) + } + + #[test] + fn gumbel_distributions_can_be_compared() { + assert_eq!(Gumbel::new(1.0, 2.0), Gumbel::new(1.0, 2.0)); + } +} diff --git a/rand_distr/src/hypergeometric.rs b/rand_distr/src/hypergeometric.rs index 406e97ea88a..f446357530b 100644 --- a/rand_distr/src/hypergeometric.rs +++ b/rand_distr/src/hypergeometric.rs @@ -1,14 +1,20 @@ -//! The hypergeometric distribution. +//! The hypergeometric distribution `Hypergeometric(N, K, n)`. use crate::Distribution; -use rand::Rng; -use rand::distributions::uniform::Uniform; use core::fmt; +#[allow(unused_imports)] +use num_traits::Float; +use rand::distr::uniform::Uniform; +use rand::Rng; -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] enum SamplingMethod { - InverseTransform{ initial_p: f64, initial_x: i64 }, - RejectionAcceptance{ + InverseTransform { + initial_p: f64, + initial_x: i64, + }, + RejectionAcceptance { m: f64, a: f64, lambda_l: f64, @@ -17,32 +23,42 @@ enum SamplingMethod { x_r: f64, p1: f64, p2: f64, - p3: f64 + p3: f64, }, } -/// The hypergeometric distribution `Hypergeometric(N, K, n)`. -/// +/// The [hypergeometric distribution](https://en.wikipedia.org/wiki/Hypergeometric_distribution) `Hypergeometric(N, K, n)`. +/// /// This is the distribution of successes in samples of size `n` drawn without /// replacement from a population of size `N` containing `K` success states. -/// It has the density function: -/// `f(k) = binomial(K, k) * binomial(N-K, n-k) / binomial(N, n)`, -/// where `binomial(a, b) = a! / (b! * (a - b)!)`. -/// -/// The [binomial distribution](crate::Binomial) is the analagous distribution +/// +/// See the [binomial distribution](crate::Binomial) for the analogous distribution /// for sampling with replacement. It is a good approximation when the population /// size is much larger than the sample size. -/// +/// +/// # Density function +/// +/// `f(k) = binomial(K, k) * binomial(N-K, n-k) / binomial(N, n)`, +/// where `binomial(a, b) = a! / (b! * (a - b)!)`. +/// +/// # Plot +/// +/// The following plot of the hypergeometric distribution illustrates the probability of drawing +/// `k` successes in `n = 10` draws from a population of `N = 50` items, of which either `K = 12` +/// or `K = 35` are successes. +/// +/// ![Hypergeometric distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/hypergeometric.svg) +/// /// # Example -/// /// ``` /// use rand_distr::{Distribution, Hypergeometric}; /// /// let hypergeo = Hypergeometric::new(60, 24, 7).unwrap(); -/// let v = hypergeo.sample(&mut rand::thread_rng()); +/// let v = hypergeo.sample(&mut rand::rng()); /// println!("{} is from a hypergeometric distribution", v); /// ``` -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct Hypergeometric { n1: u64, n2: u64, @@ -52,7 +68,7 @@ pub struct Hypergeometric { sampling_method: SamplingMethod, } -/// Error type returned from `Hypergeometric::new`. +/// Error type returned from [`Hypergeometric::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Error { /// `total_population_size` is too large, causing floating point underflow. @@ -66,13 +82,22 @@ pub enum Error { impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str(match self { - Error::PopulationTooLarge => "total_population_size is too large causing underflow in geometric distribution", - Error::ProbabilityTooLarge => "population_with_feature > total_population_size in geometric distribution", - Error::SampleSizeTooLarge => "sample_size > total_population_size in geometric distribution", + Error::PopulationTooLarge => { + "total_population_size is too large causing underflow in geometric distribution" + } + Error::ProbabilityTooLarge => { + "population_with_feature > total_population_size in geometric distribution" + } + Error::SampleSizeTooLarge => { + "sample_size > total_population_size in geometric distribution" + } }) } } +#[cfg(feature = "std")] +impl std::error::Error for Error {} + // evaluate fact(numerator.0)*fact(numerator.1) / fact(denominator.0)*fact(denominator.1) fn fraction_of_products_of_factorials(numerator: (u64, u64), denominator: (u64, u64)) -> f64 { let min_top = u64::min(numerator.0, numerator.1); @@ -89,27 +114,34 @@ fn fraction_of_products_of_factorials(numerator: (u64, u64), denominator: (u64, if i <= min_top { result *= i as f64; } - + if i <= min_bottom { result /= i as f64; } - + if i <= max_top { result *= i as f64; } - + if i <= max_bottom { result /= i as f64; } } - + result } +const LOGSQRT2PI: f64 = 0.91893853320467274178; // log(sqrt(2*pi)) + fn ln_of_factorial(v: f64) -> f64 { // the paper calls for ln(v!), but also wants to pass in fractions, // so we need to use Stirling's approximation to fill in the gaps: - v * v.ln() - v + + // shift v by 3, because Stirling is bad for small values + let v_3 = v + 3.0; + let ln_fac = (v_3 + 0.5) * v_3.ln() - v_3 + LOGSQRT2PI + 1.0 / (12.0 * v_3); + // make the correction for the shift + ln_fac - ((v + 3.0) * (v + 2.0) * (v + 1.0)).ln() } impl Hypergeometric { @@ -118,7 +150,11 @@ impl Hypergeometric { /// `K = population_with_feature`, /// `n = sample_size`. #[allow(clippy::many_single_char_names)] // Same names as in the reference. - pub fn new(total_population_size: u64, population_with_feature: u64, sample_size: u64) -> Result { + pub fn new( + total_population_size: u64, + population_with_feature: u64, + sample_size: u64, + ) -> Result { if population_with_feature > total_population_size { return Err(Error::ProbabilityTooLarge); } @@ -143,7 +179,7 @@ impl Hypergeometric { }; // when sampling more than half the total population, take the smaller // group as sampled instead (we can then return n1-x instead). - // + // // Note: the boundary condition given in the paper is `sample_size < n / 2`; // we're deviating here, because when n is even, it doesn't matter whether // we switch here or not, but when n is odd `n/2 < n - n/2`, so switching @@ -159,7 +195,7 @@ impl Hypergeometric { // Algorithm H2PE has bounded runtime only if `M - max(0, k-n2) >= 10`, // where `M` is the mode of the distribution. // Use algorithm HIN for the remaining parameter space. - // + // // Voratas Kachitvichyanukul and Bruce W. Schmeiser. 1985. Computer // generation of hypergeometric random variates. // J. Statist. Comput. Simul. Vol.22 (August 1985), 127-145 @@ -168,21 +204,30 @@ impl Hypergeometric { let m = ((k + 1) as f64 * (n1 + 1) as f64 / (n + 2) as f64).floor(); let sampling_method = if m - f64::max(0.0, k as f64 - n2 as f64) < HIN_THRESHOLD { let (initial_p, initial_x) = if k < n2 { - (fraction_of_products_of_factorials((n2, n - k), (n, n2 - k)), 0) + ( + fraction_of_products_of_factorials((n2, n - k), (n, n2 - k)), + 0, + ) } else { - (fraction_of_products_of_factorials((n1, k), (n, k - n2)), (k - n2) as i64) + ( + fraction_of_products_of_factorials((n1, k), (n, k - n2)), + (k - n2) as i64, + ) }; if initial_p <= 0.0 || !initial_p.is_finite() { return Err(Error::PopulationTooLarge); } - SamplingMethod::InverseTransform { initial_p, initial_x } + SamplingMethod::InverseTransform { + initial_p, + initial_x, + } } else { - let a = ln_of_factorial(m) + - ln_of_factorial(n1 as f64 - m) + - ln_of_factorial(k as f64 - m) + - ln_of_factorial((n2 - k) as f64 + m); + let a = ln_of_factorial(m) + + ln_of_factorial(n1 as f64 - m) + + ln_of_factorial(k as f64 - m) + + ln_of_factorial((n2 - k) as f64 + m); let numerator = (n - k) as f64 * k as f64 * n1 as f64 * n2 as f64; let denominator = (n - 1) as f64 * n as f64 * n as f64; @@ -191,17 +236,19 @@ impl Hypergeometric { let x_l = m - d + 0.5; let x_r = m + d + 0.5; - let k_l = f64::exp(a - - ln_of_factorial(x_l) - - ln_of_factorial(n1 as f64 - x_l) - - ln_of_factorial(k as f64 - x_l) - - ln_of_factorial((n2 - k) as f64 + x_l)); - let k_r = f64::exp(a - - ln_of_factorial(x_r - 1.0) - - ln_of_factorial(n1 as f64 - x_r + 1.0) - - ln_of_factorial(k as f64 - x_r + 1.0) - - ln_of_factorial((n2 - k) as f64 + x_r - 1.0)); - + let k_l = f64::exp( + a - ln_of_factorial(x_l) + - ln_of_factorial(n1 as f64 - x_l) + - ln_of_factorial(k as f64 - x_l) + - ln_of_factorial((n2 - k) as f64 + x_l), + ); + let k_r = f64::exp( + a - ln_of_factorial(x_r - 1.0) + - ln_of_factorial(n1 as f64 - x_r + 1.0) + - ln_of_factorial(k as f64 - x_r + 1.0) + - ln_of_factorial((n2 - k) as f64 + x_r - 1.0), + ); + let numerator = x_l * ((n2 - k) as f64 + x_l); let denominator = (n1 as f64 - x_l + 1.0) * (k as f64 - x_l + 1.0); let lambda_l = -((numerator / denominator).ln()); @@ -217,11 +264,26 @@ impl Hypergeometric { let p3 = p2 + k_r / lambda_r; SamplingMethod::RejectionAcceptance { - m, a, lambda_l, lambda_r, x_l, x_r, p1, p2, p3 + m, + a, + lambda_l, + lambda_r, + x_l, + x_r, + p1, + p2, + p3, } }; - Ok(Hypergeometric { n1, n2, k, sign_x, offset_x, sampling_method }) + Ok(Hypergeometric { + n1, + n2, + k, + offset_x, + sign_x, + sampling_method, + }) } } @@ -230,25 +292,47 @@ impl Distribution for Hypergeometric { fn sample(&self, rng: &mut R) -> u64 { use SamplingMethod::*; - let Hypergeometric { n1, n2, k, sign_x, offset_x, sampling_method } = *self; + let Hypergeometric { + n1, + n2, + k, + sign_x, + offset_x, + sampling_method, + } = *self; let x = match sampling_method { - InverseTransform { initial_p: mut p, initial_x: mut x } => { - let mut u = rng.gen::(); - while u > p && x < k as i64 { // the paper erroneously uses `until n < p`, which doesn't make any sense + InverseTransform { + initial_p: mut p, + initial_x: mut x, + } => { + let mut u = rng.random::(); + + // the paper erroneously uses `until n < p`, which doesn't make any sense + while u > p && x < k as i64 { u -= p; - p *= ((n1 as i64 - x as i64) * (k as i64 - x as i64)) as f64; - p /= ((x as i64 + 1) * (n2 as i64 - k as i64 + 1 + x as i64)) as f64; + p *= ((n1 as i64 - x) * (k as i64 - x)) as f64; + p /= ((x + 1) * (n2 as i64 - k as i64 + 1 + x)) as f64; x += 1; } x - }, - RejectionAcceptance { m, a, lambda_l, lambda_r, x_l, x_r, p1, p2, p3 } => { - let distr_region_select = Uniform::new(0.0, p3); + } + RejectionAcceptance { + m, + a, + lambda_l, + lambda_r, + x_l, + x_r, + p1, + p2, + p3, + } => { + let distr_region_select = Uniform::new(0.0, p3).unwrap(); loop { let (y, v) = loop { let u = distr_region_select.sample(rng); - let v = rng.gen::(); // for the accept/reject decision - + let v = rng.random::(); // for the accept/reject decision + if u <= p1 { // Region 1, central bell let y = (x_l + u).floor(); @@ -269,7 +353,7 @@ impl Distribution for Hypergeometric { } } }; - + // Step 4: Acceptance/Rejection Comparison if m < 100.0 || y <= 50.0 { // Step 4.1: evaluate f(y) via recursive relationship @@ -282,11 +366,13 @@ impl Distribution for Hypergeometric { } else { for i in (y as u64 + 1)..=(m as u64) { f *= i as f64 * (n2 - k + i) as f64; - f /= (n1 - i) as f64 * (k - i) as f64; + f /= (n1 - i + 1) as f64 * (k - i + 1) as f64; } } - - if v <= f { break y as i64; } + + if v <= f { + break y as i64; + } } else { // Step 4.2: Squeezing let y1 = y + 1.0; @@ -299,24 +385,24 @@ impl Distribution for Hypergeometric { let t = ym / yk; let e = -ym / nk; let g = yn * yk / (y1 * nk) - 1.0; - let dg = if g < 0.0 { - 1.0 + g - } else { - 1.0 - }; + let dg = if g < 0.0 { 1.0 + g } else { 1.0 }; let gu = g * (1.0 + g * (-0.5 + g / 3.0)); let gl = gu - g.powi(4) / (4.0 * dg); let xm = m + 0.5; let xn = n1 as f64 - m + 0.5; let xk = k as f64 - m + 0.5; let nm = n2 as f64 - k as f64 + xm; - let ub = xm * r * (1.0 + r * (-0.5 + r / 3.0)) + - xn * s * (1.0 + s * (-0.5 + s / 3.0)) + - xk * t * (1.0 + t * (-0.5 + t / 3.0)) + - nm * e * (1.0 + e * (-0.5 + e / 3.0)) + - y * gu - m * gl + 0.0034; + let ub = xm * r * (1.0 + r * (-0.5 + r / 3.0)) + + xn * s * (1.0 + s * (-0.5 + s / 3.0)) + + xk * t * (1.0 + t * (-0.5 + t / 3.0)) + + nm * e * (1.0 + e * (-0.5 + e / 3.0)) + + y * gu + - m * gl + + 0.0034; let av = v.ln(); - if av > ub { continue; } + if av > ub { + continue; + } let dr = if r < 0.0 { xm * r.powi(4) / (1.0 + r) } else { @@ -337,17 +423,17 @@ impl Distribution for Hypergeometric { } else { nm * e.powi(4) }; - - if av < ub - 0.25*(dr + ds + dt + de) + (y + m)*(gl - gu) - 0.0078 { + + if av < ub - 0.25 * (dr + ds + dt + de) + (y + m) * (gl - gu) - 0.0078 { break y as i64; } - + // Step 4.3: Final Acceptance/Rejection Test - let av_critical = a - - ln_of_factorial(y) - - ln_of_factorial(n1 as f64 - y) - - ln_of_factorial(k as f64 - y) - - ln_of_factorial((n2 - k) as f64 + y); + let av_critical = a + - ln_of_factorial(y) + - ln_of_factorial(n1 as f64 - y) + - ln_of_factorial(k as f64 - y) + - ln_of_factorial((n2 - k) as f64 + y); if v.ln() <= av_critical { break y as i64; } @@ -362,6 +448,7 @@ impl Distribution for Hypergeometric { #[cfg(test)] mod test { + use super::*; #[test] @@ -372,8 +459,7 @@ mod test { assert!(Hypergeometric::new(100, 10, 5).is_ok()); } - fn test_hypergeometric_mean_and_variance(n: u64, k: u64, s: u64, rng: &mut R) - { + fn test_hypergeometric_mean_and_variance(n: u64, k: u64, s: u64, rng: &mut R) { let distr = Hypergeometric::new(n, k, s).unwrap(); let expected_mean = s as f64 * k as f64 / n as f64; @@ -389,7 +475,7 @@ mod test { } let mean = results.iter().sum::() / results.len() as f64; - assert!((mean as f64 - expected_mean).abs() < expected_mean / 50.0); + assert!((mean - expected_mean).abs() < expected_mean / 50.0); let variance = results.iter().map(|x| (x - mean) * (x - mean)).sum::() / results.len() as f64; @@ -411,4 +497,18 @@ mod test { test_hypergeometric_mean_and_variance(10100, 10000, 1000, &mut rng); test_hypergeometric_mean_and_variance(100100, 100, 10000, &mut rng); } + + #[test] + fn hypergeometric_distributions_can_be_compared() { + assert_eq!(Hypergeometric::new(1, 2, 3), Hypergeometric::new(1, 2, 3)); + } + + #[test] + fn stirling() { + let test = [0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; + for &v in test.iter() { + let ln_fac = ln_of_factorial(v); + assert!((special::Gamma::ln_gamma(v + 1.0).0 - ln_fac).abs() < 1e-4); + } + } } diff --git a/rand_distr/src/inverse_gaussian.rs b/rand_distr/src/inverse_gaussian.rs index 7af645a23c4..354c2e05986 100644 --- a/rand_distr/src/inverse_gaussian.rs +++ b/rand_distr/src/inverse_gaussian.rs @@ -1,9 +1,12 @@ -use crate::{Distribution, Standard, StandardNormal}; +//! The inverse Gaussian distribution `IG(μ, λ)`. + +use crate::{Distribution, StandardNormal, StandardUniform}; +use core::fmt; use num_traits::Float; use rand::Rng; -/// Error type returned from `InverseGaussian::new` -#[derive(Debug, PartialEq)] +/// Error type returned from [`InverseGaussian::new`] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Error { /// `mean <= 0` or `nan`. MeanNegativeOrNull, @@ -11,13 +14,46 @@ pub enum Error { ShapeNegativeOrNull, } -/// The [inverse Gaussian distribution](https://en.wikipedia.org/wiki/Inverse_Gaussian_distribution) -#[derive(Debug)] +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + Error::MeanNegativeOrNull => "mean <= 0 or is NaN in inverse Gaussian distribution", + Error::ShapeNegativeOrNull => "shape <= 0 or is NaN in inverse Gaussian distribution", + }) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for Error {} + +/// The [inverse Gaussian distribution](https://en.wikipedia.org/wiki/Inverse_Gaussian_distribution) `IG(μ, λ)`. +/// +/// This is a continuous probability distribution with mean parameter `μ` (`mu`) +/// and shape parameter `λ` (`lambda`), defined for `x > 0`. +/// It is also known as the Wald distribution. +/// +/// # Plot +/// +/// The following plot shows the inverse Gaussian distribution +/// with various values of `μ` and `λ`. +/// +/// ![Inverse Gaussian distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/inverse_gaussian.svg) +/// +/// # Example +/// ``` +/// use rand_distr::{InverseGaussian, Distribution}; +/// +/// let inv_gauss = InverseGaussian::new(1.0, 2.0).unwrap(); +/// let v = inv_gauss.sample(&mut rand::rng()); +/// println!("{} is from a inverse Gaussian(1, 2) distribution", v); +/// ``` +#[derive(Debug, Clone, Copy, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct InverseGaussian where F: Float, StandardNormal: Distribution, - Standard: Distribution, + StandardUniform: Distribution, { mean: F, shape: F, @@ -27,7 +63,7 @@ impl InverseGaussian where F: Float, StandardNormal: Distribution, - Standard: Distribution, + StandardUniform: Distribution, { /// Construct a new `InverseGaussian` distribution with the given mean and /// shape. @@ -49,11 +85,13 @@ impl Distribution for InverseGaussian where F: Float, StandardNormal: Distribution, - Standard: Distribution, + StandardUniform: Distribution, { #[allow(clippy::many_single_char_names)] fn sample(&self, rng: &mut R) -> F - where R: Rng + ?Sized { + where + R: Rng + ?Sized, + { let mu = self.mean; let l = self.shape; @@ -64,7 +102,7 @@ where let x = mu + mu_2l * (y - (F::from(4.).unwrap() * l * y + y * y).sqrt()); - let u: F = rng.gen(); + let u: F = rng.random(); if u <= mu / (mu + x) { return x; @@ -94,4 +132,12 @@ mod tests { assert!(InverseGaussian::new(1.0, -1.0).is_err()); assert!(InverseGaussian::new(1.0, 1.0).is_ok()); } + + #[test] + fn inverse_gaussian_distributions_can_be_compared() { + assert_eq!( + InverseGaussian::new(1.0, 2.0), + InverseGaussian::new(1.0, 2.0) + ); + } } diff --git a/rand_distr/src/lib.rs b/rand_distr/src/lib.rs index 0043bd9f62e..ef1109b7d6f 100644 --- a/rand_distr/src/lib.rs +++ b/rand_distr/src/lib.rs @@ -11,6 +11,7 @@ html_favicon_url = "https://www.rust-lang.org/favicon.ico", html_root_url = "https://rust-random.github.io/rand/" )] +#![forbid(unsafe_code)] #![deny(missing_docs)] #![deny(missing_debug_implementations)] #![allow( @@ -20,21 +21,22 @@ )] #![allow(clippy::neg_cmp_op_on_partial_ord)] // suggested fix too verbose #![no_std] -#![cfg_attr(doc_cfg, feature(doc_cfg))] +#![cfg_attr(docsrs, feature(doc_auto_cfg))] //! Generating random samples from probability distributions. //! //! ## Re-exports //! -//! This crate is a super-set of the [`rand::distributions`] module. See the -//! [`rand::distributions`] module documentation for an overview of the core +//! This crate is a super-set of the [`rand::distr`] module. See the +//! [`rand::distr`] module documentation for an overview of the core //! [`Distribution`] trait and implementations. //! //! The following are re-exported: //! -//! - The [`Distribution`] trait and [`DistIter`] helper type -//! - The [`Standard`], [`Alphanumeric`], [`Uniform`], [`OpenClosed01`], -//! [`Open01`], [`Bernoulli`], and [`WeightedIndex`] distributions +//! - The [`Distribution`] trait and [`Iter`] helper type +//! - The [`StandardUniform`], [`Alphanumeric`], [`Uniform`], [`OpenClosed01`], +//! [`Open01`], [`Bernoulli`] distributions +//! - The [`weighted`] module //! //! ## Distributions //! @@ -43,6 +45,7 @@ //! - Related to real-valued quantities that grow linearly //! (e.g. errors, offsets): //! - [`Normal`] distribution, and [`StandardNormal`] as a primitive +//! - [`SkewNormal`] distribution //! - [`Cauchy`] distribution //! - Related to Bernoulli trials (yes/no events, with a given probability): //! - [`Binomial`] distribution @@ -56,6 +59,10 @@ //! - [`Poisson`] distribution //! - [`Exp`]onential distribution, and [`Exp1`] as a primitive //! - [`Weibull`] distribution +//! - [`Gumbel`] distribution +//! - [`Frechet`] distribution +//! - [`Zeta`] distribution +//! - [`Zipf`] distribution //! - Gamma and derived distributions: //! - [`Gamma`] distribution //! - [`ChiSquared`] distribution @@ -70,8 +77,6 @@ //! - [`UnitBall`] distribution //! - [`UnitCircle`] distribution //! - [`UnitDisc`] distribution -//! - Alternative implementation for weighted index sampling -//! - [`WeightedAliasIndex`] distribution //! - Misc. distributions //! - [`InverseGaussian`] distribution //! - [`NormalInverseGaussian`] distribution @@ -86,42 +91,48 @@ extern crate std; #[allow(unused)] use rand::Rng; -pub use rand::distributions::{ - uniform, Alphanumeric, Bernoulli, BernoulliError, DistIter, Distribution, Open01, OpenClosed01, - Standard, Uniform, +pub use rand::distr::{ + uniform, Alphanumeric, Bernoulli, BernoulliError, Distribution, Iter, Open01, OpenClosed01, + StandardUniform, Uniform, }; +pub use self::beta::{Beta, Error as BetaError}; pub use self::binomial::{Binomial, Error as BinomialError}; pub use self::cauchy::{Cauchy, Error as CauchyError}; +pub use self::chi_squared::{ChiSquared, Error as ChiSquaredError}; #[cfg(feature = "alloc")] pub use self::dirichlet::{Dirichlet, Error as DirichletError}; pub use self::exponential::{Error as ExpError, Exp, Exp1}; -pub use self::gamma::{ - Beta, BetaError, ChiSquared, ChiSquaredError, Error as GammaError, FisherF, FisherFError, - Gamma, StudentT, -}; +pub use self::fisher_f::{Error as FisherFError, FisherF}; +pub use self::frechet::{Error as FrechetError, Frechet}; +pub use self::gamma::{Error as GammaError, Gamma}; pub use self::geometric::{Error as GeoError, Geometric, StandardGeometric}; +pub use self::gumbel::{Error as GumbelError, Gumbel}; pub use self::hypergeometric::{Error as HyperGeoError, Hypergeometric}; -pub use self::inverse_gaussian::{InverseGaussian, Error as InverseGaussianError}; +pub use self::inverse_gaussian::{Error as InverseGaussianError, InverseGaussian}; pub use self::normal::{Error as NormalError, LogNormal, Normal, StandardNormal}; -pub use self::normal_inverse_gaussian::{NormalInverseGaussian, Error as NormalInverseGaussianError}; +pub use self::normal_inverse_gaussian::{ + Error as NormalInverseGaussianError, NormalInverseGaussian, +}; pub use self::pareto::{Error as ParetoError, Pareto}; -pub use self::pert::{Pert, PertError}; +pub use self::pert::{Pert, PertBuilder, PertError}; pub use self::poisson::{Error as PoissonError, Poisson}; +pub use self::skew_normal::{Error as SkewNormalError, SkewNormal}; pub use self::triangular::{Triangular, TriangularError}; pub use self::unit_ball::UnitBall; pub use self::unit_circle::UnitCircle; pub use self::unit_disc::UnitDisc; pub use self::unit_sphere::UnitSphere; pub use self::weibull::{Error as WeibullError, Weibull}; -#[cfg(feature = "alloc")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] -pub use rand::distributions::{WeightedError, WeightedIndex}; -#[cfg(feature = "alloc")] -pub use weighted_alias::WeightedAliasIndex; +pub use self::zeta::{Error as ZetaError, Zeta}; +pub use self::zipf::{Error as ZipfError, Zipf}; +pub use student_t::StudentT; pub use num_traits; +#[cfg(feature = "alloc")] +pub mod weighted; + #[cfg(test)] #[macro_use] mod test { @@ -160,34 +171,39 @@ mod test { macro_rules! assert_almost_eq { ($a:expr, $b:expr, $prec:expr) => { let diff = ($a - $b).abs(); - if diff > $prec { - panic!( - "assertion failed: `abs(left - right) = {:.1e} < {:e}`, \ - (left: `{}`, right: `{}`)", - diff, $prec, $a, $b - ); - } + assert!( + diff <= $prec, + "assertion failed: `abs(left - right) = {:.1e} < {:e}`, \ + (left: `{}`, right: `{}`)", + diff, + $prec, + $a, + $b + ); }; } } -#[cfg(feature = "alloc")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] -pub mod weighted_alias; - +mod beta; mod binomial; mod cauchy; +mod chi_squared; mod dirichlet; mod exponential; +mod fisher_f; +mod frechet; mod gamma; mod geometric; +mod gumbel; mod hypergeometric; mod inverse_gaussian; mod normal; mod normal_inverse_gaussian; mod pareto; mod pert; -mod poisson; +pub(crate) mod poisson; +mod skew_normal; +mod student_t; mod triangular; mod unit_ball; mod unit_circle; @@ -195,5 +211,6 @@ mod unit_disc; mod unit_sphere; mod utils; mod weibull; +mod zeta; mod ziggurat_tables; - +mod zipf; diff --git a/rand_distr/src/normal.rs b/rand_distr/src/normal.rs index 8c3c8f8fd2f..330c1ec2d6f 100644 --- a/rand_distr/src/normal.rs +++ b/rand_distr/src/normal.rs @@ -7,36 +7,45 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! The normal and derived distributions. +//! The Normal and derived distributions. use crate::utils::ziggurat; -use num_traits::Float; use crate::{ziggurat_tables, Distribution, Open01}; -use rand::Rng; use core::fmt; +use num_traits::Float; +use rand::Rng; -/// Samples floating-point numbers according to the normal distribution -/// `N(0, 1)` (a.k.a. a standard normal, or Gaussian). This is equivalent to -/// `Normal::new(0.0, 1.0)` but faster. +/// The standard Normal distribution `N(0, 1)`. /// -/// See `Normal` for the general normal distribution. +/// This is equivalent to `Normal::new(0.0, 1.0)`, but faster. /// -/// Implemented via the ZIGNOR variant[^1] of the Ziggurat method. +/// See [`Normal`](crate::Normal) for the general Normal distribution. /// -/// [^1]: Jurgen A. Doornik (2005). [*An Improved Ziggurat Method to -/// Generate Normal Random Samples*]( -/// https://www.doornik.com/research/ziggurat.pdf). -/// Nuffield College, Oxford +/// # Plot +/// +/// The following diagram shows the standard Normal distribution. +/// +/// ![Standard Normal distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/standard_normal.svg) /// /// # Example /// ``` /// use rand::prelude::*; /// use rand_distr::StandardNormal; /// -/// let val: f64 = thread_rng().sample(StandardNormal); +/// let val: f64 = rand::rng().sample(StandardNormal); /// println!("{}", val); /// ``` +/// +/// # Notes +/// +/// Implemented via the ZIGNOR variant[^1] of the Ziggurat method. +/// +/// [^1]: Jurgen A. Doornik (2005). [*An Improved Ziggurat Method to +/// Generate Normal Random Samples*]( +/// https://www.doornik.com/research/ziggurat.pdf). +/// Nuffield College, Oxford #[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct StandardNormal; impl Distribution for StandardNormal { @@ -91,13 +100,28 @@ impl Distribution for StandardNormal { } } -/// The normal distribution `N(mean, std_dev**2)`. +/// The [Normal distribution](https://en.wikipedia.org/wiki/Normal_distribution) `N(μ, σ²)`. +/// +/// The Normal distribution, also known as the Gaussian distribution or +/// bell curve, is a continuous probability distribution with mean +/// `μ` (`mu`) and standard deviation `σ` (`sigma`). +/// It is used to model continuous data that tend to cluster around a mean. +/// The Normal distribution is symmetric and characterized by its bell-shaped curve. +/// +/// See [`StandardNormal`](crate::StandardNormal) for an +/// optimised implementation for `μ = 0` and `σ = 1`. +/// +/// # Density function +/// +/// `f(x) = (1 / sqrt(2π σ²)) * exp(-((x - μ)² / (2σ²)))` +/// +/// # Plot /// -/// This uses the ZIGNOR variant of the Ziggurat method, see [`StandardNormal`] -/// for more details. +/// The following diagram shows the Normal distribution with various values of `μ` +/// and `σ`. +/// The blue curve is the [`StandardNormal`](crate::StandardNormal) distribution, `N(0, 1)`. /// -/// Note that [`StandardNormal`] is an optimised implementation for mean 0, and -/// standard deviation 1. +/// ![Normal distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/normal.svg) /// /// # Example /// @@ -106,20 +130,30 @@ impl Distribution for StandardNormal { /// /// // mean 2, standard deviation 3 /// let normal = Normal::new(2.0, 3.0).unwrap(); -/// let v = normal.sample(&mut rand::thread_rng()); +/// let v = normal.sample(&mut rand::rng()); /// println!("{} is from a N(2, 9) distribution", v) /// ``` /// -/// [`StandardNormal`]: crate::StandardNormal -#[derive(Clone, Copy, Debug)] +/// # Notes +/// +/// Implemented via the ZIGNOR variant[^1] of the Ziggurat method. +/// +/// [^1]: Jurgen A. Doornik (2005). [*An Improved Ziggurat Method to +/// Generate Normal Random Samples*]( +/// https://www.doornik.com/research/ziggurat.pdf). +/// Nuffield College, Oxford +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct Normal -where F: Float, StandardNormal: Distribution +where + F: Float, + StandardNormal: Distribution, { mean: F, std_dev: F, } -/// Error type returned from `Normal::new` and `LogNormal::new`. +/// Error type returned from [`Normal::new`] and [`LogNormal::new`](crate::LogNormal::new). #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Error { /// The mean value is too small (log-normal samples must be positive) @@ -141,7 +175,9 @@ impl fmt::Display for Error { impl std::error::Error for Error {} impl Normal -where F: Float, StandardNormal: Distribution +where + F: Float, + StandardNormal: Distribution, { /// Construct, from mean and standard deviation /// @@ -179,7 +215,7 @@ where F: Float, StandardNormal: Distribution /// ``` /// # use rand::prelude::*; /// # use rand_distr::{Normal, StandardNormal}; - /// let mut rng = thread_rng(); + /// let mut rng = rand::rng(); /// let z = StandardNormal.sample(&mut rng); /// let x1 = Normal::new(0.0, 1.0).unwrap().from_zscore(z); /// let x2 = Normal::new(2.0, -3.0).unwrap().from_zscore(z); @@ -188,21 +224,40 @@ where F: Float, StandardNormal: Distribution pub fn from_zscore(&self, zscore: F) -> F { self.mean + self.std_dev * zscore } + + /// Returns the mean (`μ`) of the distribution. + pub fn mean(&self) -> F { + self.mean + } + + /// Returns the standard deviation (`σ`) of the distribution. + pub fn std_dev(&self) -> F { + self.std_dev + } } impl Distribution for Normal -where F: Float, StandardNormal: Distribution +where + F: Float, + StandardNormal: Distribution, { fn sample(&self, rng: &mut R) -> F { self.from_zscore(rng.sample(StandardNormal)) } } - -/// The log-normal distribution `ln N(mean, std_dev**2)`. +/// The [log-normal distribution](https://en.wikipedia.org/wiki/Log-normal_distribution) `ln N(μ, σ²)`. +/// +/// This is the distribution of the random variable `X = exp(Y)` where `Y` is +/// normally distributed with mean `μ` and variance `σ²`. In other words, if +/// `X` is log-normal distributed, then `ln(X)` is `N(μ, σ²)` distributed. +/// +/// # Plot /// -/// If `X` is log-normal distributed, then `ln(X)` is `N(mean, std_dev**2)` -/// distributed. +/// The following diagram shows the log-normal distribution with various values +/// of `μ` and `σ`. +/// +/// ![Log-normal distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/log_normal.svg) /// /// # Example /// @@ -211,18 +266,23 @@ where F: Float, StandardNormal: Distribution /// /// // mean 2, standard deviation 3 /// let log_normal = LogNormal::new(2.0, 3.0).unwrap(); -/// let v = log_normal.sample(&mut rand::thread_rng()); +/// let v = log_normal.sample(&mut rand::rng()); /// println!("{} is from an ln N(2, 9) distribution", v) /// ``` -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct LogNormal -where F: Float, StandardNormal: Distribution +where + F: Float, + StandardNormal: Distribution, { norm: Normal, } impl LogNormal -where F: Float, StandardNormal: Distribution +where + F: Float, + StandardNormal: Distribution, { /// Construct, from (log-space) mean and standard deviation /// @@ -281,7 +341,7 @@ where F: Float, StandardNormal: Distribution /// ``` /// # use rand::prelude::*; /// # use rand_distr::{LogNormal, StandardNormal}; - /// let mut rng = thread_rng(); + /// let mut rng = rand::rng(); /// let z = StandardNormal.sample(&mut rng); /// let x1 = LogNormal::from_mean_cv(3.0, 1.0).unwrap().from_zscore(z); /// let x2 = LogNormal::from_mean_cv(2.0, 4.0).unwrap().from_zscore(z); @@ -293,7 +353,9 @@ where F: Float, StandardNormal: Distribution } impl Distribution for LogNormal -where F: Float, StandardNormal: Distribution +where + F: Float, + StandardNormal: Distribution, { #[inline] fn sample(&self, rng: &mut R) -> F { @@ -334,7 +396,10 @@ mod tests { #[test] fn test_log_normal_cv() { let lnorm = LogNormal::from_mean_cv(0.0, 0.0).unwrap(); - assert_eq!((lnorm.norm.mean, lnorm.norm.std_dev), (-core::f64::INFINITY, 0.0)); + assert_eq!( + (lnorm.norm.mean, lnorm.norm.std_dev), + (f64::NEG_INFINITY, 0.0) + ); let lnorm = LogNormal::from_mean_cv(1.0, 0.0).unwrap(); assert_eq!((lnorm.norm.mean, lnorm.norm.std_dev), (0.0, 0.0)); @@ -354,4 +419,14 @@ mod tests { assert!(LogNormal::from_mean_cv(0.0, 1.0).is_err()); assert!(LogNormal::from_mean_cv(1.0, -1.0).is_err()); } + + #[test] + fn normal_distributions_can_be_compared() { + assert_eq!(Normal::new(1.0, 2.0), Normal::new(1.0, 2.0)); + } + + #[test] + fn log_normal_distributions_can_be_compared() { + assert_eq!(LogNormal::new(1.0, 2.0), LogNormal::new(1.0, 2.0)); + } } diff --git a/rand_distr/src/normal_inverse_gaussian.rs b/rand_distr/src/normal_inverse_gaussian.rs index 252a319d877..6ad2e58fe65 100644 --- a/rand_distr/src/normal_inverse_gaussian.rs +++ b/rand_distr/src/normal_inverse_gaussian.rs @@ -1,9 +1,10 @@ -use crate::{Distribution, InverseGaussian, Standard, StandardNormal}; +use crate::{Distribution, InverseGaussian, StandardNormal, StandardUniform}; +use core::fmt; use num_traits::Float; use rand::Rng; -/// Error type returned from `NormalInverseGaussian::new` -#[derive(Debug, PartialEq)] +/// Error type returned from [`NormalInverseGaussian::new`] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Error { /// `alpha <= 0` or `nan`. AlphaNegativeOrNull, @@ -11,15 +12,50 @@ pub enum Error { AbsoluteBetaNotLessThanAlpha, } -/// The [normal-inverse Gaussian distribution](https://en.wikipedia.org/wiki/Normal-inverse_Gaussian_distribution) -#[derive(Debug)] +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + Error::AlphaNegativeOrNull => { + "alpha <= 0 or is NaN in normal inverse Gaussian distribution" + } + Error::AbsoluteBetaNotLessThanAlpha => { + "|beta| >= alpha or is NaN in normal inverse Gaussian distribution" + } + }) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for Error {} + +/// The [normal-inverse Gaussian distribution](https://en.wikipedia.org/wiki/Normal-inverse_Gaussian_distribution) `NIG(α, β)`. +/// +/// This is a continuous probability distribution with two parameters, +/// `α` (`alpha`) and `β` (`beta`), defined in `(-∞, ∞)`. +/// It is also known as the normal-Wald distribution. +/// +/// # Plot +/// +/// The following plot shows the normal-inverse Gaussian distribution with various values of `α` and `β`. +/// +/// ![Normal-inverse Gaussian distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/normal_inverse_gaussian.svg) +/// +/// # Example +/// ``` +/// use rand_distr::{NormalInverseGaussian, Distribution}; +/// +/// let norm_inv_gauss = NormalInverseGaussian::new(2.0, 1.0).unwrap(); +/// let v = norm_inv_gauss.sample(&mut rand::rng()); +/// println!("{} is from a normal-inverse Gaussian(2, 1) distribution", v); +/// ``` +#[derive(Debug, Clone, Copy, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct NormalInverseGaussian where F: Float, StandardNormal: Distribution, - Standard: Distribution, + StandardUniform: Distribution, { - alpha: F, beta: F, inverse_gaussian: InverseGaussian, } @@ -28,7 +64,7 @@ impl NormalInverseGaussian where F: Float, StandardNormal: Distribution, - Standard: Distribution, + StandardUniform: Distribution, { /// Construct a new `NormalInverseGaussian` distribution with the given alpha (tail heaviness) and /// beta (asymmetry) parameters. @@ -48,7 +84,6 @@ where let inverse_gaussian = InverseGaussian::new(mu, F::one()).unwrap(); Ok(Self { - alpha, beta, inverse_gaussian, }) @@ -59,11 +94,13 @@ impl Distribution for NormalInverseGaussian where F: Float, StandardNormal: Distribution, - Standard: Distribution, + StandardUniform: Distribution, { fn sample(&self, rng: &mut R) -> F - where R: Rng + ?Sized { - let inv_gauss = rng.sample(&self.inverse_gaussian); + where + R: Rng + ?Sized, + { + let inv_gauss = rng.sample(self.inverse_gaussian); self.beta * inv_gauss + inv_gauss.sqrt() * rng.sample(StandardNormal) } @@ -89,4 +126,12 @@ mod tests { assert!(NormalInverseGaussian::new(1.0, 2.0).is_err()); assert!(NormalInverseGaussian::new(2.0, 1.0).is_ok()); } + + #[test] + fn normal_inverse_gaussian_distributions_can_be_compared() { + assert_eq!( + NormalInverseGaussian::new(1.0, 2.0), + NormalInverseGaussian::new(1.0, 2.0) + ); + } } diff --git a/rand_distr/src/pareto.rs b/rand_distr/src/pareto.rs index 3250c86ffe9..7334ccd5f15 100644 --- a/rand_distr/src/pareto.rs +++ b/rand_distr/src/pareto.rs @@ -6,32 +6,47 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! The Pareto distribution. +//! The Pareto distribution `Pareto(xₘ, α)`. -use num_traits::Float; use crate::{Distribution, OpenClosed01}; -use rand::Rng; use core::fmt; +use num_traits::Float; +use rand::Rng; -/// Samples floating-point numbers according to the Pareto distribution +/// The [Pareto distribution](https://en.wikipedia.org/wiki/Pareto_distribution) `Pareto(xₘ, α)`. +/// +/// The Pareto distribution is a continuous probability distribution with +/// scale parameter `xₘ` ( or `k`) and shape parameter `α`. +/// +/// # Plot +/// +/// The following plot shows the Pareto distribution with various values of +/// `xₘ` and `α`. +/// Note how the shape parameter `α` corresponds to the height of the jump +/// in density at `x = xₘ`, and to the rate of decay in the tail. +/// +/// ![Pareto distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/pareto.svg) /// /// # Example /// ``` /// use rand::prelude::*; /// use rand_distr::Pareto; /// -/// let val: f64 = thread_rng().sample(Pareto::new(1., 2.).unwrap()); +/// let val: f64 = rand::rng().sample(Pareto::new(1., 2.).unwrap()); /// println!("{}", val); /// ``` -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct Pareto -where F: Float, OpenClosed01: Distribution +where + F: Float, + OpenClosed01: Distribution, { scale: F, inv_neg_shape: F, } -/// Error type returned from `Pareto::new`. +/// Error type returned from [`Pareto::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Error { /// `scale <= 0` or `nan`. @@ -53,7 +68,9 @@ impl fmt::Display for Error { impl std::error::Error for Error {} impl Pareto -where F: Float, OpenClosed01: Distribution +where + F: Float, + OpenClosed01: Distribution, { /// Construct a new Pareto distribution with given `scale` and `shape`. /// @@ -76,7 +93,9 @@ where F: Float, OpenClosed01: Distribution } impl Distribution for Pareto -where F: Float, OpenClosed01: Distribution +where + F: Float, + OpenClosed01: Distribution, { fn sample(&self, rng: &mut R) -> F { let u: F = OpenClosed01.sample(rng); @@ -110,7 +129,9 @@ mod tests { #[test] fn value_stability() { fn test_samples>( - distr: D, thresh: F, expected: &[F], + distr: D, + thresh: F, + expected: &[F], ) { let mut rng = crate::test::rng(213); for v in expected { @@ -119,14 +140,25 @@ mod tests { } } - test_samples(Pareto::new(1f32, 1.0).unwrap(), 1e-6, &[ - 1.0423688, 2.1235929, 4.132709, 1.4679428, - ]); - test_samples(Pareto::new(2.0, 0.5).unwrap(), 1e-14, &[ - 9.019295276219136, - 4.3097126018270595, - 6.837815045397157, - 105.8826669383772, - ]); + test_samples( + Pareto::new(1f32, 1.0).unwrap(), + 1e-6, + &[1.0423688, 2.1235929, 4.132709, 1.4679428], + ); + test_samples( + Pareto::new(2.0, 0.5).unwrap(), + 1e-14, + &[ + 9.019295276219136, + 4.3097126018270595, + 6.837815045397157, + 105.8826669383772, + ], + ); + } + + #[test] + fn pareto_distributions_can_be_compared() { + assert_eq!(Pareto::new(1.0, 2.0), Pareto::new(1.0, 2.0)); } } diff --git a/rand_distr/src/pert.rs b/rand_distr/src/pert.rs index d6905e014bf..5c247a3d1e8 100644 --- a/rand_distr/src/pert.rs +++ b/rand_distr/src/pert.rs @@ -7,30 +7,38 @@ // except according to those terms. //! The PERT distribution. -use num_traits::Float; use crate::{Beta, Distribution, Exp1, Open01, StandardNormal}; -use rand::Rng; use core::fmt; +use num_traits::Float; +use rand::Rng; -/// The PERT distribution. +/// The [PERT distribution](https://en.wikipedia.org/wiki/PERT_distribution) `PERT(min, max, mode, shape)`. /// /// Similar to the [`Triangular`] distribution, the PERT distribution is /// parameterised by a range and a mode within that range. Unlike the /// [`Triangular`] distribution, the probability density function of the PERT /// distribution is smooth, with a configurable weighting around the mode. /// +/// # Plot +/// +/// The following plot shows the PERT distribution with `min = -1`, `max = 1`, +/// and various values of `mode` and `shape`. +/// +/// ![PERT distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/pert.svg) +/// /// # Example /// /// ```rust /// use rand_distr::{Pert, Distribution}; /// -/// let d = Pert::new(0., 5., 2.5).unwrap(); -/// let v = d.sample(&mut rand::thread_rng()); +/// let d = Pert::new(0., 5.).with_mode(2.5).unwrap(); +/// let v = d.sample(&mut rand::rng()); /// println!("{} is from a PERT distribution", v); /// ``` /// /// [`Triangular`]: crate::Triangular -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct Pert where F: Float, @@ -74,35 +82,75 @@ where Exp1: Distribution, Open01: Distribution, { - /// Set up the PERT distribution with defined `min`, `max` and `mode`. + /// Construct a PERT distribution with defined `min`, `max` + /// + /// # Example + /// + /// ``` + /// use rand_distr::Pert; + /// let pert_dist = Pert::new(0.0, 10.0) + /// .with_shape(3.5) + /// .with_mean(3.0) + /// .unwrap(); + /// # let _unused: Pert = pert_dist; + /// ``` + #[allow(clippy::new_ret_no_self)] + #[inline] + pub fn new(min: F, max: F) -> PertBuilder { + let shape = F::from(4.0).unwrap(); + PertBuilder { min, max, shape } + } +} + +/// Struct used to build a [`Pert`] +#[derive(Debug)] +pub struct PertBuilder { + min: F, + max: F, + shape: F, +} + +impl PertBuilder +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + /// Set the shape parameter /// - /// This is equivalent to calling `Pert::new_shape` with `shape == 4.0`. + /// If not specified, this defaults to 4. + #[inline] + pub fn with_shape(mut self, shape: F) -> PertBuilder { + self.shape = shape; + self + } + + /// Specify the mean #[inline] - pub fn new(min: F, max: F, mode: F) -> Result, PertError> { - Pert::new_with_shape(min, max, mode, F::from(4.).unwrap()) + pub fn with_mean(self, mean: F) -> Result, PertError> { + let two = F::from(2.0).unwrap(); + let mode = ((self.shape + two) * mean - self.min - self.max) / self.shape; + self.with_mode(mode) } - /// Set up the PERT distribution with defined `min`, `max`, `mode` and - /// `shape`. - pub fn new_with_shape(min: F, max: F, mode: F, shape: F) -> Result, PertError> { - if !(max > min) { + /// Specify the mode + #[inline] + pub fn with_mode(self, mode: F) -> Result, PertError> { + if !(self.max > self.min) { return Err(PertError::RangeTooSmall); } - if !(mode >= min && max >= mode) { + if !(mode >= self.min && self.max >= mode) { return Err(PertError::ModeRange); } - if !(shape >= F::from(0.).unwrap()) { + if !(self.shape >= F::from(0.).unwrap()) { return Err(PertError::ShapeTooSmall); } + let (min, max, shape) = (self.min, self.max, self.shape); let range = max - min; - let mu = (min + max + shape * mode) / (shape + F::from(2.).unwrap()); - let v = if mu == mode { - shape * F::from(0.5).unwrap() + F::from(1.).unwrap() - } else { - (mu - min) * (F::from(2.).unwrap() * mode - min - max) / ((mode - mu) * (max - min)) - }; - let w = v * (max - mu) / (mu - min); + let v = F::from(1.0).unwrap() + shape * (mode - min) / range; + let w = F::from(1.0).unwrap() + shape * (max - mode) / range; let beta = Beta::new(v, w).map_err(|_| PertError::RangeTooSmall)?; Ok(Pert { min, range, beta }) } @@ -127,21 +175,39 @@ mod test { #[test] fn test_pert() { - for &(min, max, mode) in &[ - (-1., 1., 0.), - (1., 2., 1.), - (5., 25., 25.), - ] { - let _distr = Pert::new(min, max, mode).unwrap(); + for &(min, max, mode) in &[(-1., 1., 0.), (1., 2., 1.), (5., 25., 25.)] { + let _distr = Pert::new(min, max).with_mode(mode).unwrap(); // TODO: test correctness } - for &(min, max, mode) in &[ - (-1., 1., 2.), - (-1., 1., -2.), - (2., 1., 1.), - ] { - assert!(Pert::new(min, max, mode).is_err()); + for &(min, max, mode) in &[(-1., 1., 2.), (-1., 1., -2.), (2., 1., 1.)] { + assert!(Pert::new(min, max).with_mode(mode).is_err()); } } + + #[test] + fn distributions_can_be_compared() { + let (min, mode, max, shape) = (1.0, 2.0, 3.0, 4.0); + let p1 = Pert::new(min, max).with_mode(mode).unwrap(); + let mean = (min + shape * mode + max) / (shape + 2.0); + let p2 = Pert::new(min, max).with_mean(mean).unwrap(); + assert_eq!(p1, p2); + } + + #[test] + fn mode_almost_half_range() { + assert!(Pert::new(0.0f32, 0.48258883).with_mode(0.24129441).is_ok()); + } + + #[test] + fn almost_symmetric_about_zero() { + let distr = Pert::new(-10f32, 10f32).with_mode(f32::EPSILON); + assert!(distr.is_ok()); + } + + #[test] + fn almost_symmetric() { + let distr = Pert::new(0f32, 2f32).with_mode(1f32 + f32::EPSILON); + assert!(distr.is_ok()); + } } diff --git a/rand_distr/src/poisson.rs b/rand_distr/src/poisson.rs index a190256e15e..3e4421259bd 100644 --- a/rand_distr/src/poisson.rs +++ b/rand_distr/src/poisson.rs @@ -7,17 +7,28 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! The Poisson distribution. +//! The Poisson distribution `Poisson(λ)`. +use crate::{Cauchy, Distribution, StandardUniform}; +use core::fmt; use num_traits::{Float, FloatConst}; -use crate::{Cauchy, Distribution, Standard}; use rand::Rng; -use core::fmt; -/// The Poisson distribution `Poisson(lambda)`. +/// The [Poisson distribution](https://en.wikipedia.org/wiki/Poisson_distribution) `Poisson(λ)`. +/// +/// The Poisson distribution is a discrete probability distribution with +/// rate parameter `λ` (`lambda`). It models the number of events occurring in a fixed +/// interval of time or space. +/// +/// This distribution has density function: +/// `f(k) = λ^k * exp(-λ) / k!` for `k >= 0`. +/// +/// # Plot +/// +/// The following plot shows the Poisson distribution with various values of `λ`. +/// Note how the expected number of events increases with `λ`. /// -/// This distribution has a density function: -/// `f(k) = lambda^k * exp(-lambda) / k!` for `k >= 0`. +/// ![Poisson distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/poisson.svg) /// /// # Example /// @@ -25,32 +36,46 @@ use core::fmt; /// use rand_distr::{Poisson, Distribution}; /// /// let poi = Poisson::new(2.0).unwrap(); -/// let v = poi.sample(&mut rand::thread_rng()); +/// let v: f64 = poi.sample(&mut rand::rng()); /// println!("{} is from a Poisson(2) distribution", v); /// ``` -#[derive(Clone, Copy, Debug)] -pub struct Poisson -where F: Float + FloatConst, Standard: Distribution -{ - lambda: F, - // precalculated values - exp_lambda: F, - log_lambda: F, - sqrt_2lambda: F, - magic_val: F, -} +/// +/// # Integer vs FP return type +/// +/// This implementation uses floating-point (FP) logic internally. +/// +/// Due to the parameter limit λ < [Self::MAX_LAMBDA], it +/// statistically impossible to sample a value larger [`u64::MAX`]. As such, it +/// is reasonable to cast generated samples to `u64` using `as`: +/// `distr.sample(&mut rng) as u64` (and memory safe since Rust 1.45). +/// Similarly, when `λ < 4.2e9` it can be safely assumed that samples are less +/// than `u32::MAX`. +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct Poisson(Method) +where + F: Float + FloatConst, + StandardUniform: Distribution; -/// Error type returned from `Poisson::new`. +/// Error type returned from [`Poisson::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Error { - /// `lambda <= 0` or `nan`. + /// `lambda <= 0` ShapeTooSmall, + /// `lambda = ∞` or `lambda = nan` + NonFinite, + /// `lambda` is too large, see [Poisson::MAX_LAMBDA] + ShapeTooLarge, } impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str(match self { Error::ShapeTooSmall => "lambda is not positive in Poisson distribution", + Error::NonFinite => "lambda is infinite or nan in Poisson distribution", + Error::ShapeTooLarge => { + "lambda is too large in Poisson distribution, see Poisson::MAX_LAMBDA" + } }) } } @@ -58,84 +83,168 @@ impl fmt::Display for Error { #[cfg(feature = "std")] impl std::error::Error for Error {} +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub(crate) struct KnuthMethod { + exp_lambda: F, +} + +impl KnuthMethod { + pub(crate) fn new(lambda: F) -> Self { + KnuthMethod { + exp_lambda: (-lambda).exp(), + } + } +} + +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +struct RejectionMethod { + lambda: F, + log_lambda: F, + sqrt_2lambda: F, + magic_val: F, +} + +impl RejectionMethod { + pub(crate) fn new(lambda: F) -> Self { + let log_lambda = lambda.ln(); + let sqrt_2lambda = (F::from(2.0).unwrap() * lambda).sqrt(); + let magic_val = lambda * log_lambda - crate::utils::log_gamma(F::one() + lambda); + RejectionMethod { + lambda, + log_lambda, + sqrt_2lambda, + magic_val, + } + } +} + +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +enum Method { + Knuth(KnuthMethod), + Rejection(RejectionMethod), +} + impl Poisson -where F: Float + FloatConst, Standard: Distribution +where + F: Float + FloatConst, + StandardUniform: Distribution, { /// Construct a new `Poisson` with the given shape parameter /// `lambda`. + /// + /// The maximum allowed lambda is [MAX_LAMBDA](Self::MAX_LAMBDA). pub fn new(lambda: F) -> Result, Error> { + if !lambda.is_finite() { + return Err(Error::NonFinite); + } if !(lambda > F::zero()) { return Err(Error::ShapeTooSmall); } - let log_lambda = lambda.ln(); - Ok(Poisson { - lambda, - exp_lambda: (-lambda).exp(), - log_lambda, - sqrt_2lambda: (F::from(2.0).unwrap() * lambda).sqrt(), - magic_val: lambda * log_lambda - crate::utils::log_gamma(F::one() + lambda), - }) + + // Use the Knuth method only for low expected values + let method = if lambda < F::from(12.0).unwrap() { + Method::Knuth(KnuthMethod::new(lambda)) + } else { + if lambda > F::from(Self::MAX_LAMBDA).unwrap() { + return Err(Error::ShapeTooLarge); + } + Method::Rejection(RejectionMethod::new(lambda)) + }; + + Ok(Poisson(method)) } + + /// The maximum supported value of `lambda` + /// + /// This value was selected such that + /// `MAX_LAMBDA + 1e6 * sqrt(MAX_LAMBDA) < 2^64 - 1`, + /// thus ensuring that the probability of sampling a value larger than + /// `u64::MAX` is less than 1e-1000. + /// + /// Applying this limit also solves + /// [#1312](https://github.com/rust-random/rand/issues/1312). + pub const MAX_LAMBDA: f64 = 1.844e19; } -impl Distribution for Poisson -where F: Float + FloatConst, Standard: Distribution +impl Distribution for KnuthMethod +where + F: Float + FloatConst, + StandardUniform: Distribution, { - #[inline] fn sample(&self, rng: &mut R) -> F { - // using the algorithm from Numerical Recipes in C - - // for low expected values use the Knuth method - if self.lambda < F::from(12.0).unwrap() { - let mut result = F::zero(); - let mut p = F::one(); - while p > self.exp_lambda { - p = p*rng.gen::(); - result = result + F::one(); - } - result - F::one() + let mut result = F::one(); + let mut p = rng.random::(); + while p > self.exp_lambda { + p = p * rng.random::(); + result = result + F::one(); } - // high expected values - rejection method - else { - // we use the Cauchy distribution as the comparison distribution - // f(x) ~ 1/(1+x^2) - let cauchy = Cauchy::new(F::zero(), F::one()).unwrap(); - let mut result; + result - F::one() + } +} + +impl Distribution for RejectionMethod +where + F: Float + FloatConst, + StandardUniform: Distribution, +{ + fn sample(&self, rng: &mut R) -> F { + // The algorithm from Numerical Recipes in C + + // we use the Cauchy distribution as the comparison distribution + // f(x) ~ 1/(1+x^2) + let cauchy = Cauchy::new(F::zero(), F::one()).unwrap(); + let mut result; + + loop { + let mut comp_dev; loop { - let mut comp_dev; - - loop { - // draw from the Cauchy distribution - comp_dev = rng.sample(cauchy); - // shift the peak of the comparison ditribution - result = self.sqrt_2lambda * comp_dev + self.lambda; - // repeat the drawing until we are in the range of possible values - if result >= F::zero() { - break; - } - } - // now the result is a random variable greater than 0 with Cauchy distribution - // the result should be an integer value - result = result.floor(); - - // this is the ratio of the Poisson distribution to the comparison distribution - // the magic value scales the distribution function to a range of approximately 0-1 - // since it is not exact, we multiply the ratio by 0.9 to avoid ratios greater than 1 - // this doesn't change the resulting distribution, only increases the rate of failed drawings - let check = F::from(0.9).unwrap() - * (F::one() + comp_dev * comp_dev) - * (result * self.log_lambda - - crate::utils::log_gamma(F::one() + result) - - self.magic_val) - .exp(); - - // check with uniform random value - if below the threshold, we are within the target distribution - if rng.gen::() <= check { + // draw from the Cauchy distribution + comp_dev = rng.sample(cauchy); + // shift the peak of the comparison distribution + result = self.sqrt_2lambda * comp_dev + self.lambda; + // repeat the drawing until we are in the range of possible values + if result >= F::zero() { break; } } - result + // now the result is a random variable greater than 0 with Cauchy distribution + // the result should be an integer value + result = result.floor(); + + // this is the ratio of the Poisson distribution to the comparison distribution + // the magic value scales the distribution function to a range of approximately 0-1 + // since it is not exact, we multiply the ratio by 0.9 to avoid ratios greater than 1 + // this doesn't change the resulting distribution, only increases the rate of failed drawings + let check = F::from(0.9).unwrap() + * (F::one() + comp_dev * comp_dev) + * (result * self.log_lambda + - crate::utils::log_gamma(F::one() + result) + - self.magic_val) + .exp(); + + // check with uniform random value - if below the threshold, we are within the target distribution + if rng.random::() <= check { + break; + } + } + result + } +} + +impl Distribution for Poisson +where + F: Float + FloatConst, + StandardUniform: Distribution, +{ + #[inline] + fn sample(&self, rng: &mut R) -> F { + match &self.0 { + Method::Knuth(method) => method.sample(rng), + Method::Rejection(method) => method.sample(rng), } } } @@ -145,7 +254,8 @@ mod test { use super::*; fn test_poisson_avg_gen(lambda: F, tol: F) - where Standard: Distribution + where + StandardUniform: Distribution, { let poisson = Poisson::new(lambda).unwrap(); let mut rng = crate::test::rng(123); @@ -159,10 +269,15 @@ mod test { #[test] fn test_poisson_avg() { - test_poisson_avg_gen::(10.0, 0.5); - test_poisson_avg_gen::(15.0, 0.5); - test_poisson_avg_gen::(10.0, 0.5); - test_poisson_avg_gen::(15.0, 0.5); + test_poisson_avg_gen::(10.0, 0.1); + test_poisson_avg_gen::(15.0, 0.1); + + test_poisson_avg_gen::(10.0, 0.1); + test_poisson_avg_gen::(15.0, 0.1); + + // Small lambda will use Knuth's method with exp_lambda == 1.0 + test_poisson_avg_gen::(0.00000000000000005, 0.1); + test_poisson_avg_gen::(0.00000000000000005, 0.1); } #[test] @@ -171,9 +286,20 @@ mod test { Poisson::new(0.0).unwrap(); } + #[test] + #[should_panic] + fn test_poisson_invalid_lambda_infinity() { + Poisson::new(f64::INFINITY).unwrap(); + } + #[test] #[should_panic] fn test_poisson_invalid_lambda_neg() { Poisson::new(-10.0).unwrap(); } + + #[test] + fn poisson_distributions_can_be_compared() { + assert_eq!(Poisson::new(1.0), Poisson::new(1.0)); + } } diff --git a/rand_distr/src/skew_normal.rs b/rand_distr/src/skew_normal.rs new file mode 100644 index 00000000000..1be2311a6b5 --- /dev/null +++ b/rand_distr/src/skew_normal.rs @@ -0,0 +1,272 @@ +// Copyright 2021 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! The Skew Normal distribution `SN(ξ, ω, α)`. + +use crate::{Distribution, StandardNormal}; +use core::fmt; +use num_traits::Float; +use rand::Rng; + +/// The [skew normal distribution](https://en.wikipedia.org/wiki/Skew_normal_distribution) `SN(ξ, ω, α)`. +/// +/// The skew normal distribution is a generalization of the +/// [`Normal`](crate::Normal) distribution to allow for non-zero skewness. +/// It has location parameter `ξ` (`xi`), scale parameter `ω` (`omega`), +/// and shape parameter `α` (`alpha`). +/// +/// The `ξ` and `ω` parameters correspond to the mean `μ` and standard +/// deviation `σ` of the normal distribution, respectively. +/// The `α` parameter controls the skewness. +/// +/// # Density function +/// +/// It has the density function, for `scale > 0`, +/// `f(x) = 2 / scale * phi((x - location) / scale) * Phi(alpha * (x - location) / scale)` +/// where `phi` and `Phi` are the density and distribution of a standard normal variable. +/// +/// # Plot +/// +/// The following plot shows the skew normal distribution with `location = 0`, `scale = 1` +/// (corresponding to the [`standard normal distribution`](crate::StandardNormal)), and +/// various values of `shape`. +/// +/// ![Skew normal distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/skew_normal.svg) +/// +/// # Example +/// +/// ``` +/// use rand_distr::{SkewNormal, Distribution}; +/// +/// // location 2, scale 3, shape 1 +/// let skew_normal = SkewNormal::new(2.0, 3.0, 1.0).unwrap(); +/// let v = skew_normal.sample(&mut rand::rng()); +/// println!("{} is from a SN(2, 3, 1) distribution", v) +/// ``` +/// +/// # Implementation details +/// +/// We are using the algorithm from [A Method to Simulate the Skew Normal Distribution]. +/// +/// [skew normal distribution]: https://en.wikipedia.org/wiki/Skew_normal_distribution +/// [`Normal`]: struct.Normal.html +/// [A Method to Simulate the Skew Normal Distribution]: https://dx.doi.org/10.4236/am.2014.513201 +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct SkewNormal +where + F: Float, + StandardNormal: Distribution, +{ + location: F, + scale: F, + shape: F, +} + +/// Error type returned from [`SkewNormal::new`]. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Error { + /// The scale parameter is not finite or it is less or equal to zero. + ScaleTooSmall, + /// The shape parameter is not finite. + BadShape, +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + Error::ScaleTooSmall => { + "scale parameter is either non-finite or it is less or equal to zero in skew normal distribution" + } + Error::BadShape => "shape parameter is non-finite in skew normal distribution", + }) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for Error {} + +impl SkewNormal +where + F: Float, + StandardNormal: Distribution, +{ + /// Construct, from location, scale and shape. + /// + /// Parameters: + /// + /// - location (unrestricted) + /// - scale (must be finite and larger than zero) + /// - shape (must be finite) + #[inline] + pub fn new(location: F, scale: F, shape: F) -> Result, Error> { + if !scale.is_finite() || !(scale > F::zero()) { + return Err(Error::ScaleTooSmall); + } + if !shape.is_finite() { + return Err(Error::BadShape); + } + Ok(SkewNormal { + location, + scale, + shape, + }) + } + + /// Returns the location of the distribution. + pub fn location(&self) -> F { + self.location + } + + /// Returns the scale of the distribution. + pub fn scale(&self) -> F { + self.scale + } + + /// Returns the shape of the distribution. + pub fn shape(&self) -> F { + self.shape + } +} + +impl Distribution for SkewNormal +where + F: Float, + StandardNormal: Distribution, +{ + fn sample(&self, rng: &mut R) -> F { + let linear_map = |x: F| -> F { x * self.scale + self.location }; + let u_1: F = rng.sample(StandardNormal); + if self.shape == F::zero() { + linear_map(u_1) + } else { + let u_2 = rng.sample(StandardNormal); + let (u, v) = (u_1.max(u_2), u_1.min(u_2)); + if self.shape == -F::one() { + linear_map(v) + } else if self.shape == F::one() { + linear_map(u) + } else { + let normalized = ((F::one() + self.shape) * u + (F::one() - self.shape) * v) + / ((F::one() + self.shape * self.shape).sqrt() + * F::from(core::f64::consts::SQRT_2).unwrap()); + linear_map(normalized) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_samples>(distr: D, zero: F, expected: &[F]) { + let mut rng = crate::test::rng(213); + let mut buf = [zero; 4]; + for x in &mut buf { + *x = rng.sample(&distr); + } + assert_eq!(buf, expected); + } + + #[test] + #[should_panic] + fn invalid_scale_nan() { + SkewNormal::new(0.0, f64::NAN, 0.0).unwrap(); + } + + #[test] + #[should_panic] + fn invalid_scale_zero() { + SkewNormal::new(0.0, 0.0, 0.0).unwrap(); + } + + #[test] + #[should_panic] + fn invalid_scale_negative() { + SkewNormal::new(0.0, -1.0, 0.0).unwrap(); + } + + #[test] + #[should_panic] + fn invalid_scale_infinite() { + SkewNormal::new(0.0, f64::INFINITY, 0.0).unwrap(); + } + + #[test] + #[should_panic] + fn invalid_shape_nan() { + SkewNormal::new(0.0, 1.0, f64::NAN).unwrap(); + } + + #[test] + #[should_panic] + fn invalid_shape_infinite() { + SkewNormal::new(0.0, 1.0, f64::INFINITY).unwrap(); + } + + #[test] + fn valid_location_nan() { + SkewNormal::new(f64::NAN, 1.0, 0.0).unwrap(); + } + + #[test] + fn skew_normal_value_stability() { + test_samples( + SkewNormal::new(0.0, 1.0, 0.0).unwrap(), + 0f32, + &[-0.11844189, 0.781378, 0.06563994, -1.1932899], + ); + test_samples( + SkewNormal::new(0.0, 1.0, 0.0).unwrap(), + 0f64, + &[ + -0.11844188827977231, + 0.7813779637772346, + 0.06563993969580051, + -1.1932899004186373, + ], + ); + test_samples( + SkewNormal::new(f64::INFINITY, 1.0, 0.0).unwrap(), + 0f64, + &[f64::INFINITY, f64::INFINITY, f64::INFINITY, f64::INFINITY], + ); + test_samples( + SkewNormal::new(f64::NEG_INFINITY, 1.0, 0.0).unwrap(), + 0f64, + &[ + f64::NEG_INFINITY, + f64::NEG_INFINITY, + f64::NEG_INFINITY, + f64::NEG_INFINITY, + ], + ); + } + + #[test] + fn skew_normal_value_location_nan() { + let skew_normal = SkewNormal::new(f64::NAN, 1.0, 0.0).unwrap(); + let mut rng = crate::test::rng(213); + let mut buf = [0.0; 4]; + for x in &mut buf { + *x = rng.sample(skew_normal); + } + for value in buf.iter() { + assert!(value.is_nan()); + } + } + + #[test] + fn skew_normal_distributions_can_be_compared() { + assert_eq!( + SkewNormal::new(1.0, 2.0, 3.0), + SkewNormal::new(1.0, 2.0, 3.0) + ); + } +} diff --git a/rand_distr/src/student_t.rs b/rand_distr/src/student_t.rs new file mode 100644 index 00000000000..b0d7d078ae2 --- /dev/null +++ b/rand_distr/src/student_t.rs @@ -0,0 +1,107 @@ +// Copyright 2018 Developers of the Rand project. +// Copyright 2013 The Rust Project Developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! The Student's t-distribution. + +use crate::{ChiSquared, ChiSquaredError}; +use crate::{Distribution, Exp1, Open01, StandardNormal}; +use num_traits::Float; +use rand::Rng; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +/// The [Student t-distribution](https://en.wikipedia.org/wiki/Student%27s_t-distribution) `t(ν)`. +/// +/// The t-distribution is a continuous probability distribution +/// parameterized by degrees of freedom `ν` (`nu`), which +/// arises when estimating the mean of a normally-distributed +/// population in situations where the sample size is small and +/// the population's standard deviation is unknown. +/// It is widely used in hypothesis testing. +/// +/// For `ν = 1`, this is equivalent to the standard +/// [`Cauchy`](crate::Cauchy) distribution, +/// and as `ν` diverges to infinity, `t(ν)` converges to +/// [`StandardNormal`](crate::StandardNormal). +/// +/// # Plot +/// +/// The plot shows the t-distribution with various degrees of freedom. +/// +/// ![T-distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/student_t.svg) +/// +/// # Example +/// +/// ``` +/// use rand_distr::{StudentT, Distribution}; +/// +/// let t = StudentT::new(11.0).unwrap(); +/// let v = t.sample(&mut rand::rng()); +/// println!("{} is from a t(11) distribution", v) +/// ``` +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct StudentT +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + chi: ChiSquared, + dof: F, +} + +impl StudentT +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + /// Create a new Student t-distribution with `ν` (nu) + /// degrees of freedom. + pub fn new(nu: F) -> Result, ChiSquaredError> { + Ok(StudentT { + chi: ChiSquared::new(nu)?, + dof: nu, + }) + } +} +impl Distribution for StudentT +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + fn sample(&self, rng: &mut R) -> F { + let norm: F = rng.sample(StandardNormal); + norm * (self.dof / self.chi.sample(rng)).sqrt() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_t() { + let t = StudentT::new(11.0).unwrap(); + let mut rng = crate::test::rng(205); + for _ in 0..1000 { + t.sample(&mut rng); + } + } + + #[test] + fn student_t_distributions_can_be_compared() { + assert_eq!(StudentT::new(1.0), StudentT::new(1.0)); + } +} diff --git a/rand_distr/src/triangular.rs b/rand_distr/src/triangular.rs index 6d3d4cfd03f..05a46e57ecf 100644 --- a/rand_distr/src/triangular.rs +++ b/rand_distr/src/triangular.rs @@ -7,12 +7,12 @@ // except according to those terms. //! The triangular distribution. +use crate::{Distribution, StandardUniform}; +use core::fmt; use num_traits::Float; -use crate::{Distribution, Standard}; use rand::Rng; -use core::fmt; -/// The triangular distribution. +/// The [triangular distribution](https://en.wikipedia.org/wiki/Triangular_distribution) `Triangular(min, max, mode)`. /// /// A continuous probability distribution parameterised by a range, and a mode /// (most likely value) within that range. @@ -20,20 +20,30 @@ use core::fmt; /// The probability density function is triangular. For a similar distribution /// with a smooth PDF, see the [`Pert`] distribution. /// +/// # Plot +/// +/// The following plot shows the triangular distribution with various values of +/// `min`, `max`, and `mode`. +/// +/// ![Triangular distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/triangular.svg) +/// /// # Example /// /// ```rust /// use rand_distr::{Triangular, Distribution}; /// /// let d = Triangular::new(0., 5., 2.5).unwrap(); -/// let v = d.sample(&mut rand::thread_rng()); +/// let v = d.sample(&mut rand::rng()); /// println!("{} is from a triangular distribution", v); /// ``` /// /// [`Pert`]: crate::Pert -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct Triangular -where F: Float, Standard: Distribution +where + F: Float, + StandardUniform: Distribution, { min: F, max: F, @@ -64,7 +74,9 @@ impl fmt::Display for TriangularError { impl std::error::Error for TriangularError {} impl Triangular -where F: Float, Standard: Distribution +where + F: Float, + StandardUniform: Distribution, { /// Set up the Triangular distribution with defined `min`, `max` and `mode`. #[inline] @@ -80,11 +92,13 @@ where F: Float, Standard: Distribution } impl Distribution for Triangular -where F: Float, Standard: Distribution +where + F: Float, + StandardUniform: Distribution, { #[inline] fn sample(&self, rng: &mut R) -> F { - let f: F = rng.sample(Standard); + let f: F = rng.sample(StandardUniform); let diff_mode_min = self.mode - self.min; let range = self.max - self.min; let f_range = f * range; @@ -104,7 +118,7 @@ mod test { #[test] fn test_triangular() { let mut half_rng = mock::StepRng::new(0x8000_0000_0000_0000, 0); - assert_eq!(half_rng.gen::(), 0.5); + assert_eq!(half_rng.random::(), 0.5); for &(min, max, mode, median) in &[ (-1., 1., 0., 0.), (1., 2., 1., 2. - 0.5f64.sqrt()), @@ -120,12 +134,16 @@ mod test { assert_eq!(distr.sample(&mut half_rng), median); } - for &(min, max, mode) in &[ - (-1., 1., 2.), - (-1., 1., -2.), - (2., 1., 1.), - ] { + for &(min, max, mode) in &[(-1., 1., 2.), (-1., 1., -2.), (2., 1., 1.)] { assert!(Triangular::new(min, max, mode).is_err()); } } + + #[test] + fn triangular_distributions_can_be_compared() { + assert_eq!( + Triangular::new(1.0, 3.0, 2.0), + Triangular::new(1.0, 3.0, 2.0) + ); + } } diff --git a/rand_distr/src/unit_ball.rs b/rand_distr/src/unit_ball.rs index e5585a1e677..514fc30812a 100644 --- a/rand_distr/src/unit_ball.rs +++ b/rand_distr/src/unit_ball.rs @@ -6,31 +6,43 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use num_traits::Float; use crate::{uniform::SampleUniform, Distribution, Uniform}; +use num_traits::Float; use rand::Rng; -/// Samples uniformly from the unit ball (surface and interior) in three -/// dimensions. +/// Samples uniformly from the volume of the unit ball in three dimensions. /// /// Implemented via rejection sampling. /// +/// For a distribution that samples only from the surface of the unit ball, +/// see [`UnitSphere`](crate::UnitSphere). +/// +/// For a similar distribution in two dimensions, see [`UnitDisc`](crate::UnitDisc). +/// +/// # Plot +/// +/// The following plot shows the unit ball in three dimensions. +/// This distribution samples individual points from the entire volume +/// of the ball. +/// +/// ![Unit ball](https://raw.githubusercontent.com/rust-random/charts/main/charts/unit_ball.svg) /// /// # Example /// /// ``` /// use rand_distr::{UnitBall, Distribution}; /// -/// let v: [f64; 3] = UnitBall.sample(&mut rand::thread_rng()); +/// let v: [f64; 3] = UnitBall.sample(&mut rand::rng()); /// println!("{:?} is from the unit ball.", v) /// ``` #[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct UnitBall; impl Distribution<[F; 3]> for UnitBall { #[inline] fn sample(&self, rng: &mut R) -> [F; 3] { - let uniform = Uniform::new(F::from(-1.).unwrap(), F::from(1.).unwrap()); + let uniform = Uniform::new(F::from(-1.).unwrap(), F::from(1.).unwrap()).unwrap(); let mut x1; let mut x2; let mut x3; diff --git a/rand_distr/src/unit_circle.rs b/rand_distr/src/unit_circle.rs index 29e5c9a5939..d25d829f5a5 100644 --- a/rand_distr/src/unit_circle.rs +++ b/rand_distr/src/unit_circle.rs @@ -6,21 +6,31 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use num_traits::Float; use crate::{uniform::SampleUniform, Distribution, Uniform}; +use num_traits::Float; use rand::Rng; -/// Samples uniformly from the edge of the unit circle in two dimensions. +/// Samples uniformly from the circumference of the unit circle in two dimensions. /// /// Implemented via a method by von Neumann[^1]. /// +/// For a distribution that also samples from the interior of the unit circle, +/// see [`UnitDisc`](crate::UnitDisc). +/// +/// For a similar distribution in three dimensions, see [`UnitSphere`](crate::UnitSphere). +/// +/// # Plot +/// +/// The following plot shows the unit circle. +/// +/// ![Unit circle](https://raw.githubusercontent.com/rust-random/charts/main/charts/unit_circle.svg) /// /// # Example /// /// ``` /// use rand_distr::{UnitCircle, Distribution}; /// -/// let v: [f64; 2] = UnitCircle.sample(&mut rand::thread_rng()); +/// let v: [f64; 2] = UnitCircle.sample(&mut rand::rng()); /// println!("{:?} is from the unit circle.", v) /// ``` /// @@ -29,12 +39,13 @@ use rand::Rng; /// NBS Appl. Math. Ser., No. 12. Washington, DC: U.S. Government Printing /// Office, pp. 36-38. #[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct UnitCircle; impl Distribution<[F; 2]> for UnitCircle { #[inline] fn sample(&self, rng: &mut R) -> [F; 2] { - let uniform = Uniform::new(F::from(-1.).unwrap(), F::from(1.).unwrap()); + let uniform = Uniform::new(F::from(-1.).unwrap(), F::from(1.).unwrap()).unwrap(); let mut x1; let mut x2; let mut sum; diff --git a/rand_distr/src/unit_disc.rs b/rand_distr/src/unit_disc.rs index ced548b4dc0..c95fd1d6c83 100644 --- a/rand_distr/src/unit_disc.rs +++ b/rand_distr/src/unit_disc.rs @@ -6,30 +6,42 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use num_traits::Float; use crate::{uniform::SampleUniform, Distribution, Uniform}; +use num_traits::Float; use rand::Rng; /// Samples uniformly from the unit disc in two dimensions. /// /// Implemented via rejection sampling. /// +/// For a distribution that samples only from the circumference of the unit disc, +/// see [`UnitCircle`](crate::UnitCircle). +/// +/// For a similar distribution in three dimensions, see [`UnitBall`](crate::UnitBall). +/// +/// # Plot +/// +/// The following plot shows the unit disc. +/// This distribution samples individual points from the entire area of the disc. +/// +/// ![Unit disc](https://raw.githubusercontent.com/rust-random/charts/main/charts/unit_disc.svg) /// /// # Example /// /// ``` /// use rand_distr::{UnitDisc, Distribution}; /// -/// let v: [f64; 2] = UnitDisc.sample(&mut rand::thread_rng()); +/// let v: [f64; 2] = UnitDisc.sample(&mut rand::rng()); /// println!("{:?} is from the unit Disc.", v) /// ``` #[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct UnitDisc; impl Distribution<[F; 2]> for UnitDisc { #[inline] fn sample(&self, rng: &mut R) -> [F; 2] { - let uniform = Uniform::new(F::from(-1.).unwrap(), F::from(1.).unwrap()); + let uniform = Uniform::new(F::from(-1.).unwrap(), F::from(1.).unwrap()).unwrap(); let mut x1; let mut x2; loop { diff --git a/rand_distr/src/unit_sphere.rs b/rand_distr/src/unit_sphere.rs index b167a5d5d63..1d531924efb 100644 --- a/rand_distr/src/unit_sphere.rs +++ b/rand_distr/src/unit_sphere.rs @@ -6,21 +6,33 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use num_traits::Float; use crate::{uniform::SampleUniform, Distribution, Uniform}; +use num_traits::Float; use rand::Rng; /// Samples uniformly from the surface of the unit sphere in three dimensions. /// /// Implemented via a method by Marsaglia[^1]. /// +/// For a distribution that also samples from the interior of the sphere, +/// see [`UnitBall`](crate::UnitBall). +/// +/// For a similar distribution in two dimensions, see [`UnitCircle`](crate::UnitCircle). +/// +/// # Plot +/// +/// The following plot shows the unit sphere as a wireframe. +/// The wireframe is meant to illustrate that this distribution samples +/// from the surface of the sphere only, not from the interior. +/// +/// ![Unit sphere](https://raw.githubusercontent.com/rust-random/charts/main/charts/unit_sphere.svg) /// /// # Example /// /// ``` /// use rand_distr::{UnitSphere, Distribution}; /// -/// let v: [f64; 3] = UnitSphere.sample(&mut rand::thread_rng()); +/// let v: [f64; 3] = UnitSphere.sample(&mut rand::rng()); /// println!("{:?} is from the unit sphere surface.", v) /// ``` /// @@ -28,12 +40,13 @@ use rand::Rng; /// Sphere.*](https://doi.org/10.1214/aoms/1177692644) /// Ann. Math. Statist. 43, no. 2, 645--646. #[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct UnitSphere; impl Distribution<[F; 3]> for UnitSphere { #[inline] fn sample(&self, rng: &mut R) -> [F; 3] { - let uniform = Uniform::new(F::from(-1.).unwrap(), F::from(1.).unwrap()); + let uniform = Uniform::new(F::from(-1.).unwrap(), F::from(1.).unwrap()).unwrap(); loop { let (x1, x2) = (uniform.sample(rng), uniform.sample(rng)); let sum = x1 * x1 + x2 * x2; @@ -41,7 +54,11 @@ impl Distribution<[F; 3]> for UnitSphere { continue; } let factor = F::from(2.).unwrap() * (F::one() - sum).sqrt(); - return [x1 * factor, x2 * factor, F::from(1.).unwrap() - F::from(2.).unwrap() * sum]; + return [ + x1 * factor, + x2 * factor, + F::from(1.).unwrap() - F::from(2.).unwrap() * sum, + ]; } } } diff --git a/rand_distr/src/utils.rs b/rand_distr/src/utils.rs index 878faf2072b..f0cf2a1005a 100644 --- a/rand_distr/src/utils.rs +++ b/rand_distr/src/utils.rs @@ -9,7 +9,8 @@ //! Math helper functions use crate::ziggurat_tables; -use rand::distributions::hidden_export::IntoFloat; +use num_traits::Float; +use rand::distr::hidden_export::IntoFloat; use rand::Rng; /// Calculates ln(gamma(x)) (natural logarithm of the gamma @@ -25,7 +26,7 @@ use rand::Rng; /// `Ag(z)` is an infinite series with coefficients that can be calculated /// ahead of time - we use just the first 6 terms, which is good enough /// for most purposes. -pub(crate) fn log_gamma(x: F) -> F { +pub(crate) fn log_gamma(x: F) -> F { // precalculated 6 coefficients for the first 6 terms of the series let coefficients: [F; 6] = [ F::from(76.18009172947146).unwrap(), @@ -66,17 +67,14 @@ pub(crate) fn log_gamma(x: F) -> F { /// * `pdf`: the probability density function /// * `zero_case`: manual sampling from the tail when we chose the /// bottom box (i.e. i == 0) - -// the perf improvement (25-50%) is definitely worth the extra code -// size from force-inlining. -#[inline(always)] +#[inline(always)] // Forced inlining improves the perf by 25-50% pub(crate) fn ziggurat( rng: &mut R, symmetric: bool, x_tab: ziggurat_tables::ZigTable, f_tab: ziggurat_tables::ZigTable, mut pdf: P, - mut zero_case: Z + mut zero_case: Z, ) -> f64 where P: FnMut(f64) -> f64, @@ -91,15 +89,15 @@ where let i = bits as usize & 0xff; let u = if symmetric { - // Convert to a value in the range [2,4) and substract to get [-1,1) + // Convert to a value in the range [2,4) and subtract to get [-1,1) // We can't convert to an open range directly, that would require - // substracting `3.0 - EPSILON`, which is not representable. + // subtracting `3.0 - EPSILON`, which is not representable. // It is possible with an extra step, but an open range does not - // seem neccesary for the ziggurat algorithm anyway. + // seem necessary for the ziggurat algorithm anyway. (bits >> 12).into_float_with_exponent(1) - 3.0 } else { - // Convert to a value in the range [1,2) and substract to get (0,1) - (bits >> 12).into_float_with_exponent(0) - (1.0 - core::f64::EPSILON / 2.0) + // Convert to a value in the range [1,2) and subtract to get (0,1) + (bits >> 12).into_float_with_exponent(0) - (1.0 - f64::EPSILON / 2.0) }; let x = u * x_tab[i]; @@ -113,7 +111,7 @@ where return zero_case(rng, u); } // algebraically equivalent to f1 + DRanU()*(f0 - f1) < 1 - if f_tab[i + 1] + (f_tab[i] - f_tab[i + 1]) * rng.gen::() < pdf(x) { + if f_tab[i + 1] + (f_tab[i] - f_tab[i + 1]) * rng.random::() < pdf(x) { return x; } } diff --git a/rand_distr/src/weibull.rs b/rand_distr/src/weibull.rs index 184e5e06b16..1a9faf46c22 100644 --- a/rand_distr/src/weibull.rs +++ b/rand_distr/src/weibull.rs @@ -6,32 +6,54 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! The Weibull distribution. +//! The Weibull distribution `Weibull(λ, k)` -use num_traits::Float; use crate::{Distribution, OpenClosed01}; -use rand::Rng; use core::fmt; +use num_traits::Float; +use rand::Rng; -/// Samples floating-point numbers according to the Weibull distribution +/// The [Weibull distribution](https://en.wikipedia.org/wiki/Weibull_distribution) `Weibull(λ, k)`. +/// +/// This is a family of continuous probability distributions with +/// scale parameter `λ` (`lambda`) and shape parameter `k`. It is used +/// to model reliability data, life data, and accelerated life testing data. +/// +/// # Density function +/// +/// `f(x; λ, k) = (k / λ) * (x / λ)^(k - 1) * exp(-(x / λ)^k)` for `x >= 0`. +/// +/// # Plot +/// +/// The following plot shows the Weibull distribution with various values of `λ` and `k`. +/// +/// ![Weibull distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/weibull.svg) /// /// # Example /// ``` /// use rand::prelude::*; /// use rand_distr::Weibull; /// -/// let val: f64 = thread_rng().sample(Weibull::new(1., 10.).unwrap()); +/// let val: f64 = rand::rng().sample(Weibull::new(1., 10.).unwrap()); /// println!("{}", val); /// ``` -#[derive(Clone, Copy, Debug)] +/// +/// # Numerics +/// +/// For small `k` like `< 0.005`, even with `f64` a significant number of samples will be so small that they underflow to `0.0` +/// or so big they overflow to `inf`. This is a limitation of the floating point representation and not specific to this implementation. +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct Weibull -where F: Float, OpenClosed01: Distribution +where + F: Float, + OpenClosed01: Distribution, { inv_shape: F, scale: F, } -/// Error type returned from `Weibull::new`. +/// Error type returned from [`Weibull::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Error { /// `scale <= 0` or `nan`. @@ -53,7 +75,9 @@ impl fmt::Display for Error { impl std::error::Error for Error {} impl Weibull -where F: Float, OpenClosed01: Distribution +where + F: Float, + OpenClosed01: Distribution, { /// Construct a new `Weibull` distribution with given `scale` and `shape`. pub fn new(scale: F, shape: F) -> Result, Error> { @@ -71,7 +95,9 @@ where F: Float, OpenClosed01: Distribution } impl Distribution for Weibull -where F: Float, OpenClosed01: Distribution +where + F: Float, + OpenClosed01: Distribution, { fn sample(&self, rng: &mut R) -> F { let x: F = rng.sample(OpenClosed01); @@ -103,8 +129,10 @@ mod tests { #[test] fn value_stability() { - fn test_samples>( - distr: D, zero: F, expected: &[F], + fn test_samples>( + distr: D, + zero: F, + expected: &[F], ) { let mut rng = crate::test::rng(213); let mut buf = [zero; 4]; @@ -114,17 +142,25 @@ mod tests { assert_eq!(buf, expected); } - test_samples(Weibull::new(1.0, 1.0).unwrap(), 0f32, &[ - 0.041495778, - 0.7531094, - 1.4189332, - 0.38386202, - ]); - test_samples(Weibull::new(2.0, 0.5).unwrap(), 0f64, &[ - 1.1343478702739669, - 0.29470010050655226, - 0.7556151370284702, - 7.877212340241561, - ]); + test_samples( + Weibull::new(1.0, 1.0).unwrap(), + 0f32, + &[0.041495778, 0.7531094, 1.4189332, 0.38386202], + ); + test_samples( + Weibull::new(2.0, 0.5).unwrap(), + 0f64, + &[ + 1.1343478702739669, + 0.29470010050655226, + 0.7556151370284702, + 7.877212340241561, + ], + ); + } + + #[test] + fn weibull_distributions_can_be_compared() { + assert_eq!(Weibull::new(1.0, 2.0), Weibull::new(1.0, 2.0)); } } diff --git a/rand_distr/src/weighted/mod.rs b/rand_distr/src/weighted/mod.rs new file mode 100644 index 00000000000..1c54e48e69c --- /dev/null +++ b/rand_distr/src/weighted/mod.rs @@ -0,0 +1,28 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Weighted (index) sampling +//! +//! This module is a superset of [`rand::distr::weighted`]. +//! +//! Multiple implementations of weighted index sampling are provided: +//! +//! - [`WeightedIndex`] (a re-export from [`rand`]) supports fast construction +//! and `O(log N)` sampling over `N` weights. +//! It also supports updating weights with `O(N)` time. +//! - [`WeightedAliasIndex`] supports `O(1)` sampling, but due to high +//! construction time many samples are required to outperform [`WeightedIndex`]. +//! - [`WeightedTreeIndex`] supports `O(log N)` sampling and +//! update/insertion/removal of weights with `O(log N)` time. + +mod weighted_alias; +mod weighted_tree; + +pub use rand::distr::weighted::*; +pub use weighted_alias::*; +pub use weighted_tree::*; diff --git a/rand_distr/src/weighted_alias.rs b/rand_distr/src/weighted/weighted_alias.rs similarity index 82% rename from rand_distr/src/weighted_alias.rs rename to rand_distr/src/weighted/weighted_alias.rs index 527aece7479..862f2b70b33 100644 --- a/rand_distr/src/weighted_alias.rs +++ b/rand_distr/src/weighted/weighted_alias.rs @@ -9,13 +9,15 @@ //! This module contains an implementation of alias method for sampling random //! indices with probabilities proportional to a collection of weights. -use super::WeightedError; +use super::Error; use crate::{uniform::SampleUniform, Distribution, Uniform}; +use alloc::{boxed::Box, vec, vec::Vec}; use core::fmt; use core::iter::Sum; use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign}; use rand::Rng; -use alloc::{boxed::Box, vec, vec::Vec}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; /// A distribution using weighted sampling to pick a discretely selected item. /// @@ -39,13 +41,13 @@ use alloc::{boxed::Box, vec, vec::Vec}; /// # Example /// /// ``` -/// use rand_distr::WeightedAliasIndex; +/// use rand_distr::weighted::WeightedAliasIndex; /// use rand::prelude::*; /// /// let choices = vec!['a', 'b', 'c']; /// let weights = vec![2, 1, 1]; /// let dist = WeightedAliasIndex::new(weights).unwrap(); -/// let mut rng = thread_rng(); +/// let mut rng = rand::rng(); /// for _ in 0..100 { /// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' /// println!("{}", choices[dist.sample(&mut rng)]); @@ -63,7 +65,15 @@ use alloc::{boxed::Box, vec, vec::Vec}; /// [`Vec`]: Vec /// [`Uniform::sample`]: Distribution::sample /// [`Uniform::sample`]: Distribution::sample -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr( + feature = "serde", + serde(bound(serialize = "W: Serialize, W::Sampler: Serialize")) +)] +#[cfg_attr( + feature = "serde", + serde(bound(deserialize = "W: Deserialize<'de>, W::Sampler: Deserialize<'de>")) +)] pub struct WeightedAliasIndex { aliases: Box<[u32]>, no_alias_odds: Box<[W]>, @@ -74,18 +84,15 @@ pub struct WeightedAliasIndex { impl WeightedAliasIndex { /// Creates a new [`WeightedAliasIndex`]. /// - /// Returns an error if: - /// - The vector is empty. - /// - The vector is longer than `u32::MAX`. - /// - For any weight `w`: `w < 0` or `w > max` where `max = W::MAX / - /// weights.len()`. - /// - The sum of weights is zero. - pub fn new(weights: Vec) -> Result { + /// Error cases: + /// - [`Error::InvalidInput`] when `weights.len()` is zero or greater than `u32::MAX`. + /// - [`Error::InvalidWeight`] when a weight is not-a-number, + /// negative or greater than `max = W::MAX / weights.len()`. + /// - [`Error::InsufficientNonZero`] when the sum of all weights is zero. + pub fn new(weights: Vec) -> Result { let n = weights.len(); - if n == 0 { - return Err(WeightedError::NoItem); - } else if n > ::core::u32::MAX as usize { - return Err(WeightedError::TooMany); + if n == 0 || n > u32::MAX as usize { + return Err(Error::InvalidInput); } let n = n as u32; @@ -96,7 +103,7 @@ impl WeightedAliasIndex { .iter() .all(|&w| W::ZERO <= w && w <= max_weight_size) { - return Err(WeightedError::InvalidWeight); + return Err(Error::InvalidWeight); } // The sum of weights will represent 100% of no alias odds. @@ -108,7 +115,7 @@ impl WeightedAliasIndex { weight_sum }; if weight_sum == W::ZERO { - return Err(WeightedError::AllWeightsZero); + return Err(Error::InsufficientNonZero); } // `weight_sum` would have been zero if `try_from_lossy` causes an error here. @@ -137,8 +144,8 @@ impl WeightedAliasIndex { fn new(size: u32) -> Self { Aliases { aliases: vec![0; size as usize].into_boxed_slice(), - smalls_head: ::core::u32::MAX, - bigs_head: ::core::u32::MAX, + smalls_head: u32::MAX, + bigs_head: u32::MAX, } } @@ -165,11 +172,11 @@ impl WeightedAliasIndex { } fn smalls_is_empty(&self) -> bool { - self.smalls_head == ::core::u32::MAX + self.smalls_head == u32::MAX } fn bigs_is_empty(&self) -> bool { - self.bigs_head == ::core::u32::MAX + self.bigs_head == u32::MAX } fn set_alias(&mut self, idx: u32, alias: u32) { @@ -216,8 +223,8 @@ impl WeightedAliasIndex { // Prepare distributions for sampling. Creating them beforehand improves // sampling performance. - let uniform_index = Uniform::new(0, n); - let uniform_within_weight_sum = Uniform::new(W::ZERO, weight_sum); + let uniform_index = Uniform::new(0, n).unwrap(); + let uniform_within_weight_sum = Uniform::new(W::ZERO, weight_sum).unwrap(); Ok(Self { aliases: aliases.aliases, @@ -255,7 +262,8 @@ where } impl Clone for WeightedAliasIndex -where Uniform: Clone +where + Uniform: Clone, { fn clone(&self) -> Self { Self { @@ -267,10 +275,10 @@ where Uniform: Clone } } -/// Trait that must be implemented for weights, that are used with -/// [`WeightedAliasIndex`]. Currently no guarantees on the correctness of -/// [`WeightedAliasIndex`] are given for custom implementations of this trait. -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] +/// Weight bound for [`WeightedAliasIndex`] +/// +/// Currently no guarantees on the correctness of [`WeightedAliasIndex`] are +/// given for custom implementations of this trait. pub trait AliasableWeight: Sized + Copy @@ -306,7 +314,7 @@ pub trait AliasableWeight: macro_rules! impl_weight_for_float { ($T: ident) => { impl AliasableWeight for $T { - const MAX: Self = ::core::$T::MAX; + const MAX: Self = $T::MAX; const ZERO: Self = 0.0; fn try_from_u32_lossy(n: u32) -> Option { @@ -335,7 +343,7 @@ fn pairwise_sum(values: &[T]) -> T { macro_rules! impl_weight_for_int { ($T: ident) => { impl AliasableWeight for $T { - const MAX: Self = ::core::$T::MAX; + const MAX: Self = $T::MAX; const ZERO: Self = 0; fn try_from_u32_lossy(n: u32) -> Option { @@ -353,14 +361,11 @@ macro_rules! impl_weight_for_int { impl_weight_for_float!(f64); impl_weight_for_float!(f32); impl_weight_for_int!(usize); -#[cfg(not(target_os = "emscripten"))] impl_weight_for_int!(u128); impl_weight_for_int!(u64); impl_weight_for_int!(u32); impl_weight_for_int!(u16); impl_weight_for_int!(u8); -impl_weight_for_int!(isize); -#[cfg(not(target_os = "emscripten"))] impl_weight_for_int!(i128); impl_weight_for_int!(i64); impl_weight_for_int!(i32); @@ -378,35 +383,33 @@ mod test { // Floating point special cases assert_eq!( - WeightedAliasIndex::new(vec![::core::f32::INFINITY]).unwrap_err(), - WeightedError::InvalidWeight + WeightedAliasIndex::new(vec![f32::INFINITY]).unwrap_err(), + Error::InvalidWeight ); assert_eq!( WeightedAliasIndex::new(vec![-0_f32]).unwrap_err(), - WeightedError::AllWeightsZero + Error::InsufficientNonZero ); assert_eq!( WeightedAliasIndex::new(vec![-1_f32]).unwrap_err(), - WeightedError::InvalidWeight + Error::InvalidWeight ); assert_eq!( - WeightedAliasIndex::new(vec![-::core::f32::INFINITY]).unwrap_err(), - WeightedError::InvalidWeight + WeightedAliasIndex::new(vec![f32::NEG_INFINITY]).unwrap_err(), + Error::InvalidWeight ); assert_eq!( - WeightedAliasIndex::new(vec![::core::f32::NAN]).unwrap_err(), - WeightedError::InvalidWeight + WeightedAliasIndex::new(vec![f32::NAN]).unwrap_err(), + Error::InvalidWeight ); } - #[cfg(not(target_os = "emscripten"))] #[test] #[cfg_attr(miri, ignore)] // Miri is too slow fn test_weighted_index_u128() { test_weighted_index(|x: u128| x as f64); } - #[cfg(not(target_os = "emscripten"))] #[test] #[cfg_attr(miri, ignore)] // Miri is too slow fn test_weighted_index_i128() { @@ -415,11 +418,11 @@ mod test { // Signed integer special cases assert_eq!( WeightedAliasIndex::new(vec![-1_i128]).unwrap_err(), - WeightedError::InvalidWeight + Error::InvalidWeight ); assert_eq!( - WeightedAliasIndex::new(vec![::core::i128::MIN]).unwrap_err(), - WeightedError::InvalidWeight + WeightedAliasIndex::new(vec![i128::MIN]).unwrap_err(), + Error::InvalidWeight ); } @@ -437,16 +440,18 @@ mod test { // Signed integer special cases assert_eq!( WeightedAliasIndex::new(vec![-1_i8]).unwrap_err(), - WeightedError::InvalidWeight + Error::InvalidWeight ); assert_eq!( - WeightedAliasIndex::new(vec![::core::i8::MIN]).unwrap_err(), - WeightedError::InvalidWeight + WeightedAliasIndex::new(vec![i8::MIN]).unwrap_err(), + Error::InvalidWeight ); } fn test_weighted_index f64>(w_to_f64: F) - where WeightedAliasIndex: fmt::Debug { + where + WeightedAliasIndex: fmt::Debug, + { const NUM_WEIGHTS: u32 = 10; const ZERO_WEIGHT_INDEX: u32 = 3; const NUM_SAMPLES: u32 = 15000; @@ -457,14 +462,15 @@ mod test { let random_weight_distribution = Uniform::new_inclusive( W::ZERO, W::MAX / W::try_from_u32_lossy(NUM_WEIGHTS).unwrap(), - ); + ) + .unwrap(); for _ in 0..NUM_WEIGHTS { weights.push(rng.sample(&random_weight_distribution)); } weights[ZERO_WEIGHT_INDEX as usize] = W::ZERO; weights }; - let weight_sum = weights.iter().map(|w| *w).sum::(); + let weight_sum = weights.iter().copied().sum::(); let expected_counts = weights .iter() .map(|&w| w_to_f64(w) / w_to_f64(weight_sum) * NUM_SAMPLES as f64) @@ -485,21 +491,25 @@ mod test { assert_eq!( WeightedAliasIndex::::new(vec![]).unwrap_err(), - WeightedError::NoItem + Error::InvalidInput ); assert_eq!( WeightedAliasIndex::new(vec![W::ZERO]).unwrap_err(), - WeightedError::AllWeightsZero + Error::InsufficientNonZero ); assert_eq!( WeightedAliasIndex::new(vec![W::MAX, W::MAX]).unwrap_err(), - WeightedError::InvalidWeight + Error::InvalidWeight ); } #[test] fn value_stability() { - fn test_samples(weights: Vec, buf: &mut [usize], expected: &[usize]) { + fn test_samples( + weights: Vec, + buf: &mut [usize], + expected: &[usize], + ) { assert_eq!(buf.len(), expected.len()); let distr = WeightedAliasIndex::new(weights).unwrap(); let mut rng = crate::test::rng(0x9c9fa0b0580a7031); @@ -510,14 +520,20 @@ mod test { } let mut buf = [0; 10]; - test_samples(vec![1i32, 1, 1, 1, 1, 1, 1, 1, 1], &mut buf, &[ - 6, 5, 7, 5, 8, 7, 6, 2, 3, 7, - ]); - test_samples(vec![0.7f32, 0.1, 0.1, 0.1], &mut buf, &[ - 2, 0, 0, 0, 0, 0, 0, 0, 1, 3, - ]); - test_samples(vec![1.0f64, 0.999, 0.998, 0.997], &mut buf, &[ - 2, 1, 2, 3, 2, 1, 3, 2, 1, 1, - ]); + test_samples( + vec![1i32, 1, 1, 1, 1, 1, 1, 1, 1], + &mut buf, + &[6, 5, 7, 5, 8, 7, 6, 2, 3, 7], + ); + test_samples( + vec![0.7f32, 0.1, 0.1, 0.1], + &mut buf, + &[2, 0, 0, 0, 0, 0, 0, 0, 1, 3], + ); + test_samples( + vec![1.0f64, 0.999, 0.998, 0.997], + &mut buf, + &[2, 1, 2, 3, 2, 1, 3, 2, 1, 1], + ); } } diff --git a/rand_distr/src/weighted/weighted_tree.rs b/rand_distr/src/weighted/weighted_tree.rs new file mode 100644 index 00000000000..dd315aa5f8f --- /dev/null +++ b/rand_distr/src/weighted/weighted_tree.rs @@ -0,0 +1,390 @@ +// Copyright 2024 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! This module contains an implementation of a tree structure for sampling random +//! indices with probabilities proportional to a collection of weights. + +use core::ops::SubAssign; + +use super::{Error, Weight}; +use crate::Distribution; +use alloc::vec::Vec; +use rand::distr::uniform::{SampleBorrow, SampleUniform}; +use rand::Rng; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +/// A distribution using weighted sampling to pick a discretely selected item. +/// +/// Sampling a [`WeightedTreeIndex`] distribution returns the index of a randomly +/// selected element from the vector used to create the [`WeightedTreeIndex`]. +/// The chance of a given element being picked is proportional to the value of +/// the element. The weights can have any type `W` for which an implementation of +/// [`Weight`] exists. +/// +/// # Key differences +/// +/// The main distinction between [`WeightedTreeIndex`] and [`WeightedIndex`] +/// lies in the internal representation of weights. In [`WeightedTreeIndex`], +/// weights are structured as a tree, which is optimized for frequent updates of the weights. +/// +/// # Caution: Floating point types +/// +/// When utilizing [`WeightedTreeIndex`] with floating point types (such as f32 or f64), +/// exercise caution due to the inherent nature of floating point arithmetic. Floating point types +/// are susceptible to numerical rounding errors. Since operations on floating point weights are +/// repeated numerous times, rounding errors can accumulate, potentially leading to noticeable +/// deviations from the expected behavior. +/// +/// Ideally, use fixed point or integer types whenever possible. +/// +/// # Performance +/// +/// A [`WeightedTreeIndex`] with `n` elements requires `O(n)` memory. +/// +/// Time complexity for the operations of a [`WeightedTreeIndex`] are: +/// * Constructing: Building the initial tree from an iterator of weights takes `O(n)` time. +/// * Sampling: Choosing an index (traversing down the tree) requires `O(log n)` time. +/// * Weight Update: Modifying a weight (traversing up the tree), requires `O(log n)` time. +/// * Weight Addition (Pushing): Adding a new weight (traversing up the tree), requires `O(log n)` time. +/// * Weight Removal (Popping): Removing a weight (traversing up the tree), requires `O(log n)` time. +/// +/// # Example +/// +/// ``` +/// use rand_distr::weighted::WeightedTreeIndex; +/// use rand::prelude::*; +/// +/// let choices = vec!['a', 'b', 'c']; +/// let weights = vec![2, 0]; +/// let mut dist = WeightedTreeIndex::new(&weights).unwrap(); +/// dist.push(1).unwrap(); +/// dist.update(1, 1).unwrap(); +/// let mut rng = rand::rng(); +/// let mut samples = [0; 3]; +/// for _ in 0..100 { +/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' +/// let i = dist.sample(&mut rng); +/// samples[i] += 1; +/// } +/// println!("Results: {:?}", choices.iter().zip(samples.iter()).collect::>()); +/// ``` +/// +/// [`WeightedTreeIndex`]: WeightedTreeIndex +/// [`WeightedIndex`]: super::WeightedIndex +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr( + feature = "serde", + serde(bound(serialize = "W: Serialize, W::Sampler: Serialize")) +)] +#[cfg_attr( + feature = "serde", + serde(bound(deserialize = "W: Deserialize<'de>, W::Sampler: Deserialize<'de>")) +)] +#[derive(Clone, Default, Debug, PartialEq)] +pub struct WeightedTreeIndex< + W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign + Weight, +> { + subtotals: Vec, +} + +impl + Weight> + WeightedTreeIndex +{ + /// Creates a new [`WeightedTreeIndex`] from a slice of weights. + /// + /// Error cases: + /// - [`Error::InvalidWeight`] when a weight is not-a-number or negative. + /// - [`Error::Overflow`] when the sum of all weights overflows. + pub fn new(weights: I) -> Result + where + I: IntoIterator, + I::Item: SampleBorrow, + { + let mut subtotals: Vec = weights.into_iter().map(|x| x.borrow().clone()).collect(); + for weight in subtotals.iter() { + if !(*weight >= W::ZERO) { + return Err(Error::InvalidWeight); + } + } + let n = subtotals.len(); + for i in (1..n).rev() { + let w = subtotals[i].clone(); + let parent = (i - 1) / 2; + subtotals[parent] + .checked_add_assign(&w) + .map_err(|()| Error::Overflow)?; + } + Ok(Self { subtotals }) + } + + /// Returns `true` if the tree contains no weights. + pub fn is_empty(&self) -> bool { + self.subtotals.is_empty() + } + + /// Returns the number of weights. + pub fn len(&self) -> usize { + self.subtotals.len() + } + + /// Returns `true` if we can sample. + /// + /// This is the case if the total weight of the tree is greater than zero. + pub fn is_valid(&self) -> bool { + if let Some(weight) = self.subtotals.first() { + *weight > W::ZERO + } else { + false + } + } + + /// Gets the weight at an index. + pub fn get(&self, index: usize) -> W { + let left_index = 2 * index + 1; + let right_index = 2 * index + 2; + let mut w = self.subtotals[index].clone(); + w -= self.subtotal(left_index); + w -= self.subtotal(right_index); + w + } + + /// Removes the last weight and returns it, or [`None`] if it is empty. + pub fn pop(&mut self) -> Option { + self.subtotals.pop().map(|weight| { + let mut index = self.len(); + while index != 0 { + index = (index - 1) / 2; + self.subtotals[index] -= weight.clone(); + } + weight + }) + } + + /// Appends a new weight at the end. + /// + /// Error cases: + /// - [`Error::InvalidWeight`] when a weight is not-a-number or negative. + /// - [`Error::Overflow`] when the sum of all weights overflows. + pub fn push(&mut self, weight: W) -> Result<(), Error> { + if !(weight >= W::ZERO) { + return Err(Error::InvalidWeight); + } + if let Some(total) = self.subtotals.first() { + let mut total = total.clone(); + if total.checked_add_assign(&weight).is_err() { + return Err(Error::Overflow); + } + } + let mut index = self.len(); + self.subtotals.push(weight.clone()); + while index != 0 { + index = (index - 1) / 2; + self.subtotals[index].checked_add_assign(&weight).unwrap(); + } + Ok(()) + } + + /// Updates the weight at an index. + /// + /// Error cases: + /// - [`Error::InvalidWeight`] when a weight is not-a-number or negative. + /// - [`Error::Overflow`] when the sum of all weights overflows. + pub fn update(&mut self, mut index: usize, weight: W) -> Result<(), Error> { + if !(weight >= W::ZERO) { + return Err(Error::InvalidWeight); + } + let old_weight = self.get(index); + if weight > old_weight { + let mut difference = weight; + difference -= old_weight; + if let Some(total) = self.subtotals.first() { + let mut total = total.clone(); + if total.checked_add_assign(&difference).is_err() { + return Err(Error::Overflow); + } + } + self.subtotals[index] + .checked_add_assign(&difference) + .unwrap(); + while index != 0 { + index = (index - 1) / 2; + self.subtotals[index] + .checked_add_assign(&difference) + .unwrap(); + } + } else if weight < old_weight { + let mut difference = old_weight; + difference -= weight; + self.subtotals[index] -= difference.clone(); + while index != 0 { + index = (index - 1) / 2; + self.subtotals[index] -= difference.clone(); + } + } + Ok(()) + } + + fn subtotal(&self, index: usize) -> W { + if index < self.subtotals.len() { + self.subtotals[index].clone() + } else { + W::ZERO + } + } +} + +impl + Weight> + WeightedTreeIndex +{ + /// Samples a randomly selected index from the weighted distribution. + /// + /// Returns an error if there are no elements or all weights are zero. This + /// is unlike [`Distribution::sample`], which panics in those cases. + pub fn try_sample(&self, rng: &mut R) -> Result { + let total_weight = self.subtotals.first().cloned().unwrap_or(W::ZERO); + if total_weight == W::ZERO { + return Err(Error::InsufficientNonZero); + } + let mut target_weight = rng.random_range(W::ZERO..total_weight); + let mut index = 0; + loop { + // Maybe descend into the left sub tree. + let left_index = 2 * index + 1; + let left_subtotal = self.subtotal(left_index); + if target_weight < left_subtotal { + index = left_index; + continue; + } + target_weight -= left_subtotal; + + // Maybe descend into the right sub tree. + let right_index = 2 * index + 2; + let right_subtotal = self.subtotal(right_index); + if target_weight < right_subtotal { + index = right_index; + continue; + } + target_weight -= right_subtotal; + + // Otherwise we found the index with the target weight. + break; + } + assert!(target_weight >= W::ZERO); + assert!(target_weight < self.get(index)); + Ok(index) + } +} + +/// Samples a randomly selected index from the weighted distribution. +/// +/// Caution: This method panics if there are no elements or all weights are zero. However, +/// it is guaranteed that this method will not panic if a call to [`WeightedTreeIndex::is_valid`] +/// returns `true`. +impl + Weight> Distribution + for WeightedTreeIndex +{ + #[track_caller] + fn sample(&self, rng: &mut R) -> usize { + self.try_sample(rng).unwrap() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_no_item_error() { + let mut rng = crate::test::rng(0x9c9fa0b0580a7031); + #[allow(clippy::needless_borrows_for_generic_args)] + let tree = WeightedTreeIndex::::new(&[]).unwrap(); + assert_eq!( + tree.try_sample(&mut rng).unwrap_err(), + Error::InsufficientNonZero + ); + } + + #[test] + fn test_overflow_error() { + assert_eq!(WeightedTreeIndex::new([i32::MAX, 2]), Err(Error::Overflow)); + let mut tree = WeightedTreeIndex::new([i32::MAX - 2, 1]).unwrap(); + assert_eq!(tree.push(3), Err(Error::Overflow)); + assert_eq!(tree.update(1, 4), Err(Error::Overflow)); + tree.update(1, 2).unwrap(); + } + + #[test] + fn test_all_weights_zero_error() { + let tree = WeightedTreeIndex::::new([0.0, 0.0]).unwrap(); + let mut rng = crate::test::rng(0x9c9fa0b0580a7031); + assert_eq!( + tree.try_sample(&mut rng).unwrap_err(), + Error::InsufficientNonZero + ); + } + + #[test] + fn test_invalid_weight_error() { + assert_eq!( + WeightedTreeIndex::::new([1, -1]).unwrap_err(), + Error::InvalidWeight + ); + #[allow(clippy::needless_borrows_for_generic_args)] + let mut tree = WeightedTreeIndex::::new(&[]).unwrap(); + assert_eq!(tree.push(-1).unwrap_err(), Error::InvalidWeight); + tree.push(1).unwrap(); + assert_eq!(tree.update(0, -1).unwrap_err(), Error::InvalidWeight); + } + + #[test] + fn test_tree_modifications() { + let mut tree = WeightedTreeIndex::new([9, 1, 2]).unwrap(); + tree.push(3).unwrap(); + tree.push(5).unwrap(); + tree.update(0, 0).unwrap(); + assert_eq!(tree.pop(), Some(5)); + let expected = WeightedTreeIndex::new([0, 1, 2, 3]).unwrap(); + assert_eq!(tree, expected); + } + + #[test] + #[allow(clippy::needless_range_loop)] + fn test_sample_counts_match_probabilities() { + let start = 1; + let end = 3; + let samples = 20; + let mut rng = crate::test::rng(0x9c9fa0b0580a7031); + let weights: Vec = (0..end).map(|_| rng.random()).collect(); + let mut tree = WeightedTreeIndex::new(weights).unwrap(); + let mut total_weight = 0.0; + let mut weights = alloc::vec![0.0; end]; + for i in 0..end { + tree.update(i, i as f64).unwrap(); + weights[i] = i as f64; + total_weight += i as f64; + } + for i in 0..start { + tree.update(i, 0.0).unwrap(); + weights[i] = 0.0; + total_weight -= i as f64; + } + let mut counts = alloc::vec![0_usize; end]; + for _ in 0..samples { + let i = tree.sample(&mut rng); + counts[i] += 1; + } + for i in 0..start { + assert_eq!(counts[i], 0); + } + for i in start..end { + let diff = counts[i] as f64 / samples as f64 - weights[i] / total_weight; + assert!(diff.abs() < 0.05); + } + } +} diff --git a/rand_distr/src/zeta.rs b/rand_distr/src/zeta.rs new file mode 100644 index 00000000000..f93f167d7c3 --- /dev/null +++ b/rand_distr/src/zeta.rs @@ -0,0 +1,203 @@ +// Copyright 2021 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! The Zeta distribution. + +use crate::{Distribution, StandardUniform}; +use core::fmt; +use num_traits::Float; +use rand::{distr::OpenClosed01, Rng}; + +/// The [Zeta distribution](https://en.wikipedia.org/wiki/Zeta_distribution) `Zeta(s)`. +/// +/// The [Zeta distribution](https://en.wikipedia.org/wiki/Zeta_distribution) +/// is a discrete probability distribution with parameter `s`. +/// It is a special case of the [`Zipf`](crate::Zipf) distribution with `n = ∞`. +/// It is also known as the discrete Pareto, Riemann-Zeta, Zipf, or Zipf–Estoup distribution. +/// +/// # Density function +/// +/// `f(k) = k^(-s) / ζ(s)` for `k >= 1`, where `ζ` is the +/// [Riemann zeta function](https://en.wikipedia.org/wiki/Riemann_zeta_function). +/// +/// # Plot +/// +/// The following plot illustrates the zeta distribution for various values of `s`. +/// +/// ![Zeta distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/zeta.svg) +/// +/// # Example +/// ``` +/// use rand::prelude::*; +/// use rand_distr::Zeta; +/// +/// let val: f64 = rand::rng().sample(Zeta::new(1.5).unwrap()); +/// println!("{}", val); +/// ``` +/// +/// # Integer vs FP return type +/// +/// This implementation uses floating-point (FP) logic internally, which can +/// potentially generate very large samples (exceeding e.g. `u64::MAX`). +/// +/// It is *safe* to cast such results to an integer type using `as` +/// (e.g. `distr.sample(&mut rng) as u64`), since such casts are saturating +/// (e.g. `2f64.powi(64) as u64 == u64::MAX`). It is up to the user to +/// determine whether this potential loss of accuracy is acceptable +/// (this determination may depend on the distribution's parameters). +/// +/// # Notes +/// +/// The zeta distribution has no upper limit. Sampled values may be infinite. +/// In particular, a value of infinity might be returned for the following +/// reasons: +/// 1. it is the best representation in the type `F` of the actual sample. +/// 2. to prevent infinite loops for very small `s`. +/// +/// # Implementation details +/// +/// We are using the algorithm from +/// [Non-Uniform Random Variate Generation](https://doi.org/10.1007/978-1-4613-8643-8), +/// Section 6.1, page 551. +#[derive(Clone, Copy, Debug, PartialEq)] +pub struct Zeta +where + F: Float, + StandardUniform: Distribution, + OpenClosed01: Distribution, +{ + s_minus_1: F, + b: F, +} + +/// Error type returned from [`Zeta::new`]. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Error { + /// `s <= 1` or `nan`. + STooSmall, +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + Error::STooSmall => "s <= 1 or is NaN in Zeta distribution", + }) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for Error {} + +impl Zeta +where + F: Float, + StandardUniform: Distribution, + OpenClosed01: Distribution, +{ + /// Construct a new `Zeta` distribution with given `s` parameter. + #[inline] + pub fn new(s: F) -> Result, Error> { + if !(s > F::one()) { + return Err(Error::STooSmall); + } + let s_minus_1 = s - F::one(); + let two = F::one() + F::one(); + Ok(Zeta { + s_minus_1, + b: two.powf(s_minus_1), + }) + } +} + +impl Distribution for Zeta +where + F: Float, + StandardUniform: Distribution, + OpenClosed01: Distribution, +{ + #[inline] + fn sample(&self, rng: &mut R) -> F { + loop { + let u = rng.sample(OpenClosed01); + let x = u.powf(-F::one() / self.s_minus_1).floor(); + debug_assert!(x >= F::one()); + if x.is_infinite() { + // For sufficiently small `s`, `x` will always be infinite, + // which is rejected, resulting in an infinite loop. We avoid + // this by always returning infinity instead. + return x; + } + + let t = (F::one() + F::one() / x).powf(self.s_minus_1); + + let v = rng.sample(StandardUniform); + if v * x * (t - F::one()) * self.b <= t * (self.b - F::one()) { + return x; + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_samples>(distr: D, zero: F, expected: &[F]) { + let mut rng = crate::test::rng(213); + let mut buf = [zero; 4]; + for x in &mut buf { + *x = rng.sample(&distr); + } + assert_eq!(buf, expected); + } + + #[test] + #[should_panic] + fn zeta_invalid() { + Zeta::new(1.).unwrap(); + } + + #[test] + #[should_panic] + fn zeta_nan() { + Zeta::new(f64::NAN).unwrap(); + } + + #[test] + fn zeta_sample() { + let a = 2.0; + let d = Zeta::new(a).unwrap(); + let mut rng = crate::test::rng(1); + for _ in 0..1000 { + let r = d.sample(&mut rng); + assert!(r >= 1.); + } + } + + #[test] + fn zeta_small_a() { + let a = 1. + 1e-15; + let d = Zeta::new(a).unwrap(); + let mut rng = crate::test::rng(2); + for _ in 0..1000 { + let r = d.sample(&mut rng); + assert!(r >= 1.); + } + } + + #[test] + fn zeta_value_stability() { + test_samples(Zeta::new(1.5).unwrap(), 0f32, &[1.0, 2.0, 1.0, 1.0]); + test_samples(Zeta::new(2.0).unwrap(), 0f64, &[2.0, 1.0, 1.0, 1.0]); + } + + #[test] + fn zeta_distributions_can_be_compared() { + assert_eq!(Zeta::new(1.0), Zeta::new(1.0)); + } +} diff --git a/rand_distr/src/zipf.rs b/rand_distr/src/zipf.rs new file mode 100644 index 00000000000..f2e80d37908 --- /dev/null +++ b/rand_distr/src/zipf.rs @@ -0,0 +1,244 @@ +// Copyright 2021 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! The Zipf distribution. + +use crate::{Distribution, StandardUniform}; +use core::fmt; +use num_traits::Float; +use rand::Rng; + +/// The Zipf (Zipfian) distribution `Zipf(n, s)`. +/// +/// The samples follow [Zipf's law](https://en.wikipedia.org/wiki/Zipf%27s_law): +/// The frequency of each sample from a finite set of size `n` is inversely +/// proportional to a power of its frequency rank (with exponent `s`). +/// +/// For large `n`, this converges to the [`Zeta`](crate::Zeta) distribution. +/// +/// For `s = 0`, this becomes a [`uniform`](crate::Uniform) distribution. +/// +/// # Plot +/// +/// The following plot illustrates the Zipf distribution for `n = 10` and +/// various values of `s`. +/// +/// ![Zipf distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/zipf.svg) +/// +/// # Example +/// ``` +/// use rand::prelude::*; +/// use rand_distr::Zipf; +/// +/// let val: f64 = rand::rng().sample(Zipf::new(10.0, 1.5).unwrap()); +/// println!("{}", val); +/// ``` +/// +/// # Integer vs FP return type +/// +/// This implementation uses floating-point (FP) logic internally. It may be +/// expected that the samples are no greater than `n`, thus it is reasonable to +/// cast generated samples to any integer type which can also represent `n` +/// (e.g. `distr.sample(&mut rng) as u64`). +/// +/// # Implementation details +/// +/// Implemented via [rejection sampling](https://en.wikipedia.org/wiki/Rejection_sampling), +/// due to Jason Crease[1]. +/// +/// [1]: https://jasoncrease.medium.com/rejection-sampling-the-zipf-distribution-6b359792cffa +#[derive(Clone, Copy, Debug, PartialEq)] +pub struct Zipf +where + F: Float, + StandardUniform: Distribution, +{ + s: F, + t: F, + q: F, +} + +/// Error type returned from [`Zipf::new`]. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Error { + /// `s < 0` or `nan`. + STooSmall, + /// `n < 1`. + NTooSmall, +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + Error::STooSmall => "s < 0 or is NaN in Zipf distribution", + Error::NTooSmall => "n < 1 in Zipf distribution", + }) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for Error {} + +impl Zipf +where + F: Float, + StandardUniform: Distribution, +{ + /// Construct a new `Zipf` distribution for a set with `n` elements and a + /// frequency rank exponent `s`. + /// + /// The parameter `n` is typically integral, however we use type + ///
F: [Float]
in order to permit very large values + /// and since our implementation requires a floating-point type. + #[inline] + pub fn new(n: F, s: F) -> Result, Error> { + if !(s >= F::zero()) { + return Err(Error::STooSmall); + } + if n < F::one() { + return Err(Error::NTooSmall); + } + let q = if s != F::one() { + // Make sure to calculate the division only once. + F::one() / (F::one() - s) + } else { + // This value is never used. + F::zero() + }; + let t = if s != F::one() { + (n.powf(F::one() - s) - s) * q + } else { + F::one() + n.ln() + }; + debug_assert!(t > F::zero()); + Ok(Zipf { s, t, q }) + } + + /// Inverse cumulative density function + #[inline] + fn inv_cdf(&self, p: F) -> F { + let one = F::one(); + let pt = p * self.t; + if pt <= one { + pt + } else if self.s != one { + (pt * (one - self.s) + self.s).powf(self.q) + } else { + (pt - one).exp() + } + } +} + +impl Distribution for Zipf +where + F: Float, + StandardUniform: Distribution, +{ + #[inline] + fn sample(&self, rng: &mut R) -> F { + let one = F::one(); + loop { + let inv_b = self.inv_cdf(rng.sample(StandardUniform)); + let x = (inv_b + one).floor(); + let mut ratio = x.powf(-self.s); + if x > one { + ratio = ratio * inv_b.powf(self.s) + }; + + let y = rng.sample(StandardUniform); + if y < ratio { + return x; + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_samples>(distr: D, zero: F, expected: &[F]) { + let mut rng = crate::test::rng(213); + let mut buf = [zero; 4]; + for x in &mut buf { + *x = rng.sample(&distr); + } + assert_eq!(buf, expected); + } + + #[test] + #[should_panic] + fn zipf_s_too_small() { + Zipf::new(10., -1.).unwrap(); + } + + #[test] + #[should_panic] + fn zipf_n_too_small() { + Zipf::new(0., 1.).unwrap(); + } + + #[test] + #[should_panic] + fn zipf_nan() { + Zipf::new(10., f64::NAN).unwrap(); + } + + #[test] + fn zipf_sample() { + let d = Zipf::new(10., 0.5).unwrap(); + let mut rng = crate::test::rng(2); + for _ in 0..1000 { + let r = d.sample(&mut rng); + assert!(r >= 1.); + } + } + + #[test] + fn zipf_sample_s_1() { + let d = Zipf::new(10., 1.).unwrap(); + let mut rng = crate::test::rng(2); + for _ in 0..1000 { + let r = d.sample(&mut rng); + assert!(r >= 1.); + } + } + + #[test] + fn zipf_sample_s_0() { + let d = Zipf::new(10., 0.).unwrap(); + let mut rng = crate::test::rng(2); + for _ in 0..1000 { + let r = d.sample(&mut rng); + assert!(r >= 1.); + } + // TODO: verify that this is a uniform distribution + } + + #[test] + fn zipf_sample_large_n() { + let d = Zipf::new(f64::MAX, 1.5).unwrap(); + let mut rng = crate::test::rng(2); + for _ in 0..1000 { + let r = d.sample(&mut rng); + assert!(r >= 1.); + } + // TODO: verify that this is a zeta distribution + } + + #[test] + fn zipf_value_stability() { + test_samples(Zipf::new(10., 0.5).unwrap(), 0f32, &[10.0, 2.0, 6.0, 7.0]); + test_samples(Zipf::new(10., 2.0).unwrap(), 0f64, &[1.0, 2.0, 3.0, 2.0]); + } + + #[test] + fn zipf_distributions_can_be_compared() { + assert_eq!(Zipf::new(1.0, 2.0), Zipf::new(1.0, 2.0)); + } +} diff --git a/rand_distr/tests/uniformity.rs b/rand_distr/tests/uniformity.rs deleted file mode 100644 index 7d359c7d733..00000000000 --- a/rand_distr/tests/uniformity.rs +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -use average::Histogram; -use rand::prelude::*; - -const N_BINS: usize = 100; -const N_SAMPLES: u32 = 1_000_000; -const TOL: f64 = 1e-3; -average::define_histogram!(hist, 100); -use hist::Histogram as Histogram100; - -#[test] -fn unit_sphere() { - const N_DIM: usize = 3; - let h = Histogram100::with_const_width(-1., 1.); - let mut histograms = [h.clone(), h.clone(), h]; - let dist = rand_distr::UnitSphere; - let mut rng = rand_pcg::Pcg32::from_entropy(); - for _ in 0..N_SAMPLES { - let v: [f64; 3] = dist.sample(&mut rng); - for i in 0..N_DIM { - histograms[i] - .add(v[i]) - .map_err(|e| { - println!("v: {}", v[i]); - e - }) - .unwrap(); - } - } - for h in &histograms { - let sum: u64 = h.bins().iter().sum(); - println!("{:?}", h); - for &b in h.bins() { - let p = (b as f64) / (sum as f64); - assert!((p - 1.0 / (N_BINS as f64)).abs() < TOL, "{}", p); - } - } -} - -#[test] -fn unit_circle() { - use std::f64::consts::PI; - let mut h = Histogram100::with_const_width(-PI, PI); - let dist = rand_distr::UnitCircle; - let mut rng = rand_pcg::Pcg32::from_entropy(); - for _ in 0..N_SAMPLES { - let v: [f64; 2] = dist.sample(&mut rng); - h.add(v[0].atan2(v[1])).unwrap(); - } - let sum: u64 = h.bins().iter().sum(); - println!("{:?}", h); - for &b in h.bins() { - let p = (b as f64) / (sum as f64); - assert!((p - 1.0 / (N_BINS as f64)).abs() < TOL, "{}", p); - } -} diff --git a/rand_distr/tests/value_stability.rs b/rand_distr/tests/value_stability.rs index 65c49644a41..330119b68f6 100644 --- a/rand_distr/tests/value_stability.rs +++ b/rand_distr/tests/value_stability.rs @@ -11,7 +11,7 @@ use core::fmt::Debug; use rand::Rng; use rand_distr::*; -fn get_rng(seed: u64) -> impl rand::Rng { +fn get_rng(seed: u64) -> impl Rng { // For tests, we want a statistically good, fast, reproducible RNG. // PCG32 will do fine, and will be easy to embed if we ever need to. const INC: u64 = 11634580027462260723; @@ -53,9 +53,7 @@ impl ApproxEq for [T; 3] { } } -fn test_samples>( - seed: u64, distr: D, expected: &[F], -) { +fn test_samples>(seed: u64, distr: D, expected: &[F]) { let mut rng = get_rng(seed); for val in expected { let x = rng.sample(&distr); @@ -64,283 +62,439 @@ fn test_samples>( } #[test] -fn binominal_stability() { +fn binomial_stability() { // We have multiple code paths: np < 10, p > 0.5 test_samples(353, Binomial::new(2, 0.7).unwrap(), &[1, 1, 2, 1]); test_samples(353, Binomial::new(20, 0.3).unwrap(), &[7, 7, 5, 7]); - test_samples(353, Binomial::new(2000, 0.6).unwrap(), &[1194, 1208, 1192, 1210]); + test_samples( + 353, + Binomial::new(2000, 0.6).unwrap(), + &[1194, 1208, 1192, 1210], + ); } #[test] fn geometric_stability() { test_samples(464, StandardGeometric, &[3, 0, 1, 0, 0, 3, 2, 1, 2, 0]); - + test_samples(464, Geometric::new(0.5).unwrap(), &[2, 1, 1, 0, 0, 1, 0, 1]); - test_samples(464, Geometric::new(0.05).unwrap(), &[24, 51, 81, 67, 27, 11, 7, 6]); - test_samples(464, Geometric::new(0.95).unwrap(), &[0, 0, 0, 0, 1, 0, 0, 0]); + test_samples( + 464, + Geometric::new(0.05).unwrap(), + &[24, 51, 81, 67, 27, 11, 7, 6], + ); + test_samples( + 464, + Geometric::new(0.95).unwrap(), + &[0, 0, 0, 0, 1, 0, 0, 0], + ); // expect non-random behaviour for series of pre-determined trials - test_samples(464, Geometric::new(0.0).unwrap(), &[u64::max_value(); 100][..]); + test_samples(464, Geometric::new(0.0).unwrap(), &[u64::MAX; 100][..]); test_samples(464, Geometric::new(1.0).unwrap(), &[0; 100][..]); } #[test] fn hypergeometric_stability() { // We have multiple code paths based on the distribution's mode and sample_size - test_samples(7221, Hypergeometric::new(99, 33, 8).unwrap(), &[4, 3, 2, 2, 3, 2, 3, 1]); // Algorithm HIN - test_samples(7221, Hypergeometric::new(100, 50, 50).unwrap(), &[23, 27, 26, 27, 22, 24, 31, 22]); // Algorithm H2PE + test_samples( + 7221, + Hypergeometric::new(99, 33, 8).unwrap(), + &[4, 3, 2, 2, 3, 2, 3, 1], + ); // Algorithm HIN + test_samples( + 7221, + Hypergeometric::new(100, 50, 50).unwrap(), + &[23, 27, 26, 27, 22, 25, 31, 25], + ); // Algorithm H2PE } #[test] fn unit_ball_stability() { - test_samples(2, UnitBall, &[ - [0.018035709265959987f64, -0.4348771383120438, -0.07982762085055706], - [0.10588569388223945, -0.4734350111375454, -0.7392104908825501], - [0.11060237642041049, -0.16065642822852677, -0.8444043930440075] - ]); + test_samples( + 2, + UnitBall, + &[ + [ + 0.018035709265959987f64, + -0.4348771383120438, + -0.07982762085055706, + ], + [ + 0.10588569388223945, + -0.4734350111375454, + -0.7392104908825501, + ], + [ + 0.11060237642041049, + -0.16065642822852677, + -0.8444043930440075, + ], + ], + ); } #[test] fn unit_circle_stability() { - test_samples(2, UnitCircle, &[ - [-0.9965658683520504f64, -0.08280380447614634], - [-0.9790853270389644, -0.20345004884984505], - [-0.8449189758898707, 0.5348943112253227], - ]); + test_samples( + 2, + UnitCircle, + &[ + [-0.9965658683520504f64, -0.08280380447614634], + [-0.9790853270389644, -0.20345004884984505], + [-0.8449189758898707, 0.5348943112253227], + ], + ); } #[test] fn unit_sphere_stability() { - test_samples(2, UnitSphere, &[ - [0.03247542860231647f64, -0.7830477442152738, 0.6211131755296027], - [-0.09978440840914075, 0.9706650829833128, -0.21875184231323952], - [0.2735582468624679, 0.9435374242279655, -0.1868234852870203], - ]); + test_samples( + 2, + UnitSphere, + &[ + [ + 0.03247542860231647f64, + -0.7830477442152738, + 0.6211131755296027, + ], + [ + -0.09978440840914075, + 0.9706650829833128, + -0.21875184231323952, + ], + [0.2735582468624679, 0.9435374242279655, -0.1868234852870203], + ], + ); } #[test] fn unit_disc_stability() { - test_samples(2, UnitDisc, &[ - [0.018035709265959987f64, -0.4348771383120438], - [-0.07982762085055706, 0.7765329819820659], - [0.21450745997299503, 0.7398636984333291], - ]); + test_samples( + 2, + UnitDisc, + &[ + [0.018035709265959987f64, -0.4348771383120438], + [-0.07982762085055706, 0.7765329819820659], + [0.21450745997299503, 0.7398636984333291], + ], + ); } #[test] fn pareto_stability() { - test_samples(213, Pareto::new(1.0, 1.0).unwrap(), &[ - 1.0423688f32, 2.1235929, 4.132709, 1.4679428, - ]); - test_samples(213, Pareto::new(2.0, 0.5).unwrap(), &[ - 9.019295276219136f64, - 4.3097126018270595, - 6.837815045397157, - 105.8826669383772, - ]); + test_samples( + 213, + Pareto::new(1.0, 1.0).unwrap(), + &[1.0423688f32, 2.1235929, 4.132709, 1.4679428], + ); + test_samples( + 213, + Pareto::new(2.0, 0.5).unwrap(), + &[ + 9.019295276219136f64, + 4.3097126018270595, + 6.837815045397157, + 105.8826669383772, + ], + ); } #[test] fn poisson_stability() { test_samples(223, Poisson::new(7.0).unwrap(), &[5.0f32, 11.0, 6.0, 5.0]); test_samples(223, Poisson::new(7.0).unwrap(), &[9.0f64, 5.0, 7.0, 6.0]); - test_samples(223, Poisson::new(27.0).unwrap(), &[28.0f32, 32.0, 36.0, 36.0]); + test_samples( + 223, + Poisson::new(27.0).unwrap(), + &[28.0f32, 32.0, 36.0, 36.0], + ); } - #[test] fn triangular_stability() { - test_samples(860, Triangular::new(2., 10., 3.).unwrap(), &[ - 5.74373257511361f64, - 7.890059162791258f64, - 4.7256280652553455f64, - 2.9474808121184077f64, - 3.058301946314053f64, - ]); + test_samples( + 860, + Triangular::new(2., 10., 3.).unwrap(), + &[ + 5.74373257511361f64, + 7.890059162791258f64, + 4.7256280652553455f64, + 2.9474808121184077f64, + 3.058301946314053f64, + ], + ); } - #[test] fn normal_inverse_gaussian_stability() { - test_samples(213, NormalInverseGaussian::new(2.0, 1.0).unwrap(), &[ - 0.6568966f32, 1.3744819, 2.216063, 0.11488572, - ]); - test_samples(213, NormalInverseGaussian::new(2.0, 1.0).unwrap(), &[ - 0.6838707059642927f64, - 2.4447306460569784, - 0.2361045023235968, - 1.7774534624785319, - ]); + test_samples( + 213, + NormalInverseGaussian::new(2.0, 1.0).unwrap(), + &[0.6568966f32, 1.3744819, 2.216063, 0.11488572], + ); + test_samples( + 213, + NormalInverseGaussian::new(2.0, 1.0).unwrap(), + &[ + 0.6838707059642927f64, + 2.4447306460569784, + 0.2361045023235968, + 1.7774534624785319, + ], + ); } #[test] fn pert_stability() { // mean = 4, var = 12/7 - test_samples(860, Pert::new(2., 10., 3.).unwrap(), &[ - 4.908681667460367, - 4.014196196158352, - 2.6489397149197234, - 3.4569780580044727, - 4.242864311947118, - ]); + test_samples( + 860, + Pert::new(2., 10.).with_mode(3.).unwrap(), + &[ + 4.908681667460367, + 4.014196196158352, + 2.6489397149197234, + 3.4569780580044727, + 4.242864311947118, + ], + ); } #[test] fn inverse_gaussian_stability() { - test_samples(213, InverseGaussian::new(1.0, 3.0).unwrap(),&[ - 0.9339157f32, 1.108113, 0.50864697, 0.39849377, - ]); - test_samples(213, InverseGaussian::new(1.0, 3.0).unwrap(), &[ - 1.0707604954722476f64, - 0.9628140605340697, - 0.4069687656468226, - 0.660283852985818, - ]); + test_samples( + 213, + InverseGaussian::new(1.0, 3.0).unwrap(), + &[0.9339157f32, 1.108113, 0.50864697, 0.39849377], + ); + test_samples( + 213, + InverseGaussian::new(1.0, 3.0).unwrap(), + &[ + 1.0707604954722476f64, + 0.9628140605340697, + 0.4069687656468226, + 0.660283852985818, + ], + ); } #[test] fn gamma_stability() { // Gamma has 3 cases: shape == 1, shape < 1, shape > 1 - test_samples(223, Gamma::new(1.0, 5.0).unwrap(), &[ - 5.398085f32, 9.162783, 0.2300583, 1.7235851, - ]); - test_samples(223, Gamma::new(0.8, 5.0).unwrap(), &[ - 0.5051203f32, 0.9048302, 3.095812, 1.8566116, - ]); - test_samples(223, Gamma::new(1.1, 5.0).unwrap(), &[ - 7.783878094584059f64, - 1.4939528171618057, - 8.638017638857592, - 3.0949337228829004, - ]); + test_samples( + 223, + Gamma::new(1.0, 5.0).unwrap(), + &[5.398085f32, 9.162783, 0.2300583, 1.7235851], + ); + test_samples( + 223, + Gamma::new(0.8, 5.0).unwrap(), + &[0.5051203f32, 0.9048302, 3.095812, 1.8566116], + ); + test_samples( + 223, + Gamma::new(1.1, 5.0).unwrap(), + &[ + 7.783878094584059f64, + 1.4939528171618057, + 8.638017638857592, + 3.0949337228829004, + ], + ); // ChiSquared has 2 cases: k == 1, k != 1 - test_samples(223, ChiSquared::new(1.0).unwrap(), &[ - 0.4893526200348249f64, - 1.635249736808788, - 0.5013580219361969, - 0.1457735613733489, - ]); - test_samples(223, ChiSquared::new(0.1).unwrap(), &[ - 0.014824404726978617f64, - 0.021602123937134326, - 0.0000003431429746851693, - 0.00000002291755769542258, - ]); - test_samples(223, ChiSquared::new(10.0).unwrap(), &[ - 12.693656f32, 6.812016, 11.082001, 12.436167, - ]); + test_samples( + 223, + ChiSquared::new(1.0).unwrap(), + &[ + 0.4893526200348249f64, + 1.635249736808788, + 0.5013580219361969, + 0.1457735613733489, + ], + ); + test_samples( + 223, + ChiSquared::new(0.1).unwrap(), + &[ + 0.014824404726978617f64, + 0.021602123937134326, + 0.0000003431429746851693, + 0.00000002291755769542258, + ], + ); + test_samples( + 223, + ChiSquared::new(10.0).unwrap(), + &[12.693656f32, 6.812016, 11.082001, 12.436167], + ); // FisherF has same special cases as ChiSquared on each param - test_samples(223, FisherF::new(1.0, 13.5).unwrap(), &[ - 0.32283646f32, 0.048049655, 0.0788893, 1.817178, - ]); - test_samples(223, FisherF::new(1.0, 1.0).unwrap(), &[ - 0.29925257f32, 3.4392934, 9.567652, 0.020074, - ]); - test_samples(223, FisherF::new(0.7, 13.5).unwrap(), &[ - 3.3196593155045124f64, - 0.3409169916262829, - 0.03377989856426519, - 0.00004041672861036937, - ]); + test_samples( + 223, + FisherF::new(1.0, 13.5).unwrap(), + &[0.32283646f32, 0.048049655, 0.0788893, 1.817178], + ); + test_samples( + 223, + FisherF::new(1.0, 1.0).unwrap(), + &[0.29925257f32, 3.4392934, 9.567652, 0.020074], + ); + test_samples( + 223, + FisherF::new(0.7, 13.5).unwrap(), + &[ + 3.3196593155045124f64, + 0.3409169916262829, + 0.03377989856426519, + 0.00004041672861036937, + ], + ); // StudentT has same special cases as ChiSquared - test_samples(223, StudentT::new(1.0).unwrap(), &[ - 0.54703987f32, -1.8545331, 3.093162, -0.14168274, - ]); - test_samples(223, StudentT::new(1.1).unwrap(), &[ - 0.7729195887949754f64, - 1.2606210611616204, - -1.7553606501113175, - -2.377641221169782, - ]); + test_samples( + 223, + StudentT::new(1.0).unwrap(), + &[0.54703987f32, -1.8545331, 3.093162, -0.14168274], + ); + test_samples( + 223, + StudentT::new(1.1).unwrap(), + &[ + 0.7729195887949754f64, + 1.2606210611616204, + -1.7553606501113175, + -2.377641221169782, + ], + ); // Beta has two special cases: // // 1. min(alpha, beta) <= 1 // 2. min(alpha, beta) > 1 - test_samples(223, Beta::new(1.0, 0.8).unwrap(), &[ - 0.8300703726659456, - 0.8134131062097899, - 0.47912589330631555, - 0.25323238071138526, - ]); - test_samples(223, Beta::new(3.0, 1.2).unwrap(), &[ - 0.49563509121756827, - 0.9551305482256759, - 0.5151181353461637, - 0.7551732971235077, - ]); + test_samples( + 223, + Beta::new(1.0, 0.8).unwrap(), + &[ + 0.8300703726659456, + 0.8134131062097899, + 0.47912589330631555, + 0.25323238071138526, + ], + ); + test_samples( + 223, + Beta::new(3.0, 1.2).unwrap(), + &[ + 0.49563509121756827, + 0.9551305482256759, + 0.5151181353461637, + 0.7551732971235077, + ], + ); } #[test] fn exponential_stability() { - test_samples(223, Exp1, &[ - 1.079617f32, 1.8325565, 0.04601166, 0.34471703, - ]); - test_samples(223, Exp1, &[ - 1.0796170642388276f64, - 1.8325565304274, - 0.04601166186842716, - 0.3447170217100157, - ]); - - test_samples(223, Exp::new(2.0).unwrap(), &[ - 0.5398085f32, 0.91627824, 0.02300583, 0.17235851, - ]); - test_samples(223, Exp::new(1.0).unwrap(), &[ - 1.0796170642388276f64, - 1.8325565304274, - 0.04601166186842716, - 0.3447170217100157, - ]); + test_samples(223, Exp1, &[1.079617f32, 1.8325565, 0.04601166, 0.34471703]); + test_samples( + 223, + Exp1, + &[ + 1.0796170642388276f64, + 1.8325565304274, + 0.04601166186842716, + 0.3447170217100157, + ], + ); + + test_samples( + 223, + Exp::new(2.0).unwrap(), + &[0.5398085f32, 0.91627824, 0.02300583, 0.17235851], + ); + test_samples( + 223, + Exp::new(1.0).unwrap(), + &[ + 1.0796170642388276f64, + 1.8325565304274, + 0.04601166186842716, + 0.3447170217100157, + ], + ); } #[test] fn normal_stability() { - test_samples(213, StandardNormal, &[ - -0.11844189f32, 0.781378, 0.06563994, -1.1932899, - ]); - test_samples(213, StandardNormal, &[ - -0.11844188827977231f64, - 0.7813779637772346, - 0.06563993969580051, - -1.1932899004186373, - ]); - - test_samples(213, Normal::new(0.0, 1.0).unwrap(), &[ - -0.11844189f32, 0.781378, 0.06563994, -1.1932899, - ]); - test_samples(213, Normal::new(2.0, 0.5).unwrap(), &[ - 1.940779055860114f64, - 2.3906889818886174, - 2.0328199698479, - 1.4033550497906813, - ]); - - test_samples(213, LogNormal::new(0.0, 1.0).unwrap(), &[ - 0.88830346f32, 2.1844804, 1.0678421, 0.30322206, - ]); - test_samples(213, LogNormal::new(2.0, 0.5).unwrap(), &[ - 6.964174338639032f64, - 10.921015733601452, - 7.6355881556915906, - 4.068828213584092, - ]); + test_samples( + 213, + StandardNormal, + &[-0.11844189f32, 0.781378, 0.06563994, -1.1932899], + ); + test_samples( + 213, + StandardNormal, + &[ + -0.11844188827977231f64, + 0.7813779637772346, + 0.06563993969580051, + -1.1932899004186373, + ], + ); + + test_samples( + 213, + Normal::new(0.0, 1.0).unwrap(), + &[-0.11844189f32, 0.781378, 0.06563994, -1.1932899], + ); + test_samples( + 213, + Normal::new(2.0, 0.5).unwrap(), + &[ + 1.940779055860114f64, + 2.3906889818886174, + 2.0328199698479, + 1.4033550497906813, + ], + ); + + test_samples( + 213, + LogNormal::new(0.0, 1.0).unwrap(), + &[0.88830346f32, 2.1844804, 1.0678421, 0.30322206], + ); + test_samples( + 213, + LogNormal::new(2.0, 0.5).unwrap(), + &[ + 6.964174338639032f64, + 10.921015733601452, + 7.6355881556915906, + 4.068828213584092, + ], + ); } #[test] fn weibull_stability() { - test_samples(213, Weibull::new(1.0, 1.0).unwrap(), &[ - 0.041495778f32, 0.7531094, 1.4189332, 0.38386202, - ]); - test_samples(213, Weibull::new(2.0, 0.5).unwrap(), &[ - 1.1343478702739669f64, - 0.29470010050655226, - 0.7556151370284702, - 7.877212340241561, - ]); + test_samples( + 213, + Weibull::new(1.0, 1.0).unwrap(), + &[0.041495778f32, 0.7531094, 1.4189332, 0.38386202], + ); + test_samples( + 213, + Weibull::new(2.0, 0.5).unwrap(), + &[ + 1.1343478702739669f64, + 0.29470010050655226, + 0.7556151370284702, + 7.877212340241561, + ], + ); } #[cfg(feature = "alloc")] @@ -348,26 +502,43 @@ fn weibull_stability() { fn dirichlet_stability() { let mut rng = get_rng(223); assert_eq!( - rng.sample(Dirichlet::new(&[1.0, 2.0, 3.0]).unwrap()), - vec![0.12941567177708177, 0.4702121891675036, 0.4003721390554146] - ); - assert_eq!(rng.sample(Dirichlet::new_with_size(8.0, 5).unwrap()), vec![ - 0.17684200044809556, - 0.29915953935953055, - 0.1832858056608014, - 0.1425623503573967, - 0.19815030417417595 - ]); + rng.sample(Dirichlet::new([1.0, 2.0, 3.0]).unwrap()), + [0.12941567177708177, 0.4702121891675036, 0.4003721390554146] + ); + assert_eq!( + rng.sample(Dirichlet::new([8.0; 5]).unwrap()), + [ + 0.17684200044809556, + 0.29915953935953055, + 0.1832858056608014, + 0.1425623503573967, + 0.19815030417417595 + ] + ); + // Test stability for the case where all alphas are less than 0.1. + assert_eq!( + rng.sample(Dirichlet::new([0.05, 0.025, 0.075, 0.05]).unwrap()), + [ + 0.00027580456855692104, + 2.296135759821706e-20, + 3.004118281150937e-9, + 0.9997241924273248 + ] + ); } #[test] fn cauchy_stability() { - test_samples(353, Cauchy::new(100f64, 10.0).unwrap(), &[ - 77.93369152808678f64, - 90.1606912098641, - 125.31516221323625, - 86.10217834773925, - ]); + test_samples( + 353, + Cauchy::new(100f64, 10.0).unwrap(), + &[ + 77.93369152808678f64, + 90.1606912098641, + 125.31516221323625, + 86.10217834773925, + ], + ); // Unfortunately this test is not fully portable due to reliance on the // system's implementation of tanf (see doc on Cauchy struct). @@ -376,7 +547,7 @@ fn cauchy_stability() { let mut rng = get_rng(353); let expected = [15.023088, -5.446413, 3.7092876, 3.112482]; for &a in expected.iter() { - let b = rng.sample(&distr); + let b = rng.sample(distr); assert_almost_eq!(a, b, 1e-5); } } diff --git a/rand_hc/CHANGELOG.md b/rand_hc/CHANGELOG.md deleted file mode 100644 index ad9fe4dfb40..00000000000 --- a/rand_hc/CHANGELOG.md +++ /dev/null @@ -1,22 +0,0 @@ -# Changelog -All notable changes to this project will be documented in this file. - -The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) -and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). - -## [0.3.0] - 2020-12-08 -- Bump `rand_core` version to 0.6.0 -- Bump MSRV to 1.36 (#1011) -- impl PartialEq+Eq for Hc128Rng and Hc128Core (#979) -- Drop some unsafe code, fixing an unsound internal function (#960) - -## [0.2.0] - 2019-06-12 -- Bump minor crate version since rand_core bump is a breaking change -- Switch to Edition 2018 - -## [0.1.1] - 2019-06-06 - yanked -- Bump `rand_core` version -- Adjust usage of `#[inline]` - -## [0.1.0] - 2018-10-17 -- Pulled out of the Rand crate diff --git a/rand_hc/COPYRIGHT b/rand_hc/COPYRIGHT deleted file mode 100644 index 468d907caf9..00000000000 --- a/rand_hc/COPYRIGHT +++ /dev/null @@ -1,12 +0,0 @@ -Copyrights in the Rand project are retained by their contributors. No -copyright assignment is required to contribute to the Rand project. - -For full authorship information, see the version control history. - -Except as otherwise noted (below and/or in individual files), Rand is -licensed under the Apache License, Version 2.0 or - or the MIT license - or , at your option. - -The Rand project includes code from the Rust project -published under these same licenses. diff --git a/rand_hc/Cargo.toml b/rand_hc/Cargo.toml deleted file mode 100644 index 403f9f0fb1f..00000000000 --- a/rand_hc/Cargo.toml +++ /dev/null @@ -1,18 +0,0 @@ -[package] -name = "rand_hc" -version = "0.3.0" -authors = ["The Rand Project Developers"] -license = "MIT OR Apache-2.0" -readme = "README.md" -repository = "https://github.com/rust-random/rand" -documentation = "https://docs.rs/rand_hc" -homepage = "https://rust-random.github.io/book" -description = """ -HC128 random number generator -""" -keywords = ["random", "rng", "hc128"] -categories = ["algorithms", "no-std"] -edition = "2018" - -[dependencies] -rand_core = { path = "../rand_core", version = "0.6.0" } diff --git a/rand_hc/LICENSE-APACHE b/rand_hc/LICENSE-APACHE deleted file mode 100644 index 17d74680f8c..00000000000 --- a/rand_hc/LICENSE-APACHE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - https://www.apache.org/licenses/ - -TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - -1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - -2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - -3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - -4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - -5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - -6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - -7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - -8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - -9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - -END OF TERMS AND CONDITIONS - -APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - -Copyright [yyyy] [name of copyright owner] - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. diff --git a/rand_hc/LICENSE-MIT b/rand_hc/LICENSE-MIT deleted file mode 100644 index cf656074cbf..00000000000 --- a/rand_hc/LICENSE-MIT +++ /dev/null @@ -1,25 +0,0 @@ -Copyright 2018 Developers of the Rand project - -Permission is hereby granted, free of charge, to any -person obtaining a copy of this software and associated -documentation files (the "Software"), to deal in the -Software without restriction, including without -limitation the rights to use, copy, modify, merge, -publish, distribute, sublicense, and/or sell copies of -the Software, and to permit persons to whom the Software -is furnished to do so, subject to the following -conditions: - -The above copyright notice and this permission notice -shall be included in all copies or substantial portions -of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF -ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED -TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A -PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT -SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY -CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR -IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -DEALINGS IN THE SOFTWARE. diff --git a/rand_hc/README.md b/rand_hc/README.md deleted file mode 100644 index 87f1c8915c1..00000000000 --- a/rand_hc/README.md +++ /dev/null @@ -1,44 +0,0 @@ -# rand_hc - -[![Test Status](https://github.com/rust-random/rand/workflows/Tests/badge.svg?event=push)](https://github.com/rust-random/rand/actions) -[![Latest version](https://img.shields.io/crates/v/rand_hc.svg)](https://crates.io/crates/rand_hc) -[[![Book](https://img.shields.io/badge/book-master-yellow.svg)](https://rust-random.github.io/book/) -[![API](https://img.shields.io/badge/api-master-yellow.svg)](https://rust-random.github.io/rand/rand_hc) -[![API](https://docs.rs/rand_hc/badge.svg)](https://docs.rs/rand_hc) -[![Minimum rustc version](https://img.shields.io/badge/rustc-1.36+-lightgray.svg)](https://github.com/rust-random/rand#rust-version-requirements) - -A cryptographically secure random number generator that uses the HC-128 -algorithm. - -HC-128 is a stream cipher designed by Hongjun Wu[^1], that we use as an -RNG. It is selected as one of the "stream ciphers suitable for widespread -adoption" by eSTREAM[^2]. - -Links: - -- [API documentation (master)](https://rust-random.github.io/rand/rand_hc) -- [API documentation (docs.rs)](https://docs.rs/rand_hc) -- [Changelog](https://github.com/rust-random/rand/blob/master/rand_hc/CHANGELOG.md) - -[rand]: https://crates.io/crates/rand -[^1]: Hongjun Wu (2008). ["The Stream Cipher HC-128"]( - http://www.ecrypt.eu.org/stream/p3ciphers/hc/hc128_p3.pdf). - *The eSTREAM Finalists*, LNCS 4986, pp. 39–47, Springer-Verlag. - -[^2]: [eSTREAM: the ECRYPT Stream Cipher Project]( - http://www.ecrypt.eu.org/stream/) - - -## Crate Features - -`rand_hc` is `no_std` compatible. It does not require any functionality -outside of the `core` lib, thus there are no features to configure. - - -# License - -`rand_hc` is distributed under the terms of both the MIT license and the -Apache License (Version 2.0). - -See [LICENSE-APACHE](LICENSE-APACHE) and [LICENSE-MIT](LICENSE-MIT), and -[COPYRIGHT](COPYRIGHT) for details. diff --git a/rand_hc/src/hc128.rs b/rand_hc/src/hc128.rs deleted file mode 100644 index 94d75778f75..00000000000 --- a/rand_hc/src/hc128.rs +++ /dev/null @@ -1,513 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! The HC-128 random number generator. - -use core::fmt; -use rand_core::block::{BlockRng, BlockRngCore}; -use rand_core::{le, CryptoRng, Error, RngCore, SeedableRng}; - -const SEED_WORDS: usize = 8; // 128 bit key followed by 128 bit iv - -/// A cryptographically secure random number generator that uses the HC-128 -/// algorithm. -/// -/// HC-128 is a stream cipher designed by Hongjun Wu[^1], that we use as an -/// RNG. It is selected as one of the "stream ciphers suitable for widespread -/// adoption" by eSTREAM[^2]. -/// -/// HC-128 is an array based RNG. In this it is similar to RC-4 and ISAAC before -/// it, but those have never been proven cryptographically secure (or have even -/// been significantly compromised, as in the case of RC-4[^5]). -/// -/// Because HC-128 works with simple indexing into a large array and with a few -/// operations that parallelize well, it has very good performance. The size of -/// the array it needs, 4kb, can however be a disadvantage. -/// -/// This implementation is not based on the version of HC-128 submitted to the -/// eSTREAM contest, but on a later version by the author with a few small -/// improvements from December 15, 2009[^3]. -/// -/// HC-128 has no known weaknesses that are easier to exploit than doing a -/// brute-force search of 2128. A very comprehensive analysis of the -/// current state of known attacks / weaknesses of HC-128 is given in *Some -/// Results On Analysis And Implementation Of HC-128 Stream Cipher*[^4]. -/// -/// The average cycle length is expected to be -/// 21024*32+10-1 = 232777. -/// We support seeding with a 256-bit array, which matches the 128-bit key -/// concatenated with a 128-bit IV from the stream cipher. -/// -/// This implementation uses an output buffer of sixteen `u32` words, and uses -/// [`BlockRng`] to implement the [`RngCore`] methods. -/// -/// ## References -/// [^1]: Hongjun Wu (2008). ["The Stream Cipher HC-128"]( -/// http://www.ecrypt.eu.org/stream/p3ciphers/hc/hc128_p3.pdf). -/// *The eSTREAM Finalists*, LNCS 4986, pp. 39–47, Springer-Verlag. -/// -/// [^2]: [eSTREAM: the ECRYPT Stream Cipher Project]( -/// http://www.ecrypt.eu.org/stream/) -/// -/// [^3]: Hongjun Wu, [Stream Ciphers HC-128 and HC-256]( -/// https://www.ntu.edu.sg/home/wuhj/research/hc/index.html) -/// -/// [^4]: Shashwat Raizada (January 2015),["Some Results On Analysis And -/// Implementation Of HC-128 Stream Cipher"]( -/// http://library.isical.ac.in:8080/jspui/bitstream/123456789/6636/1/TH431.pdf). -/// -/// [^5]: Internet Engineering Task Force (February 2015), -/// ["Prohibiting RC4 Cipher Suites"](https://tools.ietf.org/html/rfc7465). -#[derive(Clone, Debug)] -pub struct Hc128Rng(BlockRng); - -impl RngCore for Hc128Rng { - #[inline] - fn next_u32(&mut self) -> u32 { - self.0.next_u32() - } - - #[inline] - fn next_u64(&mut self) -> u64 { - self.0.next_u64() - } - - #[inline] - fn fill_bytes(&mut self, dest: &mut [u8]) { - self.0.fill_bytes(dest) - } - - #[inline] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.0.try_fill_bytes(dest) - } -} - -impl SeedableRng for Hc128Rng { - type Seed = ::Seed; - - #[inline] - fn from_seed(seed: Self::Seed) -> Self { - Hc128Rng(BlockRng::::from_seed(seed)) - } - - #[inline] - fn from_rng(rng: R) -> Result { - BlockRng::::from_rng(rng).map(Hc128Rng) - } -} - -impl CryptoRng for Hc128Rng {} - -impl PartialEq for Hc128Rng { - fn eq(&self, rhs: &Self) -> bool { - self.0.core == rhs.0.core && self.0.index() == rhs.0.index() - } -} -impl Eq for Hc128Rng {} - -/// The core of `Hc128Rng`, used with `BlockRng`. -#[derive(Clone)] -pub struct Hc128Core { - t: [u32; 1024], - counter1024: usize, -} - -// Custom Debug implementation that does not expose the internal state -impl fmt::Debug for Hc128Core { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "Hc128Core {{}}") - } -} - -impl BlockRngCore for Hc128Core { - type Item = u32; - type Results = [u32; 16]; - - fn generate(&mut self, results: &mut Self::Results) { - assert!(self.counter1024 % 16 == 0); - - let cc = self.counter1024 % 512; - let dd = (cc + 16) % 512; - let ee = cc.wrapping_sub(16) % 512; - // These asserts let the compiler optimize out the bounds checks. - // Some of them may be superflous, and that's fine: - // they'll be optimized out if that's the case. - assert!(ee + 15 < 512); - assert!(cc + 15 < 512); - assert!(dd < 512); - - if self.counter1024 & 512 == 0 { - // P block - results[0] = self.step_p(cc+0, cc+1, ee+13, ee+6, ee+4); - results[1] = self.step_p(cc+1, cc+2, ee+14, ee+7, ee+5); - results[2] = self.step_p(cc+2, cc+3, ee+15, ee+8, ee+6); - results[3] = self.step_p(cc+3, cc+4, cc+0, ee+9, ee+7); - results[4] = self.step_p(cc+4, cc+5, cc+1, ee+10, ee+8); - results[5] = self.step_p(cc+5, cc+6, cc+2, ee+11, ee+9); - results[6] = self.step_p(cc+6, cc+7, cc+3, ee+12, ee+10); - results[7] = self.step_p(cc+7, cc+8, cc+4, ee+13, ee+11); - results[8] = self.step_p(cc+8, cc+9, cc+5, ee+14, ee+12); - results[9] = self.step_p(cc+9, cc+10, cc+6, ee+15, ee+13); - results[10] = self.step_p(cc+10, cc+11, cc+7, cc+0, ee+14); - results[11] = self.step_p(cc+11, cc+12, cc+8, cc+1, ee+15); - results[12] = self.step_p(cc+12, cc+13, cc+9, cc+2, cc+0); - results[13] = self.step_p(cc+13, cc+14, cc+10, cc+3, cc+1); - results[14] = self.step_p(cc+14, cc+15, cc+11, cc+4, cc+2); - results[15] = self.step_p(cc+15, dd+0, cc+12, cc+5, cc+3); - } else { - // Q block - results[0] = self.step_q(cc+0, cc+1, ee+13, ee+6, ee+4); - results[1] = self.step_q(cc+1, cc+2, ee+14, ee+7, ee+5); - results[2] = self.step_q(cc+2, cc+3, ee+15, ee+8, ee+6); - results[3] = self.step_q(cc+3, cc+4, cc+0, ee+9, ee+7); - results[4] = self.step_q(cc+4, cc+5, cc+1, ee+10, ee+8); - results[5] = self.step_q(cc+5, cc+6, cc+2, ee+11, ee+9); - results[6] = self.step_q(cc+6, cc+7, cc+3, ee+12, ee+10); - results[7] = self.step_q(cc+7, cc+8, cc+4, ee+13, ee+11); - results[8] = self.step_q(cc+8, cc+9, cc+5, ee+14, ee+12); - results[9] = self.step_q(cc+9, cc+10, cc+6, ee+15, ee+13); - results[10] = self.step_q(cc+10, cc+11, cc+7, cc+0, ee+14); - results[11] = self.step_q(cc+11, cc+12, cc+8, cc+1, ee+15); - results[12] = self.step_q(cc+12, cc+13, cc+9, cc+2, cc+0); - results[13] = self.step_q(cc+13, cc+14, cc+10, cc+3, cc+1); - results[14] = self.step_q(cc+14, cc+15, cc+11, cc+4, cc+2); - results[15] = self.step_q(cc+15, dd+0, cc+12, cc+5, cc+3); - } - self.counter1024 = self.counter1024.wrapping_add(16); - } -} - -impl Hc128Core { - // One step of HC-128, update P and generate 32 bits keystream - #[inline(always)] - fn step_p(&mut self, i: usize, i511: usize, i3: usize, i10: usize, i12: usize) -> u32 { - let (p, q) = self.t.split_at_mut(512); - let temp0 = p[i511].rotate_right(23); - let temp1 = p[i3].rotate_right(10); - let temp2 = p[i10].rotate_right(8); - p[i] = p[i] - .wrapping_add(temp2) - .wrapping_add(temp0 ^ temp1); - let temp3 = { - // The h1 function in HC-128 - let a = p[i12] as u8; - let c = (p[i12] >> 16) as u8; - q[a as usize].wrapping_add(q[256 + c as usize]) - }; - temp3 ^ p[i] - } - - // One step of HC-128, update Q and generate 32 bits keystream - // Similar to `step_p`, but `p` and `q` are swapped, and the rotates are to - // the left instead of to the right. - #[inline(always)] - fn step_q(&mut self, i: usize, i511: usize, i3: usize, i10: usize, i12: usize) -> u32 { - let (p, q) = self.t.split_at_mut(512); - let temp0 = q[i511].rotate_left(23); - let temp1 = q[i3].rotate_left(10); - let temp2 = q[i10].rotate_left(8); - q[i] = q - [i] - .wrapping_add(temp2) - .wrapping_add(temp0 ^ temp1); - let temp3 = { - // The h2 function in HC-128 - let a = q[i12] as u8; - let c = (q[i12] >> 16) as u8; - p[a as usize].wrapping_add(p[256 + c as usize]) - }; - temp3 ^ q[i] - } - - fn sixteen_steps(&mut self) { - assert!(self.counter1024 % 16 == 0); - - let cc = self.counter1024 % 512; - let dd = (cc + 16) % 512; - let ee = cc.wrapping_sub(16) % 512; - // These asserts let the compiler optimize out the bounds checks. - // Some of them may be superflous, and that's fine: - // they'll be optimized out if that's the case. - assert!(ee + 15 < 512); - assert!(cc + 15 < 512); - assert!(dd < 512); - - if self.counter1024 < 512 { - // P block - self.t[cc+0] = self.step_p(cc+0, cc+1, ee+13, ee+6, ee+4); - self.t[cc+1] = self.step_p(cc+1, cc+2, ee+14, ee+7, ee+5); - self.t[cc+2] = self.step_p(cc+2, cc+3, ee+15, ee+8, ee+6); - self.t[cc+3] = self.step_p(cc+3, cc+4, cc+0, ee+9, ee+7); - self.t[cc+4] = self.step_p(cc+4, cc+5, cc+1, ee+10, ee+8); - self.t[cc+5] = self.step_p(cc+5, cc+6, cc+2, ee+11, ee+9); - self.t[cc+6] = self.step_p(cc+6, cc+7, cc+3, ee+12, ee+10); - self.t[cc+7] = self.step_p(cc+7, cc+8, cc+4, ee+13, ee+11); - self.t[cc+8] = self.step_p(cc+8, cc+9, cc+5, ee+14, ee+12); - self.t[cc+9] = self.step_p(cc+9, cc+10, cc+6, ee+15, ee+13); - self.t[cc+10] = self.step_p(cc+10, cc+11, cc+7, cc+0, ee+14); - self.t[cc+11] = self.step_p(cc+11, cc+12, cc+8, cc+1, ee+15); - self.t[cc+12] = self.step_p(cc+12, cc+13, cc+9, cc+2, cc+0); - self.t[cc+13] = self.step_p(cc+13, cc+14, cc+10, cc+3, cc+1); - self.t[cc+14] = self.step_p(cc+14, cc+15, cc+11, cc+4, cc+2); - self.t[cc+15] = self.step_p(cc+15, dd+0, cc+12, cc+5, cc+3); - } else { - // Q block - self.t[cc+512+0] = self.step_q(cc+0, cc+1, ee+13, ee+6, ee+4); - self.t[cc+512+1] = self.step_q(cc+1, cc+2, ee+14, ee+7, ee+5); - self.t[cc+512+2] = self.step_q(cc+2, cc+3, ee+15, ee+8, ee+6); - self.t[cc+512+3] = self.step_q(cc+3, cc+4, cc+0, ee+9, ee+7); - self.t[cc+512+4] = self.step_q(cc+4, cc+5, cc+1, ee+10, ee+8); - self.t[cc+512+5] = self.step_q(cc+5, cc+6, cc+2, ee+11, ee+9); - self.t[cc+512+6] = self.step_q(cc+6, cc+7, cc+3, ee+12, ee+10); - self.t[cc+512+7] = self.step_q(cc+7, cc+8, cc+4, ee+13, ee+11); - self.t[cc+512+8] = self.step_q(cc+8, cc+9, cc+5, ee+14, ee+12); - self.t[cc+512+9] = self.step_q(cc+9, cc+10, cc+6, ee+15, ee+13); - self.t[cc+512+10] = self.step_q(cc+10, cc+11, cc+7, cc+0, ee+14); - self.t[cc+512+11] = self.step_q(cc+11, cc+12, cc+8, cc+1, ee+15); - self.t[cc+512+12] = self.step_q(cc+12, cc+13, cc+9, cc+2, cc+0); - self.t[cc+512+13] = self.step_q(cc+13, cc+14, cc+10, cc+3, cc+1); - self.t[cc+512+14] = self.step_q(cc+14, cc+15, cc+11, cc+4, cc+2); - self.t[cc+512+15] = self.step_q(cc+15, dd+0, cc+12, cc+5, cc+3); - } - self.counter1024 += 16; - } - - // Initialize an HC-128 random number generator. The seed has to be - // 256 bits in length (`[u32; 8]`), matching the 128 bit `key` followed by - // 128 bit `iv` when HC-128 where to be used as a stream cipher. - #[inline(always)] // single use: SeedableRng::from_seed - fn init(seed: [u32; SEED_WORDS]) -> Self { - #[inline] - fn f1(x: u32) -> u32 { - x.rotate_right(7) ^ x.rotate_right(18) ^ (x >> 3) - } - - #[inline] - fn f2(x: u32) -> u32 { - x.rotate_right(17) ^ x.rotate_right(19) ^ (x >> 10) - } - - let mut t = [0u32; 1024]; - - // Expand the key and iv into P and Q - let (key, iv) = seed.split_at(4); - t[..4].copy_from_slice(key); - t[4..8].copy_from_slice(key); - t[8..12].copy_from_slice(iv); - t[12..16].copy_from_slice(iv); - - // Generate the 256 intermediate values W[16] ... W[256+16-1], and - // copy the last 16 generated values to the start op P. - for i in 16..256 + 16 { - t[i] = f2(t[i - 2]) - .wrapping_add(t[i - 7]) - .wrapping_add(f1(t[i - 15])) - .wrapping_add(t[i - 16]) - .wrapping_add(i as u32); - } - { - let (p1, p2) = t.split_at_mut(256); - p1[0..16].copy_from_slice(&p2[0..16]); - } - - // Generate both the P and Q tables - for i in 16..1024 { - t[i] = f2(t[i - 2]) - .wrapping_add(t[i - 7]) - .wrapping_add(f1(t[i - 15])) - .wrapping_add(t[i - 16]) - .wrapping_add(256 + i as u32); - } - - let mut core = Self { t, counter1024: 0 }; - - // run the cipher 1024 steps - for _ in 0..64 { - core.sixteen_steps() - } - core.counter1024 = 0; - core - } -} - -impl SeedableRng for Hc128Core { - type Seed = [u8; SEED_WORDS * 4]; - - /// Create an HC-128 random number generator with a seed. The seed has to be - /// 256 bits in length, matching the 128 bit `key` followed by 128 bit `iv` - /// when HC-128 where to be used as a stream cipher. - fn from_seed(seed: Self::Seed) -> Self { - let mut seed_u32 = [0u32; SEED_WORDS]; - le::read_u32_into(&seed, &mut seed_u32); - Self::init(seed_u32) - } -} - -impl CryptoRng for Hc128Core {} - -// Custom PartialEq implementation as it can't currently be derived from an array of size 1024 -impl PartialEq for Hc128Core { - fn eq(&self, rhs: &Self) -> bool { - &self.t[..] == &rhs.t[..] && self.counter1024 == rhs.counter1024 - } -} -impl Eq for Hc128Core {} - -#[cfg(test)] -mod test { - use super::Hc128Rng; - use ::rand_core::{RngCore, SeedableRng}; - - #[test] - // Test vector 1 from the paper "The Stream Cipher HC-128" - fn test_hc128_true_values_a() { - #[rustfmt::skip] - let seed = [0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0, // key - 0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0]; // iv - let mut rng = Hc128Rng::from_seed(seed); - - let mut results = [0u32; 16]; - for i in results.iter_mut() { - *i = rng.next_u32(); - } - #[rustfmt::skip] - let expected = [0x73150082, 0x3bfd03a0, 0xfb2fd77f, 0xaa63af0e, - 0xde122fc6, 0xa7dc29b6, 0x62a68527, 0x8b75ec68, - 0x9036db1e, 0x81896005, 0x00ade078, 0x491fbf9a, - 0x1cdc3013, 0x6c3d6e24, 0x90f664b2, 0x9cd57102]; - assert_eq!(results, expected); - } - - #[test] - // Test vector 2 from the paper "The Stream Cipher HC-128" - fn test_hc128_true_values_b() { - #[rustfmt::skip] - let seed = [0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0, // key - 1,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0]; // iv - let mut rng = Hc128Rng::from_seed(seed); - - let mut results = [0u32; 16]; - for i in results.iter_mut() { - *i = rng.next_u32(); - } - #[rustfmt::skip] - let expected = [0xc01893d5, 0xb7dbe958, 0x8f65ec98, 0x64176604, - 0x36fc6724, 0xc82c6eec, 0x1b1c38a7, 0xc9b42a95, - 0x323ef123, 0x0a6a908b, 0xce757b68, 0x9f14f7bb, - 0xe4cde011, 0xaeb5173f, 0x89608c94, 0xb5cf46ca]; - assert_eq!(results, expected); - } - - #[test] - // Test vector 3 from the paper "The Stream Cipher HC-128" - fn test_hc128_true_values_c() { - #[rustfmt::skip] - let seed = [0x55,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0, // key - 0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0]; // iv - let mut rng = Hc128Rng::from_seed(seed); - - let mut results = [0u32; 16]; - for i in results.iter_mut() { - *i = rng.next_u32(); - } - #[rustfmt::skip] - let expected = [0x518251a4, 0x04b4930a, 0xb02af931, 0x0639f032, - 0xbcb4a47a, 0x5722480b, 0x2bf99f72, 0xcdc0e566, - 0x310f0c56, 0xd3cc83e8, 0x663db8ef, 0x62dfe07f, - 0x593e1790, 0xc5ceaa9c, 0xab03806f, 0xc9a6e5a0]; - assert_eq!(results, expected); - } - - #[test] - fn test_hc128_true_values_u64() { - #[rustfmt::skip] - let seed = [0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0, // key - 0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0]; // iv - let mut rng = Hc128Rng::from_seed(seed); - - let mut results = [0u64; 8]; - for i in results.iter_mut() { - *i = rng.next_u64(); - } - #[rustfmt::skip] - let expected = [0x3bfd03a073150082, 0xaa63af0efb2fd77f, - 0xa7dc29b6de122fc6, 0x8b75ec6862a68527, - 0x818960059036db1e, 0x491fbf9a00ade078, - 0x6c3d6e241cdc3013, 0x9cd5710290f664b2]; - assert_eq!(results, expected); - - // The RNG operates in a P block of 512 results and next a Q block. - // After skipping 2*800 u32 results we end up somewhere in the Q block - // of the second round - for _ in 0..800 { - rng.next_u64(); - } - - for i in results.iter_mut() { - *i = rng.next_u64(); - } - #[rustfmt::skip] - let expected = [0xd8c4d6ca84d0fc10, 0xf16a5d91dc66e8e7, - 0xd800de5bc37a8653, 0x7bae1f88c0dfbb4c, - 0x3bfe1f374e6d4d14, 0x424b55676be3fa06, - 0xe3a1e8758cbff579, 0x417f7198c5652bcd]; - assert_eq!(results, expected); - } - - #[test] - fn test_hc128_true_values_bytes() { - #[rustfmt::skip] - let seed = [0x55,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0, // key - 0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0]; // iv - let mut rng = Hc128Rng::from_seed(seed); - #[rustfmt::skip] - let expected = [0x31, 0xf9, 0x2a, 0xb0, 0x32, 0xf0, 0x39, 0x06, - 0x7a, 0xa4, 0xb4, 0xbc, 0x0b, 0x48, 0x22, 0x57, - 0x72, 0x9f, 0xf9, 0x2b, 0x66, 0xe5, 0xc0, 0xcd, - 0x56, 0x0c, 0x0f, 0x31, 0xe8, 0x83, 0xcc, 0xd3, - 0xef, 0xb8, 0x3d, 0x66, 0x7f, 0xe0, 0xdf, 0x62, - 0x90, 0x17, 0x3e, 0x59, 0x9c, 0xaa, 0xce, 0xc5, - 0x6f, 0x80, 0x03, 0xab, 0xa0, 0xe5, 0xa6, 0xc9, - 0x60, 0x95, 0x84, 0x7a, 0xa5, 0x68, 0x5a, 0x84, - 0xea, 0xd5, 0xf3, 0xea, 0x73, 0xa9, 0xad, 0x01, - 0x79, 0x7d, 0xbe, 0x9f, 0xea, 0xe3, 0xf9, 0x74, - 0x0e, 0xda, 0x2f, 0xa0, 0xe4, 0x7b, 0x4b, 0x1b, - 0xdd, 0x17, 0x69, 0x4a, 0xfe, 0x9f, 0x56, 0x95, - 0xad, 0x83, 0x6b, 0x9d, 0x60, 0xa1, 0x99, 0x96, - 0x90, 0x00, 0x66, 0x7f, 0xfa, 0x7e, 0x65, 0xe9, - 0xac, 0x8b, 0x92, 0x34, 0x77, 0xb4, 0x23, 0xd0, - 0xb9, 0xab, 0xb1, 0x47, 0x7d, 0x4a, 0x13, 0x0a]; - - // Pick a somewhat large buffer so we can test filling with the - // remainder from `state.results`, directly filling the buffer, and - // filling the remainder of the buffer. - let mut buffer = [0u8; 16 * 4 * 2]; - // Consume a value so that we have a remainder. - assert!(rng.next_u64() == 0x04b4930a518251a4); - rng.fill_bytes(&mut buffer); - - // [u8; 128] doesn't implement PartialEq - assert_eq!(buffer.len(), expected.len()); - for (b, e) in buffer.iter().zip(expected.iter()) { - assert_eq!(b, e); - } - } - - #[test] - fn test_hc128_clone() { - #[rustfmt::skip] - let seed = [0x55,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0, // key - 0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0]; // iv - let mut rng1 = Hc128Rng::from_seed(seed); - let mut rng2 = rng1.clone(); - for _ in 0..16 { - assert_eq!(rng1.next_u32(), rng2.next_u32()); - } - } -} diff --git a/rand_hc/src/lib.rs b/rand_hc/src/lib.rs deleted file mode 100644 index 995cb1d043d..00000000000 --- a/rand_hc/src/lib.rs +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! The HC128 random number generator. - -#![doc( - html_logo_url = "https://www.rust-lang.org/logos/rust-logo-128x128-blk.png", - html_favicon_url = "https://www.rust-lang.org/favicon.ico", - html_root_url = "https://rust-random.github.io/rand/" -)] -#![deny(missing_docs)] -#![deny(missing_debug_implementations)] -#![doc(test(attr(allow(unused_variables), deny(warnings))))] -#![no_std] - -mod hc128; - -pub use hc128::{Hc128Core, Hc128Rng}; diff --git a/rand_pcg/CHANGELOG.md b/rand_pcg/CHANGELOG.md index 7c929789e11..bab1cd0e8c8 100644 --- a/rand_pcg/CHANGELOG.md +++ b/rand_pcg/CHANGELOG.md @@ -4,6 +4,21 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.9.0] - 2025-01-27 +### Dependencies and features +- Update to `rand_core` v0.9.0 (#1558) +- Rename feature `serde1` to `serde` (#1477) +- Rename feature `getrandom` to `os_rng` (#1537) + +### Other changes +- Add `Lcg128CmDxsm64` generator compatible with NumPy's `PCG64DXSM` (#1202) +- Add examples for initializing the RNGs (#1352) +- Revise crate docs (#1454) + +## [0.3.1] - 2021-06-15 +- Add `advance` methods to RNGs (#1111) +- Document dependencies between streams (#1122) + ## [0.3.0] - 2020-12-08 - Bump `rand_core` version to 0.6.0 - Bump MSRV to 1.36 (#1011) diff --git a/rand_pcg/Cargo.toml b/rand_pcg/Cargo.toml index 4a3f90eff94..74740950712 100644 --- a/rand_pcg/Cargo.toml +++ b/rand_pcg/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rand_pcg" -version = "0.3.0" +version = "0.9.0" authors = ["The Rand Project Developers"] license = "MIT OR Apache-2.0" readme = "README.md" @@ -12,13 +12,19 @@ Selected PCG random number generators """ keywords = ["random", "rng", "pcg"] categories = ["algorithms", "no-std"] -edition = "2018" +edition = "2021" +rust-version = "1.63" + +[package.metadata.docs.rs] +all-features = true +rustdoc-args = ["--generate-link-to-definition"] [features] -serde1 = ["serde"] +serde = ["dep:serde"] +os_rng = ["rand_core/os_rng"] [dependencies] -rand_core = { path = "../rand_core", version = "0.6.0" } +rand_core = { path = "../rand_core", version = "0.9.0" } serde = { version = "1", features = ["derive"], optional = true } [dev-dependencies] @@ -26,3 +32,4 @@ serde = { version = "1", features = ["derive"], optional = true } # deps yet, see: https://github.com/rust-lang/cargo/issues/1596 # Versions prior to 1.1.4 had incorrect minimal dependencies. bincode = { version = "1.1.4" } +rand_core = { path = "../rand_core", version = "0.9.0", features = ["os_rng"] } diff --git a/rand_pcg/LICENSE-APACHE b/rand_pcg/LICENSE-APACHE index 17d74680f8c..455787c2334 100644 --- a/rand_pcg/LICENSE-APACHE +++ b/rand_pcg/LICENSE-APACHE @@ -185,17 +185,3 @@ APPENDIX: How to apply the Apache License to your work. file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. - -Copyright [yyyy] [name of copyright owner] - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. diff --git a/rand_pcg/README.md b/rand_pcg/README.md index 736a789035c..50e91e59795 100644 --- a/rand_pcg/README.md +++ b/rand_pcg/README.md @@ -1,11 +1,10 @@ # rand_pcg -[![Test Status](https://github.com/rust-random/rand/workflows/Tests/badge.svg?event=push)](https://github.com/rust-random/rand/actions) +[![Test Status](https://github.com/rust-random/rand/actions/workflows/test.yml/badge.svg?event=push)](https://github.com/rust-random/rand/actions) [![Latest version](https://img.shields.io/crates/v/rand_pcg.svg)](https://crates.io/crates/rand_pcg) [![Book](https://img.shields.io/badge/book-master-yellow.svg)](https://rust-random.github.io/book/) [![API](https://img.shields.io/badge/api-master-yellow.svg)](https://rust-random.github.io/rand/rand_pcg) [![API](https://docs.rs/rand_pcg/badge.svg)](https://docs.rs/rand_pcg) -[![Minimum rustc version](https://img.shields.io/badge/rustc-1.36+-lightgray.svg)](https://github.com/rust-random/rand#rust-version-requirements) Implements a selection of PCG random number generators. @@ -30,7 +29,7 @@ Links: `rand_pcg` is `no_std` compatible by default. -The `serde1` feature includes implementations of `Serialize` and `Deserialize` +The `serde` feature includes implementations of `Serialize` and `Deserialize` for the included RNGs. ## License diff --git a/rand_pcg/src/lib.rs b/rand_pcg/src/lib.rs index c25bc439a4c..6b9d9d833f0 100644 --- a/rand_pcg/src/lib.rs +++ b/rand_pcg/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright 2018 Developers of the Rand project. +// Copyright 2018-2023 Developers of the Rand project. // // Licensed under the Apache License, Version 2.0 or the MIT license @@ -8,38 +8,93 @@ //! The PCG random number generators. //! -//! This is a native Rust implementation of a small selection of PCG generators. +//! This is a native Rust implementation of a small selection of [PCG generators]. //! The primary goal of this crate is simple, minimal, well-tested code; in //! other words it is explicitly not a goal to re-implement all of PCG. //! +//! ## Generators +//! //! This crate provides: //! -//! - `Pcg32` aka `Lcg64Xsh32`, officially known as `pcg32`, a general +//! - [`Pcg32`] aka [`Lcg64Xsh32`], officially known as `pcg32`, a general //! purpose RNG. This is a good choice on both 32-bit and 64-bit CPUs //! (for 32-bit output). -//! - `Pcg64` aka `Lcg128Xsl64`, officially known as `pcg64`, a general +//! - [`Pcg64`] aka [`Lcg128Xsl64`], officially known as `pcg64`, a general //! purpose RNG. This is a good choice on 64-bit CPUs. -//! - `Pcg64Mcg` aka `Mcg128Xsl64`, officially known as `pcg64_fast`, +//! - [`Pcg64Mcg`] aka [`Mcg128Xsl64`], officially known as `pcg64_fast`, //! a general purpose RNG using 128-bit multiplications. This has poor //! performance on 32-bit CPUs but is a good choice on 64-bit CPUs for //! both 32-bit and 64-bit output. //! -//! Both of these use 16 bytes of state and 128-bit seeds, and are considered -//! value-stable (i.e. any change affecting the output given a fixed seed would -//! be considered a breaking change to the crate). +//! These generators are all deterministic and portable (see [Reproducibility] +//! in the book), with testing against reference vectors. +//! +//! ## Seeding (construction) +//! +//! Generators implement the [`SeedableRng`] trait. All methods are suitable for +//! seeding. Some suggestions: +//! +//! 1. To automatically seed with a unique seed, use [`SeedableRng::from_rng`] +//! with a master generator (here [`rand::rng()`](https://docs.rs/rand/latest/rand/fn.rng.html)): +//! ```ignore +//! use rand_core::SeedableRng; +//! use rand_pcg::Pcg64Mcg; +//! let rng = Pcg64Mcg::from_rng(&mut rand::rng()); +//! # let _: Pcg64Mcg = rng; +//! ``` +//! 2. Seed **from an integer** via `seed_from_u64`. This uses a hash function +//! internally to yield a (typically) good seed from any input. +//! ``` +//! # use {rand_core::SeedableRng, rand_pcg::Pcg64Mcg}; +//! let rng = Pcg64Mcg::seed_from_u64(1); +//! # let _: Pcg64Mcg = rng; +//! ``` +//! +//! See also [Seeding RNGs] in the book. +//! +//! ## Generation +//! +//! Generators implement [`RngCore`], whose methods may be used directly to +//! generate unbounded integer or byte values. +//! ``` +//! use rand_core::{SeedableRng, RngCore}; +//! use rand_pcg::Pcg64Mcg; +//! +//! let mut rng = Pcg64Mcg::seed_from_u64(0); +//! let x = rng.next_u64(); +//! assert_eq!(x, 0x5603f242407deca2); +//! ``` +//! +//! It is often more convenient to use the [`rand::Rng`] trait, which provides +//! further functionality. See also the [Random Values] chapter in the book. +//! +//! [PCG generators]: https://www.pcg-random.org/ +//! [Reproducibility]: https://rust-random.github.io/book/crate-reprod.html +//! [Seeding RNGs]: https://rust-random.github.io/book/guide-seeding.html +//! [Random Values]: https://rust-random.github.io/book/guide-values.html +//! [`RngCore`]: rand_core::RngCore +//! [`SeedableRng`]: rand_core::SeedableRng +//! [`SeedableRng::from_rng`]: rand_core::SeedableRng#method.from_rng +//! [`rand::rng`]: https://docs.rs/rand/latest/rand/fn.rng.html +//! [`rand::Rng`]: https://docs.rs/rand/latest/rand/trait.Rng.html +//! [`rand_chacha::ChaCha8Rng`]: https://docs.rs/rand_chacha/latest/rand_chacha/struct.ChaCha8Rng.html #![doc( html_logo_url = "https://www.rust-lang.org/logos/rust-logo-128x128-blk.png", html_favicon_url = "https://www.rust-lang.org/favicon.ico", html_root_url = "https://rust-random.github.io/rand/" )] +#![forbid(unsafe_code)] #![deny(missing_docs)] #![deny(missing_debug_implementations)] #![no_std] -#[cfg(not(target_os = "emscripten"))] mod pcg128; +mod pcg128; +mod pcg128cm; mod pcg64; -#[cfg(not(target_os = "emscripten"))] +pub use rand_core; + pub use self::pcg128::{Lcg128Xsl64, Mcg128Xsl64, Pcg64, Pcg64Mcg}; +pub use self::pcg128cm::{Lcg128CmDxsm64, Pcg64Dxsm}; pub use self::pcg64::{Lcg64Xsh32, Pcg32}; diff --git a/rand_pcg/src/pcg128.rs b/rand_pcg/src/pcg128.rs index 58a8e36b7c1..990303c41fb 100644 --- a/rand_pcg/src/pcg128.rs +++ b/rand_pcg/src/pcg128.rs @@ -14,8 +14,9 @@ const MULTIPLIER: u128 = 0x2360_ED05_1FC6_5DA4_4385_DF64_9FCC_F645; use core::fmt; -use rand_core::{le, Error, RngCore, SeedableRng}; -#[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; +use rand_core::{impls, le, RngCore, SeedableRng}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; /// A PCG random number generator (XSL RR 128/64 (LCG) variant). /// @@ -29,8 +30,11 @@ use rand_core::{le, Error, RngCore, SeedableRng}; /// Despite the name, this implementation uses 32 bytes (256 bit) space /// comprising 128 bits of state and 128 bits stream selector. These are both /// set by `SeedableRng`, using a 256-bit seed. +/// +/// Note that two generators with different stream parameters may be closely +/// correlated. #[derive(Clone, PartialEq, Eq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Lcg128Xsl64 { state: u128, increment: u128, @@ -40,9 +44,47 @@ pub struct Lcg128Xsl64 { pub type Pcg64 = Lcg128Xsl64; impl Lcg128Xsl64 { + /// Multi-step advance functions (jump-ahead, jump-back) + /// + /// The method used here is based on Brown, "Random Number Generation + /// with Arbitrary Stride,", Transactions of the American Nuclear + /// Society (Nov. 1994). The algorithm is very similar to fast + /// exponentiation. + /// + /// Even though delta is an unsigned integer, we can pass a + /// signed integer to go backwards, it just goes "the long way round". + /// + /// Using this function is equivalent to calling `next_64()` `delta` + /// number of times. + #[inline] + pub fn advance(&mut self, delta: u128) { + let mut acc_mult: u128 = 1; + let mut acc_plus: u128 = 0; + let mut cur_mult = MULTIPLIER; + let mut cur_plus = self.increment; + let mut mdelta = delta; + + while mdelta > 0 { + if (mdelta & 1) != 0 { + acc_mult = acc_mult.wrapping_mul(cur_mult); + acc_plus = acc_plus.wrapping_mul(cur_mult).wrapping_add(cur_plus); + } + cur_plus = cur_mult.wrapping_add(1).wrapping_mul(cur_plus); + cur_mult = cur_mult.wrapping_mul(cur_mult); + mdelta /= 2; + } + self.state = acc_mult.wrapping_mul(self.state).wrapping_add(acc_plus); + } + /// Construct an instance compatible with PCG seed and stream. /// - /// Note that PCG specifies default values for both parameters: + /// Note that the highest bit of the `stream` parameter is discarded + /// to simplify upholding internal invariants. + /// + /// Note that two generators with different stream parameters may be closely + /// correlated. + /// + /// PCG specifies the following default values for both parameters: /// /// - `state = 0xcafef00dd15ea5e5` /// - `stream = 0xa02bdbf7bb3c0a7ac28fa16a64abf96` @@ -55,7 +97,7 @@ impl Lcg128Xsl64 { #[inline] fn from_state_incr(state: u128, increment: u128) -> Self { let mut pcg = Lcg128Xsl64 { state, increment }; - // Move away from inital value: + // Move away from initial value: pcg.state = pcg.state.wrapping_add(pcg.increment); pcg.step(); pcg @@ -78,11 +120,11 @@ impl fmt::Debug for Lcg128Xsl64 { } } -/// We use a single 255-bit seed to initialise the state and select a stream. -/// One `seed` bit (lowest bit of `seed[8]`) is ignored. impl SeedableRng for Lcg128Xsl64 { type Seed = [u8; 32]; + /// We use a single 255-bit seed to initialise the state and select a stream. + /// One `seed` bit (lowest bit of `seed[8]`) is ignored. fn from_seed(seed: Self::Seed) -> Self { let mut seed_u64 = [0u64; 4]; le::read_u64_into(&seed, &mut seed_u64); @@ -108,17 +150,10 @@ impl RngCore for Lcg128Xsl64 { #[inline] fn fill_bytes(&mut self, dest: &mut [u8]) { - fill_bytes_impl(self, dest) - } - - #[inline] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.fill_bytes(dest); - Ok(()) + impls::fill_bytes_via_next(self, dest) } } - /// A PCG random number generator (XSL 128/64 (MCG) variant). /// /// Permuted Congruential Generator with 128-bit state, internal Multiplicative @@ -131,7 +166,7 @@ impl RngCore for Lcg128Xsl64 { /// output function), this RNG is faster, also has a long cycle, and still has /// good performance on statistical tests. #[derive(Clone, PartialEq, Eq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Mcg128Xsl64 { state: u128, } @@ -140,6 +175,38 @@ pub struct Mcg128Xsl64 { pub type Pcg64Mcg = Mcg128Xsl64; impl Mcg128Xsl64 { + /// Multi-step advance functions (jump-ahead, jump-back) + /// + /// The method used here is based on Brown, "Random Number Generation + /// with Arbitrary Stride,", Transactions of the American Nuclear + /// Society (Nov. 1994). The algorithm is very similar to fast + /// exponentiation. + /// + /// Even though delta is an unsigned integer, we can pass a + /// signed integer to go backwards, it just goes "the long way round". + /// + /// Using this function is equivalent to calling `next_64()` `delta` + /// number of times. + #[inline] + pub fn advance(&mut self, delta: u128) { + let mut acc_mult: u128 = 1; + let mut acc_plus: u128 = 0; + let mut cur_mult = MULTIPLIER; + let mut cur_plus: u128 = 0; + let mut mdelta = delta; + + while mdelta > 0 { + if (mdelta & 1) != 0 { + acc_mult = acc_mult.wrapping_mul(cur_mult); + acc_plus = acc_plus.wrapping_mul(cur_mult).wrapping_add(cur_plus); + } + cur_plus = cur_mult.wrapping_add(1).wrapping_mul(cur_plus); + cur_mult = cur_mult.wrapping_mul(cur_mult); + mdelta /= 2; + } + self.state = acc_mult.wrapping_mul(self.state).wrapping_add(acc_plus); + } + /// Construct an instance compatible with PCG seed. /// /// Note that PCG specifies a default value for the parameter: @@ -167,8 +234,7 @@ impl SeedableRng for Mcg128Xsl64 { // Read as if a little-endian u128 value: let mut seed_u64 = [0u64; 2]; le::read_u64_into(&seed, &mut seed_u64); - let state = u128::from(seed_u64[0]) | - u128::from(seed_u64[1]) << 64; + let state = u128::from(seed_u64[0]) | u128::from(seed_u64[1]) << 64; Mcg128Xsl64::new(state) } } @@ -187,13 +253,7 @@ impl RngCore for Mcg128Xsl64 { #[inline] fn fill_bytes(&mut self, dest: &mut [u8]) { - fill_bytes_impl(self, dest) - } - - #[inline] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.fill_bytes(dest); - Ok(()) + impls::fill_bytes_via_next(self, dest) } } @@ -208,19 +268,3 @@ fn output_xsl_rr(state: u128) -> u64 { let xsl = ((state >> XSHIFT) as u64) ^ (state as u64); xsl.rotate_right(rot) } - -#[inline(always)] -fn fill_bytes_impl(rng: &mut R, dest: &mut [u8]) { - let mut left = dest; - while left.len() >= 8 { - let (l, r) = { left }.split_at_mut(8); - left = r; - let chunk: [u8; 8] = rng.next_u64().to_le_bytes(); - l.copy_from_slice(&chunk); - } - let n = left.len(); - if n > 0 { - let chunk: [u8; 8] = rng.next_u64().to_le_bytes(); - left.copy_from_slice(&chunk[..n]); - } -} diff --git a/rand_pcg/src/pcg128cm.rs b/rand_pcg/src/pcg128cm.rs new file mode 100644 index 00000000000..a5a2b178795 --- /dev/null +++ b/rand_pcg/src/pcg128cm.rs @@ -0,0 +1,177 @@ +// Copyright 2018-2021 Developers of the Rand project. +// Copyright 2017 Paul Dicker. +// Copyright 2014-2017, 2019 Melissa O'Neill and PCG Project contributors +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! PCG random number generators + +// This is the cheap multiplier used by PCG for 128-bit state. +const MULTIPLIER: u64 = 15750249268501108917; + +use core::fmt; +use rand_core::{impls, le, RngCore, SeedableRng}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +/// A PCG random number generator (CM DXSM 128/64 (LCG) variant). +/// +/// Permuted Congruential Generator with 128-bit state, internal Linear +/// Congruential Generator, and 64-bit output via "double xorshift multiply" +/// output function. +/// +/// This is a 128-bit LCG with explicitly chosen stream with the PCG-DXSM +/// output function. This corresponds to `pcg_engines::cm_setseq_dxsm_128_64` +/// from pcg_cpp and `PCG64DXSM` from NumPy. +/// +/// Despite the name, this implementation uses 32 bytes (256 bit) space +/// comprising 128 bits of state and 128 bits stream selector. These are both +/// set by `SeedableRng`, using a 256-bit seed. +/// +/// Note that while two generators with different stream parameter may be +/// closely correlated, this is [mitigated][upgrading-pcg64] by the DXSM output function. +/// +/// [upgrading-pcg64]: https://numpy.org/doc/stable/reference/random/upgrading-pcg64.html +#[derive(Clone, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct Lcg128CmDxsm64 { + state: u128, + increment: u128, +} + +/// [`Lcg128CmDxsm64`] is also known as `PCG64DXSM`. +pub type Pcg64Dxsm = Lcg128CmDxsm64; + +impl Lcg128CmDxsm64 { + /// Multi-step advance functions (jump-ahead, jump-back) + /// + /// The method used here is based on Brown, "Random Number Generation + /// with Arbitrary Stride,", Transactions of the American Nuclear + /// Society (Nov. 1994). The algorithm is very similar to fast + /// exponentiation. + /// + /// Even though delta is an unsigned integer, we can pass a + /// signed integer to go backwards, it just goes "the long way round". + /// + /// Using this function is equivalent to calling `next_64()` `delta` + /// number of times. + #[inline] + pub fn advance(&mut self, delta: u128) { + let mut acc_mult: u128 = 1; + let mut acc_plus: u128 = 0; + let mut cur_mult = MULTIPLIER as u128; + let mut cur_plus = self.increment; + let mut mdelta = delta; + + while mdelta > 0 { + if (mdelta & 1) != 0 { + acc_mult = acc_mult.wrapping_mul(cur_mult); + acc_plus = acc_plus.wrapping_mul(cur_mult).wrapping_add(cur_plus); + } + cur_plus = cur_mult.wrapping_add(1).wrapping_mul(cur_plus); + cur_mult = cur_mult.wrapping_mul(cur_mult); + mdelta /= 2; + } + self.state = acc_mult.wrapping_mul(self.state).wrapping_add(acc_plus); + } + + /// Construct an instance compatible with PCG seed and stream. + /// + /// Note that the highest bit of the `stream` parameter is discarded + /// to simplify upholding internal invariants. + /// + /// Note that while two generators with different stream parameter may be + /// closely correlated, this is [mitigated][upgrading-pcg64] by the DXSM output function. + /// + /// PCG specifies the following default values for both parameters: + /// + /// - `state = 0xcafef00dd15ea5e5` + /// - `stream = 0xa02bdbf7bb3c0a7ac28fa16a64abf96` + /// + /// [upgrading-pcg64]: https://numpy.org/doc/stable/reference/random/upgrading-pcg64.html + pub fn new(state: u128, stream: u128) -> Self { + // The increment must be odd, hence we discard one bit: + let increment = (stream << 1) | 1; + Self::from_state_incr(state, increment) + } + + #[inline] + fn from_state_incr(state: u128, increment: u128) -> Self { + let mut pcg = Self { state, increment }; + // Move away from initial value: + pcg.state = pcg.state.wrapping_add(pcg.increment); + pcg.step(); + pcg + } + + #[inline(always)] + fn step(&mut self) { + // prepare the LCG for the next round + self.state = self + .state + .wrapping_mul(MULTIPLIER as u128) + .wrapping_add(self.increment); + } +} + +// Custom Debug implementation that does not expose the internal state +impl fmt::Debug for Lcg128CmDxsm64 { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "Lcg128CmDxsm64 {{}}") + } +} + +impl SeedableRng for Lcg128CmDxsm64 { + type Seed = [u8; 32]; + + /// We use a single 255-bit seed to initialise the state and select a stream. + /// One `seed` bit (lowest bit of `seed[8]`) is ignored. + fn from_seed(seed: Self::Seed) -> Self { + let mut seed_u64 = [0u64; 4]; + le::read_u64_into(&seed, &mut seed_u64); + let state = u128::from(seed_u64[0]) | (u128::from(seed_u64[1]) << 64); + let incr = u128::from(seed_u64[2]) | (u128::from(seed_u64[3]) << 64); + + // The increment must be odd, hence we discard one bit: + Self::from_state_incr(state, incr | 1) + } +} + +impl RngCore for Lcg128CmDxsm64 { + #[inline] + fn next_u32(&mut self) -> u32 { + self.next_u64() as u32 + } + + #[inline] + fn next_u64(&mut self) -> u64 { + let res = output_dxsm(self.state); + self.step(); + res + } + + #[inline] + fn fill_bytes(&mut self, dest: &mut [u8]) { + impls::fill_bytes_via_next(self, dest) + } +} + +#[inline(always)] +fn output_dxsm(state: u128) -> u64 { + // See https://github.com/imneme/pcg-cpp/blob/ffd522e7188bef30a00c74dc7eb9de5faff90092/include/pcg_random.hpp#L1016 + // for a short discussion of the construction and its original implementation. + let mut hi = (state >> 64) as u64; + let mut lo = state as u64; + + lo |= 1; + hi ^= hi >> 32; + hi = hi.wrapping_mul(MULTIPLIER); + hi ^= hi >> 48; + hi = hi.wrapping_mul(lo); + + hi +} diff --git a/rand_pcg/src/pcg64.rs b/rand_pcg/src/pcg64.rs index ed7442f9670..771a996d28f 100644 --- a/rand_pcg/src/pcg64.rs +++ b/rand_pcg/src/pcg64.rs @@ -11,8 +11,9 @@ //! PCG random number generators use core::fmt; -use rand_core::{impls, le, Error, RngCore, SeedableRng}; -#[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; +use rand_core::{impls, le, RngCore, SeedableRng}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; // This is the default multiplier used by PCG for 64-bit state. const MULTIPLIER: u64 = 6364136223846793005; @@ -29,8 +30,11 @@ const MULTIPLIER: u64 = 6364136223846793005; /// Despite the name, this implementation uses 16 bytes (128 bit) space /// comprising 64 bits of state and 64 bits stream selector. These are both set /// by `SeedableRng`, using a 128-bit seed. +/// +/// Note that two generators with different stream parameter may be closely +/// correlated. #[derive(Clone, PartialEq, Eq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Lcg64Xsh32 { state: u64, increment: u64, @@ -40,9 +44,47 @@ pub struct Lcg64Xsh32 { pub type Pcg32 = Lcg64Xsh32; impl Lcg64Xsh32 { + /// Multi-step advance functions (jump-ahead, jump-back) + /// + /// The method used here is based on Brown, "Random Number Generation + /// with Arbitrary Stride,", Transactions of the American Nuclear + /// Society (Nov. 1994). The algorithm is very similar to fast + /// exponentiation. + /// + /// Even though delta is an unsigned integer, we can pass a + /// signed integer to go backwards, it just goes "the long way round". + /// + /// Using this function is equivalent to calling `next_32()` `delta` + /// number of times. + #[inline] + pub fn advance(&mut self, delta: u64) { + let mut acc_mult: u64 = 1; + let mut acc_plus: u64 = 0; + let mut cur_mult = MULTIPLIER; + let mut cur_plus = self.increment; + let mut mdelta = delta; + + while mdelta > 0 { + if (mdelta & 1) != 0 { + acc_mult = acc_mult.wrapping_mul(cur_mult); + acc_plus = acc_plus.wrapping_mul(cur_mult).wrapping_add(cur_plus); + } + cur_plus = cur_mult.wrapping_add(1).wrapping_mul(cur_plus); + cur_mult = cur_mult.wrapping_mul(cur_mult); + mdelta /= 2; + } + self.state = acc_mult.wrapping_mul(self.state).wrapping_add(acc_plus); + } + /// Construct an instance compatible with PCG seed and stream. /// - /// Note that PCG specifies default values for both parameters: + /// Note that the highest bit of the `stream` parameter is discarded + /// to simplify upholding internal invariants. + /// + /// Note that two generators with different stream parameters may be closely + /// correlated. + /// + /// PCG specifies the following default values for both parameters: /// /// - `state = 0xcafef00dd15ea5e5` /// - `stream = 0xa02bdbf7bb3c0a7` @@ -56,7 +98,7 @@ impl Lcg64Xsh32 { #[inline] fn from_state_incr(state: u64, increment: u64) -> Self { let mut pcg = Lcg64Xsh32 { state, increment }; - // Move away from inital value: + // Move away from initial value: pcg.state = pcg.state.wrapping_add(pcg.increment); pcg.step(); pcg @@ -79,11 +121,11 @@ impl fmt::Debug for Lcg64Xsh32 { } } -/// We use a single 127-bit seed to initialise the state and select a stream. -/// One `seed` bit (lowest bit of `seed[8]`) is ignored. impl SeedableRng for Lcg64Xsh32 { type Seed = [u8; 16]; + /// We use a single 127-bit seed to initialise the state and select a stream. + /// One `seed` bit (lowest bit of `seed[8]`) is ignored. fn from_seed(seed: Self::Seed) -> Self { let mut seed_u64 = [0u64; 2]; le::read_u64_into(&seed, &mut seed_u64); @@ -119,10 +161,4 @@ impl RngCore for Lcg64Xsh32 { fn fill_bytes(&mut self, dest: &mut [u8]) { impls::fill_bytes_via_next(self, dest) } - - #[inline] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.fill_bytes(dest); - Ok(()) - } } diff --git a/rand_pcg/tests/lcg128cmdxsm64.rs b/rand_pcg/tests/lcg128cmdxsm64.rs new file mode 100644 index 00000000000..b5b37f582e0 --- /dev/null +++ b/rand_pcg/tests/lcg128cmdxsm64.rs @@ -0,0 +1,77 @@ +use rand_core::{RngCore, SeedableRng}; +use rand_pcg::{Lcg128CmDxsm64, Pcg64Dxsm}; + +#[test] +fn test_lcg128cmdxsm64_advancing() { + for seed in 0..20 { + let mut rng1 = Lcg128CmDxsm64::seed_from_u64(seed); + let mut rng2 = rng1.clone(); + for _ in 0..20 { + rng1.next_u64(); + } + rng2.advance(20); + assert_eq!(rng1, rng2); + } +} + +#[test] +fn test_lcg128cmdxsm64_construction() { + // Test that various construction techniques produce a working RNG. + #[rustfmt::skip] + let seed = [1,2,3,4, 5,6,7,8, 9,10,11,12, 13,14,15,16, + 17,18,19,20, 21,22,23,24, 25,26,27,28, 29,30,31,32]; + let mut rng1 = Lcg128CmDxsm64::from_seed(seed); + assert_eq!(rng1.next_u64(), 12201417210360370199); + + let mut rng2 = Lcg128CmDxsm64::from_rng(&mut rng1); + assert_eq!(rng2.next_u64(), 11487972556150888383); + + let mut rng3 = Lcg128CmDxsm64::seed_from_u64(0); + assert_eq!(rng3.next_u64(), 4111470453933123814); + + // This is the same as Lcg128CmDxsm64, so we only have a single test: + let mut rng4 = Pcg64Dxsm::seed_from_u64(0); + assert_eq!(rng4.next_u64(), 4111470453933123814); +} + +#[test] +fn test_lcg128cmdxsm64_reference() { + // Numbers determined using `pcg_engines::cm_setseq_dxsm_128_64` from pcg-cpp. + let mut rng = Lcg128CmDxsm64::new(42, 54); + + let mut results = [0u64; 6]; + for i in results.iter_mut() { + *i = rng.next_u64(); + } + let expected: [u64; 6] = [ + 17331114245835578256, + 10267467544499227306, + 9726600296081716989, + 10165951391103677450, + 12131334649314727261, + 10134094537930450875, + ]; + assert_eq!(results, expected); +} + +#[cfg(feature = "serde")] +#[test] +fn test_lcg128cmdxsm64_serde() { + use bincode; + use std::io::{BufReader, BufWriter}; + + let mut rng = Lcg128CmDxsm64::seed_from_u64(0); + + let buf: Vec = Vec::new(); + let mut buf = BufWriter::new(buf); + bincode::serialize_into(&mut buf, &rng).expect("Could not serialize"); + + let buf = buf.into_inner().unwrap(); + let mut read = BufReader::new(&buf[..]); + let mut deserialized: Lcg128CmDxsm64 = + bincode::deserialize_from(&mut read).expect("Could not deserialize"); + + for _ in 0..16 { + assert_eq!(rng.next_u64(), deserialized.next_u64()); + } +} diff --git a/rand_pcg/tests/lcg128xsl64.rs b/rand_pcg/tests/lcg128xsl64.rs index ac238b51623..07bd6137da9 100644 --- a/rand_pcg/tests/lcg128xsl64.rs +++ b/rand_pcg/tests/lcg128xsl64.rs @@ -1,6 +1,19 @@ use rand_core::{RngCore, SeedableRng}; use rand_pcg::{Lcg128Xsl64, Pcg64}; +#[test] +fn test_lcg128xsl64_advancing() { + for seed in 0..20 { + let mut rng1 = Lcg128Xsl64::seed_from_u64(seed); + let mut rng2 = rng1.clone(); + for _ in 0..20 { + rng1.next_u64(); + } + rng2.advance(20); + assert_eq!(rng1, rng2); + } +} + #[test] fn test_lcg128xsl64_construction() { // Test that various construction techniques produce a working RNG. @@ -10,7 +23,7 @@ fn test_lcg128xsl64_construction() { let mut rng1 = Lcg128Xsl64::from_seed(seed); assert_eq!(rng1.next_u64(), 8740028313290271629); - let mut rng2 = Lcg128Xsl64::from_rng(&mut rng1).unwrap(); + let mut rng2 = Lcg128Xsl64::from_rng(&mut rng1); assert_eq!(rng2.next_u64(), 1922280315005786345); let mut rng3 = Lcg128Xsl64::seed_from_u64(0); @@ -22,7 +35,7 @@ fn test_lcg128xsl64_construction() { } #[test] -fn test_lcg128xsl64_true_values() { +fn test_lcg128xsl64_reference() { // Numbers copied from official test suite (C version). let mut rng = Lcg128Xsl64::new(42, 54); @@ -41,7 +54,7 @@ fn test_lcg128xsl64_true_values() { assert_eq!(results, expected); } -#[cfg(feature = "serde1")] +#[cfg(feature = "serde")] #[test] fn test_lcg128xsl64_serde() { use bincode; diff --git a/rand_pcg/tests/lcg64xsh32.rs b/rand_pcg/tests/lcg64xsh32.rs index 24a06d32483..ea704a50f6f 100644 --- a/rand_pcg/tests/lcg64xsh32.rs +++ b/rand_pcg/tests/lcg64xsh32.rs @@ -1,6 +1,19 @@ use rand_core::{RngCore, SeedableRng}; use rand_pcg::{Lcg64Xsh32, Pcg32}; +#[test] +fn test_lcg64xsh32_advancing() { + for seed in 0..20 { + let mut rng1 = Lcg64Xsh32::seed_from_u64(seed); + let mut rng2 = rng1.clone(); + for _ in 0..20 { + rng1.next_u32(); + } + rng2.advance(20); + assert_eq!(rng1, rng2); + } +} + #[test] fn test_lcg64xsh32_construction() { // Test that various construction techniques produce a working RNG. @@ -8,7 +21,7 @@ fn test_lcg64xsh32_construction() { let mut rng1 = Lcg64Xsh32::from_seed(seed); assert_eq!(rng1.next_u64(), 1204678643940597513); - let mut rng2 = Lcg64Xsh32::from_rng(&mut rng1).unwrap(); + let mut rng2 = Lcg64Xsh32::from_rng(&mut rng1); assert_eq!(rng2.next_u64(), 12384929573776311845); let mut rng3 = Lcg64Xsh32::seed_from_u64(0); @@ -20,7 +33,7 @@ fn test_lcg64xsh32_construction() { } #[test] -fn test_lcg64xsh32_true_values() { +fn test_lcg64xsh32_reference() { // Numbers copied from official test suite. let mut rng = Lcg64Xsh32::new(42, 54); @@ -34,7 +47,7 @@ fn test_lcg64xsh32_true_values() { assert_eq!(results, expected); } -#[cfg(feature = "serde1")] +#[cfg(feature = "serde")] #[test] fn test_lcg64xsh32_serde() { use bincode; diff --git a/rand_pcg/tests/mcg128xsl64.rs b/rand_pcg/tests/mcg128xsl64.rs index 32f363f350f..6125f1998c2 100644 --- a/rand_pcg/tests/mcg128xsl64.rs +++ b/rand_pcg/tests/mcg128xsl64.rs @@ -1,6 +1,19 @@ use rand_core::{RngCore, SeedableRng}; use rand_pcg::{Mcg128Xsl64, Pcg64Mcg}; +#[test] +fn test_mcg128xsl64_advancing() { + for seed in 0..20 { + let mut rng1 = Mcg128Xsl64::seed_from_u64(seed); + let mut rng2 = rng1.clone(); + for _ in 0..20 { + rng1.next_u64(); + } + rng2.advance(20); + assert_eq!(rng1, rng2); + } +} + #[test] fn test_mcg128xsl64_construction() { // Test that various construction techniques produce a working RNG. @@ -8,7 +21,7 @@ fn test_mcg128xsl64_construction() { let mut rng1 = Mcg128Xsl64::from_seed(seed); assert_eq!(rng1.next_u64(), 7071994460355047496); - let mut rng2 = Mcg128Xsl64::from_rng(&mut rng1).unwrap(); + let mut rng2 = Mcg128Xsl64::from_rng(&mut rng1); assert_eq!(rng2.next_u64(), 12300796107712034932); let mut rng3 = Mcg128Xsl64::seed_from_u64(0); @@ -20,7 +33,7 @@ fn test_mcg128xsl64_construction() { } #[test] -fn test_mcg128xsl64_true_values() { +fn test_mcg128xsl64_reference() { // Numbers copied from official test suite (C version). let mut rng = Mcg128Xsl64::new(42); @@ -39,7 +52,7 @@ fn test_mcg128xsl64_true_values() { assert_eq!(results, expected); } -#[cfg(feature = "serde1")] +#[cfg(feature = "serde")] #[test] fn test_mcg128xsl64_serde() { use bincode; diff --git a/rustfmt.toml b/rustfmt.toml deleted file mode 100644 index 6a2d9d48215..00000000000 --- a/rustfmt.toml +++ /dev/null @@ -1,32 +0,0 @@ -# This rustfmt file is added for configuration, but in practice much of our -# code is hand-formatted, frequently with more readable results. - -# Comments: -normalize_comments = true -wrap_comments = false -comment_width = 90 # small excess is okay but prefer 80 - -# Arguments: -use_small_heuristics = "Default" -# TODO: single line functions only where short, please? -# https://github.com/rust-lang/rustfmt/issues/3358 -fn_single_line = false -fn_args_layout = "Compressed" -overflow_delimited_expr = true -where_single_line = true - -# enum_discrim_align_threshold = 20 -# struct_field_align_threshold = 20 - -# Compatibility: -edition = "2018" # we require compatibility back to 1.32.0 - -# Misc: -inline_attribute_width = 80 -blank_lines_upper_bound = 2 -reorder_impl_items = true -# report_todo = "Unnumbered" -# report_fixme = "Unnumbered" - -# Ignored files: -ignore = [] diff --git a/src/distributions/bernoulli.rs b/src/distr/bernoulli.rs similarity index 76% rename from src/distributions/bernoulli.rs rename to src/distr/bernoulli.rs index b968ca046ed..6803518e376 100644 --- a/src/distributions/bernoulli.rs +++ b/src/distr/bernoulli.rs @@ -6,25 +6,35 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! The Bernoulli distribution. +//! The Bernoulli distribution `Bernoulli(p)`. -use crate::distributions::Distribution; +use crate::distr::Distribution; use crate::Rng; -use core::{fmt, u64}; +use core::fmt; -#[cfg(feature = "serde1")] -use serde::{Serialize, Deserialize}; -/// The Bernoulli distribution. +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +/// The [Bernoulli distribution](https://en.wikipedia.org/wiki/Bernoulli_distribution) `Bernoulli(p)`. +/// +/// This distribution describes a single boolean random variable, which is true +/// with probability `p` and false with probability `1 - p`. +/// It is a special case of the Binomial distribution with `n = 1`. +/// +/// # Plot +/// +/// The following plot shows the Bernoulli distribution with `p = 0.1`, +/// `p = 0.5`, and `p = 0.9`. /// -/// This is a special case of the Binomial distribution where `n = 1`. +/// ![Bernoulli distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/bernoulli.svg) /// /// # Example /// /// ```rust -/// use rand::distributions::{Bernoulli, Distribution}; +/// use rand::distr::{Bernoulli, Distribution}; /// /// let d = Bernoulli::new(0.3).unwrap(); -/// let v = d.sample(&mut rand::thread_rng()); +/// let v = d.sample(&mut rand::rng()); /// println!("{} is from a Bernoulli distribution", v); /// ``` /// @@ -33,8 +43,8 @@ use serde::{Serialize, Deserialize}; /// This `Bernoulli` distribution uses 64 bits from the RNG (a `u64`), /// so only probabilities that are multiples of 2-64 can be /// represented. -#[derive(Clone, Copy, Debug)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Bernoulli { /// Probability of success, relative to the maximal integer. p_int: u64, @@ -49,7 +59,7 @@ pub struct Bernoulli { // `f64` only has 53 bits of precision, and the next largest value of `p` will // result in `2^64 - 2048`. // -// Also there is a 100% theoretical concern: if someone consistenly wants to +// Also there is a 100% theoretical concern: if someone consistently wants to // generate `true` using the Bernoulli distribution (i.e. by using a probability // of `1.0`), just using `u64::MAX` is not enough. On average it would return // false once every 2^64 iterations. Some people apparently care about this @@ -65,7 +75,7 @@ const ALWAYS_TRUE: u64 = u64::MAX; // in `no_std` mode. const SCALE: f64 = 2.0 * (1u64 << 63) as f64; -/// Error type returned from `Bernoulli::new`. +/// Error type returned from [`Bernoulli::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum BernoulliError { /// `p < 0` or `p > 1`. @@ -81,7 +91,7 @@ impl fmt::Display for BernoulliError { } #[cfg(feature = "std")] -impl ::std::error::Error for BernoulliError {} +impl std::error::Error for BernoulliError {} impl Bernoulli { /// Construct a new `Bernoulli` with the given probability of success `p`. @@ -96,7 +106,7 @@ impl Bernoulli { /// 2-64 in `[0, 1]` can be represented as a `f64`.) #[inline] pub fn new(p: f64) -> Result { - if !(p >= 0.0 && p < 1.0) { + if !(0.0..1.0).contains(&p) { if p == 1.0 { return Ok(Bernoulli { p_int: ALWAYS_TRUE }); } @@ -126,6 +136,18 @@ impl Bernoulli { let p_int = ((f64::from(numerator) / f64::from(denominator)) * SCALE) as u64; Ok(Bernoulli { p_int }) } + + #[inline] + /// Returns the probability (`p`) of the distribution. + /// + /// This value may differ slightly from the input due to loss of precision. + pub fn p(&self) -> f64 { + if self.p_int == ALWAYS_TRUE { + 1.0 + } else { + (self.p_int as f64) / SCALE + } + } } impl Distribution for Bernoulli { @@ -135,7 +157,7 @@ impl Distribution for Bernoulli { if self.p_int == ALWAYS_TRUE { return true; } - let v: u64 = rng.gen(); + let v: u64 = rng.random(); v < self.p_int } } @@ -143,20 +165,24 @@ impl Distribution for Bernoulli { #[cfg(test)] mod test { use super::Bernoulli; - use crate::distributions::Distribution; + use crate::distr::Distribution; use crate::Rng; #[test] - #[cfg(feature="serde1")] + #[cfg(feature = "serde")] fn test_serializing_deserializing_bernoulli() { let coin_flip = Bernoulli::new(0.5).unwrap(); - let de_coin_flip : Bernoulli = bincode::deserialize(&bincode::serialize(&coin_flip).unwrap()).unwrap(); + let de_coin_flip: Bernoulli = + bincode::deserialize(&bincode::serialize(&coin_flip).unwrap()).unwrap(); assert_eq!(coin_flip.p_int, de_coin_flip.p_int); } #[test] fn test_trivial() { + // We prefer to be explicit here. + #![allow(clippy::bool_assert_comparison)] + let mut r = crate::test::rng(1); let always_false = Bernoulli::new(0.0).unwrap(); let always_true = Bernoulli::new(1.0).unwrap(); @@ -202,10 +228,16 @@ mod test { let distr = Bernoulli::new(0.4532).unwrap(); let mut buf = [false; 10]; for x in &mut buf { - *x = rng.sample(&distr); + *x = rng.sample(distr); } - assert_eq!(buf, [ - true, false, false, true, false, false, true, true, true, true - ]); + assert_eq!( + buf, + [true, false, false, true, false, false, true, true, true, true] + ); + } + + #[test] + fn bernoulli_distributions_can_be_compared() { + assert_eq!(Bernoulli::new(1.0), Bernoulli::new(1.0)); } } diff --git a/src/distr/distribution.rs b/src/distr/distribution.rs new file mode 100644 index 00000000000..6f4e202647e --- /dev/null +++ b/src/distr/distribution.rs @@ -0,0 +1,265 @@ +// Copyright 2018 Developers of the Rand project. +// Copyright 2013-2017 The Rust Project Developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Distribution trait and associates + +use crate::Rng; +#[cfg(feature = "alloc")] +use alloc::string::String; +use core::iter; + +/// Types (distributions) that can be used to create a random instance of `T`. +/// +/// It is possible to sample from a distribution through both the +/// `Distribution` and [`Rng`] traits, via `distr.sample(&mut rng)` and +/// `rng.sample(distr)`. They also both offer the [`sample_iter`] method, which +/// produces an iterator that samples from the distribution. +/// +/// All implementations are expected to be immutable; this has the significant +/// advantage of not needing to consider thread safety, and for most +/// distributions efficient state-less sampling algorithms are available. +/// +/// Implementations are typically expected to be portable with reproducible +/// results when used with a PRNG with fixed seed; see the +/// [portability chapter](https://rust-random.github.io/book/portability.html) +/// of The Rust Rand Book. In some cases this does not apply, e.g. the `usize` +/// type requires different sampling on 32-bit and 64-bit machines. +/// +/// [`sample_iter`]: Distribution::sample_iter +pub trait Distribution { + /// Generate a random value of `T`, using `rng` as the source of randomness. + fn sample(&self, rng: &mut R) -> T; + + /// Create an iterator that generates random values of `T`, using `rng` as + /// the source of randomness. + /// + /// Note that this function takes `self` by value. This works since + /// `Distribution` is impl'd for `&D` where `D: Distribution`, + /// however borrowing is not automatic hence `distr.sample_iter(...)` may + /// need to be replaced with `(&distr).sample_iter(...)` to borrow or + /// `(&*distr).sample_iter(...)` to reborrow an existing reference. + /// + /// # Example + /// + /// ``` + /// use rand::distr::{Distribution, Alphanumeric, Uniform, StandardUniform}; + /// + /// let mut rng = rand::rng(); + /// + /// // Vec of 16 x f32: + /// let v: Vec = StandardUniform.sample_iter(&mut rng).take(16).collect(); + /// + /// // String: + /// let s: String = Alphanumeric + /// .sample_iter(&mut rng) + /// .take(7) + /// .map(char::from) + /// .collect(); + /// + /// // Dice-rolling: + /// let die_range = Uniform::new_inclusive(1, 6).unwrap(); + /// let mut roll_die = die_range.sample_iter(&mut rng); + /// while roll_die.next().unwrap() != 6 { + /// println!("Not a 6; rolling again!"); + /// } + /// ``` + fn sample_iter(self, rng: R) -> Iter + where + R: Rng, + Self: Sized, + { + Iter { + distr: self, + rng, + phantom: core::marker::PhantomData, + } + } + + /// Map sampled values to type `S` + /// + /// # Example + /// + /// ``` + /// use rand::distr::{Distribution, Uniform}; + /// + /// let die = Uniform::new_inclusive(1, 6).unwrap(); + /// let even_number = die.map(|num| num % 2 == 0); + /// while !even_number.sample(&mut rand::rng()) { + /// println!("Still odd; rolling again!"); + /// } + /// ``` + fn map(self, func: F) -> Map + where + F: Fn(T) -> S, + Self: Sized, + { + Map { + distr: self, + func, + phantom: core::marker::PhantomData, + } + } +} + +impl + ?Sized> Distribution for &D { + fn sample(&self, rng: &mut R) -> T { + (*self).sample(rng) + } +} + +/// An iterator over a [`Distribution`] +/// +/// This iterator yields random values of type `T` with distribution `D` +/// from a random generator of type `R`. +/// +/// Construct this `struct` using [`Distribution::sample_iter`] or +/// [`Rng::sample_iter`]. It is also used by [`Rng::random_iter`] and +/// [`crate::random_iter`]. +#[derive(Debug)] +pub struct Iter { + distr: D, + rng: R, + phantom: core::marker::PhantomData, +} + +impl Iterator for Iter +where + D: Distribution, + R: Rng, +{ + type Item = T; + + #[inline(always)] + fn next(&mut self) -> Option { + // Here, self.rng may be a reference, but we must take &mut anyway. + // Even if sample could take an R: Rng by value, we would need to do this + // since Rng is not copyable and we cannot enforce that this is "reborrowable". + Some(self.distr.sample(&mut self.rng)) + } + + fn size_hint(&self) -> (usize, Option) { + (usize::MAX, None) + } +} + +impl iter::FusedIterator for Iter +where + D: Distribution, + R: Rng, +{ +} + +/// A [`Distribution`] which maps sampled values to type `S` +/// +/// This `struct` is created by the [`Distribution::map`] method. +/// See its documentation for more. +#[derive(Debug)] +pub struct Map { + distr: D, + func: F, + phantom: core::marker::PhantomData S>, +} + +impl Distribution for Map +where + D: Distribution, + F: Fn(T) -> S, +{ + fn sample(&self, rng: &mut R) -> S { + (self.func)(self.distr.sample(rng)) + } +} + +/// Sample or extend a [`String`] +/// +/// Helper methods to extend a [`String`] or sample a new [`String`]. +#[cfg(feature = "alloc")] +pub trait SampleString { + /// Append `len` random chars to `string` + /// + /// Note: implementations may leave `string` with excess capacity. If this + /// is undesirable, consider calling [`String::shrink_to_fit`] after this + /// method. + fn append_string(&self, rng: &mut R, string: &mut String, len: usize); + + /// Generate a [`String`] of `len` random chars + /// + /// Note: implementations may leave the string with excess capacity. If this + /// is undesirable, consider calling [`String::shrink_to_fit`] after this + /// method. + #[inline] + fn sample_string(&self, rng: &mut R, len: usize) -> String { + let mut s = String::new(); + self.append_string(rng, &mut s, len); + s + } +} + +#[cfg(test)] +mod tests { + use crate::distr::{Distribution, Uniform}; + use crate::Rng; + + #[test] + fn test_distributions_iter() { + use crate::distr::Open01; + let mut rng = crate::test::rng(210); + let distr = Open01; + let mut iter = Distribution::::sample_iter(distr, &mut rng); + let mut sum: f32 = 0.; + for _ in 0..100 { + sum += iter.next().unwrap(); + } + assert!(0. < sum && sum < 100.); + } + + #[test] + fn test_distributions_map() { + let dist = Uniform::new_inclusive(0, 5).unwrap().map(|val| val + 15); + + let mut rng = crate::test::rng(212); + let val = dist.sample(&mut rng); + assert!((15..=20).contains(&val)); + } + + #[test] + fn test_make_an_iter() { + fn ten_dice_rolls_other_than_five(rng: &mut R) -> impl Iterator + '_ { + Uniform::new_inclusive(1, 6) + .unwrap() + .sample_iter(rng) + .filter(|x| *x != 5) + .take(10) + } + + let mut rng = crate::test::rng(211); + let mut count = 0; + for val in ten_dice_rolls_other_than_five(&mut rng) { + assert!((1..=6).contains(&val) && val != 5); + count += 1; + } + assert_eq!(count, 10); + } + + #[test] + #[cfg(feature = "alloc")] + fn test_dist_string() { + use crate::distr::{Alphanumeric, SampleString, StandardUniform}; + use core::str; + let mut rng = crate::test::rng(213); + + let s1 = Alphanumeric.sample_string(&mut rng, 20); + assert_eq!(s1.len(), 20); + assert_eq!(str::from_utf8(s1.as_bytes()), Ok(s1.as_str())); + + let s2 = StandardUniform.sample_string(&mut rng, 20); + assert_eq!(s2.chars().count(), 20); + assert_eq!(str::from_utf8(s2.as_bytes()), Ok(s2.as_str())); + } +} diff --git a/src/distributions/float.rs b/src/distr/float.rs similarity index 58% rename from src/distributions/float.rs rename to src/distr/float.rs index 733a40394dd..ec380b4bd4d 100644 --- a/src/distributions/float.rs +++ b/src/distr/float.rs @@ -8,14 +8,15 @@ //! Basic floating-point number distributions -use crate::distributions::utils::FloatSIMDUtils; -use crate::distributions::{Distribution, Standard}; +use crate::distr::utils::{FloatAsSIMD, FloatSIMDUtils, IntAsSIMD}; +use crate::distr::{Distribution, StandardUniform}; use crate::Rng; use core::mem; -#[cfg(feature = "simd_support")] use packed_simd::*; +#[cfg(feature = "simd_support")] +use core::simd::prelude::*; -#[cfg(feature = "serde1")] -use serde::{Serialize, Deserialize}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; /// A distribution to sample floating point numbers uniformly in the half-open /// interval `(0, 1]`, i.e. including 1 but not 0. @@ -25,24 +26,24 @@ use serde::{Serialize, Deserialize}; /// 53 most significant bits of a `u64` are used. The conversion uses the /// multiplicative method. /// -/// See also: [`Standard`] which samples from `[0, 1)`, [`Open01`] +/// See also: [`StandardUniform`] which samples from `[0, 1)`, [`Open01`] /// which samples from `(0, 1)` and [`Uniform`] which samples from arbitrary /// ranges. /// /// # Example /// ``` -/// use rand::{thread_rng, Rng}; -/// use rand::distributions::OpenClosed01; +/// use rand::Rng; +/// use rand::distr::OpenClosed01; /// -/// let val: f32 = thread_rng().sample(OpenClosed01); +/// let val: f32 = rand::rng().sample(OpenClosed01); /// println!("f32 from (0, 1): {}", val); /// ``` /// -/// [`Standard`]: crate::distributions::Standard -/// [`Open01`]: crate::distributions::Open01 -/// [`Uniform`]: crate::distributions::uniform::Uniform -#[derive(Clone, Copy, Debug)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +/// [`StandardUniform`]: crate::distr::StandardUniform +/// [`Open01`]: crate::distr::Open01 +/// [`Uniform`]: crate::distr::uniform::Uniform +#[derive(Clone, Copy, Debug, Default)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct OpenClosed01; /// A distribution to sample floating point numbers uniformly in the open @@ -52,33 +53,32 @@ pub struct OpenClosed01; /// the 23 most significant random bits of an `u32` are used, for `f64` 52 from /// an `u64`. The conversion uses a transmute-based method. /// -/// See also: [`Standard`] which samples from `[0, 1)`, [`OpenClosed01`] +/// See also: [`StandardUniform`] which samples from `[0, 1)`, [`OpenClosed01`] /// which samples from `(0, 1]` and [`Uniform`] which samples from arbitrary /// ranges. /// /// # Example /// ``` -/// use rand::{thread_rng, Rng}; -/// use rand::distributions::Open01; +/// use rand::Rng; +/// use rand::distr::Open01; /// -/// let val: f32 = thread_rng().sample(Open01); +/// let val: f32 = rand::rng().sample(Open01); /// println!("f32 from (0, 1): {}", val); /// ``` /// -/// [`Standard`]: crate::distributions::Standard -/// [`OpenClosed01`]: crate::distributions::OpenClosed01 -/// [`Uniform`]: crate::distributions::uniform::Uniform -#[derive(Clone, Copy, Debug)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +/// [`StandardUniform`]: crate::distr::StandardUniform +/// [`OpenClosed01`]: crate::distr::OpenClosed01 +/// [`Uniform`]: crate::distr::uniform::Uniform +#[derive(Clone, Copy, Debug, Default)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Open01; - // This trait is needed by both this lib and rand_distr hence is a hidden export #[doc(hidden)] pub trait IntoFloat { type F; - /// Helper method to combine the fraction and a contant exponent into a + /// Helper method to combine the fraction and a constant exponent into a /// float. /// /// Only the least significant bits of `self` may be set, 23 for `f32` and @@ -90,8 +90,9 @@ pub trait IntoFloat { } macro_rules! float_impls { - ($ty:ident, $uty:ident, $f_scalar:ident, $u_scalar:ty, + ($($meta:meta)?, $ty:ident, $uty:ident, $f_scalar:ident, $u_scalar:ty, $fraction_bits:expr, $exponent_bias:expr) => { + $(#[cfg($meta)])? impl IntoFloat for $uty { type F = $ty; #[inline(always)] @@ -99,112 +100,121 @@ macro_rules! float_impls { // The exponent is encoded using an offset-binary representation let exponent_bits: $u_scalar = (($exponent_bias + exponent) as $u_scalar) << $fraction_bits; - $ty::from_bits(self | exponent_bits) + $ty::from_bits(self | $uty::splat(exponent_bits)) } } - impl Distribution<$ty> for Standard { + $(#[cfg($meta)])? + impl Distribution<$ty> for StandardUniform { fn sample(&self, rng: &mut R) -> $ty { // Multiply-based method; 24/53 random bits; [0, 1) interval. // We use the most significant bits because for simple RNGs // those are usually more random. - let float_size = mem::size_of::<$f_scalar>() as u32 * 8; + let float_size = mem::size_of::<$f_scalar>() as $u_scalar * 8; let precision = $fraction_bits + 1; let scale = 1.0 / ((1 as $u_scalar << precision) as $f_scalar); - let value: $uty = rng.gen(); - let value = value >> (float_size - precision); - scale * $ty::cast_from_int(value) + let value: $uty = rng.random(); + let value = value >> $uty::splat(float_size - precision); + $ty::splat(scale) * $ty::cast_from_int(value) } } + $(#[cfg($meta)])? impl Distribution<$ty> for OpenClosed01 { fn sample(&self, rng: &mut R) -> $ty { // Multiply-based method; 24/53 random bits; (0, 1] interval. // We use the most significant bits because for simple RNGs // those are usually more random. - let float_size = mem::size_of::<$f_scalar>() as u32 * 8; + let float_size = mem::size_of::<$f_scalar>() as $u_scalar * 8; let precision = $fraction_bits + 1; let scale = 1.0 / ((1 as $u_scalar << precision) as $f_scalar); - let value: $uty = rng.gen(); - let value = value >> (float_size - precision); + let value: $uty = rng.random(); + let value = value >> $uty::splat(float_size - precision); // Add 1 to shift up; will not overflow because of right-shift: - scale * $ty::cast_from_int(value + 1) + $ty::splat(scale) * $ty::cast_from_int(value + $uty::splat(1)) } } + $(#[cfg($meta)])? impl Distribution<$ty> for Open01 { fn sample(&self, rng: &mut R) -> $ty { // Transmute-based method; 23/52 random bits; (0, 1) interval. // We use the most significant bits because for simple RNGs // those are usually more random. - use core::$f_scalar::EPSILON; - let float_size = mem::size_of::<$f_scalar>() as u32 * 8; + let float_size = mem::size_of::<$f_scalar>() as $u_scalar * 8; - let value: $uty = rng.gen(); - let fraction = value >> (float_size - $fraction_bits); - fraction.into_float_with_exponent(0) - (1.0 - EPSILON / 2.0) + let value: $uty = rng.random(); + let fraction = value >> $uty::splat(float_size - $fraction_bits); + fraction.into_float_with_exponent(0) - $ty::splat(1.0 - $f_scalar::EPSILON / 2.0) } } } } -float_impls! { f32, u32, f32, u32, 23, 127 } -float_impls! { f64, u64, f64, u64, 52, 1023 } +float_impls! { , f32, u32, f32, u32, 23, 127 } +float_impls! { , f64, u64, f64, u64, 52, 1023 } #[cfg(feature = "simd_support")] -float_impls! { f32x2, u32x2, f32, u32, 23, 127 } +float_impls! { feature = "simd_support", f32x2, u32x2, f32, u32, 23, 127 } #[cfg(feature = "simd_support")] -float_impls! { f32x4, u32x4, f32, u32, 23, 127 } +float_impls! { feature = "simd_support", f32x4, u32x4, f32, u32, 23, 127 } #[cfg(feature = "simd_support")] -float_impls! { f32x8, u32x8, f32, u32, 23, 127 } +float_impls! { feature = "simd_support", f32x8, u32x8, f32, u32, 23, 127 } #[cfg(feature = "simd_support")] -float_impls! { f32x16, u32x16, f32, u32, 23, 127 } +float_impls! { feature = "simd_support", f32x16, u32x16, f32, u32, 23, 127 } #[cfg(feature = "simd_support")] -float_impls! { f64x2, u64x2, f64, u64, 52, 1023 } +float_impls! { feature = "simd_support", f64x2, u64x2, f64, u64, 52, 1023 } #[cfg(feature = "simd_support")] -float_impls! { f64x4, u64x4, f64, u64, 52, 1023 } +float_impls! { feature = "simd_support", f64x4, u64x4, f64, u64, 52, 1023 } #[cfg(feature = "simd_support")] -float_impls! { f64x8, u64x8, f64, u64, 52, 1023 } - +float_impls! { feature = "simd_support", f64x8, u64x8, f64, u64, 52, 1023 } #[cfg(test)] mod tests { use super::*; use crate::rngs::mock::StepRng; - const EPSILON32: f32 = ::core::f32::EPSILON; - const EPSILON64: f64 = ::core::f64::EPSILON; + const EPSILON32: f32 = f32::EPSILON; + const EPSILON64: f64 = f64::EPSILON; macro_rules! test_f32 { ($fnn:ident, $ty:ident, $ZERO:expr, $EPSILON:expr) => { #[test] fn $fnn() { - // Standard + let two = $ty::splat(2.0); + + // StandardUniform let mut zeros = StepRng::new(0, 0); - assert_eq!(zeros.gen::<$ty>(), $ZERO); + assert_eq!(zeros.random::<$ty>(), $ZERO); let mut one = StepRng::new(1 << 8 | 1 << (8 + 32), 0); - assert_eq!(one.gen::<$ty>(), $EPSILON / 2.0); + assert_eq!(one.random::<$ty>(), $EPSILON / two); let mut max = StepRng::new(!0, 0); - assert_eq!(max.gen::<$ty>(), 1.0 - $EPSILON / 2.0); + assert_eq!(max.random::<$ty>(), $ty::splat(1.0) - $EPSILON / two); // OpenClosed01 let mut zeros = StepRng::new(0, 0); - assert_eq!(zeros.sample::<$ty, _>(OpenClosed01), 0.0 + $EPSILON / 2.0); + assert_eq!(zeros.sample::<$ty, _>(OpenClosed01), $ZERO + $EPSILON / two); let mut one = StepRng::new(1 << 8 | 1 << (8 + 32), 0); assert_eq!(one.sample::<$ty, _>(OpenClosed01), $EPSILON); let mut max = StepRng::new(!0, 0); - assert_eq!(max.sample::<$ty, _>(OpenClosed01), $ZERO + 1.0); + assert_eq!(max.sample::<$ty, _>(OpenClosed01), $ZERO + $ty::splat(1.0)); // Open01 let mut zeros = StepRng::new(0, 0); - assert_eq!(zeros.sample::<$ty, _>(Open01), 0.0 + $EPSILON / 2.0); + assert_eq!(zeros.sample::<$ty, _>(Open01), $ZERO + $EPSILON / two); let mut one = StepRng::new(1 << 9 | 1 << (9 + 32), 0); - assert_eq!(one.sample::<$ty, _>(Open01), $EPSILON / 2.0 * 3.0); + assert_eq!( + one.sample::<$ty, _>(Open01), + $EPSILON / two * $ty::splat(3.0) + ); let mut max = StepRng::new(!0, 0); - assert_eq!(max.sample::<$ty, _>(Open01), 1.0 - $EPSILON / 2.0); + assert_eq!( + max.sample::<$ty, _>(Open01), + $ty::splat(1.0) - $EPSILON / two + ); } }; } @@ -222,29 +232,37 @@ mod tests { ($fnn:ident, $ty:ident, $ZERO:expr, $EPSILON:expr) => { #[test] fn $fnn() { - // Standard + let two = $ty::splat(2.0); + + // StandardUniform let mut zeros = StepRng::new(0, 0); - assert_eq!(zeros.gen::<$ty>(), $ZERO); + assert_eq!(zeros.random::<$ty>(), $ZERO); let mut one = StepRng::new(1 << 11, 0); - assert_eq!(one.gen::<$ty>(), $EPSILON / 2.0); + assert_eq!(one.random::<$ty>(), $EPSILON / two); let mut max = StepRng::new(!0, 0); - assert_eq!(max.gen::<$ty>(), 1.0 - $EPSILON / 2.0); + assert_eq!(max.random::<$ty>(), $ty::splat(1.0) - $EPSILON / two); // OpenClosed01 let mut zeros = StepRng::new(0, 0); - assert_eq!(zeros.sample::<$ty, _>(OpenClosed01), 0.0 + $EPSILON / 2.0); + assert_eq!(zeros.sample::<$ty, _>(OpenClosed01), $ZERO + $EPSILON / two); let mut one = StepRng::new(1 << 11, 0); assert_eq!(one.sample::<$ty, _>(OpenClosed01), $EPSILON); let mut max = StepRng::new(!0, 0); - assert_eq!(max.sample::<$ty, _>(OpenClosed01), $ZERO + 1.0); + assert_eq!(max.sample::<$ty, _>(OpenClosed01), $ZERO + $ty::splat(1.0)); // Open01 let mut zeros = StepRng::new(0, 0); - assert_eq!(zeros.sample::<$ty, _>(Open01), 0.0 + $EPSILON / 2.0); + assert_eq!(zeros.sample::<$ty, _>(Open01), $ZERO + $EPSILON / two); let mut one = StepRng::new(1 << 12, 0); - assert_eq!(one.sample::<$ty, _>(Open01), $EPSILON / 2.0 * 3.0); + assert_eq!( + one.sample::<$ty, _>(Open01), + $EPSILON / two * $ty::splat(3.0) + ); let mut max = StepRng::new(!0, 0); - assert_eq!(max.sample::<$ty, _>(Open01), 1.0 - $EPSILON / 2.0); + assert_eq!( + max.sample::<$ty, _>(Open01), + $ty::splat(1.0) - $EPSILON / two + ); } }; } @@ -259,36 +277,42 @@ mod tests { #[test] fn value_stability() { fn test_samples>( - distr: &D, zero: T, expected: &[T], + distr: &D, + zero: T, + expected: &[T], ) { let mut rng = crate::test::rng(0x6f44f5646c2a7334); let mut buf = [zero; 3]; for x in &mut buf { - *x = rng.sample(&distr); + *x = rng.sample(distr); } assert_eq!(&buf, expected); } - test_samples(&Standard, 0f32, &[0.0035963655, 0.7346052, 0.09778172]); - test_samples(&Standard, 0f64, &[ - 0.7346051961657583, - 0.20298547462974248, - 0.8166436635290655, - ]); + test_samples( + &StandardUniform, + 0f32, + &[0.0035963655, 0.7346052, 0.09778172], + ); + test_samples( + &StandardUniform, + 0f64, + &[0.7346051961657583, 0.20298547462974248, 0.8166436635290655], + ); test_samples(&OpenClosed01, 0f32, &[0.003596425, 0.73460525, 0.09778178]); - test_samples(&OpenClosed01, 0f64, &[ - 0.7346051961657584, - 0.2029854746297426, - 0.8166436635290656, - ]); + test_samples( + &OpenClosed01, + 0f64, + &[0.7346051961657584, 0.2029854746297426, 0.8166436635290656], + ); test_samples(&Open01, 0f32, &[0.0035963655, 0.73460525, 0.09778172]); - test_samples(&Open01, 0f64, &[ - 0.7346051961657584, - 0.20298547462974248, - 0.8166436635290656, - ]); + test_samples( + &Open01, + 0f64, + &[0.7346051961657584, 0.20298547462974248, 0.8166436635290656], + ); #[cfg(feature = "simd_support")] { @@ -296,17 +320,25 @@ mod tests { // non-SIMD types; we assume this pattern continues across all // SIMD types. - test_samples(&Standard, f32x2::new(0.0, 0.0), &[ - f32x2::new(0.0035963655, 0.7346052), - f32x2::new(0.09778172, 0.20298547), - f32x2::new(0.34296435, 0.81664366), - ]); - - test_samples(&Standard, f64x2::new(0.0, 0.0), &[ - f64x2::new(0.7346051961657583, 0.20298547462974248), - f64x2::new(0.8166436635290655, 0.7423708925400552), - f64x2::new(0.16387782224016323, 0.9087068770169618), - ]); + test_samples( + &StandardUniform, + f32x2::from([0.0, 0.0]), + &[ + f32x2::from([0.0035963655, 0.7346052]), + f32x2::from([0.09778172, 0.20298547]), + f32x2::from([0.34296435, 0.81664366]), + ], + ); + + test_samples( + &StandardUniform, + f64x2::from([0.0, 0.0]), + &[ + f64x2::from([0.7346051961657583, 0.20298547462974248]), + f64x2::from([0.8166436635290655, 0.7423708925400552]), + f64x2::from([0.16387782224016323, 0.9087068770169618]), + ], + ); } } } diff --git a/src/distr/integer.rs b/src/distr/integer.rs new file mode 100644 index 00000000000..d0040e69e7e --- /dev/null +++ b/src/distr/integer.rs @@ -0,0 +1,296 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! The implementations of the `StandardUniform` distribution for integer types. + +use crate::distr::{Distribution, StandardUniform}; +use crate::Rng; +#[cfg(all(target_arch = "x86", feature = "simd_support"))] +use core::arch::x86::__m512i; +#[cfg(target_arch = "x86")] +use core::arch::x86::{__m128i, __m256i}; +#[cfg(all(target_arch = "x86_64", feature = "simd_support"))] +use core::arch::x86_64::__m512i; +#[cfg(target_arch = "x86_64")] +use core::arch::x86_64::{__m128i, __m256i}; +use core::num::{ + NonZeroI128, NonZeroI16, NonZeroI32, NonZeroI64, NonZeroI8, NonZeroU128, NonZeroU16, + NonZeroU32, NonZeroU64, NonZeroU8, +}; +#[cfg(feature = "simd_support")] +use core::simd::*; + +impl Distribution for StandardUniform { + #[inline] + fn sample(&self, rng: &mut R) -> u8 { + rng.next_u32() as u8 + } +} + +impl Distribution for StandardUniform { + #[inline] + fn sample(&self, rng: &mut R) -> u16 { + rng.next_u32() as u16 + } +} + +impl Distribution for StandardUniform { + #[inline] + fn sample(&self, rng: &mut R) -> u32 { + rng.next_u32() + } +} + +impl Distribution for StandardUniform { + #[inline] + fn sample(&self, rng: &mut R) -> u64 { + rng.next_u64() + } +} + +impl Distribution for StandardUniform { + #[inline] + fn sample(&self, rng: &mut R) -> u128 { + // Use LE; we explicitly generate one value before the next. + let x = u128::from(rng.next_u64()); + let y = u128::from(rng.next_u64()); + (y << 64) | x + } +} + +macro_rules! impl_int_from_uint { + ($ty:ty, $uty:ty) => { + impl Distribution<$ty> for StandardUniform { + #[inline] + fn sample(&self, rng: &mut R) -> $ty { + rng.random::<$uty>() as $ty + } + } + }; +} + +impl_int_from_uint! { i8, u8 } +impl_int_from_uint! { i16, u16 } +impl_int_from_uint! { i32, u32 } +impl_int_from_uint! { i64, u64 } +impl_int_from_uint! { i128, u128 } + +macro_rules! impl_nzint { + ($ty:ty, $new:path) => { + impl Distribution<$ty> for StandardUniform { + fn sample(&self, rng: &mut R) -> $ty { + loop { + if let Some(nz) = $new(rng.random()) { + break nz; + } + } + } + } + }; +} + +impl_nzint!(NonZeroU8, NonZeroU8::new); +impl_nzint!(NonZeroU16, NonZeroU16::new); +impl_nzint!(NonZeroU32, NonZeroU32::new); +impl_nzint!(NonZeroU64, NonZeroU64::new); +impl_nzint!(NonZeroU128, NonZeroU128::new); + +impl_nzint!(NonZeroI8, NonZeroI8::new); +impl_nzint!(NonZeroI16, NonZeroI16::new); +impl_nzint!(NonZeroI32, NonZeroI32::new); +impl_nzint!(NonZeroI64, NonZeroI64::new); +impl_nzint!(NonZeroI128, NonZeroI128::new); + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +macro_rules! x86_intrinsic_impl { + ($meta:meta, $($intrinsic:ident),+) => {$( + #[cfg($meta)] + impl Distribution<$intrinsic> for StandardUniform { + #[inline] + fn sample(&self, rng: &mut R) -> $intrinsic { + // On proper hardware, this should compile to SIMD instructions + // Verified on x86 Haswell with __m128i, __m256i + let mut buf = [0_u8; core::mem::size_of::<$intrinsic>()]; + rng.fill_bytes(&mut buf); + // x86 is little endian so no need for conversion + zerocopy::transmute!(buf) + } + } + )+}; +} + +#[cfg(feature = "simd_support")] +macro_rules! simd_impl { + ($($ty:ty),+) => {$( + /// Requires nightly Rust and the [`simd_support`] feature + /// + /// [`simd_support`]: https://github.com/rust-random/rand#crate-features + #[cfg(feature = "simd_support")] + impl Distribution> for StandardUniform + where + LaneCount: SupportedLaneCount, + { + #[inline] + fn sample(&self, rng: &mut R) -> Simd<$ty, LANES> { + let mut vec = Simd::default(); + rng.fill(vec.as_mut_array().as_mut_slice()); + vec + } + } + )+}; +} + +#[cfg(feature = "simd_support")] +simd_impl!(u8, i8, u16, i16, u32, i32, u64, i64); + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +x86_intrinsic_impl!( + any(target_arch = "x86", target_arch = "x86_64"), + __m128i, + __m256i +); +#[cfg(all( + any(target_arch = "x86", target_arch = "x86_64"), + feature = "simd_support" +))] +x86_intrinsic_impl!( + all( + any(target_arch = "x86", target_arch = "x86_64"), + feature = "simd_support" + ), + __m512i +); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_integers() { + let mut rng = crate::test::rng(806); + + rng.sample::(StandardUniform); + rng.sample::(StandardUniform); + rng.sample::(StandardUniform); + rng.sample::(StandardUniform); + rng.sample::(StandardUniform); + + rng.sample::(StandardUniform); + rng.sample::(StandardUniform); + rng.sample::(StandardUniform); + rng.sample::(StandardUniform); + rng.sample::(StandardUniform); + } + + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + #[test] + fn x86_integers() { + let mut rng = crate::test::rng(807); + + rng.sample::<__m128i, _>(StandardUniform); + rng.sample::<__m256i, _>(StandardUniform); + #[cfg(feature = "simd_support")] + rng.sample::<__m512i, _>(StandardUniform); + } + + #[test] + fn value_stability() { + fn test_samples(zero: T, expected: &[T]) + where + StandardUniform: Distribution, + { + let mut rng = crate::test::rng(807); + let mut buf = [zero; 3]; + for x in &mut buf { + *x = rng.sample(StandardUniform); + } + assert_eq!(&buf, expected); + } + + test_samples(0u8, &[9, 247, 111]); + test_samples(0u16, &[32265, 42999, 38255]); + test_samples(0u32, &[2220326409, 2575017975, 2018088303]); + test_samples( + 0u64, + &[ + 11059617991457472009, + 16096616328739788143, + 1487364411147516184, + ], + ); + test_samples( + 0u128, + &[ + 296930161868957086625409848350820761097, + 145644820879247630242265036535529306392, + 111087889832015897993126088499035356354, + ], + ); + + test_samples(0i8, &[9, -9, 111]); + // Skip further i* types: they are simple reinterpretation of u* samples + + #[cfg(feature = "simd_support")] + { + // We only test a sub-set of types here and make assumptions about the rest. + + test_samples( + u8x4::default(), + &[ + u8x4::from([9, 126, 87, 132]), + u8x4::from([247, 167, 123, 153]), + u8x4::from([111, 149, 73, 120]), + ], + ); + test_samples( + u8x8::default(), + &[ + u8x8::from([9, 126, 87, 132, 247, 167, 123, 153]), + u8x8::from([111, 149, 73, 120, 68, 171, 98, 223]), + u8x8::from([24, 121, 1, 50, 13, 46, 164, 20]), + ], + ); + + test_samples( + i64x8::default(), + &[ + i64x8::from([ + -7387126082252079607, + -2350127744969763473, + 1487364411147516184, + 7895421560427121838, + 602190064936008898, + 6022086574635100741, + -5080089175222015595, + -4066367846667249123, + ]), + i64x8::from([ + 9180885022207963908, + 3095981199532211089, + 6586075293021332726, + 419343203796414657, + 3186951873057035255, + 5287129228749947252, + 444726432079249540, + -1587028029513790706, + ]), + i64x8::from([ + 6075236523189346388, + 1351763722368165432, + -6192309979959753740, + -7697775502176768592, + -4482022114172078123, + 7522501477800909500, + -1837258847956201231, + -586926753024886735, + ]), + ], + ); + } + } +} diff --git a/src/distr/mod.rs b/src/distr/mod.rs new file mode 100644 index 00000000000..10016119ba2 --- /dev/null +++ b/src/distr/mod.rs @@ -0,0 +1,210 @@ +// Copyright 2018 Developers of the Rand project. +// Copyright 2013-2017 The Rust Project Developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Generating random samples from probability distributions +//! +//! This module is the home of the [`Distribution`] trait and several of its +//! implementations. It is the workhorse behind some of the convenient +//! functionality of the [`Rng`] trait, e.g. [`Rng::random`] and of course +//! [`Rng::sample`]. +//! +//! Abstractly, a [probability distribution] describes the probability of +//! occurrence of each value in its sample space. +//! +//! More concretely, an implementation of `Distribution` for type `X` is an +//! algorithm for choosing values from the sample space (a subset of `T`) +//! according to the distribution `X` represents, using an external source of +//! randomness (an RNG supplied to the `sample` function). +//! +//! A type `X` may implement `Distribution` for multiple types `T`. +//! Any type implementing [`Distribution`] is stateless (i.e. immutable), +//! but it may have internal parameters set at construction time (for example, +//! [`Uniform`] allows specification of its sample space as a range within `T`). +//! +//! +//! # The Standard Uniform distribution +//! +//! The [`StandardUniform`] distribution is important to mention. This is the +//! distribution used by [`Rng::random`] and represents the "default" way to +//! produce a random value for many different types, including most primitive +//! types, tuples, arrays, and a few derived types. See the documentation of +//! [`StandardUniform`] for more details. +//! +//! Implementing [`Distribution`] for [`StandardUniform`] for user types `T` makes it +//! possible to generate type `T` with [`Rng::random`], and by extension also +//! with the [`random`] function. +//! +//! ## Other standard uniform distributions +//! +//! [`Alphanumeric`] is a simple distribution to sample random letters and +//! numbers of the `char` type; in contrast [`StandardUniform`] may sample any valid +//! `char`. +//! +//! For floats (`f32`, `f64`), [`StandardUniform`] samples from `[0, 1)`. Also +//! provided are [`Open01`] (samples from `(0, 1)`) and [`OpenClosed01`] +//! (samples from `(0, 1]`). No option is provided to sample from `[0, 1]`; it +//! is suggested to use one of the above half-open ranges since the failure to +//! sample a value which would have a low chance of being sampled anyway is +//! rarely an issue in practice. +//! +//! # Parameterized Uniform distributions +//! +//! The [`Uniform`] distribution provides uniform sampling over a specified +//! range on a subset of the types supported by the above distributions. +//! +//! Implementations support single-value-sampling via +//! [`Rng::random_range(Range)`](Rng::random_range). +//! Where a fixed (non-`const`) range will be sampled many times, it is likely +//! faster to pre-construct a [`Distribution`] object using +//! [`Uniform::new`], [`Uniform::new_inclusive`] or `From`. +//! +//! # Non-uniform sampling +//! +//! Sampling a simple true/false outcome with a given probability has a name: +//! the [`Bernoulli`] distribution (this is used by [`Rng::random_bool`]). +//! +//! For weighted sampling of discrete values see the [`weighted`] module. +//! +//! This crate no longer includes other non-uniform distributions; instead +//! it is recommended that you use either [`rand_distr`] or [`statrs`]. +//! +//! +//! [probability distribution]: https://en.wikipedia.org/wiki/Probability_distribution +//! [`rand_distr`]: https://crates.io/crates/rand_distr +//! [`statrs`]: https://crates.io/crates/statrs + +//! [`random`]: crate::random +//! [`rand_distr`]: https://crates.io/crates/rand_distr +//! [`statrs`]: https://crates.io/crates/statrs + +mod bernoulli; +mod distribution; +mod float; +mod integer; +mod other; +mod utils; + +#[doc(hidden)] +pub mod hidden_export { + pub use super::float::IntoFloat; // used by rand_distr +} +pub mod slice; +pub mod uniform; +#[cfg(feature = "alloc")] +pub mod weighted; + +pub use self::bernoulli::{Bernoulli, BernoulliError}; +#[cfg(feature = "alloc")] +pub use self::distribution::SampleString; +pub use self::distribution::{Distribution, Iter, Map}; +pub use self::float::{Open01, OpenClosed01}; +pub use self::other::Alphanumeric; +#[doc(inline)] +pub use self::uniform::Uniform; + +#[allow(unused)] +use crate::Rng; + +/// The Standard Uniform distribution +/// +/// This [`Distribution`] is the *standard* parameterization of [`Uniform`]. Bounds +/// are selected according to the output type. +/// +/// Assuming the provided `Rng` is well-behaved, these implementations +/// generate values with the following ranges and distributions: +/// +/// * Integers (`i8`, `i32`, `u64`, etc.) are uniformly distributed +/// over the whole range of the type (thus each possible value may be sampled +/// with equal probability). +/// * `char` is uniformly distributed over all Unicode scalar values, i.e. all +/// code points in the range `0...0x10_FFFF`, except for the range +/// `0xD800...0xDFFF` (the surrogate code points). This includes +/// unassigned/reserved code points. +/// For some uses, the [`Alphanumeric`] distribution will be more appropriate. +/// * `bool` samples `false` or `true`, each with probability 0.5. +/// * Floating point types (`f32` and `f64`) are uniformly distributed in the +/// half-open range `[0, 1)`. See also the [notes below](#floating-point-implementation). +/// * Wrapping integers ([`Wrapping`]), besides the type identical to their +/// normal integer variants. +/// * Non-zero integers ([`NonZeroU8`]), which are like their normal integer +/// variants but cannot sample zero. +/// +/// The `StandardUniform` distribution also supports generation of the following +/// compound types where all component types are supported: +/// +/// * Tuples (up to 12 elements): each element is sampled sequentially and +/// independently (thus, assuming a well-behaved RNG, there is no correlation +/// between elements). +/// * Arrays `[T; n]` where `T` is supported. Each element is sampled +/// sequentially and independently. Note that for small `T` this usually +/// results in the RNG discarding random bits; see also [`Rng::fill`] which +/// offers a more efficient approach to filling an array of integer types +/// with random data. +/// * SIMD types (requires [`simd_support`] feature) like x86's [`__m128i`] +/// and `std::simd`'s [`u32x4`], [`f32x4`] and [`mask32x4`] types are +/// effectively arrays of integer or floating-point types. Each lane is +/// sampled independently, potentially with more efficient random-bit-usage +/// (and a different resulting value) than would be achieved with sequential +/// sampling (as with the array types above). +/// +/// ## Custom implementations +/// +/// The [`StandardUniform`] distribution may be implemented for user types as follows: +/// +/// ``` +/// # #![allow(dead_code)] +/// use rand::Rng; +/// use rand::distr::{Distribution, StandardUniform}; +/// +/// struct MyF32 { +/// x: f32, +/// } +/// +/// impl Distribution for StandardUniform { +/// fn sample(&self, rng: &mut R) -> MyF32 { +/// MyF32 { x: rng.random() } +/// } +/// } +/// ``` +/// +/// ## Example usage +/// ``` +/// use rand::prelude::*; +/// use rand::distr::StandardUniform; +/// +/// let val: f32 = rand::rng().sample(StandardUniform); +/// println!("f32 from [0, 1): {}", val); +/// ``` +/// +/// # Floating point implementation +/// The floating point implementations for `StandardUniform` generate a random value in +/// the half-open interval `[0, 1)`, i.e. including 0 but not 1. +/// +/// All values that can be generated are of the form `n * ε/2`. For `f32` +/// the 24 most significant random bits of a `u32` are used and for `f64` the +/// 53 most significant bits of a `u64` are used. The conversion uses the +/// multiplicative method: `(rng.gen::<$uty>() >> N) as $ty * (ε/2)`. +/// +/// See also: [`Open01`] which samples from `(0, 1)`, [`OpenClosed01`] which +/// samples from `(0, 1]` and `Rng::random_range(0..1)` which also samples from +/// `[0, 1)`. Note that `Open01` uses transmute-based methods which yield 1 bit +/// less precision but may perform faster on some architectures (on modern Intel +/// CPUs all methods have approximately equal performance). +/// +/// [`Uniform`]: uniform::Uniform +/// [`Wrapping`]: std::num::Wrapping +/// [`NonZeroU8`]: std::num::NonZeroU8 +/// [`__m128i`]: https://doc.rust-lang.org/core/arch/x86/struct.__m128i.html +/// [`u32x4`]: std::simd::u32x4 +/// [`f32x4`]: std::simd::f32x4 +/// [`mask32x4`]: std::simd::mask32x4 +/// [`simd_support`]: https://github.com/rust-random/rand#crate-features +#[derive(Clone, Copy, Debug, Default)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct StandardUniform; diff --git a/src/distr/other.rs b/src/distr/other.rs new file mode 100644 index 00000000000..9890bdafe6d --- /dev/null +++ b/src/distr/other.rs @@ -0,0 +1,375 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! The implementations of the `StandardUniform` distribution for other built-in types. + +#[cfg(feature = "alloc")] +use alloc::string::String; +use core::char; +use core::num::Wrapping; + +#[cfg(feature = "alloc")] +use crate::distr::SampleString; +use crate::distr::{Distribution, StandardUniform, Uniform}; +use crate::Rng; + +use core::mem::{self, MaybeUninit}; +#[cfg(feature = "simd_support")] +use core::simd::prelude::*; +#[cfg(feature = "simd_support")] +use core::simd::{LaneCount, MaskElement, SupportedLaneCount}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +// ----- Sampling distributions ----- + +/// Sample a `u8`, uniformly distributed over ASCII letters and numbers: +/// a-z, A-Z and 0-9. +/// +/// # Example +/// +/// ``` +/// use rand::Rng; +/// use rand::distr::Alphanumeric; +/// +/// let mut rng = rand::rng(); +/// let chars: String = (0..7).map(|_| rng.sample(Alphanumeric) as char).collect(); +/// println!("Random chars: {}", chars); +/// ``` +/// +/// The [`SampleString`] trait provides an easier method of generating +/// a random [`String`], and offers more efficient allocation: +/// ``` +/// use rand::distr::{Alphanumeric, SampleString}; +/// let string = Alphanumeric.sample_string(&mut rand::rng(), 16); +/// println!("Random string: {}", string); +/// ``` +/// +/// # Passwords +/// +/// Users sometimes ask whether it is safe to use a string of random characters +/// as a password. In principle, all RNGs in Rand implementing `CryptoRng` are +/// suitable as a source of randomness for generating passwords (if they are +/// properly seeded), but it is more conservative to only use randomness +/// directly from the operating system via the `getrandom` crate, or the +/// corresponding bindings of a crypto library. +/// +/// When generating passwords or keys, it is important to consider the threat +/// model and in some cases the memorability of the password. This is out of +/// scope of the Rand project, and therefore we defer to the following +/// references: +/// +/// - [Wikipedia article on Password Strength](https://en.wikipedia.org/wiki/Password_strength) +/// - [Diceware for generating memorable passwords](https://en.wikipedia.org/wiki/Diceware) +#[derive(Debug, Clone, Copy, Default)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct Alphanumeric; + +// ----- Implementations of distributions ----- + +impl Distribution for StandardUniform { + #[inline] + fn sample(&self, rng: &mut R) -> char { + // A valid `char` is either in the interval `[0, 0xD800)` or + // `(0xDFFF, 0x11_0000)`. All `char`s must therefore be in + // `[0, 0x11_0000)` but not in the "gap" `[0xD800, 0xDFFF]` which is + // reserved for surrogates. This is the size of that gap. + const GAP_SIZE: u32 = 0xDFFF - 0xD800 + 1; + + // Uniform::new(0, 0x11_0000 - GAP_SIZE) can also be used, but it + // seemed slower. + let range = Uniform::new(GAP_SIZE, 0x11_0000).unwrap(); + + let mut n = range.sample(rng); + if n <= 0xDFFF { + n -= GAP_SIZE; + } + unsafe { char::from_u32_unchecked(n) } + } +} + +#[cfg(feature = "alloc")] +impl SampleString for StandardUniform { + fn append_string(&self, rng: &mut R, s: &mut String, len: usize) { + // A char is encoded with at most four bytes, thus this reservation is + // guaranteed to be sufficient. We do not shrink_to_fit afterwards so + // that repeated usage on the same `String` buffer does not reallocate. + s.reserve(4 * len); + s.extend(Distribution::::sample_iter(self, rng).take(len)); + } +} + +impl Distribution for Alphanumeric { + fn sample(&self, rng: &mut R) -> u8 { + const RANGE: u32 = 26 + 26 + 10; + const GEN_ASCII_STR_CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ\ + abcdefghijklmnopqrstuvwxyz\ + 0123456789"; + // We can pick from 62 characters. This is so close to a power of 2, 64, + // that we can do better than `Uniform`. Use a simple bitshift and + // rejection sampling. We do not use a bitmask, because for small RNGs + // the most significant bits are usually of higher quality. + loop { + let var = rng.next_u32() >> (32 - 6); + if var < RANGE { + return GEN_ASCII_STR_CHARSET[var as usize]; + } + } + } +} + +#[cfg(feature = "alloc")] +impl SampleString for Alphanumeric { + fn append_string(&self, rng: &mut R, string: &mut String, len: usize) { + unsafe { + let v = string.as_mut_vec(); + v.extend(self.sample_iter(rng).take(len)); + } + } +} + +impl Distribution for StandardUniform { + #[inline] + fn sample(&self, rng: &mut R) -> bool { + // We can compare against an arbitrary bit of an u32 to get a bool. + // Because the least significant bits of a lower quality RNG can have + // simple patterns, we compare against the most significant bit. This is + // easiest done using a sign test. + (rng.next_u32() as i32) < 0 + } +} + +/// Note that on some hardware like x86/64 mask operations like [`_mm_blendv_epi8`] +/// only care about a single bit. This means that you could use uniform random bits +/// directly: +/// +/// ```ignore +/// // this may be faster... +/// let x = unsafe { _mm_blendv_epi8(a.into(), b.into(), rng.random::<__m128i>()) }; +/// +/// // ...than this +/// let x = rng.random::().select(b, a); +/// ``` +/// +/// Since most bits are unused you could also generate only as many bits as you need, i.e.: +/// ``` +/// #![feature(portable_simd)] +/// use std::simd::prelude::*; +/// use rand::prelude::*; +/// let mut rng = rand::rng(); +/// +/// let x = u16x8::splat(rng.random::() as u16); +/// let mask = u16x8::splat(1) << u16x8::from([0, 1, 2, 3, 4, 5, 6, 7]); +/// let rand_mask = (x & mask).simd_eq(mask); +/// ``` +/// +/// [`_mm_blendv_epi8`]: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_blendv_epi8&ig_expand=514/ +/// [`simd_support`]: https://github.com/rust-random/rand#crate-features +#[cfg(feature = "simd_support")] +impl Distribution> for StandardUniform +where + T: MaskElement + Default, + LaneCount: SupportedLaneCount, + StandardUniform: Distribution>, + Simd: SimdPartialOrd>, +{ + #[inline] + fn sample(&self, rng: &mut R) -> Mask { + // `MaskElement` must be a signed integer, so this is equivalent + // to the scalar `i32 < 0` method + let var = rng.random::>(); + var.simd_lt(Simd::default()) + } +} + +/// Implement `Distribution<(A, B, C, ...)> for StandardUniform`, using the list of +/// identifiers +macro_rules! tuple_impl { + ($($tyvar:ident)*) => { + impl< $($tyvar,)* > Distribution<($($tyvar,)*)> for StandardUniform + where $( + StandardUniform: Distribution< $tyvar >, + )* + { + #[inline] + fn sample(&self, rng: &mut R) -> ( $($tyvar,)* ) { + let out = ($( + // use the $tyvar's to get the appropriate number of + // repeats (they're not actually needed) + rng.random::<$tyvar>() + ,)*); + + // Suppress the unused variable warning for empty tuple + let _rng = rng; + + out + } + } + } +} + +/// Looping wrapper for `tuple_impl`. Given (A, B, C), it also generates +/// implementations for (A, B) and (A,) +macro_rules! tuple_impls { + ($($tyvar:ident)*) => {tuple_impls!{[] $($tyvar)*}}; + + ([$($prefix:ident)*] $head:ident $($tail:ident)*) => { + tuple_impl!{$($prefix)*} + tuple_impls!{[$($prefix)* $head] $($tail)*} + }; + + + ([$($prefix:ident)*]) => { + tuple_impl!{$($prefix)*} + }; + +} + +tuple_impls! {A B C D E F G H I J K L} + +impl Distribution<[T; N]> for StandardUniform +where + StandardUniform: Distribution, +{ + #[inline] + fn sample(&self, _rng: &mut R) -> [T; N] { + let mut buff: [MaybeUninit; N] = unsafe { MaybeUninit::uninit().assume_init() }; + + for elem in &mut buff { + *elem = MaybeUninit::new(_rng.random()); + } + + unsafe { mem::transmute_copy::<_, _>(&buff) } + } +} + +impl Distribution> for StandardUniform +where + StandardUniform: Distribution, +{ + #[inline] + fn sample(&self, rng: &mut R) -> Wrapping { + Wrapping(rng.random()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::RngCore; + + #[test] + fn test_misc() { + let rng: &mut dyn RngCore = &mut crate::test::rng(820); + + rng.sample::(StandardUniform); + rng.sample::(StandardUniform); + } + + #[cfg(feature = "alloc")] + #[test] + fn test_chars() { + use core::iter; + let mut rng = crate::test::rng(805); + + // Test by generating a relatively large number of chars, so we also + // take the rejection sampling path. + let word: String = iter::repeat(()) + .map(|()| rng.random::()) + .take(1000) + .collect(); + assert!(!word.is_empty()); + } + + #[test] + fn test_alphanumeric() { + let mut rng = crate::test::rng(806); + + // Test by generating a relatively large number of chars, so we also + // take the rejection sampling path. + let mut incorrect = false; + for _ in 0..100 { + let c: char = rng.sample(Alphanumeric).into(); + incorrect |= !c.is_ascii_alphanumeric(); + } + assert!(!incorrect); + } + + #[test] + fn value_stability() { + fn test_samples>( + distr: &D, + zero: T, + expected: &[T], + ) { + let mut rng = crate::test::rng(807); + let mut buf = [zero; 5]; + for x in &mut buf { + *x = rng.sample(distr); + } + assert_eq!(&buf, expected); + } + + test_samples( + &StandardUniform, + 'a', + &[ + '\u{8cdac}', + '\u{a346a}', + '\u{80120}', + '\u{ed692}', + '\u{35888}', + ], + ); + test_samples(&Alphanumeric, 0, &[104, 109, 101, 51, 77]); + test_samples(&StandardUniform, false, &[true, true, false, true, false]); + test_samples( + &StandardUniform, + Wrapping(0i32), + &[ + Wrapping(-2074640887), + Wrapping(-1719949321), + Wrapping(2018088303), + Wrapping(-547181756), + Wrapping(838957336), + ], + ); + + // We test only sub-sets of tuple and array impls + test_samples(&StandardUniform, (), &[(), (), (), (), ()]); + test_samples( + &StandardUniform, + (false,), + &[(true,), (true,), (false,), (true,), (false,)], + ); + test_samples( + &StandardUniform, + (false, false), + &[ + (true, true), + (false, true), + (false, false), + (true, false), + (false, false), + ], + ); + + test_samples(&StandardUniform, [0u8; 0], &[[], [], [], [], []]); + test_samples( + &StandardUniform, + [0u8; 3], + &[ + [9, 247, 111], + [68, 24, 13], + [174, 19, 194], + [172, 69, 213], + [149, 207, 29], + ], + ); + } +} diff --git a/src/distr/slice.rs b/src/distr/slice.rs new file mode 100644 index 00000000000..07e243fec5d --- /dev/null +++ b/src/distr/slice.rs @@ -0,0 +1,167 @@ +// Copyright 2021 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Distributions over slices + +use core::num::NonZeroUsize; + +use crate::distr::uniform::{UniformSampler, UniformUsize}; +use crate::distr::Distribution; +#[cfg(feature = "alloc")] +use alloc::string::String; + +/// A distribution to uniformly sample elements of a slice +/// +/// Like [`IndexedRandom::choose`], this uniformly samples elements of a slice +/// without modification of the slice (so called "sampling with replacement"). +/// This distribution object may be a little faster for repeated sampling (but +/// slower for small numbers of samples). +/// +/// ## Examples +/// +/// Since this is a distribution, [`Rng::sample_iter`] and +/// [`Distribution::sample_iter`] may be used, for example: +/// ``` +/// use rand::distr::{Distribution, slice::Choose}; +/// +/// let vowels = ['a', 'e', 'i', 'o', 'u']; +/// let vowels_dist = Choose::new(&vowels).unwrap(); +/// +/// // build a string of 10 vowels +/// let vowel_string: String = vowels_dist +/// .sample_iter(&mut rand::rng()) +/// .take(10) +/// .collect(); +/// +/// println!("{}", vowel_string); +/// assert_eq!(vowel_string.len(), 10); +/// assert!(vowel_string.chars().all(|c| vowels.contains(&c))); +/// ``` +/// +/// For a single sample, [`IndexedRandom::choose`] may be preferred: +/// ``` +/// use rand::seq::IndexedRandom; +/// +/// let vowels = ['a', 'e', 'i', 'o', 'u']; +/// let mut rng = rand::rng(); +/// +/// println!("{}", vowels.choose(&mut rng).unwrap()); +/// ``` +/// +/// [`IndexedRandom::choose`]: crate::seq::IndexedRandom::choose +/// [`Rng::sample_iter`]: crate::Rng::sample_iter +#[derive(Debug, Clone, Copy)] +pub struct Choose<'a, T> { + slice: &'a [T], + range: UniformUsize, + num_choices: NonZeroUsize, +} + +impl<'a, T> Choose<'a, T> { + /// Create a new `Choose` instance which samples uniformly from the slice. + /// + /// Returns error [`Empty`] if the slice is empty. + pub fn new(slice: &'a [T]) -> Result { + let num_choices = NonZeroUsize::new(slice.len()).ok_or(Empty)?; + + Ok(Self { + slice, + range: UniformUsize::new(0, num_choices.get()).unwrap(), + num_choices, + }) + } + + /// Returns the count of choices in this distribution + pub fn num_choices(&self) -> NonZeroUsize { + self.num_choices + } +} + +impl<'a, T> Distribution<&'a T> for Choose<'a, T> { + fn sample(&self, rng: &mut R) -> &'a T { + let idx = self.range.sample(rng); + + debug_assert!( + idx < self.slice.len(), + "Uniform::new(0, {}) somehow returned {}", + self.slice.len(), + idx + ); + + // Safety: at construction time, it was ensured that the slice was + // non-empty, and that the `Uniform` range produces values in range + // for the slice + unsafe { self.slice.get_unchecked(idx) } + } +} + +/// Error: empty slice +/// +/// This error is returned when [`Choose::new`] is given an empty slice. +#[derive(Debug, Clone, Copy)] +pub struct Empty; + +impl core::fmt::Display for Empty { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!( + f, + "Tried to create a `rand::distr::slice::Choose` with an empty slice" + ) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for Empty {} + +#[cfg(feature = "alloc")] +impl super::SampleString for Choose<'_, char> { + fn append_string(&self, rng: &mut R, string: &mut String, len: usize) { + // Get the max char length to minimize extra space. + // Limit this check to avoid searching for long slice. + let max_char_len = if self.slice.len() < 200 { + self.slice + .iter() + .try_fold(1, |max_len, char| { + // When the current max_len is 4, the result max_char_len will be 4. + Some(max_len.max(char.len_utf8())).filter(|len| *len < 4) + }) + .unwrap_or(4) + } else { + 4 + }; + + // Split the extension of string to reuse the unused capacities. + // Skip the split for small length or only ascii slice. + let mut extend_len = if max_char_len == 1 || len < 100 { + len + } else { + len / 4 + }; + let mut remain_len = len; + while extend_len > 0 { + string.reserve(max_char_len * extend_len); + string.extend(self.sample_iter(&mut *rng).take(extend_len)); + remain_len -= extend_len; + extend_len = extend_len.min(remain_len); + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use core::iter; + + #[test] + fn value_stability() { + let rng = crate::test::rng(651); + let slice = Choose::new(b"escaped emus explore extensively").unwrap(); + let expected = b"eaxee"; + assert!(iter::zip(slice.sample_iter(rng), expected).all(|(a, b)| a == b)); + } +} diff --git a/src/distr/uniform.rs b/src/distr/uniform.rs new file mode 100644 index 00000000000..b59fdbf790b --- /dev/null +++ b/src/distr/uniform.rs @@ -0,0 +1,622 @@ +// Copyright 2018-2020 Developers of the Rand project. +// Copyright 2017 The Rust Project Developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! A distribution uniformly sampling numbers within a given range. +//! +//! [`Uniform`] is the standard distribution to sample uniformly from a range; +//! e.g. `Uniform::new_inclusive(1, 6).unwrap()` can sample integers from 1 to 6, like a +//! standard die. [`Rng::random_range`] is implemented over [`Uniform`]. +//! +//! # Example usage +//! +//! ``` +//! use rand::Rng; +//! use rand::distr::Uniform; +//! +//! let mut rng = rand::rng(); +//! let side = Uniform::new(-10.0, 10.0).unwrap(); +//! +//! // sample between 1 and 10 points +//! for _ in 0..rng.random_range(1..=10) { +//! // sample a point from the square with sides -10 - 10 in two dimensions +//! let (x, y) = (rng.sample(side), rng.sample(side)); +//! println!("Point: {}, {}", x, y); +//! } +//! ``` +//! +//! # Extending `Uniform` to support a custom type +//! +//! To extend [`Uniform`] to support your own types, write a back-end which +//! implements the [`UniformSampler`] trait, then implement the [`SampleUniform`] +//! helper trait to "register" your back-end. See the `MyF32` example below. +//! +//! At a minimum, the back-end needs to store any parameters needed for sampling +//! (e.g. the target range) and implement `new`, `new_inclusive` and `sample`. +//! Those methods should include an assertion to check the range is valid (i.e. +//! `low < high`). The example below merely wraps another back-end. +//! +//! The `new`, `new_inclusive`, `sample_single` and `sample_single_inclusive` +//! functions use arguments of +//! type `SampleBorrow` to support passing in values by reference or +//! by value. In the implementation of these functions, you can choose to +//! simply use the reference returned by [`SampleBorrow::borrow`], or you can choose +//! to copy or clone the value, whatever is appropriate for your type. +//! +//! ``` +//! use rand::prelude::*; +//! use rand::distr::uniform::{Uniform, SampleUniform, +//! UniformSampler, UniformFloat, SampleBorrow, Error}; +//! +//! struct MyF32(f32); +//! +//! #[derive(Clone, Copy, Debug)] +//! struct UniformMyF32(UniformFloat); +//! +//! impl UniformSampler for UniformMyF32 { +//! type X = MyF32; +//! +//! fn new(low: B1, high: B2) -> Result +//! where B1: SampleBorrow + Sized, +//! B2: SampleBorrow + Sized +//! { +//! UniformFloat::::new(low.borrow().0, high.borrow().0).map(UniformMyF32) +//! } +//! fn new_inclusive(low: B1, high: B2) -> Result +//! where B1: SampleBorrow + Sized, +//! B2: SampleBorrow + Sized +//! { +//! UniformFloat::::new_inclusive(low.borrow().0, high.borrow().0).map(UniformMyF32) +//! } +//! fn sample(&self, rng: &mut R) -> Self::X { +//! MyF32(self.0.sample(rng)) +//! } +//! } +//! +//! impl SampleUniform for MyF32 { +//! type Sampler = UniformMyF32; +//! } +//! +//! let (low, high) = (MyF32(17.0f32), MyF32(22.0f32)); +//! let uniform = Uniform::new(low, high).unwrap(); +//! let x = uniform.sample(&mut rand::rng()); +//! ``` +//! +//! [`SampleUniform`]: crate::distr::uniform::SampleUniform +//! [`UniformSampler`]: crate::distr::uniform::UniformSampler +//! [`UniformInt`]: crate::distr::uniform::UniformInt +//! [`UniformFloat`]: crate::distr::uniform::UniformFloat +//! [`UniformDuration`]: crate::distr::uniform::UniformDuration +//! [`SampleBorrow::borrow`]: crate::distr::uniform::SampleBorrow::borrow + +#[path = "uniform_float.rs"] +mod float; +#[doc(inline)] +pub use float::UniformFloat; + +#[path = "uniform_int.rs"] +mod int; +#[doc(inline)] +pub use int::{UniformInt, UniformUsize}; + +#[path = "uniform_other.rs"] +mod other; +#[doc(inline)] +pub use other::{UniformChar, UniformDuration}; + +use core::fmt; +use core::ops::{Range, RangeInclusive, RangeTo, RangeToInclusive}; + +use crate::distr::Distribution; +use crate::{Rng, RngCore}; + +/// Error type returned from [`Uniform::new`] and `new_inclusive`. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Error { + /// `low > high`, or equal in case of exclusive range. + EmptyRange, + /// Input or range `high - low` is non-finite. Not relevant to integer types. + NonFinite, +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + Error::EmptyRange => "low > high (or equal if exclusive) in uniform distribution", + Error::NonFinite => "Non-finite range in uniform distribution", + }) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for Error {} + +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +/// Sample values uniformly between two bounds. +/// +/// # Construction +/// +/// [`Uniform::new`] and [`Uniform::new_inclusive`] construct a uniform +/// distribution sampling from the given `low` and `high` limits. `Uniform` may +/// also be constructed via [`TryFrom`] as in `Uniform::try_from(1..=6).unwrap()`. +/// +/// Constructors may do extra work up front to allow faster sampling of multiple +/// values. Where only a single sample is required it is suggested to use +/// [`Rng::random_range`] or one of the `sample_single` methods instead. +/// +/// When sampling from a constant range, many calculations can happen at +/// compile-time and all methods should be fast; for floating-point ranges and +/// the full range of integer types, this should have comparable performance to +/// the [`StandardUniform`](super::StandardUniform) distribution. +/// +/// # Provided implementations +/// +/// - `char` ([`UniformChar`]): samples a range over the implementation for `u32` +/// - `f32`, `f64` ([`UniformFloat`]): samples approximately uniformly within a +/// range; bias may be present in the least-significant bit of the significand +/// and the limits of the input range may be sampled even when an open +/// (exclusive) range is used +/// - Integer types ([`UniformInt`]) may show a small bias relative to the +/// expected uniform distribution of output. In the worst case, bias affects +/// 1 in `2^n` samples where n is 56 (`i8` and `u8`), 48 (`i16` and `u16`), 96 +/// (`i32` and `u32`), 64 (`i64` and `u64`), 128 (`i128` and `u128`). +/// The `unbiased` feature flag fixes this bias. +/// - `usize` ([`UniformUsize`]) is handled specially, using the `u32` +/// implementation where possible to enable portable results across 32-bit and +/// 64-bit CPU architectures. +/// - `Duration` ([`UniformDuration`]): samples a range over the implementation +/// for `u32` or `u64` +/// - SIMD types (requires [`simd_support`] feature) like x86's [`__m128i`] +/// and `std::simd`'s [`u32x4`], [`f32x4`] and [`mask32x4`] types are +/// effectively arrays of integer or floating-point types. Each lane is +/// sampled independently from its own range, potentially with more efficient +/// random-bit-usage than would be achieved with sequential sampling. +/// +/// # Example +/// +/// ``` +/// use rand::distr::{Distribution, Uniform}; +/// +/// let between = Uniform::try_from(10..10000).unwrap(); +/// let mut rng = rand::rng(); +/// let mut sum = 0; +/// for _ in 0..1000 { +/// sum += between.sample(&mut rng); +/// } +/// println!("{}", sum); +/// ``` +/// +/// For a single sample, [`Rng::random_range`] may be preferred: +/// +/// ``` +/// use rand::Rng; +/// +/// let mut rng = rand::rng(); +/// println!("{}", rng.random_range(0..10)); +/// ``` +/// +/// [`new`]: Uniform::new +/// [`new_inclusive`]: Uniform::new_inclusive +/// [`Rng::random_range`]: Rng::random_range +/// [`__m128i`]: https://doc.rust-lang.org/core/arch/x86/struct.__m128i.html +/// [`u32x4`]: std::simd::u32x4 +/// [`f32x4`]: std::simd::f32x4 +/// [`mask32x4`]: std::simd::mask32x4 +/// [`simd_support`]: https://github.com/rust-random/rand#crate-features +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", serde(bound(serialize = "X::Sampler: Serialize")))] +#[cfg_attr( + feature = "serde", + serde(bound(deserialize = "X::Sampler: Deserialize<'de>")) +)] +pub struct Uniform(X::Sampler); + +impl Uniform { + /// Create a new `Uniform` instance, which samples uniformly from the half + /// open range `[low, high)` (excluding `high`). + /// + /// For discrete types (e.g. integers), samples will always be strictly less + /// than `high`. For (approximations of) continuous types (e.g. `f32`, `f64`), + /// samples may equal `high` due to loss of precision but may not be + /// greater than `high`. + /// + /// Fails if `low >= high`, or if `low`, `high` or the range `high - low` is + /// non-finite. In release mode, only the range is checked. + pub fn new(low: B1, high: B2) -> Result, Error> + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + X::Sampler::new(low, high).map(Uniform) + } + + /// Create a new `Uniform` instance, which samples uniformly from the closed + /// range `[low, high]` (inclusive). + /// + /// Fails if `low > high`, or if `low`, `high` or the range `high - low` is + /// non-finite. In release mode, only the range is checked. + pub fn new_inclusive(low: B1, high: B2) -> Result, Error> + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + X::Sampler::new_inclusive(low, high).map(Uniform) + } +} + +impl Distribution for Uniform { + fn sample(&self, rng: &mut R) -> X { + self.0.sample(rng) + } +} + +/// Helper trait for creating objects using the correct implementation of +/// [`UniformSampler`] for the sampling type. +/// +/// See the [module documentation] on how to implement [`Uniform`] range +/// sampling for a custom type. +/// +/// [module documentation]: crate::distr::uniform +pub trait SampleUniform: Sized { + /// The `UniformSampler` implementation supporting type `X`. + type Sampler: UniformSampler; +} + +/// Helper trait handling actual uniform sampling. +/// +/// See the [module documentation] on how to implement [`Uniform`] range +/// sampling for a custom type. +/// +/// Implementation of [`sample_single`] is optional, and is only useful when +/// the implementation can be faster than `Self::new(low, high).sample(rng)`. +/// +/// [module documentation]: crate::distr::uniform +/// [`sample_single`]: UniformSampler::sample_single +pub trait UniformSampler: Sized { + /// The type sampled by this implementation. + type X; + + /// Construct self, with inclusive lower bound and exclusive upper bound `[low, high)`. + /// + /// For discrete types (e.g. integers), samples will always be strictly less + /// than `high`. For (approximations of) continuous types (e.g. `f32`, `f64`), + /// samples may equal `high` due to loss of precision but may not be + /// greater than `high`. + /// + /// Usually users should not call this directly but prefer to use + /// [`Uniform::new`]. + fn new(low: B1, high: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized; + + /// Construct self, with inclusive bounds `[low, high]`. + /// + /// Usually users should not call this directly but prefer to use + /// [`Uniform::new_inclusive`]. + fn new_inclusive(low: B1, high: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized; + + /// Sample a value. + fn sample(&self, rng: &mut R) -> Self::X; + + /// Sample a single value uniformly from a range with inclusive lower bound + /// and exclusive upper bound `[low, high)`. + /// + /// For discrete types (e.g. integers), samples will always be strictly less + /// than `high`. For (approximations of) continuous types (e.g. `f32`, `f64`), + /// samples may equal `high` due to loss of precision but may not be + /// greater than `high`. + /// + /// By default this is implemented using + /// `UniformSampler::new(low, high).sample(rng)`. However, for some types + /// more optimal implementations for single usage may be provided via this + /// method (which is the case for integers and floats). + /// Results may not be identical. + /// + /// Note that to use this method in a generic context, the type needs to be + /// retrieved via `SampleUniform::Sampler` as follows: + /// ``` + /// use rand::distr::uniform::{SampleUniform, UniformSampler}; + /// # #[allow(unused)] + /// fn sample_from_range(lb: T, ub: T) -> T { + /// let mut rng = rand::rng(); + /// ::Sampler::sample_single(lb, ub, &mut rng).unwrap() + /// } + /// ``` + fn sample_single( + low: B1, + high: B2, + rng: &mut R, + ) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let uniform: Self = UniformSampler::new(low, high)?; + Ok(uniform.sample(rng)) + } + + /// Sample a single value uniformly from a range with inclusive lower bound + /// and inclusive upper bound `[low, high]`. + /// + /// By default this is implemented using + /// `UniformSampler::new_inclusive(low, high).sample(rng)`. However, for + /// some types more optimal implementations for single usage may be provided + /// via this method. + /// Results may not be identical. + fn sample_single_inclusive( + low: B1, + high: B2, + rng: &mut R, + ) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let uniform: Self = UniformSampler::new_inclusive(low, high)?; + Ok(uniform.sample(rng)) + } +} + +impl TryFrom> for Uniform { + type Error = Error; + + fn try_from(r: Range) -> Result, Error> { + Uniform::new(r.start, r.end) + } +} + +impl TryFrom> for Uniform { + type Error = Error; + + fn try_from(r: ::core::ops::RangeInclusive) -> Result, Error> { + Uniform::new_inclusive(r.start(), r.end()) + } +} + +/// Helper trait similar to [`Borrow`] but implemented +/// only for [`SampleUniform`] and references to [`SampleUniform`] +/// in order to resolve ambiguity issues. +/// +/// [`Borrow`]: std::borrow::Borrow +pub trait SampleBorrow { + /// Immutably borrows from an owned value. See [`Borrow::borrow`] + /// + /// [`Borrow::borrow`]: std::borrow::Borrow::borrow + fn borrow(&self) -> &Borrowed; +} +impl SampleBorrow for Borrowed +where + Borrowed: SampleUniform, +{ + #[inline(always)] + fn borrow(&self) -> &Borrowed { + self + } +} +impl SampleBorrow for &Borrowed +where + Borrowed: SampleUniform, +{ + #[inline(always)] + fn borrow(&self) -> &Borrowed { + self + } +} + +/// Range that supports generating a single sample efficiently. +/// +/// Any type implementing this trait can be used to specify the sampled range +/// for `Rng::random_range`. +pub trait SampleRange { + /// Generate a sample from the given range. + fn sample_single(self, rng: &mut R) -> Result; + + /// Check whether the range is empty. + fn is_empty(&self) -> bool; +} + +impl SampleRange for Range { + #[inline] + fn sample_single(self, rng: &mut R) -> Result { + T::Sampler::sample_single(self.start, self.end, rng) + } + + #[inline] + fn is_empty(&self) -> bool { + !(self.start < self.end) + } +} + +impl SampleRange for RangeInclusive { + #[inline] + fn sample_single(self, rng: &mut R) -> Result { + T::Sampler::sample_single_inclusive(self.start(), self.end(), rng) + } + + #[inline] + fn is_empty(&self) -> bool { + !(self.start() <= self.end()) + } +} + +macro_rules! impl_sample_range_u { + ($t:ty) => { + impl SampleRange<$t> for RangeTo<$t> { + #[inline] + fn sample_single(self, rng: &mut R) -> Result<$t, Error> { + <$t as SampleUniform>::Sampler::sample_single(0, self.end, rng) + } + + #[inline] + fn is_empty(&self) -> bool { + 0 == self.end + } + } + + impl SampleRange<$t> for RangeToInclusive<$t> { + #[inline] + fn sample_single(self, rng: &mut R) -> Result<$t, Error> { + <$t as SampleUniform>::Sampler::sample_single_inclusive(0, self.end, rng) + } + + #[inline] + fn is_empty(&self) -> bool { + false + } + } + }; +} + +impl_sample_range_u!(u8); +impl_sample_range_u!(u16); +impl_sample_range_u!(u32); +impl_sample_range_u!(u64); +impl_sample_range_u!(u128); +impl_sample_range_u!(usize); + +#[cfg(test)] +mod tests { + use super::*; + use core::time::Duration; + + #[test] + #[cfg(feature = "serde")] + fn test_uniform_serialization() { + let unit_box: Uniform = Uniform::new(-1, 1).unwrap(); + let de_unit_box: Uniform = + bincode::deserialize(&bincode::serialize(&unit_box).unwrap()).unwrap(); + assert_eq!(unit_box.0, de_unit_box.0); + + let unit_box: Uniform = Uniform::new(-1., 1.).unwrap(); + let de_unit_box: Uniform = + bincode::deserialize(&bincode::serialize(&unit_box).unwrap()).unwrap(); + assert_eq!(unit_box.0, de_unit_box.0); + } + + #[test] + fn test_custom_uniform() { + use crate::distr::uniform::{SampleBorrow, SampleUniform, UniformFloat, UniformSampler}; + #[derive(Clone, Copy, PartialEq, PartialOrd)] + struct MyF32 { + x: f32, + } + #[derive(Clone, Copy, Debug)] + struct UniformMyF32(UniformFloat); + impl UniformSampler for UniformMyF32 { + type X = MyF32; + + fn new(low: B1, high: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + UniformFloat::::new(low.borrow().x, high.borrow().x).map(UniformMyF32) + } + + fn new_inclusive(low: B1, high: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + UniformSampler::new(low, high) + } + + fn sample(&self, rng: &mut R) -> Self::X { + MyF32 { + x: self.0.sample(rng), + } + } + } + impl SampleUniform for MyF32 { + type Sampler = UniformMyF32; + } + + let (low, high) = (MyF32 { x: 17.0f32 }, MyF32 { x: 22.0f32 }); + let uniform = Uniform::new(low, high).unwrap(); + let mut rng = crate::test::rng(804); + for _ in 0..100 { + let x: MyF32 = rng.sample(uniform); + assert!(low <= x && x < high); + } + } + + #[test] + fn value_stability() { + fn test_samples( + lb: T, + ub: T, + expected_single: &[T], + expected_multiple: &[T], + ) where + Uniform: Distribution, + { + let mut rng = crate::test::rng(897); + let mut buf = [lb; 3]; + + for x in &mut buf { + *x = T::Sampler::sample_single(lb, ub, &mut rng).unwrap(); + } + assert_eq!(&buf, expected_single); + + let distr = Uniform::new(lb, ub).unwrap(); + for x in &mut buf { + *x = rng.sample(&distr); + } + assert_eq!(&buf, expected_multiple); + } + + test_samples( + 0f32, + 1e-2f32, + &[0.0003070104, 0.0026630748, 0.00979833], + &[0.008194133, 0.00398172, 0.007428536], + ); + test_samples( + -1e10f64, + 1e10f64, + &[-4673848682.871551, 6388267422.932352, 4857075081.198343], + &[1173375212.1808167, 1917642852.109581, 2365076174.3153973], + ); + + test_samples( + Duration::new(2, 0), + Duration::new(4, 0), + &[ + Duration::new(2, 532615131), + Duration::new(3, 638826742), + Duration::new(3, 485707508), + ], + &[ + Duration::new(3, 117337521), + Duration::new(3, 191764285), + Duration::new(3, 236507617), + ], + ); + } + + #[test] + fn uniform_distributions_can_be_compared() { + assert_eq!( + Uniform::new(1.0, 2.0).unwrap(), + Uniform::new(1.0, 2.0).unwrap() + ); + + // To cover UniformInt + assert_eq!( + Uniform::new(1_u32, 2_u32).unwrap(), + Uniform::new(1_u32, 2_u32).unwrap() + ); + } +} diff --git a/src/distr/uniform_float.rs b/src/distr/uniform_float.rs new file mode 100644 index 00000000000..adcc7b710d6 --- /dev/null +++ b/src/distr/uniform_float.rs @@ -0,0 +1,453 @@ +// Copyright 2018-2020 Developers of the Rand project. +// Copyright 2017 The Rust Project Developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! `UniformFloat` implementation + +use super::{Error, SampleBorrow, SampleUniform, UniformSampler}; +use crate::distr::float::IntoFloat; +use crate::distr::utils::{BoolAsSIMD, FloatAsSIMD, FloatSIMDUtils, IntAsSIMD}; +use crate::Rng; + +#[cfg(feature = "simd_support")] +use core::simd::prelude::*; +// #[cfg(feature = "simd_support")] +// use core::simd::{LaneCount, SupportedLaneCount}; + +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +/// The back-end implementing [`UniformSampler`] for floating-point types. +/// +/// Unless you are implementing [`UniformSampler`] for your own type, this type +/// should not be used directly, use [`Uniform`] instead. +/// +/// # Implementation notes +/// +/// `UniformFloat` implementations convert RNG output to a float in the range +/// `[1, 2)` via transmutation, map this to `[0, 1)`, then scale and translate +/// to the desired range. Values produced this way have what equals 23 bits of +/// random digits for an `f32` and 52 for an `f64`. +/// +/// # Bias and range errors +/// +/// Bias may be expected within the least-significant bit of the significand. +/// It is not guaranteed that exclusive limits of a range are respected; i.e. +/// when sampling the range `[a, b)` it is not guaranteed that `b` is never +/// sampled. +/// +/// [`new`]: UniformSampler::new +/// [`new_inclusive`]: UniformSampler::new_inclusive +/// [`StandardUniform`]: crate::distr::StandardUniform +/// [`Uniform`]: super::Uniform +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct UniformFloat { + low: X, + scale: X, +} + +macro_rules! uniform_float_impl { + ($($meta:meta)?, $ty:ty, $uty:ident, $f_scalar:ident, $u_scalar:ident, $bits_to_discard:expr) => { + $(#[cfg($meta)])? + impl UniformFloat<$ty> { + /// Construct, reducing `scale` as required to ensure that rounding + /// can never yield values greater than `high`. + /// + /// Note: though it may be tempting to use a variant of this method + /// to ensure that samples from `[low, high)` are always strictly + /// less than `high`, this approach may be very slow where + /// `scale.abs()` is much smaller than `high.abs()` + /// (example: `low=0.99999999997819644, high=1.`). + fn new_bounded(low: $ty, high: $ty, mut scale: $ty) -> Self { + let max_rand = <$ty>::splat(1.0 as $f_scalar - $f_scalar::EPSILON); + + loop { + let mask = (scale * max_rand + low).gt_mask(high); + if !mask.any() { + break; + } + scale = scale.decrease_masked(mask); + } + + debug_assert!(<$ty>::splat(0.0).all_le(scale)); + + UniformFloat { low, scale } + } + } + + $(#[cfg($meta)])? + impl SampleUniform for $ty { + type Sampler = UniformFloat<$ty>; + } + + $(#[cfg($meta)])? + impl UniformSampler for UniformFloat<$ty> { + type X = $ty; + + fn new(low_b: B1, high_b: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + #[cfg(debug_assertions)] + if !(low.all_finite()) || !(high.all_finite()) { + return Err(Error::NonFinite); + } + if !(low.all_lt(high)) { + return Err(Error::EmptyRange); + } + + let scale = high - low; + if !(scale.all_finite()) { + return Err(Error::NonFinite); + } + + Ok(Self::new_bounded(low, high, scale)) + } + + fn new_inclusive(low_b: B1, high_b: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + #[cfg(debug_assertions)] + if !(low.all_finite()) || !(high.all_finite()) { + return Err(Error::NonFinite); + } + if !low.all_le(high) { + return Err(Error::EmptyRange); + } + + let max_rand = <$ty>::splat(1.0 as $f_scalar - $f_scalar::EPSILON); + let scale = (high - low) / max_rand; + if !scale.all_finite() { + return Err(Error::NonFinite); + } + + Ok(Self::new_bounded(low, high, scale)) + } + + fn sample(&self, rng: &mut R) -> Self::X { + // Generate a value in the range [1, 2) + let value1_2 = (rng.random::<$uty>() >> $uty::splat($bits_to_discard)).into_float_with_exponent(0); + + // Get a value in the range [0, 1) to avoid overflow when multiplying by scale + let value0_1 = value1_2 - <$ty>::splat(1.0); + + // We don't use `f64::mul_add`, because it is not available with + // `no_std`. Furthermore, it is slower for some targets (but + // faster for others). However, the order of multiplication and + // addition is important, because on some platforms (e.g. ARM) + // it will be optimized to a single (non-FMA) instruction. + value0_1 * self.scale + self.low + } + + #[inline] + fn sample_single(low_b: B1, high_b: B2, rng: &mut R) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + Self::sample_single_inclusive(low_b, high_b, rng) + } + + #[inline] + fn sample_single_inclusive(low_b: B1, high_b: B2, rng: &mut R) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + #[cfg(debug_assertions)] + if !low.all_finite() || !high.all_finite() { + return Err(Error::NonFinite); + } + if !low.all_le(high) { + return Err(Error::EmptyRange); + } + let scale = high - low; + if !scale.all_finite() { + return Err(Error::NonFinite); + } + + // Generate a value in the range [1, 2) + let value1_2 = + (rng.random::<$uty>() >> $uty::splat($bits_to_discard)).into_float_with_exponent(0); + + // Get a value in the range [0, 1) to avoid overflow when multiplying by scale + let value0_1 = value1_2 - <$ty>::splat(1.0); + + // Doing multiply before addition allows some architectures + // to use a single instruction. + Ok(value0_1 * scale + low) + } + } + }; +} + +uniform_float_impl! { , f32, u32, f32, u32, 32 - 23 } +uniform_float_impl! { , f64, u64, f64, u64, 64 - 52 } + +#[cfg(feature = "simd_support")] +uniform_float_impl! { feature = "simd_support", f32x2, u32x2, f32, u32, 32 - 23 } +#[cfg(feature = "simd_support")] +uniform_float_impl! { feature = "simd_support", f32x4, u32x4, f32, u32, 32 - 23 } +#[cfg(feature = "simd_support")] +uniform_float_impl! { feature = "simd_support", f32x8, u32x8, f32, u32, 32 - 23 } +#[cfg(feature = "simd_support")] +uniform_float_impl! { feature = "simd_support", f32x16, u32x16, f32, u32, 32 - 23 } + +#[cfg(feature = "simd_support")] +uniform_float_impl! { feature = "simd_support", f64x2, u64x2, f64, u64, 64 - 52 } +#[cfg(feature = "simd_support")] +uniform_float_impl! { feature = "simd_support", f64x4, u64x4, f64, u64, 64 - 52 } +#[cfg(feature = "simd_support")] +uniform_float_impl! { feature = "simd_support", f64x8, u64x8, f64, u64, 64 - 52 } + +#[cfg(test)] +mod tests { + use super::*; + use crate::distr::{utils::FloatSIMDScalarUtils, Uniform}; + use crate::rngs::mock::StepRng; + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_floats() { + let mut rng = crate::test::rng(252); + let mut zero_rng = StepRng::new(0, 0); + let mut max_rng = StepRng::new(0xffff_ffff_ffff_ffff, 0); + macro_rules! t { + ($ty:ty, $f_scalar:ident, $bits_shifted:expr) => {{ + let v: &[($f_scalar, $f_scalar)] = &[ + (0.0, 100.0), + (-1e35, -1e25), + (1e-35, 1e-25), + (-1e35, 1e35), + (<$f_scalar>::from_bits(0), <$f_scalar>::from_bits(3)), + (-<$f_scalar>::from_bits(10), -<$f_scalar>::from_bits(1)), + (-<$f_scalar>::from_bits(5), 0.0), + (-<$f_scalar>::from_bits(7), -0.0), + (0.1 * $f_scalar::MAX, $f_scalar::MAX), + (-$f_scalar::MAX * 0.2, $f_scalar::MAX * 0.7), + ]; + for &(low_scalar, high_scalar) in v.iter() { + for lane in 0..<$ty>::LEN { + let low = <$ty>::splat(0.0 as $f_scalar).replace(lane, low_scalar); + let high = <$ty>::splat(1.0 as $f_scalar).replace(lane, high_scalar); + let my_uniform = Uniform::new(low, high).unwrap(); + let my_incl_uniform = Uniform::new_inclusive(low, high).unwrap(); + for _ in 0..100 { + let v = rng.sample(my_uniform).extract(lane); + assert!(low_scalar <= v && v <= high_scalar); + let v = rng.sample(my_incl_uniform).extract(lane); + assert!(low_scalar <= v && v <= high_scalar); + let v = + <$ty as SampleUniform>::Sampler::sample_single(low, high, &mut rng) + .unwrap() + .extract(lane); + assert!(low_scalar <= v && v <= high_scalar); + let v = <$ty as SampleUniform>::Sampler::sample_single_inclusive( + low, high, &mut rng, + ) + .unwrap() + .extract(lane); + assert!(low_scalar <= v && v <= high_scalar); + } + + assert_eq!( + rng.sample(Uniform::new_inclusive(low, low).unwrap()) + .extract(lane), + low_scalar + ); + + assert_eq!(zero_rng.sample(my_uniform).extract(lane), low_scalar); + assert_eq!(zero_rng.sample(my_incl_uniform).extract(lane), low_scalar); + assert_eq!( + <$ty as SampleUniform>::Sampler::sample_single( + low, + high, + &mut zero_rng + ) + .unwrap() + .extract(lane), + low_scalar + ); + assert_eq!( + <$ty as SampleUniform>::Sampler::sample_single_inclusive( + low, + high, + &mut zero_rng + ) + .unwrap() + .extract(lane), + low_scalar + ); + + assert!(max_rng.sample(my_uniform).extract(lane) <= high_scalar); + assert!(max_rng.sample(my_incl_uniform).extract(lane) <= high_scalar); + // sample_single cannot cope with max_rng: + // assert!(<$ty as SampleUniform>::Sampler + // ::sample_single(low, high, &mut max_rng).unwrap() + // .extract(lane) <= high_scalar); + assert!( + <$ty as SampleUniform>::Sampler::sample_single_inclusive( + low, + high, + &mut max_rng + ) + .unwrap() + .extract(lane) + <= high_scalar + ); + + // Don't run this test for really tiny differences between high and low + // since for those rounding might result in selecting high for a very + // long time. + if (high_scalar - low_scalar) > 0.0001 { + let mut lowering_max_rng = StepRng::new( + 0xffff_ffff_ffff_ffff, + (-1i64 << $bits_shifted) as u64, + ); + assert!( + <$ty as SampleUniform>::Sampler::sample_single( + low, + high, + &mut lowering_max_rng + ) + .unwrap() + .extract(lane) + <= high_scalar + ); + } + } + } + + assert_eq!( + rng.sample(Uniform::new_inclusive($f_scalar::MAX, $f_scalar::MAX).unwrap()), + $f_scalar::MAX + ); + assert_eq!( + rng.sample(Uniform::new_inclusive(-$f_scalar::MAX, -$f_scalar::MAX).unwrap()), + -$f_scalar::MAX + ); + }}; + } + + t!(f32, f32, 32 - 23); + t!(f64, f64, 64 - 52); + #[cfg(feature = "simd_support")] + { + t!(f32x2, f32, 32 - 23); + t!(f32x4, f32, 32 - 23); + t!(f32x8, f32, 32 - 23); + t!(f32x16, f32, 32 - 23); + t!(f64x2, f64, 64 - 52); + t!(f64x4, f64, 64 - 52); + t!(f64x8, f64, 64 - 52); + } + } + + #[test] + fn test_float_overflow() { + assert_eq!(Uniform::try_from(f64::MIN..f64::MAX), Err(Error::NonFinite)); + } + + #[test] + #[should_panic] + fn test_float_overflow_single() { + let mut rng = crate::test::rng(252); + rng.random_range(f64::MIN..f64::MAX); + } + + #[test] + #[cfg(all(feature = "std", panic = "unwind"))] + fn test_float_assertions() { + use super::SampleUniform; + fn range(low: T, high: T) -> Result { + let mut rng = crate::test::rng(253); + T::Sampler::sample_single(low, high, &mut rng) + } + + macro_rules! t { + ($ty:ident, $f_scalar:ident) => {{ + let v: &[($f_scalar, $f_scalar)] = &[ + ($f_scalar::NAN, 0.0), + (1.0, $f_scalar::NAN), + ($f_scalar::NAN, $f_scalar::NAN), + (1.0, 0.5), + ($f_scalar::MAX, -$f_scalar::MAX), + ($f_scalar::INFINITY, $f_scalar::INFINITY), + ($f_scalar::NEG_INFINITY, $f_scalar::NEG_INFINITY), + ($f_scalar::NEG_INFINITY, 5.0), + (5.0, $f_scalar::INFINITY), + ($f_scalar::NAN, $f_scalar::INFINITY), + ($f_scalar::NEG_INFINITY, $f_scalar::NAN), + ($f_scalar::NEG_INFINITY, $f_scalar::INFINITY), + ]; + for &(low_scalar, high_scalar) in v.iter() { + for lane in 0..<$ty>::LEN { + let low = <$ty>::splat(0.0 as $f_scalar).replace(lane, low_scalar); + let high = <$ty>::splat(1.0 as $f_scalar).replace(lane, high_scalar); + assert!(range(low, high).is_err()); + assert!(Uniform::new(low, high).is_err()); + assert!(Uniform::new_inclusive(low, high).is_err()); + assert!(Uniform::new(low, low).is_err()); + } + } + }}; + } + + t!(f32, f32); + t!(f64, f64); + #[cfg(feature = "simd_support")] + { + t!(f32x2, f32); + t!(f32x4, f32); + t!(f32x8, f32); + t!(f32x16, f32); + t!(f64x2, f64); + t!(f64x4, f64); + t!(f64x8, f64); + } + } + + #[test] + fn test_uniform_from_std_range() { + let r = Uniform::try_from(2.0f64..7.0).unwrap(); + assert_eq!(r.0.low, 2.0); + assert_eq!(r.0.scale, 5.0); + } + + #[test] + fn test_uniform_from_std_range_bad_limits() { + #![allow(clippy::reversed_empty_ranges)] + assert!(Uniform::try_from(100.0..10.0).is_err()); + assert!(Uniform::try_from(100.0..100.0).is_err()); + } + + #[test] + fn test_uniform_from_std_range_inclusive() { + let r = Uniform::try_from(2.0f64..=7.0).unwrap(); + assert_eq!(r.0.low, 2.0); + assert!(r.0.scale > 5.0); + assert!(r.0.scale < 5.0 + 1e-14); + } + + #[test] + fn test_uniform_from_std_range_inclusive_bad_limits() { + #![allow(clippy::reversed_empty_ranges)] + assert!(Uniform::try_from(100.0..=10.0).is_err()); + assert!(Uniform::try_from(100.0..=99.0).is_err()); + } +} diff --git a/src/distr/uniform_int.rs b/src/distr/uniform_int.rs new file mode 100644 index 00000000000..d5c56b02a0b --- /dev/null +++ b/src/distr/uniform_int.rs @@ -0,0 +1,796 @@ +// Copyright 2018-2020 Developers of the Rand project. +// Copyright 2017 The Rust Project Developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! `UniformInt` implementation + +use super::{Error, SampleBorrow, SampleUniform, UniformSampler}; +use crate::distr::utils::WideningMultiply; +#[cfg(feature = "simd_support")] +use crate::distr::{Distribution, StandardUniform}; +use crate::Rng; + +#[cfg(feature = "simd_support")] +use core::simd::prelude::*; +#[cfg(feature = "simd_support")] +use core::simd::{LaneCount, SupportedLaneCount}; + +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +/// The back-end implementing [`UniformSampler`] for integer types. +/// +/// Unless you are implementing [`UniformSampler`] for your own type, this type +/// should not be used directly, use [`Uniform`] instead. +/// +/// # Implementation notes +/// +/// For simplicity, we use the same generic struct `UniformInt` for all +/// integer types `X`. This gives us only one field type, `X`; to store unsigned +/// values of this size, we take use the fact that these conversions are no-ops. +/// +/// For a closed range, the number of possible numbers we should generate is +/// `range = (high - low + 1)`. To avoid bias, we must ensure that the size of +/// our sample space, `zone`, is a multiple of `range`; other values must be +/// rejected (by replacing with a new random sample). +/// +/// As a special case, we use `range = 0` to represent the full range of the +/// result type (i.e. for `new_inclusive($ty::MIN, $ty::MAX)`). +/// +/// The optimum `zone` is the largest product of `range` which fits in our +/// (unsigned) target type. We calculate this by calculating how many numbers we +/// must reject: `reject = (MAX + 1) % range = (MAX - range + 1) % range`. Any (large) +/// product of `range` will suffice, thus in `sample_single` we multiply by a +/// power of 2 via bit-shifting (faster but may cause more rejections). +/// +/// The smallest integer PRNGs generate is `u32`. For 8- and 16-bit outputs we +/// use `u32` for our `zone` and samples (because it's not slower and because +/// it reduces the chance of having to reject a sample). In this case we cannot +/// store `zone` in the target type since it is too large, however we know +/// `ints_to_reject < range <= $uty::MAX`. +/// +/// An alternative to using a modulus is widening multiply: After a widening +/// multiply by `range`, the result is in the high word. Then comparing the low +/// word against `zone` makes sure our distribution is uniform. +/// +/// # Bias +/// +/// Unless the `unbiased` feature flag is used, outputs may have a small bias. +/// In the worst case, bias affects 1 in `2^n` samples where n is +/// 56 (`i8` and `u8`), 48 (`i16` and `u16`), 96 (`i32` and `u32`), 64 (`i64` +/// and `u64`), 128 (`i128` and `u128`). +/// +/// [`Uniform`]: super::Uniform +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct UniformInt { + pub(super) low: X, + pub(super) range: X, + thresh: X, // effectively 2.pow(max(64, uty_bits)) % range +} + +macro_rules! uniform_int_impl { + ($ty:ty, $uty:ty, $sample_ty:ident) => { + impl SampleUniform for $ty { + type Sampler = UniformInt<$ty>; + } + + impl UniformSampler for UniformInt<$ty> { + // We play free and fast with unsigned vs signed here + // (when $ty is signed), but that's fine, since the + // contract of this macro is for $ty and $uty to be + // "bit-equal", so casting between them is a no-op. + + type X = $ty; + + #[inline] // if the range is constant, this helps LLVM to do the + // calculations at compile-time. + fn new(low_b: B1, high_b: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low < high) { + return Err(Error::EmptyRange); + } + UniformSampler::new_inclusive(low, high - 1) + } + + #[inline] // if the range is constant, this helps LLVM to do the + // calculations at compile-time. + fn new_inclusive(low_b: B1, high_b: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low <= high) { + return Err(Error::EmptyRange); + } + + let range = high.wrapping_sub(low).wrapping_add(1) as $uty; + let thresh = if range > 0 { + let range = $sample_ty::from(range); + (range.wrapping_neg() % range) + } else { + 0 + }; + + Ok(UniformInt { + low, + range: range as $ty, // type: $uty + thresh: thresh as $uty as $ty, // type: $sample_ty + }) + } + + /// Sample from distribution, Lemire's method, unbiased + #[inline] + fn sample(&self, rng: &mut R) -> Self::X { + let range = self.range as $uty as $sample_ty; + if range == 0 { + return rng.random(); + } + + let thresh = self.thresh as $uty as $sample_ty; + let hi = loop { + let (hi, lo) = rng.random::<$sample_ty>().wmul(range); + if lo >= thresh { + break hi; + } + }; + self.low.wrapping_add(hi as $ty) + } + + #[inline] + fn sample_single( + low_b: B1, + high_b: B2, + rng: &mut R, + ) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low < high) { + return Err(Error::EmptyRange); + } + Self::sample_single_inclusive(low, high - 1, rng) + } + + /// Sample single value, Canon's method, biased + /// + /// In the worst case, bias affects 1 in `2^n` samples where n is + /// 56 (`i8`), 48 (`i16`), 96 (`i32`), 64 (`i64`), 128 (`i128`). + #[cfg(not(feature = "unbiased"))] + #[inline] + fn sample_single_inclusive( + low_b: B1, + high_b: B2, + rng: &mut R, + ) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low <= high) { + return Err(Error::EmptyRange); + } + let range = high.wrapping_sub(low).wrapping_add(1) as $uty as $sample_ty; + if range == 0 { + // Range is MAX+1 (unrepresentable), so we need a special case + return Ok(rng.random()); + } + + // generate a sample using a sensible integer type + let (mut result, lo_order) = rng.random::<$sample_ty>().wmul(range); + + // if the sample is biased... + if lo_order > range.wrapping_neg() { + // ...generate a new sample to reduce bias... + let (new_hi_order, _) = (rng.random::<$sample_ty>()).wmul(range as $sample_ty); + // ... incrementing result on overflow + let is_overflow = lo_order.checked_add(new_hi_order as $sample_ty).is_none(); + result += is_overflow as $sample_ty; + } + + Ok(low.wrapping_add(result as $ty)) + } + + /// Sample single value, Canon's method, unbiased + #[cfg(feature = "unbiased")] + #[inline] + fn sample_single_inclusive( + low_b: B1, + high_b: B2, + rng: &mut R, + ) -> Result + where + B1: SampleBorrow<$ty> + Sized, + B2: SampleBorrow<$ty> + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low <= high) { + return Err(Error::EmptyRange); + } + let range = high.wrapping_sub(low).wrapping_add(1) as $uty as $sample_ty; + if range == 0 { + // Range is MAX+1 (unrepresentable), so we need a special case + return Ok(rng.random()); + } + + let (mut result, mut lo) = rng.random::<$sample_ty>().wmul(range); + + // In contrast to the biased sampler, we use a loop: + while lo > range.wrapping_neg() { + let (new_hi, new_lo) = (rng.random::<$sample_ty>()).wmul(range); + match lo.checked_add(new_hi) { + Some(x) if x < $sample_ty::MAX => { + // Anything less than MAX: last term is 0 + break; + } + None => { + // Overflow: last term is 1 + result += 1; + break; + } + _ => { + // Unlikely case: must check next sample + lo = new_lo; + continue; + } + } + } + + Ok(low.wrapping_add(result as $ty)) + } + } + }; +} + +uniform_int_impl! { i8, u8, u32 } +uniform_int_impl! { i16, u16, u32 } +uniform_int_impl! { i32, u32, u32 } +uniform_int_impl! { i64, u64, u64 } +uniform_int_impl! { i128, u128, u128 } +uniform_int_impl! { u8, u8, u32 } +uniform_int_impl! { u16, u16, u32 } +uniform_int_impl! { u32, u32, u32 } +uniform_int_impl! { u64, u64, u64 } +uniform_int_impl! { u128, u128, u128 } + +#[cfg(feature = "simd_support")] +macro_rules! uniform_simd_int_impl { + ($ty:ident, $unsigned:ident) => { + // The "pick the largest zone that can fit in an `u32`" optimization + // is less useful here. Multiple lanes complicate things, we don't + // know the PRNG's minimal output size, and casting to a larger vector + // is generally a bad idea for SIMD performance. The user can still + // implement it manually. + + #[cfg(feature = "simd_support")] + impl SampleUniform for Simd<$ty, LANES> + where + LaneCount: SupportedLaneCount, + Simd<$unsigned, LANES>: + WideningMultiply, Simd<$unsigned, LANES>)>, + StandardUniform: Distribution>, + { + type Sampler = UniformInt>; + } + + #[cfg(feature = "simd_support")] + impl UniformSampler for UniformInt> + where + LaneCount: SupportedLaneCount, + Simd<$unsigned, LANES>: + WideningMultiply, Simd<$unsigned, LANES>)>, + StandardUniform: Distribution>, + { + type X = Simd<$ty, LANES>; + + #[inline] // if the range is constant, this helps LLVM to do the + // calculations at compile-time. + fn new(low_b: B1, high_b: B2) -> Result + where B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low.simd_lt(high).all()) { + return Err(Error::EmptyRange); + } + UniformSampler::new_inclusive(low, high - Simd::splat(1)) + } + + #[inline] // if the range is constant, this helps LLVM to do the + // calculations at compile-time. + fn new_inclusive(low_b: B1, high_b: B2) -> Result + where B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low.simd_le(high).all()) { + return Err(Error::EmptyRange); + } + + // NOTE: all `Simd` operations are inherently wrapping, + // see https://doc.rust-lang.org/std/simd/struct.Simd.html + let range: Simd<$unsigned, LANES> = ((high - low) + Simd::splat(1)).cast(); + + // We must avoid divide-by-zero by using 0 % 1 == 0. + let not_full_range = range.simd_gt(Simd::splat(0)); + let modulo = not_full_range.select(range, Simd::splat(1)); + let ints_to_reject = range.wrapping_neg() % modulo; + + Ok(UniformInt { + low, + // These are really $unsigned values, but store as $ty: + range: range.cast(), + thresh: ints_to_reject.cast(), + }) + } + + fn sample(&self, rng: &mut R) -> Self::X { + let range: Simd<$unsigned, LANES> = self.range.cast(); + let thresh: Simd<$unsigned, LANES> = self.thresh.cast(); + + // This might seem very slow, generating a whole new + // SIMD vector for every sample rejection. For most uses + // though, the chance of rejection is small and provides good + // general performance. With multiple lanes, that chance is + // multiplied. To mitigate this, we replace only the lanes of + // the vector which fail, iteratively reducing the chance of + // rejection. The replacement method does however add a little + // overhead. Benchmarking or calculating probabilities might + // reveal contexts where this replacement method is slower. + let mut v: Simd<$unsigned, LANES> = rng.random(); + loop { + let (hi, lo) = v.wmul(range); + let mask = lo.simd_ge(thresh); + if mask.all() { + let hi: Simd<$ty, LANES> = hi.cast(); + // wrapping addition + let result = self.low + hi; + // `select` here compiles to a blend operation + // When `range.eq(0).none()` the compare and blend + // operations are avoided. + let v: Simd<$ty, LANES> = v.cast(); + return range.simd_gt(Simd::splat(0)).select(result, v); + } + // Replace only the failing lanes + v = mask.select(v, rng.random()); + } + } + } + }; + + // bulk implementation + ($(($unsigned:ident, $signed:ident)),+) => { + $( + uniform_simd_int_impl!($unsigned, $unsigned); + uniform_simd_int_impl!($signed, $unsigned); + )+ + }; +} + +#[cfg(feature = "simd_support")] +uniform_simd_int_impl! { (u8, i8), (u16, i16), (u32, i32), (u64, i64) } + +/// The back-end implementing [`UniformSampler`] for `usize`. +/// +/// # Implementation notes +/// +/// Sampling a `usize` value is usually used in relation to the length of an +/// array or other memory structure, thus it is reasonable to assume that the +/// vast majority of use-cases will have a maximum size under [`u32::MAX`]. +/// In part to optimise for this use-case, but mostly to ensure that results +/// are portable across 32-bit and 64-bit architectures (as far as is possible), +/// this implementation will use 32-bit sampling when possible. +#[cfg(any(target_pointer_width = "32", target_pointer_width = "64"))] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct UniformUsize { + low: usize, + range: usize, + thresh: usize, + #[cfg(target_pointer_width = "64")] + mode64: bool, +} + +#[cfg(any(target_pointer_width = "32", target_pointer_width = "64"))] +impl SampleUniform for usize { + type Sampler = UniformUsize; +} + +#[cfg(any(target_pointer_width = "32", target_pointer_width = "64"))] +impl UniformSampler for UniformUsize { + type X = usize; + + #[inline] // if the range is constant, this helps LLVM to do the + // calculations at compile-time. + fn new(low_b: B1, high_b: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low < high) { + return Err(Error::EmptyRange); + } + + UniformSampler::new_inclusive(low, high - 1) + } + + #[inline] // if the range is constant, this helps LLVM to do the + // calculations at compile-time. + fn new_inclusive(low_b: B1, high_b: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low <= high) { + return Err(Error::EmptyRange); + } + + #[cfg(target_pointer_width = "64")] + let mode64 = high > (u32::MAX as usize); + #[cfg(target_pointer_width = "32")] + let mode64 = false; + + let (range, thresh); + if cfg!(target_pointer_width = "64") && !mode64 { + let range32 = (high as u32).wrapping_sub(low as u32).wrapping_add(1); + range = range32 as usize; + thresh = if range32 > 0 { + (range32.wrapping_neg() % range32) as usize + } else { + 0 + }; + } else { + range = high.wrapping_sub(low).wrapping_add(1); + thresh = if range > 0 { + range.wrapping_neg() % range + } else { + 0 + }; + } + + Ok(UniformUsize { + low, + range, + thresh, + #[cfg(target_pointer_width = "64")] + mode64, + }) + } + + #[inline] + fn sample(&self, rng: &mut R) -> usize { + #[cfg(target_pointer_width = "32")] + let mode32 = true; + #[cfg(target_pointer_width = "64")] + let mode32 = !self.mode64; + + if mode32 { + let range = self.range as u32; + if range == 0 { + return rng.random::() as usize; + } + + let thresh = self.thresh as u32; + let hi = loop { + let (hi, lo) = rng.random::().wmul(range); + if lo >= thresh { + break hi; + } + }; + self.low.wrapping_add(hi as usize) + } else { + let range = self.range as u64; + if range == 0 { + return rng.random::() as usize; + } + + let thresh = self.thresh as u64; + let hi = loop { + let (hi, lo) = rng.random::().wmul(range); + if lo >= thresh { + break hi; + } + }; + self.low.wrapping_add(hi as usize) + } + } + + #[inline] + fn sample_single( + low_b: B1, + high_b: B2, + rng: &mut R, + ) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low < high) { + return Err(Error::EmptyRange); + } + + if cfg!(target_pointer_width = "64") && high > (u32::MAX as usize) { + return UniformInt::::sample_single(low as u64, high as u64, rng) + .map(|x| x as usize); + } + + UniformInt::::sample_single(low as u32, high as u32, rng).map(|x| x as usize) + } + + #[inline] + fn sample_single_inclusive( + low_b: B1, + high_b: B2, + rng: &mut R, + ) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low <= high) { + return Err(Error::EmptyRange); + } + + if cfg!(target_pointer_width = "64") && high > (u32::MAX as usize) { + return UniformInt::::sample_single_inclusive(low as u64, high as u64, rng) + .map(|x| x as usize); + } + + UniformInt::::sample_single_inclusive(low as u32, high as u32, rng).map(|x| x as usize) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::distr::{Distribution, Uniform}; + use core::fmt::Debug; + use core::ops::Add; + + #[test] + fn test_uniform_bad_limits_equal_int() { + assert_eq!(Uniform::new(10, 10), Err(Error::EmptyRange)); + } + + #[test] + fn test_uniform_good_limits_equal_int() { + let mut rng = crate::test::rng(804); + let dist = Uniform::new_inclusive(10, 10).unwrap(); + for _ in 0..20 { + assert_eq!(rng.sample(dist), 10); + } + } + + #[test] + fn test_uniform_bad_limits_flipped_int() { + assert_eq!(Uniform::new(10, 5), Err(Error::EmptyRange)); + } + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_integers() { + let mut rng = crate::test::rng(251); + macro_rules! t { + ($ty:ident, $v:expr, $le:expr, $lt:expr) => {{ + for &(low, high) in $v.iter() { + let my_uniform = Uniform::new(low, high).unwrap(); + for _ in 0..1000 { + let v: $ty = rng.sample(my_uniform); + assert!($le(low, v) && $lt(v, high)); + } + + let my_uniform = Uniform::new_inclusive(low, high).unwrap(); + for _ in 0..1000 { + let v: $ty = rng.sample(my_uniform); + assert!($le(low, v) && $le(v, high)); + } + + let my_uniform = Uniform::new(&low, high).unwrap(); + for _ in 0..1000 { + let v: $ty = rng.sample(my_uniform); + assert!($le(low, v) && $lt(v, high)); + } + + let my_uniform = Uniform::new_inclusive(&low, &high).unwrap(); + for _ in 0..1000 { + let v: $ty = rng.sample(my_uniform); + assert!($le(low, v) && $le(v, high)); + } + + for _ in 0..1000 { + let v = <$ty as SampleUniform>::Sampler::sample_single(low, high, &mut rng).unwrap(); + assert!($le(low, v) && $lt(v, high)); + } + + for _ in 0..1000 { + let v = <$ty as SampleUniform>::Sampler::sample_single_inclusive(low, high, &mut rng).unwrap(); + assert!($le(low, v) && $le(v, high)); + } + } + }}; + + // scalar bulk + ($($ty:ident),*) => {{ + $(t!( + $ty, + [(0, 10), (10, 127), ($ty::MIN, $ty::MAX)], + |x, y| x <= y, + |x, y| x < y + );)* + }}; + + // simd bulk + ($($ty:ident),* => $scalar:ident) => {{ + $(t!( + $ty, + [ + ($ty::splat(0), $ty::splat(10)), + ($ty::splat(10), $ty::splat(127)), + ($ty::splat($scalar::MIN), $ty::splat($scalar::MAX)), + ], + |x: $ty, y| x.simd_le(y).all(), + |x: $ty, y| x.simd_lt(y).all() + );)* + }}; + } + t!(i8, i16, i32, i64, i128, u8, u16, u32, u64, usize, u128); + + #[cfg(feature = "simd_support")] + { + t!(u8x4, u8x8, u8x16, u8x32, u8x64 => u8); + t!(i8x4, i8x8, i8x16, i8x32, i8x64 => i8); + t!(u16x2, u16x4, u16x8, u16x16, u16x32 => u16); + t!(i16x2, i16x4, i16x8, i16x16, i16x32 => i16); + t!(u32x2, u32x4, u32x8, u32x16 => u32); + t!(i32x2, i32x4, i32x8, i32x16 => i32); + t!(u64x2, u64x4, u64x8 => u64); + t!(i64x2, i64x4, i64x8 => i64); + } + } + + #[test] + fn test_uniform_from_std_range() { + let r = Uniform::try_from(2u32..7).unwrap(); + assert_eq!(r.0.low, 2); + assert_eq!(r.0.range, 5); + } + + #[test] + fn test_uniform_from_std_range_bad_limits() { + #![allow(clippy::reversed_empty_ranges)] + assert!(Uniform::try_from(100..10).is_err()); + assert!(Uniform::try_from(100..100).is_err()); + } + + #[test] + fn test_uniform_from_std_range_inclusive() { + let r = Uniform::try_from(2u32..=6).unwrap(); + assert_eq!(r.0.low, 2); + assert_eq!(r.0.range, 5); + } + + #[test] + fn test_uniform_from_std_range_inclusive_bad_limits() { + #![allow(clippy::reversed_empty_ranges)] + assert!(Uniform::try_from(100..=10).is_err()); + assert!(Uniform::try_from(100..=99).is_err()); + } + + #[test] + fn value_stability() { + fn test_samples>( + lb: T, + ub: T, + ub_excl: T, + expected: &[T], + ) where + Uniform: Distribution, + { + let mut rng = crate::test::rng(897); + let mut buf = [lb; 6]; + + for x in &mut buf[0..3] { + *x = T::Sampler::sample_single_inclusive(lb, ub, &mut rng).unwrap(); + } + + let distr = Uniform::new_inclusive(lb, ub).unwrap(); + for x in &mut buf[3..6] { + *x = rng.sample(&distr); + } + assert_eq!(&buf, expected); + + let mut rng = crate::test::rng(897); + + for x in &mut buf[0..3] { + *x = T::Sampler::sample_single(lb, ub_excl, &mut rng).unwrap(); + } + + let distr = Uniform::new(lb, ub_excl).unwrap(); + for x in &mut buf[3..6] { + *x = rng.sample(&distr); + } + assert_eq!(&buf, expected); + } + + test_samples(-105i8, 111, 112, &[-99, -48, 107, 72, -19, 56]); + test_samples(2i16, 1352, 1353, &[43, 361, 1325, 1109, 539, 1005]); + test_samples( + -313853i32, + 13513, + 13514, + &[-303803, -226673, 6912, -45605, -183505, -70668], + ); + test_samples( + 131521i64, + 6542165, + 6542166, + &[1838724, 5384489, 4893692, 3712948, 3951509, 4094926], + ); + test_samples( + -0x8000_0000_0000_0000_0000_0000_0000_0000i128, + -1, + 0, + &[ + -30725222750250982319765550926688025855, + -75088619368053423329503924805178012357, + -64950748766625548510467638647674468829, + -41794017901603587121582892414659436495, + -63623852319608406524605295913876414006, + -17404679390297612013597359206379189023, + ], + ); + test_samples(11u8, 218, 219, &[17, 66, 214, 181, 93, 165]); + test_samples(11u16, 218, 219, &[17, 66, 214, 181, 93, 165]); + test_samples(11u32, 218, 219, &[17, 66, 214, 181, 93, 165]); + test_samples(11u64, 218, 219, &[66, 181, 165, 127, 134, 139]); + test_samples(11u128, 218, 219, &[181, 127, 139, 167, 141, 197]); + test_samples(11usize, 218, 219, &[17, 66, 214, 181, 93, 165]); + + #[cfg(feature = "simd_support")] + { + let lb = Simd::from([11u8, 0, 128, 127]); + let ub = Simd::from([218, 254, 254, 254]); + let ub_excl = ub + Simd::splat(1); + test_samples( + lb, + ub, + ub_excl, + &[ + Simd::from([13, 5, 237, 130]), + Simd::from([126, 186, 149, 161]), + Simd::from([103, 86, 234, 252]), + Simd::from([35, 18, 225, 231]), + Simd::from([106, 153, 246, 177]), + Simd::from([195, 168, 149, 222]), + ], + ); + } + } +} diff --git a/src/distr/uniform_other.rs b/src/distr/uniform_other.rs new file mode 100644 index 00000000000..03533debcd8 --- /dev/null +++ b/src/distr/uniform_other.rs @@ -0,0 +1,319 @@ +// Copyright 2018-2020 Developers of the Rand project. +// Copyright 2017 The Rust Project Developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! `UniformChar`, `UniformDuration` implementations + +use super::{Error, SampleBorrow, SampleUniform, Uniform, UniformInt, UniformSampler}; +use crate::distr::Distribution; +use crate::Rng; +use core::time::Duration; + +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +impl SampleUniform for char { + type Sampler = UniformChar; +} + +/// The back-end implementing [`UniformSampler`] for `char`. +/// +/// Unless you are implementing [`UniformSampler`] for your own type, this type +/// should not be used directly, use [`Uniform`] instead. +/// +/// This differs from integer range sampling since the range `0xD800..=0xDFFF` +/// are used for surrogate pairs in UCS and UTF-16, and consequently are not +/// valid Unicode code points. We must therefore avoid sampling values in this +/// range. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct UniformChar { + sampler: UniformInt, +} + +/// UTF-16 surrogate range start +const CHAR_SURROGATE_START: u32 = 0xD800; +/// UTF-16 surrogate range size +const CHAR_SURROGATE_LEN: u32 = 0xE000 - CHAR_SURROGATE_START; + +/// Convert `char` to compressed `u32` +fn char_to_comp_u32(c: char) -> u32 { + match c as u32 { + c if c >= CHAR_SURROGATE_START => c - CHAR_SURROGATE_LEN, + c => c, + } +} + +impl UniformSampler for UniformChar { + type X = char; + + #[inline] // if the range is constant, this helps LLVM to do the + // calculations at compile-time. + fn new(low_b: B1, high_b: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = char_to_comp_u32(*low_b.borrow()); + let high = char_to_comp_u32(*high_b.borrow()); + let sampler = UniformInt::::new(low, high); + sampler.map(|sampler| UniformChar { sampler }) + } + + #[inline] // if the range is constant, this helps LLVM to do the + // calculations at compile-time. + fn new_inclusive(low_b: B1, high_b: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = char_to_comp_u32(*low_b.borrow()); + let high = char_to_comp_u32(*high_b.borrow()); + let sampler = UniformInt::::new_inclusive(low, high); + sampler.map(|sampler| UniformChar { sampler }) + } + + fn sample(&self, rng: &mut R) -> Self::X { + let mut x = self.sampler.sample(rng); + if x >= CHAR_SURROGATE_START { + x += CHAR_SURROGATE_LEN; + } + // SAFETY: x must not be in surrogate range or greater than char::MAX. + // This relies on range constructors which accept char arguments. + // Validity of input char values is assumed. + unsafe { core::char::from_u32_unchecked(x) } + } +} + +#[cfg(feature = "alloc")] +impl crate::distr::SampleString for Uniform { + fn append_string( + &self, + rng: &mut R, + string: &mut alloc::string::String, + len: usize, + ) { + // Getting the hi value to assume the required length to reserve in string. + let mut hi = self.0.sampler.low + self.0.sampler.range - 1; + if hi >= CHAR_SURROGATE_START { + hi += CHAR_SURROGATE_LEN; + } + // Get the utf8 length of hi to minimize extra space. + let max_char_len = char::from_u32(hi).map(char::len_utf8).unwrap_or(4); + string.reserve(max_char_len * len); + string.extend(self.sample_iter(rng).take(len)) + } +} + +/// The back-end implementing [`UniformSampler`] for `Duration`. +/// +/// Unless you are implementing [`UniformSampler`] for your own types, this type +/// should not be used directly, use [`Uniform`] instead. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct UniformDuration { + mode: UniformDurationMode, + offset: u32, +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +enum UniformDurationMode { + Small { + secs: u64, + nanos: Uniform, + }, + Medium { + nanos: Uniform, + }, + Large { + max_secs: u64, + max_nanos: u32, + secs: Uniform, + }, +} + +impl SampleUniform for Duration { + type Sampler = UniformDuration; +} + +impl UniformSampler for UniformDuration { + type X = Duration; + + #[inline] + fn new(low_b: B1, high_b: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low < high) { + return Err(Error::EmptyRange); + } + UniformDuration::new_inclusive(low, high - Duration::new(0, 1)) + } + + #[inline] + fn new_inclusive(low_b: B1, high_b: B2) -> Result + where + B1: SampleBorrow + Sized, + B2: SampleBorrow + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + if !(low <= high) { + return Err(Error::EmptyRange); + } + + let low_s = low.as_secs(); + let low_n = low.subsec_nanos(); + let mut high_s = high.as_secs(); + let mut high_n = high.subsec_nanos(); + + if high_n < low_n { + high_s -= 1; + high_n += 1_000_000_000; + } + + let mode = if low_s == high_s { + UniformDurationMode::Small { + secs: low_s, + nanos: Uniform::new_inclusive(low_n, high_n)?, + } + } else { + let max = high_s + .checked_mul(1_000_000_000) + .and_then(|n| n.checked_add(u64::from(high_n))); + + if let Some(higher_bound) = max { + let lower_bound = low_s * 1_000_000_000 + u64::from(low_n); + UniformDurationMode::Medium { + nanos: Uniform::new_inclusive(lower_bound, higher_bound)?, + } + } else { + // An offset is applied to simplify generation of nanoseconds + let max_nanos = high_n - low_n; + UniformDurationMode::Large { + max_secs: high_s, + max_nanos, + secs: Uniform::new_inclusive(low_s, high_s)?, + } + } + }; + Ok(UniformDuration { + mode, + offset: low_n, + }) + } + + #[inline] + fn sample(&self, rng: &mut R) -> Duration { + match self.mode { + UniformDurationMode::Small { secs, nanos } => { + let n = nanos.sample(rng); + Duration::new(secs, n) + } + UniformDurationMode::Medium { nanos } => { + let nanos = nanos.sample(rng); + Duration::new(nanos / 1_000_000_000, (nanos % 1_000_000_000) as u32) + } + UniformDurationMode::Large { + max_secs, + max_nanos, + secs, + } => { + // constant folding means this is at least as fast as `Rng::sample(Range)` + let nano_range = Uniform::new(0, 1_000_000_000).unwrap(); + loop { + let s = secs.sample(rng); + let n = nano_range.sample(rng); + if !(s == max_secs && n > max_nanos) { + let sum = n + self.offset; + break Duration::new(s, sum); + } + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + #[cfg(feature = "serde")] + fn test_serialization_uniform_duration() { + let distr = UniformDuration::new(Duration::from_secs(10), Duration::from_secs(60)).unwrap(); + let de_distr: UniformDuration = + bincode::deserialize(&bincode::serialize(&distr).unwrap()).unwrap(); + assert_eq!(distr, de_distr); + } + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_char() { + let mut rng = crate::test::rng(891); + let mut max = core::char::from_u32(0).unwrap(); + for _ in 0..100 { + let c = rng.random_range('A'..='Z'); + assert!(c.is_ascii_uppercase()); + max = max.max(c); + } + assert_eq!(max, 'Z'); + let d = Uniform::new( + core::char::from_u32(0xD7F0).unwrap(), + core::char::from_u32(0xE010).unwrap(), + ) + .unwrap(); + for _ in 0..100 { + let c = d.sample(&mut rng); + assert!((c as u32) < 0xD800 || (c as u32) > 0xDFFF); + } + #[cfg(feature = "alloc")] + { + use crate::distr::SampleString; + let string1 = d.sample_string(&mut rng, 100); + assert_eq!(string1.capacity(), 300); + let string2 = Uniform::new( + core::char::from_u32(0x0000).unwrap(), + core::char::from_u32(0x0080).unwrap(), + ) + .unwrap() + .sample_string(&mut rng, 100); + assert_eq!(string2.capacity(), 100); + let string3 = Uniform::new_inclusive( + core::char::from_u32(0x0000).unwrap(), + core::char::from_u32(0x0080).unwrap(), + ) + .unwrap() + .sample_string(&mut rng, 100); + assert_eq!(string3.capacity(), 200); + } + } + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_durations() { + let mut rng = crate::test::rng(253); + + let v = &[ + (Duration::new(10, 50000), Duration::new(100, 1234)), + (Duration::new(0, 100), Duration::new(1, 50)), + (Duration::new(0, 0), Duration::new(u64::MAX, 999_999_999)), + ]; + for &(low, high) in v.iter() { + let my_uniform = Uniform::new(low, high).unwrap(); + for _ in 0..1000 { + let v = rng.sample(my_uniform); + assert!(low <= v && v < high); + } + } + } +} diff --git a/src/distributions/utils.rs b/src/distr/utils.rs similarity index 64% rename from src/distributions/utils.rs rename to src/distr/utils.rs index e3bceb8a96c..b54dc6d6c4e 100644 --- a/src/distributions/utils.rs +++ b/src/distr/utils.rs @@ -8,8 +8,10 @@ //! Math helper functions -#[cfg(feature = "simd_support")] use packed_simd::*; - +#[cfg(feature = "simd_support")] +use core::simd::prelude::*; +#[cfg(feature = "simd_support")] +use core::simd::{LaneCount, SimdElement, SupportedLaneCount}; pub(crate) trait WideningMultiply { type Output; @@ -31,7 +33,7 @@ macro_rules! wmul_impl { }; // simd bulk implementation - ($(($ty:ident, $wide:ident),)+, $shift:expr) => { + ($(($ty:ident, $wide:ty),)+, $shift:expr) => { $( impl WideningMultiply for $ty { type Output = ($ty, $ty); @@ -45,7 +47,7 @@ macro_rules! wmul_impl { let y: $wide = self.cast(); let x: $wide = x.cast(); let tmp = y * x; - let hi: $ty = (tmp >> $shift).cast(); + let hi: $ty = (tmp >> Simd::splat($shift)).cast(); let lo: $ty = tmp.cast(); (hi, lo) } @@ -56,7 +58,6 @@ macro_rules! wmul_impl { wmul_impl! { u8, u16, 8 } wmul_impl! { u16, u32, 16 } wmul_impl! { u32, u64, 32 } -#[cfg(not(target_os = "emscripten"))] wmul_impl! { u64, u128, 64 } // This code is a translation of the __mulddi3 function in LLVM's @@ -100,19 +101,20 @@ macro_rules! wmul_impl_large { #[inline(always)] fn wmul(self, b: $ty) -> Self::Output { // needs wrapping multiplication - const LOWER_MASK: $scalar = !0 >> $half; - let mut low = (self & LOWER_MASK) * (b & LOWER_MASK); - let mut t = low >> $half; - low &= LOWER_MASK; - t += (self >> $half) * (b & LOWER_MASK); - low += (t & LOWER_MASK) << $half; - let mut high = t >> $half; - t = low >> $half; - low &= LOWER_MASK; - t += (b >> $half) * (self & LOWER_MASK); - low += (t & LOWER_MASK) << $half; - high += t >> $half; - high += (self >> $half) * (b >> $half); + let lower_mask = <$ty>::splat(!0 >> $half); + let half = <$ty>::splat($half); + let mut low = (self & lower_mask) * (b & lower_mask); + let mut t = low >> half; + low &= lower_mask; + t += (self >> half) * (b & lower_mask); + low += (t & lower_mask) << half; + let mut high = t >> half; + t = low >> half; + low &= lower_mask; + t += (b >> half) * (self & lower_mask); + low += (t & lower_mask) << half; + high += t >> half; + high += (self >> half) * (b >> half); (high, low) } @@ -120,9 +122,6 @@ macro_rules! wmul_impl_large { )+ }; } -#[cfg(target_os = "emscripten")] -wmul_impl_large! { u64, 32 } -#[cfg(not(target_os = "emscripten"))] wmul_impl_large! { u128, 64 } macro_rules! wmul_impl_usize { @@ -138,6 +137,8 @@ macro_rules! wmul_impl_usize { } }; } +#[cfg(target_pointer_width = "16")] +wmul_impl_usize! { u16 } #[cfg(target_pointer_width = "32")] wmul_impl_usize! { u32 } #[cfg(target_pointer_width = "64")] @@ -146,15 +147,17 @@ wmul_impl_usize! { u64 } #[cfg(feature = "simd_support")] mod simd_wmul { use super::*; - #[cfg(target_arch = "x86")] use core::arch::x86::*; - #[cfg(target_arch = "x86_64")] use core::arch::x86_64::*; + #[cfg(target_arch = "x86")] + use core::arch::x86::*; + #[cfg(target_arch = "x86_64")] + use core::arch::x86_64::*; wmul_impl! { - (u8x2, u16x2), (u8x4, u16x4), (u8x8, u16x8), (u8x16, u16x16), - (u8x32, u16x32),, + (u8x32, u16x32), + (u8x64, Simd),, 8 } @@ -164,21 +167,21 @@ mod simd_wmul { wmul_impl! { (u16x8, u32x8),, 16 } #[cfg(not(target_feature = "avx2"))] wmul_impl! { (u16x16, u32x16),, 16 } + #[cfg(not(target_feature = "avx512bw"))] + wmul_impl! { (u16x32, Simd),, 16 } // 16-bit lane widths allow use of the x86 `mulhi` instructions, which // means `wmul` can be implemented with only two instructions. #[allow(unused_macros)] macro_rules! wmul_impl_16 { - ($ty:ident, $intrinsic:ident, $mulhi:ident, $mullo:ident) => { + ($ty:ident, $mulhi:ident, $mullo:ident) => { impl WideningMultiply for $ty { type Output = ($ty, $ty); #[inline(always)] fn wmul(self, x: $ty) -> Self::Output { - let b = $intrinsic::from_bits(x); - let a = $intrinsic::from_bits(self); - let hi = $ty::from_bits(unsafe { $mulhi(a, b) }); - let lo = $ty::from_bits(unsafe { $mullo(a, b) }); + let hi = unsafe { $mulhi(self.into(), x.into()) }.into(); + let lo = unsafe { $mullo(self.into(), x.into()) }.into(); (hi, lo) } } @@ -186,23 +189,20 @@ mod simd_wmul { } #[cfg(target_feature = "sse2")] - wmul_impl_16! { u16x8, __m128i, _mm_mulhi_epu16, _mm_mullo_epi16 } + wmul_impl_16! { u16x8, _mm_mulhi_epu16, _mm_mullo_epi16 } #[cfg(target_feature = "avx2")] - wmul_impl_16! { u16x16, __m256i, _mm256_mulhi_epu16, _mm256_mullo_epi16 } - // FIXME: there are no `__m512i` types in stdsimd yet, so `wmul::` - // cannot use the same implementation. + wmul_impl_16! { u16x16, _mm256_mulhi_epu16, _mm256_mullo_epi16 } + #[cfg(target_feature = "avx512bw")] + wmul_impl_16! { u16x32, _mm512_mulhi_epu16, _mm512_mullo_epi16 } wmul_impl! { (u32x2, u64x2), (u32x4, u64x4), - (u32x8, u64x8),, + (u32x8, u64x8), + (u32x16, Simd),, 32 } - // TODO: optimize, this seems to seriously slow things down - wmul_impl_large! { (u8x64,) u8, 4 } - wmul_impl_large! { (u16x32,) u16, 8 } - wmul_impl_large! { (u32x16,) u32, 16 } wmul_impl_large! { (u64x2, u64x4, u64x8,) u64, 32 } } @@ -218,9 +218,7 @@ pub(crate) trait FloatSIMDUtils { fn all_finite(self) -> bool; type Mask; - fn finite_mask(self) -> Self::Mask; fn gt_mask(self, other: Self) -> Self::Mask; - fn ge_mask(self, other: Self) -> Self::Mask; // Decrease all lanes where the mask is `true` to the next lower value // representable by the floating-point type. At least one of the lanes @@ -233,40 +231,37 @@ pub(crate) trait FloatSIMDUtils { fn cast_from_int(i: Self::UInt) -> Self; } -/// Implement functions available in std builds but missing from core primitives -#[cfg(not(std))] -pub(crate) trait Float: Sized { - fn is_nan(self) -> bool; - fn is_infinite(self) -> bool; - fn is_finite(self) -> bool; +#[cfg(test)] +pub(crate) trait FloatSIMDScalarUtils: FloatSIMDUtils { + type Scalar; + + fn replace(self, index: usize, new_value: Self::Scalar) -> Self; + fn extract(self, index: usize) -> Self::Scalar; } /// Implement functions on f32/f64 to give them APIs similar to SIMD types pub(crate) trait FloatAsSIMD: Sized { - #[inline(always)] - fn lanes() -> usize { - 1 - } + #[cfg(test)] + const LEN: usize = 1; + #[inline(always)] fn splat(scalar: Self) -> Self { scalar } +} + +pub(crate) trait IntAsSIMD: Sized { #[inline(always)] - fn extract(self, index: usize) -> Self { - debug_assert_eq!(index, 0); - self - } - #[inline(always)] - fn replace(self, index: usize, new_value: Self) -> Self { - debug_assert_eq!(index, 0); - new_value + fn splat(scalar: Self) -> Self { + scalar } } +impl IntAsSIMD for u32 {} +impl IntAsSIMD for u64 {} + pub(crate) trait BoolAsSIMD: Sized { fn any(self) -> bool; - fn all(self) -> bool; - fn none(self) -> bool; } impl BoolAsSIMD for bool { @@ -274,38 +269,10 @@ impl BoolAsSIMD for bool { fn any(self) -> bool { self } - - #[inline(always)] - fn all(self) -> bool { - self - } - - #[inline(always)] - fn none(self) -> bool { - !self - } } macro_rules! scalar_float_impl { ($ty:ident, $uty:ident) => { - #[cfg(not(std))] - impl Float for $ty { - #[inline] - fn is_nan(self) -> bool { - self != self - } - - #[inline] - fn is_infinite(self) -> bool { - self == ::core::$ty::INFINITY || self == ::core::$ty::NEG_INFINITY - } - - #[inline] - fn is_finite(self) -> bool { - !(self.is_nan() || self.is_infinite()) - } - } - impl FloatSIMDUtils for $ty { type Mask = bool; type UInt = $uty; @@ -325,21 +292,11 @@ macro_rules! scalar_float_impl { self.is_finite() } - #[inline(always)] - fn finite_mask(self) -> Self::Mask { - self.is_finite() - } - #[inline(always)] fn gt_mask(self, other: Self) -> Self::Mask { self > other } - #[inline(always)] - fn ge_mask(self, other: Self) -> Self::Mask { - self >= other - } - #[inline(always)] fn decrease_masked(self, mask: Self::Mask) -> Self { debug_assert!(mask, "At least one lane must be set"); @@ -352,6 +309,23 @@ macro_rules! scalar_float_impl { } } + #[cfg(test)] + impl FloatSIMDScalarUtils for $ty { + type Scalar = $ty; + + #[inline] + fn replace(self, index: usize, new_value: Self::Scalar) -> Self { + debug_assert_eq!(index, 0); + new_value + } + + #[inline] + fn extract(self, index: usize) -> Self::Scalar { + debug_assert_eq!(index, 0); + self + } + } + impl FloatAsSIMD for $ty {} }; } @@ -359,45 +333,34 @@ macro_rules! scalar_float_impl { scalar_float_impl!(f32, u32); scalar_float_impl!(f64, u64); - #[cfg(feature = "simd_support")] macro_rules! simd_impl { - ($ty:ident, $f_scalar:ident, $mty:ident, $uty:ident) => { - impl FloatSIMDUtils for $ty { - type Mask = $mty; - type UInt = $uty; + ($fty:ident, $uty:ident) => { + impl FloatSIMDUtils for Simd<$fty, LANES> + where + LaneCount: SupportedLaneCount, + { + type Mask = Mask<<$fty as SimdElement>::Mask, LANES>; + type UInt = Simd<$uty, LANES>; #[inline(always)] fn all_lt(self, other: Self) -> bool { - self.lt(other).all() + self.simd_lt(other).all() } #[inline(always)] fn all_le(self, other: Self) -> bool { - self.le(other).all() + self.simd_le(other).all() } #[inline(always)] fn all_finite(self) -> bool { - self.finite_mask().all() - } - - #[inline(always)] - fn finite_mask(self) -> Self::Mask { - // This can possibly be done faster by checking bit patterns - let neg_inf = $ty::splat(::core::$f_scalar::NEG_INFINITY); - let pos_inf = $ty::splat(::core::$f_scalar::INFINITY); - self.gt(neg_inf) & self.lt(pos_inf) + self.is_finite().all() } #[inline(always)] fn gt_mask(self, other: Self) -> Self::Mask { - self.gt(other) - } - - #[inline(always)] - fn ge_mask(self, other: Self) -> Self::Mask { - self.ge(other) + self.simd_gt(other) } #[inline(always)] @@ -406,10 +369,10 @@ macro_rules! simd_impl { // true, and 0 for false. Adding that to the binary // representation of a float means subtracting one from // the binary representation, resulting in the next lower - // value representable by $ty. This works even when the + // value representable by $fty. This works even when the // current value is infinity. debug_assert!(mask.any(), "At least one lane must be set"); - <$ty>::from_bits(<$uty>::from_bits(self) + <$uty>::from_bits(mask)) + Self::from_bits(self.to_bits() + mask.to_int().cast()) } #[inline] @@ -417,13 +380,29 @@ macro_rules! simd_impl { i.cast() } } + + #[cfg(test)] + impl FloatSIMDScalarUtils for Simd<$fty, LANES> + where + LaneCount: SupportedLaneCount, + { + type Scalar = $fty; + + #[inline] + fn replace(mut self, index: usize, new_value: Self::Scalar) -> Self { + self.as_mut_array()[index] = new_value; + self + } + + #[inline] + fn extract(self, index: usize) -> Self::Scalar { + self.as_array()[index] + } + } }; } -#[cfg(feature="simd_support")] simd_impl! { f32x2, f32, m32x2, u32x2 } -#[cfg(feature="simd_support")] simd_impl! { f32x4, f32, m32x4, u32x4 } -#[cfg(feature="simd_support")] simd_impl! { f32x8, f32, m32x8, u32x8 } -#[cfg(feature="simd_support")] simd_impl! { f32x16, f32, m32x16, u32x16 } -#[cfg(feature="simd_support")] simd_impl! { f64x2, f64, m64x2, u64x2 } -#[cfg(feature="simd_support")] simd_impl! { f64x4, f64, m64x4, u64x4 } -#[cfg(feature="simd_support")] simd_impl! { f64x8, f64, m64x8, u64x8 } +#[cfg(feature = "simd_support")] +simd_impl!(f32, u32); +#[cfg(feature = "simd_support")] +simd_impl!(f64, u64); diff --git a/src/distr/weighted/mod.rs b/src/distr/weighted/mod.rs new file mode 100644 index 00000000000..368c5b0703d --- /dev/null +++ b/src/distr/weighted/mod.rs @@ -0,0 +1,115 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Weighted (index) sampling +//! +//! Primarily, this module houses the [`WeightedIndex`] distribution. +//! See also [`rand_distr::weighted`] for alternative implementations supporting +//! potentially-faster sampling or a more easily modifiable tree structure. +//! +//! [`rand_distr::weighted`]: https://docs.rs/rand_distr/latest/rand_distr/weighted/index.html + +use core::fmt; +mod weighted_index; + +pub use weighted_index::WeightedIndex; + +/// Bounds on a weight +/// +/// See usage in [`WeightedIndex`]. +pub trait Weight: Clone { + /// Representation of 0 + const ZERO: Self; + + /// Checked addition + /// + /// - `Result::Ok`: On success, `v` is added to `self` + /// - `Result::Err`: Returns an error when `Self` cannot represent the + /// result of `self + v` (i.e. overflow). The value of `self` should be + /// discarded. + #[allow(clippy::result_unit_err)] + fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()>; +} + +macro_rules! impl_weight_int { + ($t:ty) => { + impl Weight for $t { + const ZERO: Self = 0; + fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()> { + match self.checked_add(*v) { + Some(sum) => { + *self = sum; + Ok(()) + } + None => Err(()), + } + } + } + }; + ($t:ty, $($tt:ty),*) => { + impl_weight_int!($t); + impl_weight_int!($($tt),*); + } +} +impl_weight_int!(i8, i16, i32, i64, i128, isize); +impl_weight_int!(u8, u16, u32, u64, u128, usize); + +macro_rules! impl_weight_float { + ($t:ty) => { + impl Weight for $t { + const ZERO: Self = 0.0; + + fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()> { + // Floats have an explicit representation for overflow + *self += *v; + Ok(()) + } + } + }; +} +impl_weight_float!(f32); +impl_weight_float!(f64); + +/// Invalid weight errors +/// +/// This type represents errors from [`WeightedIndex::new`], +/// [`WeightedIndex::update_weights`] and other weighted distributions. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +// Marked non_exhaustive to allow a new error code in the solution to #1476. +#[non_exhaustive] +pub enum Error { + /// The input weight sequence is empty, too long, or wrongly ordered + InvalidInput, + + /// A weight is negative, too large for the distribution, or not a valid number + InvalidWeight, + + /// Not enough non-zero weights are available to sample values + /// + /// When attempting to sample a single value this implies that all weights + /// are zero. When attempting to sample `amount` values this implies that + /// less than `amount` weights are greater than zero. + InsufficientNonZero, + + /// Overflow when calculating the sum of weights + Overflow, +} + +#[cfg(feature = "std")] +impl std::error::Error for Error {} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str(match *self { + Error::InvalidInput => "Weights sequence is empty/too long/unordered", + Error::InvalidWeight => "A weight is negative, too large or not a valid number", + Error::InsufficientNonZero => "Not enough weights > zero", + Error::Overflow => "Overflow when summing weights", + }) + } +} diff --git a/src/distr/weighted/weighted_index.rs b/src/distr/weighted/weighted_index.rs new file mode 100644 index 00000000000..4bb9d141fb3 --- /dev/null +++ b/src/distr/weighted/weighted_index.rs @@ -0,0 +1,631 @@ +// Copyright 2018 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use super::{Error, Weight}; +use crate::distr::uniform::{SampleBorrow, SampleUniform, UniformSampler}; +use crate::distr::Distribution; +use crate::Rng; + +// Note that this whole module is only imported if feature="alloc" is enabled. +use alloc::vec::Vec; +use core::fmt::{self, Debug}; + +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +/// A distribution using weighted sampling of discrete items. +/// +/// Sampling a `WeightedIndex` distribution returns the index of a randomly +/// selected element from the iterator used when the `WeightedIndex` was +/// created. The chance of a given element being picked is proportional to the +/// weight of the element. The weights can use any type `X` for which an +/// implementation of [`Uniform`] exists. The implementation guarantees that +/// elements with zero weight are never picked, even when the weights are +/// floating point numbers. +/// +/// # Performance +/// +/// Time complexity of sampling from `WeightedIndex` is `O(log N)` where +/// `N` is the number of weights. +/// See also [`rand_distr::weighted`] for alternative implementations supporting +/// potentially-faster sampling or a more easily modifiable tree structure. +/// +/// A `WeightedIndex` contains a `Vec` and a [`Uniform`] and so its +/// size is the sum of the size of those objects, possibly plus some alignment. +/// +/// Creating a `WeightedIndex` will allocate enough space to hold `N - 1` +/// weights of type `X`, where `N` is the number of weights. However, since +/// `Vec` doesn't guarantee a particular growth strategy, additional memory +/// might be allocated but not used. Since the `WeightedIndex` object also +/// contains an instance of `X::Sampler`, this might cause additional allocations, +/// though for primitive types, [`Uniform`] doesn't allocate any memory. +/// +/// Sampling from `WeightedIndex` will result in a single call to +/// `Uniform::sample` (method of the [`Distribution`] trait), which typically +/// will request a single value from the underlying [`RngCore`], though the +/// exact number depends on the implementation of `Uniform::sample`. +/// +/// # Example +/// +/// ``` +/// use rand::prelude::*; +/// use rand::distr::weighted::WeightedIndex; +/// +/// let choices = ['a', 'b', 'c']; +/// let weights = [2, 1, 1]; +/// let dist = WeightedIndex::new(&weights).unwrap(); +/// let mut rng = rand::rng(); +/// for _ in 0..100 { +/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' +/// println!("{}", choices[dist.sample(&mut rng)]); +/// } +/// +/// let items = [('a', 0.0), ('b', 3.0), ('c', 7.0)]; +/// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1)).unwrap(); +/// for _ in 0..100 { +/// // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c' +/// println!("{}", items[dist2.sample(&mut rng)].0); +/// } +/// ``` +/// +/// [`Uniform`]: crate::distr::Uniform +/// [`RngCore`]: crate::RngCore +/// [`rand_distr::weighted`]: https://docs.rs/rand_distr/latest/rand_distr/weighted/index.html +#[derive(Debug, Clone, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct WeightedIndex { + cumulative_weights: Vec, + total_weight: X, + weight_distribution: X::Sampler, +} + +impl WeightedIndex { + /// Creates a new a `WeightedIndex` [`Distribution`] using the values + /// in `weights`. The weights can use any type `X` for which an + /// implementation of [`Uniform`] exists. + /// + /// Error cases: + /// - [`Error::InvalidInput`] when the iterator `weights` is empty. + /// - [`Error::InvalidWeight`] when a weight is not-a-number or negative. + /// - [`Error::InsufficientNonZero`] when the sum of all weights is zero. + /// - [`Error::Overflow`] when the sum of all weights overflows. + /// + /// [`Uniform`]: crate::distr::uniform::Uniform + pub fn new(weights: I) -> Result, Error> + where + I: IntoIterator, + I::Item: SampleBorrow, + X: Weight, + { + let mut iter = weights.into_iter(); + let mut total_weight: X = iter.next().ok_or(Error::InvalidInput)?.borrow().clone(); + + let zero = X::ZERO; + if !(total_weight >= zero) { + return Err(Error::InvalidWeight); + } + + let mut weights = Vec::::with_capacity(iter.size_hint().0); + for w in iter { + // Note that `!(w >= x)` is not equivalent to `w < x` for partially + // ordered types due to NaNs which are equal to nothing. + if !(w.borrow() >= &zero) { + return Err(Error::InvalidWeight); + } + weights.push(total_weight.clone()); + + if let Err(()) = total_weight.checked_add_assign(w.borrow()) { + return Err(Error::Overflow); + } + } + + if total_weight == zero { + return Err(Error::InsufficientNonZero); + } + let distr = X::Sampler::new(zero, total_weight.clone()).unwrap(); + + Ok(WeightedIndex { + cumulative_weights: weights, + total_weight, + weight_distribution: distr, + }) + } + + /// Update a subset of weights, without changing the number of weights. + /// + /// `new_weights` must be sorted by the index. + /// + /// Using this method instead of `new` might be more efficient if only a small number of + /// weights is modified. No allocations are performed, unless the weight type `X` uses + /// allocation internally. + /// + /// In case of error, `self` is not modified. Error cases: + /// - [`Error::InvalidInput`] when `new_weights` are not ordered by + /// index or an index is too large. + /// - [`Error::InvalidWeight`] when a weight is not-a-number or negative. + /// - [`Error::InsufficientNonZero`] when the sum of all weights is zero. + /// Note that due to floating-point loss of precision, this case is not + /// always correctly detected; usage of a fixed-point weight type may be + /// preferred. + /// + /// Updates take `O(N)` time. If you need to frequently update weights, consider + /// [`rand_distr::weighted_tree`](https://docs.rs/rand_distr/*/rand_distr/weighted_tree/index.html) + /// as an alternative where an update is `O(log N)`. + pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), Error> + where + X: for<'a> core::ops::AddAssign<&'a X> + + for<'a> core::ops::SubAssign<&'a X> + + Clone + + Default, + { + if new_weights.is_empty() { + return Ok(()); + } + + let zero = ::default(); + + let mut total_weight = self.total_weight.clone(); + + // Check for errors first, so we don't modify `self` in case something + // goes wrong. + let mut prev_i = None; + for &(i, w) in new_weights { + if let Some(old_i) = prev_i { + if old_i >= i { + return Err(Error::InvalidInput); + } + } + if !(*w >= zero) { + return Err(Error::InvalidWeight); + } + if i > self.cumulative_weights.len() { + return Err(Error::InvalidInput); + } + + let mut old_w = if i < self.cumulative_weights.len() { + self.cumulative_weights[i].clone() + } else { + self.total_weight.clone() + }; + if i > 0 { + old_w -= &self.cumulative_weights[i - 1]; + } + + total_weight -= &old_w; + total_weight += w; + prev_i = Some(i); + } + if total_weight <= zero { + return Err(Error::InsufficientNonZero); + } + + // Update the weights. Because we checked all the preconditions in the + // previous loop, this should never panic. + let mut iter = new_weights.iter(); + + let mut prev_weight = zero.clone(); + let mut next_new_weight = iter.next(); + let &(first_new_index, _) = next_new_weight.unwrap(); + let mut cumulative_weight = if first_new_index > 0 { + self.cumulative_weights[first_new_index - 1].clone() + } else { + zero.clone() + }; + for i in first_new_index..self.cumulative_weights.len() { + match next_new_weight { + Some(&(j, w)) if i == j => { + cumulative_weight += w; + next_new_weight = iter.next(); + } + _ => { + let mut tmp = self.cumulative_weights[i].clone(); + tmp -= &prev_weight; // We know this is positive. + cumulative_weight += &tmp; + } + } + prev_weight = cumulative_weight.clone(); + core::mem::swap(&mut prev_weight, &mut self.cumulative_weights[i]); + } + + self.total_weight = total_weight; + self.weight_distribution = X::Sampler::new(zero, self.total_weight.clone()).unwrap(); + + Ok(()) + } +} + +/// A lazy-loading iterator over the weights of a `WeightedIndex` distribution. +/// This is returned by [`WeightedIndex::weights`]. +pub struct WeightedIndexIter<'a, X: SampleUniform + PartialOrd> { + weighted_index: &'a WeightedIndex, + index: usize, +} + +impl Debug for WeightedIndexIter<'_, X> +where + X: SampleUniform + PartialOrd + Debug, + X::Sampler: Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("WeightedIndexIter") + .field("weighted_index", &self.weighted_index) + .field("index", &self.index) + .finish() + } +} + +impl Clone for WeightedIndexIter<'_, X> +where + X: SampleUniform + PartialOrd, +{ + fn clone(&self) -> Self { + WeightedIndexIter { + weighted_index: self.weighted_index, + index: self.index, + } + } +} + +impl Iterator for WeightedIndexIter<'_, X> +where + X: for<'b> core::ops::SubAssign<&'b X> + SampleUniform + PartialOrd + Clone, +{ + type Item = X; + + fn next(&mut self) -> Option { + match self.weighted_index.weight(self.index) { + None => None, + Some(weight) => { + self.index += 1; + Some(weight) + } + } + } +} + +impl WeightedIndex { + /// Returns the weight at the given index, if it exists. + /// + /// If the index is out of bounds, this will return `None`. + /// + /// # Example + /// + /// ``` + /// use rand::distr::weighted::WeightedIndex; + /// + /// let weights = [0, 1, 2]; + /// let dist = WeightedIndex::new(&weights).unwrap(); + /// assert_eq!(dist.weight(0), Some(0)); + /// assert_eq!(dist.weight(1), Some(1)); + /// assert_eq!(dist.weight(2), Some(2)); + /// assert_eq!(dist.weight(3), None); + /// ``` + pub fn weight(&self, index: usize) -> Option + where + X: for<'a> core::ops::SubAssign<&'a X>, + { + use core::cmp::Ordering::*; + + let mut weight = match index.cmp(&self.cumulative_weights.len()) { + Less => self.cumulative_weights[index].clone(), + Equal => self.total_weight.clone(), + Greater => return None, + }; + + if index > 0 { + weight -= &self.cumulative_weights[index - 1]; + } + Some(weight) + } + + /// Returns a lazy-loading iterator containing the current weights of this distribution. + /// + /// If this distribution has not been updated since its creation, this will return the + /// same weights as were passed to `new`. + /// + /// # Example + /// + /// ``` + /// use rand::distr::weighted::WeightedIndex; + /// + /// let weights = [1, 2, 3]; + /// let mut dist = WeightedIndex::new(&weights).unwrap(); + /// assert_eq!(dist.weights().collect::>(), vec![1, 2, 3]); + /// dist.update_weights(&[(0, &2)]).unwrap(); + /// assert_eq!(dist.weights().collect::>(), vec![2, 2, 3]); + /// ``` + pub fn weights(&self) -> WeightedIndexIter<'_, X> + where + X: for<'a> core::ops::SubAssign<&'a X>, + { + WeightedIndexIter { + weighted_index: self, + index: 0, + } + } + + /// Returns the sum of all weights in this distribution. + pub fn total_weight(&self) -> X { + self.total_weight.clone() + } +} + +impl Distribution for WeightedIndex +where + X: SampleUniform + PartialOrd, +{ + fn sample(&self, rng: &mut R) -> usize { + let chosen_weight = self.weight_distribution.sample(rng); + // Find the first item which has a weight *higher* than the chosen weight. + self.cumulative_weights + .partition_point(|w| w <= &chosen_weight) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[cfg(feature = "serde")] + #[test] + fn test_weightedindex_serde() { + let weighted_index = WeightedIndex::new([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).unwrap(); + + let ser_weighted_index = bincode::serialize(&weighted_index).unwrap(); + let de_weighted_index: WeightedIndex = + bincode::deserialize(&ser_weighted_index).unwrap(); + + assert_eq!( + de_weighted_index.cumulative_weights, + weighted_index.cumulative_weights + ); + assert_eq!(de_weighted_index.total_weight, weighted_index.total_weight); + } + + #[test] + fn test_accepting_nan() { + assert_eq!( + WeightedIndex::new([f32::NAN, 0.5]).unwrap_err(), + Error::InvalidWeight, + ); + assert_eq!( + WeightedIndex::new([f32::NAN]).unwrap_err(), + Error::InvalidWeight, + ); + assert_eq!( + WeightedIndex::new([0.5, f32::NAN]).unwrap_err(), + Error::InvalidWeight, + ); + + assert_eq!( + WeightedIndex::new([0.5, 7.0]) + .unwrap() + .update_weights(&[(0, &f32::NAN)]) + .unwrap_err(), + Error::InvalidWeight, + ) + } + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_weightedindex() { + let mut r = crate::test::rng(700); + const N_REPS: u32 = 5000; + let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; + let total_weight = weights.iter().sum::() as f32; + + let verify = |result: [i32; 14]| { + for (i, count) in result.iter().enumerate() { + let exp = (weights[i] * N_REPS) as f32 / total_weight; + let mut err = (*count as f32 - exp).abs(); + if err != 0.0 { + err /= exp; + } + assert!(err <= 0.25); + } + }; + + // WeightedIndex from vec + let mut chosen = [0i32; 14]; + let distr = WeightedIndex::new(weights.to_vec()).unwrap(); + for _ in 0..N_REPS { + chosen[distr.sample(&mut r)] += 1; + } + verify(chosen); + + // WeightedIndex from slice + chosen = [0i32; 14]; + let distr = WeightedIndex::new(&weights[..]).unwrap(); + for _ in 0..N_REPS { + chosen[distr.sample(&mut r)] += 1; + } + verify(chosen); + + // WeightedIndex from iterator + chosen = [0i32; 14]; + let distr = WeightedIndex::new(weights.iter()).unwrap(); + for _ in 0..N_REPS { + chosen[distr.sample(&mut r)] += 1; + } + verify(chosen); + + for _ in 0..5 { + assert_eq!(WeightedIndex::new([0, 1]).unwrap().sample(&mut r), 1); + assert_eq!(WeightedIndex::new([1, 0]).unwrap().sample(&mut r), 0); + assert_eq!( + WeightedIndex::new([0, 0, 0, 0, 10, 0]) + .unwrap() + .sample(&mut r), + 4 + ); + } + + assert_eq!( + WeightedIndex::new(&[10][0..0]).unwrap_err(), + Error::InvalidInput + ); + assert_eq!( + WeightedIndex::new([0]).unwrap_err(), + Error::InsufficientNonZero + ); + assert_eq!( + WeightedIndex::new([10, 20, -1, 30]).unwrap_err(), + Error::InvalidWeight + ); + assert_eq!( + WeightedIndex::new([-10, 20, 1, 30]).unwrap_err(), + Error::InvalidWeight + ); + assert_eq!(WeightedIndex::new([-10]).unwrap_err(), Error::InvalidWeight); + } + + #[test] + fn test_update_weights() { + let data = [ + ( + &[10u32, 2, 3, 4][..], + &[(1, &100), (2, &4)][..], // positive change + &[10, 100, 4, 4][..], + ), + ( + &[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..], + &[(2, &1), (5, &1), (13, &100)][..], // negative change and last element + &[1u32, 2, 1, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 100][..], + ), + ]; + + for (weights, update, expected_weights) in data.iter() { + let total_weight = weights.iter().sum::(); + let mut distr = WeightedIndex::new(weights.to_vec()).unwrap(); + assert_eq!(distr.total_weight, total_weight); + + distr.update_weights(update).unwrap(); + let expected_total_weight = expected_weights.iter().sum::(); + let expected_distr = WeightedIndex::new(expected_weights.to_vec()).unwrap(); + assert_eq!(distr.total_weight, expected_total_weight); + assert_eq!(distr.total_weight, expected_distr.total_weight); + assert_eq!(distr.cumulative_weights, expected_distr.cumulative_weights); + } + } + + #[test] + fn test_update_weights_errors() { + let data = [ + ( + &[1i32, 0, 0][..], + &[(0, &0)][..], + Error::InsufficientNonZero, + ), + ( + &[10, 10, 10, 10][..], + &[(1, &-11)][..], + Error::InvalidWeight, // A weight is negative + ), + ( + &[1, 2, 3, 4, 5][..], + &[(1, &5), (0, &5)][..], // Wrong order + Error::InvalidInput, + ), + ( + &[1][..], + &[(1, &1)][..], // Index too large + Error::InvalidInput, + ), + ]; + + for (weights, update, err) in data.iter() { + let total_weight = weights.iter().sum::(); + let mut distr = WeightedIndex::new(weights.to_vec()).unwrap(); + assert_eq!(distr.total_weight, total_weight); + match distr.update_weights(update) { + Ok(_) => panic!("Expected update_weights to fail, but it succeeded"), + Err(e) => assert_eq!(e, *err), + } + } + } + + #[test] + fn test_weight_at() { + let data = [ + &[1][..], + &[10, 2, 3, 4][..], + &[1, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..], + &[u32::MAX][..], + ]; + + for weights in data.iter() { + let distr = WeightedIndex::new(weights.to_vec()).unwrap(); + for (i, weight) in weights.iter().enumerate() { + assert_eq!(distr.weight(i), Some(*weight)); + } + assert_eq!(distr.weight(weights.len()), None); + } + } + + #[test] + fn test_weights() { + let data = [ + &[1][..], + &[10, 2, 3, 4][..], + &[1, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..], + &[u32::MAX][..], + ]; + + for weights in data.iter() { + let distr = WeightedIndex::new(weights.to_vec()).unwrap(); + assert_eq!(distr.weights().collect::>(), weights.to_vec()); + } + } + + #[test] + fn value_stability() { + fn test_samples( + weights: I, + buf: &mut [usize], + expected: &[usize], + ) where + I: IntoIterator, + I::Item: SampleBorrow, + { + assert_eq!(buf.len(), expected.len()); + let distr = WeightedIndex::new(weights).unwrap(); + let mut rng = crate::test::rng(701); + for r in buf.iter_mut() { + *r = rng.sample(&distr); + } + assert_eq!(buf, expected); + } + + let mut buf = [0; 10]; + test_samples( + [1i32, 1, 1, 1, 1, 1, 1, 1, 1], + &mut buf, + &[0, 6, 2, 6, 3, 4, 7, 8, 2, 5], + ); + test_samples( + [0.7f32, 0.1, 0.1, 0.1], + &mut buf, + &[0, 0, 0, 1, 0, 0, 2, 3, 0, 0], + ); + test_samples( + [1.0f64, 0.999, 0.998, 0.997], + &mut buf, + &[2, 2, 1, 3, 2, 1, 3, 3, 2, 1], + ); + } + + #[test] + fn weighted_index_distributions_can_be_compared() { + assert_eq!(WeightedIndex::new([1, 2]), WeightedIndex::new([1, 2])); + } + + #[test] + fn overflow() { + assert_eq!(WeightedIndex::new([2, usize::MAX]), Err(Error::Overflow)); + } +} diff --git a/src/distributions/integer.rs b/src/distributions/integer.rs deleted file mode 100644 index 8a2ce4cf1e6..00000000000 --- a/src/distributions/integer.rs +++ /dev/null @@ -1,279 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! The implementations of the `Standard` distribution for integer types. - -use crate::distributions::{Distribution, Standard}; -use crate::Rng; -#[cfg(all(target_arch = "x86", feature = "simd_support"))] -use core::arch::x86::{__m128i, __m256i}; -#[cfg(all(target_arch = "x86_64", feature = "simd_support"))] -use core::arch::x86_64::{__m128i, __m256i}; -#[cfg(not(target_os = "emscripten"))] use core::num::NonZeroU128; -use core::num::{NonZeroU16, NonZeroU32, NonZeroU64, NonZeroU8, NonZeroUsize}; -#[cfg(feature = "simd_support")] use packed_simd::*; - -impl Distribution for Standard { - #[inline] - fn sample(&self, rng: &mut R) -> u8 { - rng.next_u32() as u8 - } -} - -impl Distribution for Standard { - #[inline] - fn sample(&self, rng: &mut R) -> u16 { - rng.next_u32() as u16 - } -} - -impl Distribution for Standard { - #[inline] - fn sample(&self, rng: &mut R) -> u32 { - rng.next_u32() - } -} - -impl Distribution for Standard { - #[inline] - fn sample(&self, rng: &mut R) -> u64 { - rng.next_u64() - } -} - -#[cfg(not(target_os = "emscripten"))] -impl Distribution for Standard { - #[inline] - fn sample(&self, rng: &mut R) -> u128 { - // Use LE; we explicitly generate one value before the next. - let x = u128::from(rng.next_u64()); - let y = u128::from(rng.next_u64()); - (y << 64) | x - } -} - -impl Distribution for Standard { - #[inline] - #[cfg(any(target_pointer_width = "32", target_pointer_width = "16"))] - fn sample(&self, rng: &mut R) -> usize { - rng.next_u32() as usize - } - - #[inline] - #[cfg(target_pointer_width = "64")] - fn sample(&self, rng: &mut R) -> usize { - rng.next_u64() as usize - } -} - -macro_rules! impl_int_from_uint { - ($ty:ty, $uty:ty) => { - impl Distribution<$ty> for Standard { - #[inline] - fn sample(&self, rng: &mut R) -> $ty { - rng.gen::<$uty>() as $ty - } - } - }; -} - -impl_int_from_uint! { i8, u8 } -impl_int_from_uint! { i16, u16 } -impl_int_from_uint! { i32, u32 } -impl_int_from_uint! { i64, u64 } -#[cfg(not(target_os = "emscripten"))] -impl_int_from_uint! { i128, u128 } -impl_int_from_uint! { isize, usize } - -macro_rules! impl_nzint { - ($ty:ty, $new:path) => { - impl Distribution<$ty> for Standard { - fn sample(&self, rng: &mut R) -> $ty { - loop { - if let Some(nz) = $new(rng.gen()) { - break nz; - } - } - } - } - }; -} - -impl_nzint!(NonZeroU8, NonZeroU8::new); -impl_nzint!(NonZeroU16, NonZeroU16::new); -impl_nzint!(NonZeroU32, NonZeroU32::new); -impl_nzint!(NonZeroU64, NonZeroU64::new); -#[cfg(not(target_os = "emscripten"))] -impl_nzint!(NonZeroU128, NonZeroU128::new); -impl_nzint!(NonZeroUsize, NonZeroUsize::new); - -#[cfg(feature = "simd_support")] -macro_rules! simd_impl { - ($(($intrinsic:ident, $vec:ty),)+) => {$( - impl Distribution<$intrinsic> for Standard { - #[inline] - fn sample(&self, rng: &mut R) -> $intrinsic { - $intrinsic::from_bits(rng.gen::<$vec>()) - } - } - )+}; - - ($bits:expr,) => {}; - ($bits:expr, $ty:ty, $($ty_more:ty,)*) => { - simd_impl!($bits, $($ty_more,)*); - - impl Distribution<$ty> for Standard { - #[inline] - fn sample(&self, rng: &mut R) -> $ty { - let mut vec: $ty = Default::default(); - unsafe { - let ptr = &mut vec; - let b_ptr = &mut *(ptr as *mut $ty as *mut [u8; $bits/8]); - rng.fill_bytes(b_ptr); - } - vec.to_le() - } - } - }; -} - -#[cfg(feature = "simd_support")] -simd_impl!(16, u8x2, i8x2,); -#[cfg(feature = "simd_support")] -simd_impl!(32, u8x4, i8x4, u16x2, i16x2,); -#[cfg(feature = "simd_support")] -simd_impl!(64, u8x8, i8x8, u16x4, i16x4, u32x2, i32x2,); -#[cfg(feature = "simd_support")] -simd_impl!(128, u8x16, i8x16, u16x8, i16x8, u32x4, i32x4, u64x2, i64x2,); -#[cfg(feature = "simd_support")] -simd_impl!(256, u8x32, i8x32, u16x16, i16x16, u32x8, i32x8, u64x4, i64x4,); -#[cfg(feature = "simd_support")] -simd_impl!(512, u8x64, i8x64, u16x32, i16x32, u32x16, i32x16, u64x8, i64x8,); -#[cfg(all( - feature = "simd_support", - any(target_arch = "x86", target_arch = "x86_64") -))] -simd_impl!((__m128i, u8x16), (__m256i, u8x32),); - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_integers() { - let mut rng = crate::test::rng(806); - - rng.sample::(Standard); - rng.sample::(Standard); - rng.sample::(Standard); - rng.sample::(Standard); - rng.sample::(Standard); - #[cfg(not(target_os = "emscripten"))] - rng.sample::(Standard); - - rng.sample::(Standard); - rng.sample::(Standard); - rng.sample::(Standard); - rng.sample::(Standard); - rng.sample::(Standard); - #[cfg(not(target_os = "emscripten"))] - rng.sample::(Standard); - } - - #[test] - fn value_stability() { - fn test_samples(zero: T, expected: &[T]) - where Standard: Distribution { - let mut rng = crate::test::rng(807); - let mut buf = [zero; 3]; - for x in &mut buf { - *x = rng.sample(Standard); - } - assert_eq!(&buf, expected); - } - - test_samples(0u8, &[9, 247, 111]); - test_samples(0u16, &[32265, 42999, 38255]); - test_samples(0u32, &[2220326409, 2575017975, 2018088303]); - test_samples(0u64, &[ - 11059617991457472009, - 16096616328739788143, - 1487364411147516184, - ]); - test_samples(0u128, &[ - 296930161868957086625409848350820761097, - 145644820879247630242265036535529306392, - 111087889832015897993126088499035356354, - ]); - #[cfg(any(target_pointer_width = "32", target_pointer_width = "16"))] - test_samples(0usize, &[2220326409, 2575017975, 2018088303]); - #[cfg(target_pointer_width = "64")] - test_samples(0usize, &[ - 11059617991457472009, - 16096616328739788143, - 1487364411147516184, - ]); - - test_samples(0i8, &[9, -9, 111]); - // Skip further i* types: they are simple reinterpretation of u* samples - - #[cfg(feature = "simd_support")] - { - // We only test a sub-set of types here and make assumptions about the rest. - - test_samples(u8x2::default(), &[ - u8x2::new(9, 126), - u8x2::new(247, 167), - u8x2::new(111, 149), - ]); - test_samples(u8x4::default(), &[ - u8x4::new(9, 126, 87, 132), - u8x4::new(247, 167, 123, 153), - u8x4::new(111, 149, 73, 120), - ]); - test_samples(u8x8::default(), &[ - u8x8::new(9, 126, 87, 132, 247, 167, 123, 153), - u8x8::new(111, 149, 73, 120, 68, 171, 98, 223), - u8x8::new(24, 121, 1, 50, 13, 46, 164, 20), - ]); - - test_samples(i64x8::default(), &[ - i64x8::new( - -7387126082252079607, - -2350127744969763473, - 1487364411147516184, - 7895421560427121838, - 602190064936008898, - 6022086574635100741, - -5080089175222015595, - -4066367846667249123, - ), - i64x8::new( - 9180885022207963908, - 3095981199532211089, - 6586075293021332726, - 419343203796414657, - 3186951873057035255, - 5287129228749947252, - 444726432079249540, - -1587028029513790706, - ), - i64x8::new( - 6075236523189346388, - 1351763722368165432, - -6192309979959753740, - -7697775502176768592, - -4482022114172078123, - 7522501477800909500, - -1837258847956201231, - -586926753024886735, - ), - ]); - } - } -} diff --git a/src/distributions/mod.rs b/src/distributions/mod.rs deleted file mode 100644 index 652f52a1831..00000000000 --- a/src/distributions/mod.rs +++ /dev/null @@ -1,372 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// Copyright 2013-2017 The Rust Project Developers. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! Generating random samples from probability distributions -//! -//! This module is the home of the [`Distribution`] trait and several of its -//! implementations. It is the workhorse behind some of the convenient -//! functionality of the [`Rng`] trait, e.g. [`Rng::gen`] and of course -//! [`Rng::sample`]. -//! -//! Abstractly, a [probability distribution] describes the probability of -//! occurrence of each value in its sample space. -//! -//! More concretely, an implementation of `Distribution` for type `X` is an -//! algorithm for choosing values from the sample space (a subset of `T`) -//! according to the distribution `X` represents, using an external source of -//! randomness (an RNG supplied to the `sample` function). -//! -//! A type `X` may implement `Distribution` for multiple types `T`. -//! Any type implementing [`Distribution`] is stateless (i.e. immutable), -//! but it may have internal parameters set at construction time (for example, -//! [`Uniform`] allows specification of its sample space as a range within `T`). -//! -//! -//! # The `Standard` distribution -//! -//! The [`Standard`] distribution is important to mention. This is the -//! distribution used by [`Rng::gen`] and represents the "default" way to -//! produce a random value for many different types, including most primitive -//! types, tuples, arrays, and a few derived types. See the documentation of -//! [`Standard`] for more details. -//! -//! Implementing `Distribution` for [`Standard`] for user types `T` makes it -//! possible to generate type `T` with [`Rng::gen`], and by extension also -//! with the [`random`] function. -//! -//! ## Random characters -//! -//! [`Alphanumeric`] is a simple distribution to sample random letters and -//! numbers of the `char` type; in contrast [`Standard`] may sample any valid -//! `char`. -//! -//! -//! # Uniform numeric ranges -//! -//! The [`Uniform`] distribution is more flexible than [`Standard`], but also -//! more specialised: it supports fewer target types, but allows the sample -//! space to be specified as an arbitrary range within its target type `T`. -//! Both [`Standard`] and [`Uniform`] are in some sense uniform distributions. -//! -//! Values may be sampled from this distribution using [`Rng::sample(Range)`] or -//! by creating a distribution object with [`Uniform::new`], -//! [`Uniform::new_inclusive`] or `From`. When the range limits are not -//! known at compile time it is typically faster to reuse an existing -//! `Uniform` object than to call [`Rng::sample(Range)`]. -//! -//! User types `T` may also implement `Distribution` for [`Uniform`], -//! although this is less straightforward than for [`Standard`] (see the -//! documentation in the [`uniform`] module). Doing so enables generation of -//! values of type `T` with [`Rng::sample(Range)`]. -//! -//! ## Open and half-open ranges -//! -//! There are surprisingly many ways to uniformly generate random floats. A -//! range between 0 and 1 is standard, but the exact bounds (open vs closed) -//! and accuracy differ. In addition to the [`Standard`] distribution Rand offers -//! [`Open01`] and [`OpenClosed01`]. See "Floating point implementation" section of -//! [`Standard`] documentation for more details. -//! -//! # Non-uniform sampling -//! -//! Sampling a simple true/false outcome with a given probability has a name: -//! the [`Bernoulli`] distribution (this is used by [`Rng::gen_bool`]). -//! -//! For weighted sampling from a sequence of discrete values, use the -//! [`WeightedIndex`] distribution. -//! -//! This crate no longer includes other non-uniform distributions; instead -//! it is recommended that you use either [`rand_distr`] or [`statrs`]. -//! -//! -//! [probability distribution]: https://en.wikipedia.org/wiki/Probability_distribution -//! [`rand_distr`]: https://crates.io/crates/rand_distr -//! [`statrs`]: https://crates.io/crates/statrs - -//! [`random`]: crate::random -//! [`rand_distr`]: https://crates.io/crates/rand_distr -//! [`statrs`]: https://crates.io/crates/statrs - -use crate::Rng; -use core::iter; - -pub use self::bernoulli::{Bernoulli, BernoulliError}; -pub use self::float::{Open01, OpenClosed01}; -pub use self::other::Alphanumeric; -#[doc(inline)] pub use self::uniform::Uniform; - -#[cfg(feature = "alloc")] -pub use self::weighted_index::{WeightedError, WeightedIndex}; - -mod bernoulli; -pub mod uniform; - -#[deprecated(since = "0.8.0", note = "use rand::distributions::{WeightedIndex, WeightedError} instead")] -#[cfg(feature = "alloc")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] -pub mod weighted; -#[cfg(feature = "alloc")] mod weighted_index; - -#[cfg(feature = "serde1")] -use serde::{Serialize, Deserialize}; - -mod float; -#[doc(hidden)] -pub mod hidden_export { - pub use super::float::IntoFloat; // used by rand_distr -} -mod integer; -mod other; -mod utils; - -/// Types (distributions) that can be used to create a random instance of `T`. -/// -/// It is possible to sample from a distribution through both the -/// `Distribution` and [`Rng`] traits, via `distr.sample(&mut rng)` and -/// `rng.sample(distr)`. They also both offer the [`sample_iter`] method, which -/// produces an iterator that samples from the distribution. -/// -/// All implementations are expected to be immutable; this has the significant -/// advantage of not needing to consider thread safety, and for most -/// distributions efficient state-less sampling algorithms are available. -/// -/// Implementations are typically expected to be portable with reproducible -/// results when used with a PRNG with fixed seed; see the -/// [portability chapter](https://rust-random.github.io/book/portability.html) -/// of The Rust Rand Book. In some cases this does not apply, e.g. the `usize` -/// type requires different sampling on 32-bit and 64-bit machines. -/// -/// [`sample_iter`]: Distribution::method.sample_iter -pub trait Distribution { - /// Generate a random value of `T`, using `rng` as the source of randomness. - fn sample(&self, rng: &mut R) -> T; - - /// Create an iterator that generates random values of `T`, using `rng` as - /// the source of randomness. - /// - /// Note that this function takes `self` by value. This works since - /// `Distribution` is impl'd for `&D` where `D: Distribution`, - /// however borrowing is not automatic hence `distr.sample_iter(...)` may - /// need to be replaced with `(&distr).sample_iter(...)` to borrow or - /// `(&*distr).sample_iter(...)` to reborrow an existing reference. - /// - /// # Example - /// - /// ``` - /// use rand::thread_rng; - /// use rand::distributions::{Distribution, Alphanumeric, Uniform, Standard}; - /// - /// let mut rng = thread_rng(); - /// - /// // Vec of 16 x f32: - /// let v: Vec = Standard.sample_iter(&mut rng).take(16).collect(); - /// - /// // String: - /// let s: String = Alphanumeric - /// .sample_iter(&mut rng) - /// .take(7) - /// .map(char::from) - /// .collect(); - /// - /// // Dice-rolling: - /// let die_range = Uniform::new_inclusive(1, 6); - /// let mut roll_die = die_range.sample_iter(&mut rng); - /// while roll_die.next().unwrap() != 6 { - /// println!("Not a 6; rolling again!"); - /// } - /// ``` - fn sample_iter(self, rng: R) -> DistIter - where - R: Rng, - Self: Sized, - { - DistIter { - distr: self, - rng, - phantom: ::core::marker::PhantomData, - } - } -} - -impl<'a, T, D: Distribution> Distribution for &'a D { - fn sample(&self, rng: &mut R) -> T { - (*self).sample(rng) - } -} - - -/// An iterator that generates random values of `T` with distribution `D`, -/// using `R` as the source of randomness. -/// -/// This `struct` is created by the [`sample_iter`] method on [`Distribution`]. -/// See its documentation for more. -/// -/// [`sample_iter`]: Distribution::sample_iter -#[derive(Debug)] -pub struct DistIter { - distr: D, - rng: R, - phantom: ::core::marker::PhantomData, -} - -impl Iterator for DistIter -where - D: Distribution, - R: Rng, -{ - type Item = T; - - #[inline(always)] - fn next(&mut self) -> Option { - // Here, self.rng may be a reference, but we must take &mut anyway. - // Even if sample could take an R: Rng by value, we would need to do this - // since Rng is not copyable and we cannot enforce that this is "reborrowable". - Some(self.distr.sample(&mut self.rng)) - } - - fn size_hint(&self) -> (usize, Option) { - (usize::max_value(), None) - } -} - -impl iter::FusedIterator for DistIter -where - D: Distribution, - R: Rng, -{ -} - -#[cfg(features = "nightly")] -impl iter::TrustedLen for DistIter -where - D: Distribution, - R: Rng, -{ -} - - -/// A generic random value distribution, implemented for many primitive types. -/// Usually generates values with a numerically uniform distribution, and with a -/// range appropriate to the type. -/// -/// ## Provided implementations -/// -/// Assuming the provided `Rng` is well-behaved, these implementations -/// generate values with the following ranges and distributions: -/// -/// * Integers (`i32`, `u32`, `isize`, `usize`, etc.): Uniformly distributed -/// over all values of the type. -/// * `char`: Uniformly distributed over all Unicode scalar values, i.e. all -/// code points in the range `0...0x10_FFFF`, except for the range -/// `0xD800...0xDFFF` (the surrogate code points). This includes -/// unassigned/reserved code points. -/// * `bool`: Generates `false` or `true`, each with probability 0.5. -/// * Floating point types (`f32` and `f64`): Uniformly distributed in the -/// half-open range `[0, 1)`. See notes below. -/// * Wrapping integers (`Wrapping`), besides the type identical to their -/// normal integer variants. -/// -/// The `Standard` distribution also supports generation of the following -/// compound types where all component types are supported: -/// -/// * Tuples (up to 12 elements): each element is generated sequentially. -/// * Arrays (up to 32 elements): each element is generated sequentially; -/// see also [`Rng::fill`] which supports arbitrary array length for integer -/// types and tends to be faster for `u32` and smaller types. -/// * `Option` first generates a `bool`, and if true generates and returns -/// `Some(value)` where `value: T`, otherwise returning `None`. -/// -/// ## Custom implementations -/// -/// The [`Standard`] distribution may be implemented for user types as follows: -/// -/// ``` -/// # #![allow(dead_code)] -/// use rand::Rng; -/// use rand::distributions::{Distribution, Standard}; -/// -/// struct MyF32 { -/// x: f32, -/// } -/// -/// impl Distribution for Standard { -/// fn sample(&self, rng: &mut R) -> MyF32 { -/// MyF32 { x: rng.gen() } -/// } -/// } -/// ``` -/// -/// ## Example usage -/// ``` -/// use rand::prelude::*; -/// use rand::distributions::Standard; -/// -/// let val: f32 = StdRng::from_entropy().sample(Standard); -/// println!("f32 from [0, 1): {}", val); -/// ``` -/// -/// # Floating point implementation -/// The floating point implementations for `Standard` generate a random value in -/// the half-open interval `[0, 1)`, i.e. including 0 but not 1. -/// -/// All values that can be generated are of the form `n * ε/2`. For `f32` -/// the 24 most significant random bits of a `u32` are used and for `f64` the -/// 53 most significant bits of a `u64` are used. The conversion uses the -/// multiplicative method: `(rng.gen::<$uty>() >> N) as $ty * (ε/2)`. -/// -/// See also: [`Open01`] which samples from `(0, 1)`, [`OpenClosed01`] which -/// samples from `(0, 1]` and `Rng::gen_range(0..1)` which also samples from -/// `[0, 1)`. Note that `Open01` uses transmute-based methods which yield 1 bit -/// less precision but may perform faster on some architectures (on modern Intel -/// CPUs all methods have approximately equal performance). -/// -/// [`Uniform`]: uniform::Uniform -#[derive(Clone, Copy, Debug)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub struct Standard; - - -#[cfg(test)] -mod tests { - use super::{Distribution, Uniform}; - use crate::Rng; - - #[test] - fn test_distributions_iter() { - use crate::distributions::Open01; - let mut rng = crate::test::rng(210); - let distr = Open01; - let mut iter = Distribution::::sample_iter(distr, &mut rng); - let mut sum: f32 = 0.; - for _ in 0..100 { - sum += iter.next().unwrap(); - } - assert!(0. < sum && sum < 100.); - } - - #[test] - fn test_make_an_iter() { - fn ten_dice_rolls_other_than_five<'a, R: Rng>( - rng: &'a mut R, - ) -> impl Iterator + 'a { - Uniform::new_inclusive(1, 6) - .sample_iter(rng) - .filter(|x| *x != 5) - .take(10) - } - - let mut rng = crate::test::rng(211); - let mut count = 0; - for val in ten_dice_rolls_other_than_five(&mut rng) { - assert!(val >= 1 && val <= 6 && val != 5); - count += 1; - } - assert_eq!(count, 10); - } -} diff --git a/src/distributions/other.rs b/src/distributions/other.rs deleted file mode 100644 index f62fe59aa73..00000000000 --- a/src/distributions/other.rs +++ /dev/null @@ -1,313 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! The implementations of the `Standard` distribution for other built-in types. - -use core::char; -use core::num::Wrapping; - -use crate::distributions::{Distribution, Standard, Uniform}; -use crate::Rng; - -#[cfg(feature = "serde1")] -use serde::{Serialize, Deserialize}; - -// ----- Sampling distributions ----- - -/// Sample a `u8`, uniformly distributed over ASCII letters and numbers: -/// a-z, A-Z and 0-9. -/// -/// # Example -/// -/// ``` -/// use std::iter; -/// use rand::{Rng, thread_rng}; -/// use rand::distributions::Alphanumeric; -/// -/// let mut rng = thread_rng(); -/// let chars: String = iter::repeat(()) -/// .map(|()| rng.sample(Alphanumeric)) -/// .map(char::from) -/// .take(7) -/// .collect(); -/// println!("Random chars: {}", chars); -/// ``` -/// -/// # Passwords -/// -/// Users sometimes ask whether it is safe to use a string of random characters -/// as a password. In principle, all RNGs in Rand implementing `CryptoRng` are -/// suitable as a source of randomness for generating passwords (if they are -/// properly seeded), but it is more conservative to only use randomness -/// directly from the operating system via the `getrandom` crate, or the -/// corresponding bindings of a crypto library. -/// -/// When generating passwords or keys, it is important to consider the threat -/// model and in some cases the memorability of the password. This is out of -/// scope of the Rand project, and therefore we defer to the following -/// references: -/// -/// - [Wikipedia article on Password Strength](https://en.wikipedia.org/wiki/Password_strength) -/// - [Diceware for generating memorable passwords](https://en.wikipedia.org/wiki/Diceware) -#[derive(Debug)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub struct Alphanumeric; - - -// ----- Implementations of distributions ----- - -impl Distribution for Standard { - #[inline] - fn sample(&self, rng: &mut R) -> char { - // A valid `char` is either in the interval `[0, 0xD800)` or - // `(0xDFFF, 0x11_0000)`. All `char`s must therefore be in - // `[0, 0x11_0000)` but not in the "gap" `[0xD800, 0xDFFF]` which is - // reserved for surrogates. This is the size of that gap. - const GAP_SIZE: u32 = 0xDFFF - 0xD800 + 1; - - // Uniform::new(0, 0x11_0000 - GAP_SIZE) can also be used but it - // seemed slower. - let range = Uniform::new(GAP_SIZE, 0x11_0000); - - let mut n = range.sample(rng); - if n <= 0xDFFF { - n -= GAP_SIZE; - } - unsafe { char::from_u32_unchecked(n) } - } -} - -impl Distribution for Alphanumeric { - fn sample(&self, rng: &mut R) -> u8 { - const RANGE: u32 = 26 + 26 + 10; - const GEN_ASCII_STR_CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ\ - abcdefghijklmnopqrstuvwxyz\ - 0123456789"; - // We can pick from 62 characters. This is so close to a power of 2, 64, - // that we can do better than `Uniform`. Use a simple bitshift and - // rejection sampling. We do not use a bitmask, because for small RNGs - // the most significant bits are usually of higher quality. - loop { - let var = rng.next_u32() >> (32 - 6); - if var < RANGE { - return GEN_ASCII_STR_CHARSET[var as usize]; - } - } - } -} - -impl Distribution for Standard { - #[inline] - fn sample(&self, rng: &mut R) -> bool { - // We can compare against an arbitrary bit of an u32 to get a bool. - // Because the least significant bits of a lower quality RNG can have - // simple patterns, we compare against the most significant bit. This is - // easiest done using a sign test. - (rng.next_u32() as i32) < 0 - } -} - -macro_rules! tuple_impl { - // use variables to indicate the arity of the tuple - ($($tyvar:ident),* ) => { - // the trailing commas are for the 1 tuple - impl< $( $tyvar ),* > - Distribution<( $( $tyvar ),* , )> - for Standard - where $( Standard: Distribution<$tyvar> ),* - { - #[inline] - fn sample(&self, _rng: &mut R) -> ( $( $tyvar ),* , ) { - ( - // use the $tyvar's to get the appropriate number of - // repeats (they're not actually needed) - $( - _rng.gen::<$tyvar>() - ),* - , - ) - } - } - } -} - -impl Distribution<()> for Standard { - #[allow(clippy::unused_unit)] - #[inline] - fn sample(&self, _: &mut R) -> () { - () - } -} -tuple_impl! {A} -tuple_impl! {A, B} -tuple_impl! {A, B, C} -tuple_impl! {A, B, C, D} -tuple_impl! {A, B, C, D, E} -tuple_impl! {A, B, C, D, E, F} -tuple_impl! {A, B, C, D, E, F, G} -tuple_impl! {A, B, C, D, E, F, G, H} -tuple_impl! {A, B, C, D, E, F, G, H, I} -tuple_impl! {A, B, C, D, E, F, G, H, I, J} -tuple_impl! {A, B, C, D, E, F, G, H, I, J, K} -tuple_impl! {A, B, C, D, E, F, G, H, I, J, K, L} - -macro_rules! array_impl { - // recursive, given at least one type parameter: - {$n:expr, $t:ident, $($ts:ident,)*} => { - array_impl!{($n - 1), $($ts,)*} - - impl Distribution<[T; $n]> for Standard where Standard: Distribution { - #[inline] - fn sample(&self, _rng: &mut R) -> [T; $n] { - [_rng.gen::<$t>(), $(_rng.gen::<$ts>()),*] - } - } - }; - // empty case: - {$n:expr,} => { - impl Distribution<[T; $n]> for Standard { - fn sample(&self, _rng: &mut R) -> [T; $n] { [] } - } - }; -} - -array_impl! {32, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T, T,} - -impl Distribution> for Standard -where Standard: Distribution -{ - #[inline] - fn sample(&self, rng: &mut R) -> Option { - // UFCS is needed here: https://github.com/rust-lang/rust/issues/24066 - if rng.gen::() { - Some(rng.gen()) - } else { - None - } - } -} - -impl Distribution> for Standard -where Standard: Distribution -{ - #[inline] - fn sample(&self, rng: &mut R) -> Wrapping { - Wrapping(rng.gen()) - } -} - - -#[cfg(test)] -mod tests { - use super::*; - use crate::RngCore; - #[cfg(feature = "alloc")] use alloc::string::String; - - #[test] - fn test_misc() { - let rng: &mut dyn RngCore = &mut crate::test::rng(820); - - rng.sample::(Standard); - rng.sample::(Standard); - } - - #[cfg(feature = "alloc")] - #[test] - fn test_chars() { - use core::iter; - let mut rng = crate::test::rng(805); - - // Test by generating a relatively large number of chars, so we also - // take the rejection sampling path. - let word: String = iter::repeat(()) - .map(|()| rng.gen::()) - .take(1000) - .collect(); - assert!(word.len() != 0); - } - - #[test] - fn test_alphanumeric() { - let mut rng = crate::test::rng(806); - - // Test by generating a relatively large number of chars, so we also - // take the rejection sampling path. - let mut incorrect = false; - for _ in 0..100 { - let c: char = rng.sample(Alphanumeric).into(); - incorrect |= !((c >= '0' && c <= '9') || - (c >= 'A' && c <= 'Z') || - (c >= 'a' && c <= 'z') ); - } - assert!(incorrect == false); - } - - #[test] - fn value_stability() { - fn test_samples>( - distr: &D, zero: T, expected: &[T], - ) { - let mut rng = crate::test::rng(807); - let mut buf = [zero; 5]; - for x in &mut buf { - *x = rng.sample(&distr); - } - assert_eq!(&buf, expected); - } - - test_samples(&Standard, 'a', &[ - '\u{8cdac}', - '\u{a346a}', - '\u{80120}', - '\u{ed692}', - '\u{35888}', - ]); - test_samples(&Alphanumeric, 0, &[104, 109, 101, 51, 77]); - test_samples(&Standard, false, &[true, true, false, true, false]); - test_samples(&Standard, None as Option, &[ - Some(true), - None, - Some(false), - None, - Some(false), - ]); - test_samples(&Standard, Wrapping(0i32), &[ - Wrapping(-2074640887), - Wrapping(-1719949321), - Wrapping(2018088303), - Wrapping(-547181756), - Wrapping(838957336), - ]); - - // We test only sub-sets of tuple and array impls - test_samples(&Standard, (), &[(), (), (), (), ()]); - test_samples(&Standard, (false,), &[ - (true,), - (true,), - (false,), - (true,), - (false,), - ]); - test_samples(&Standard, (false, false), &[ - (true, true), - (false, true), - (false, false), - (true, false), - (false, false), - ]); - - test_samples(&Standard, [0u8; 0], &[[], [], [], [], []]); - test_samples(&Standard, [0u8; 3], &[ - [9, 247, 111], - [68, 24, 13], - [174, 19, 194], - [172, 69, 213], - [149, 207, 29], - ]); - } -} diff --git a/src/distributions/uniform.rs b/src/distributions/uniform.rs deleted file mode 100644 index bbd96948f8c..00000000000 --- a/src/distributions/uniform.rs +++ /dev/null @@ -1,1614 +0,0 @@ -// Copyright 2018-2020 Developers of the Rand project. -// Copyright 2017 The Rust Project Developers. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! A distribution uniformly sampling numbers within a given range. -//! -//! [`Uniform`] is the standard distribution to sample uniformly from a range; -//! e.g. `Uniform::new_inclusive(1, 6)` can sample integers from 1 to 6, like a -//! standard die. [`Rng::gen_range`] supports any type supported by -//! [`Uniform`]. -//! -//! This distribution is provided with support for several primitive types -//! (all integer and floating-point types) as well as [`std::time::Duration`], -//! and supports extension to user-defined types via a type-specific *back-end* -//! implementation. -//! -//! The types [`UniformInt`], [`UniformFloat`] and [`UniformDuration`] are the -//! back-ends supporting sampling from primitive integer and floating-point -//! ranges as well as from [`std::time::Duration`]; these types do not normally -//! need to be used directly (unless implementing a derived back-end). -//! -//! # Example usage -//! -//! ``` -//! use rand::{Rng, thread_rng}; -//! use rand::distributions::Uniform; -//! -//! let mut rng = thread_rng(); -//! let side = Uniform::new(-10.0, 10.0); -//! -//! // sample between 1 and 10 points -//! for _ in 0..rng.gen_range(1..=10) { -//! // sample a point from the square with sides -10 - 10 in two dimensions -//! let (x, y) = (rng.sample(side), rng.sample(side)); -//! println!("Point: {}, {}", x, y); -//! } -//! ``` -//! -//! # Extending `Uniform` to support a custom type -//! -//! To extend [`Uniform`] to support your own types, write a back-end which -//! implements the [`UniformSampler`] trait, then implement the [`SampleUniform`] -//! helper trait to "register" your back-end. See the `MyF32` example below. -//! -//! At a minimum, the back-end needs to store any parameters needed for sampling -//! (e.g. the target range) and implement `new`, `new_inclusive` and `sample`. -//! Those methods should include an assert to check the range is valid (i.e. -//! `low < high`). The example below merely wraps another back-end. -//! -//! The `new`, `new_inclusive` and `sample_single` functions use arguments of -//! type SampleBorrow in order to support passing in values by reference or -//! by value. In the implementation of these functions, you can choose to -//! simply use the reference returned by [`SampleBorrow::borrow`], or you can choose -//! to copy or clone the value, whatever is appropriate for your type. -//! -//! ``` -//! use rand::prelude::*; -//! use rand::distributions::uniform::{Uniform, SampleUniform, -//! UniformSampler, UniformFloat, SampleBorrow}; -//! -//! struct MyF32(f32); -//! -//! #[derive(Clone, Copy, Debug)] -//! struct UniformMyF32(UniformFloat); -//! -//! impl UniformSampler for UniformMyF32 { -//! type X = MyF32; -//! fn new(low: B1, high: B2) -> Self -//! where B1: SampleBorrow + Sized, -//! B2: SampleBorrow + Sized -//! { -//! UniformMyF32(UniformFloat::::new(low.borrow().0, high.borrow().0)) -//! } -//! fn new_inclusive(low: B1, high: B2) -> Self -//! where B1: SampleBorrow + Sized, -//! B2: SampleBorrow + Sized -//! { -//! UniformSampler::new(low, high) -//! } -//! fn sample(&self, rng: &mut R) -> Self::X { -//! MyF32(self.0.sample(rng)) -//! } -//! } -//! -//! impl SampleUniform for MyF32 { -//! type Sampler = UniformMyF32; -//! } -//! -//! let (low, high) = (MyF32(17.0f32), MyF32(22.0f32)); -//! let uniform = Uniform::new(low, high); -//! let x = uniform.sample(&mut thread_rng()); -//! ``` -//! -//! [`SampleUniform`]: crate::distributions::uniform::SampleUniform -//! [`UniformSampler`]: crate::distributions::uniform::UniformSampler -//! [`UniformInt`]: crate::distributions::uniform::UniformInt -//! [`UniformFloat`]: crate::distributions::uniform::UniformFloat -//! [`UniformDuration`]: crate::distributions::uniform::UniformDuration -//! [`SampleBorrow::borrow`]: crate::distributions::uniform::SampleBorrow::borrow - -#[cfg(not(feature = "std"))] use core::time::Duration; -#[cfg(feature = "std")] use std::time::Duration; -use core::ops::{Range, RangeInclusive}; - -use crate::distributions::float::IntoFloat; -use crate::distributions::utils::{BoolAsSIMD, FloatAsSIMD, FloatSIMDUtils, WideningMultiply}; -use crate::distributions::Distribution; -use crate::{Rng, RngCore}; - -#[cfg(not(feature = "std"))] -#[allow(unused_imports)] // rustc doesn't detect that this is actually used -use crate::distributions::utils::Float; - -#[cfg(feature = "simd_support")] use packed_simd::*; - -#[cfg(feature = "serde1")] -use serde::{Serialize, Deserialize}; - -/// Sample values uniformly between two bounds. -/// -/// [`Uniform::new`] and [`Uniform::new_inclusive`] construct a uniform -/// distribution sampling from the given range; these functions may do extra -/// work up front to make sampling of multiple values faster. If only one sample -/// from the range is required, [`Rng::gen_range`] can be more efficient. -/// -/// When sampling from a constant range, many calculations can happen at -/// compile-time and all methods should be fast; for floating-point ranges and -/// the full range of integer types this should have comparable performance to -/// the `Standard` distribution. -/// -/// Steps are taken to avoid bias which might be present in naive -/// implementations; for example `rng.gen::() % 170` samples from the range -/// `[0, 169]` but is twice as likely to select numbers less than 85 than other -/// values. Further, the implementations here give more weight to the high-bits -/// generated by the RNG than the low bits, since with some RNGs the low-bits -/// are of lower quality than the high bits. -/// -/// Implementations must sample in `[low, high)` range for -/// `Uniform::new(low, high)`, i.e., excluding `high`. In particular, care must -/// be taken to ensure that rounding never results values `< low` or `>= high`. -/// -/// # Example -/// -/// ``` -/// use rand::distributions::{Distribution, Uniform}; -/// -/// let between = Uniform::from(10..10000); -/// let mut rng = rand::thread_rng(); -/// let mut sum = 0; -/// for _ in 0..1000 { -/// sum += between.sample(&mut rng); -/// } -/// println!("{}", sum); -/// ``` -/// -/// For a single sample, [`Rng::gen_range`] may be prefered: -/// -/// ``` -/// use rand::Rng; -/// -/// let mut rng = rand::thread_rng(); -/// println!("{}", rng.gen_range(0..10)); -/// ``` -/// -/// [`new`]: Uniform::new -/// [`new_inclusive`]: Uniform::new_inclusive -/// [`Rng::gen_range`]: Rng::gen_range -#[derive(Clone, Copy, Debug)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub struct Uniform(X::Sampler); - -impl Uniform { - /// Create a new `Uniform` instance which samples uniformly from the half - /// open range `[low, high)` (excluding `high`). Panics if `low >= high`. - pub fn new(low: B1, high: B2) -> Uniform - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - Uniform(X::Sampler::new(low, high)) - } - - /// Create a new `Uniform` instance which samples uniformly from the closed - /// range `[low, high]` (inclusive). Panics if `low > high`. - pub fn new_inclusive(low: B1, high: B2) -> Uniform - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - Uniform(X::Sampler::new_inclusive(low, high)) - } -} - -impl Distribution for Uniform { - fn sample(&self, rng: &mut R) -> X { - self.0.sample(rng) - } -} - -/// Helper trait for creating objects using the correct implementation of -/// [`UniformSampler`] for the sampling type. -/// -/// See the [module documentation] on how to implement [`Uniform`] range -/// sampling for a custom type. -/// -/// [module documentation]: crate::distributions::uniform -pub trait SampleUniform: Sized { - /// The `UniformSampler` implementation supporting type `X`. - type Sampler: UniformSampler; -} - -/// Helper trait handling actual uniform sampling. -/// -/// See the [module documentation] on how to implement [`Uniform`] range -/// sampling for a custom type. -/// -/// Implementation of [`sample_single`] is optional, and is only useful when -/// the implementation can be faster than `Self::new(low, high).sample(rng)`. -/// -/// [module documentation]: crate::distributions::uniform -/// [`sample_single`]: UniformSampler::sample_single -pub trait UniformSampler: Sized { - /// The type sampled by this implementation. - type X; - - /// Construct self, with inclusive lower bound and exclusive upper bound - /// `[low, high)`. - /// - /// Usually users should not call this directly but instead use - /// `Uniform::new`, which asserts that `low < high` before calling this. - fn new(low: B1, high: B2) -> Self - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized; - - /// Construct self, with inclusive bounds `[low, high]`. - /// - /// Usually users should not call this directly but instead use - /// `Uniform::new_inclusive`, which asserts that `low <= high` before - /// calling this. - fn new_inclusive(low: B1, high: B2) -> Self - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized; - - /// Sample a value. - fn sample(&self, rng: &mut R) -> Self::X; - - /// Sample a single value uniformly from a range with inclusive lower bound - /// and exclusive upper bound `[low, high)`. - /// - /// By default this is implemented using - /// `UniformSampler::new(low, high).sample(rng)`. However, for some types - /// more optimal implementations for single usage may be provided via this - /// method (which is the case for integers and floats). - /// Results may not be identical. - /// - /// Note that to use this method in a generic context, the type needs to be - /// retrieved via `SampleUniform::Sampler` as follows: - /// ``` - /// use rand::{thread_rng, distributions::uniform::{SampleUniform, UniformSampler}}; - /// # #[allow(unused)] - /// fn sample_from_range(lb: T, ub: T) -> T { - /// let mut rng = thread_rng(); - /// ::Sampler::sample_single(lb, ub, &mut rng) - /// } - /// ``` - fn sample_single(low: B1, high: B2, rng: &mut R) -> Self::X - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - let uniform: Self = UniformSampler::new(low, high); - uniform.sample(rng) - } - - /// Sample a single value uniformly from a range with inclusive lower bound - /// and inclusive upper bound `[low, high]`. - /// - /// By default this is implemented using - /// `UniformSampler::new_inclusive(low, high).sample(rng)`. However, for - /// some types more optimal implementations for single usage may be provided - /// via this method. - /// Results may not be identical. - fn sample_single_inclusive(low: B1, high: B2, rng: &mut R) - -> Self::X - where B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized - { - let uniform: Self = UniformSampler::new_inclusive(low, high); - uniform.sample(rng) - } -} - -impl From> for Uniform { - fn from(r: ::core::ops::Range) -> Uniform { - Uniform::new(r.start, r.end) - } -} - -impl From> for Uniform { - fn from(r: ::core::ops::RangeInclusive) -> Uniform { - Uniform::new_inclusive(r.start(), r.end()) - } -} - - -/// Helper trait similar to [`Borrow`] but implemented -/// only for SampleUniform and references to SampleUniform in -/// order to resolve ambiguity issues. -/// -/// [`Borrow`]: std::borrow::Borrow -pub trait SampleBorrow { - /// Immutably borrows from an owned value. See [`Borrow::borrow`] - /// - /// [`Borrow::borrow`]: std::borrow::Borrow::borrow - fn borrow(&self) -> &Borrowed; -} -impl SampleBorrow for Borrowed -where Borrowed: SampleUniform -{ - #[inline(always)] - fn borrow(&self) -> &Borrowed { - self - } -} -impl<'a, Borrowed> SampleBorrow for &'a Borrowed -where Borrowed: SampleUniform -{ - #[inline(always)] - fn borrow(&self) -> &Borrowed { - *self - } -} - -/// Range that supports generating a single sample efficiently. -/// -/// Any type implementing this trait can be used to specify the sampled range -/// for `Rng::gen_range`. -pub trait SampleRange { - /// Generate a sample from the given range. - fn sample_single(self, rng: &mut R) -> T; - - /// Check whether the range is empty. - fn is_empty(&self) -> bool; -} - -impl SampleRange for Range { - #[inline] - fn sample_single(self, rng: &mut R) -> T { - T::Sampler::sample_single(self.start, self.end, rng) - } - - #[inline] - fn is_empty(&self) -> bool { - !(self.start < self.end) - } -} - -impl SampleRange for RangeInclusive { - #[inline] - fn sample_single(self, rng: &mut R) -> T { - T::Sampler::sample_single_inclusive(self.start(), self.end(), rng) - } - - #[inline] - fn is_empty(&self) -> bool { - !(self.start() <= self.end()) - } -} - - -//////////////////////////////////////////////////////////////////////////////// - -// What follows are all back-ends. - - -/// The back-end implementing [`UniformSampler`] for integer types. -/// -/// Unless you are implementing [`UniformSampler`] for your own type, this type -/// should not be used directly, use [`Uniform`] instead. -/// -/// # Implementation notes -/// -/// For simplicity, we use the same generic struct `UniformInt` for all -/// integer types `X`. This gives us only one field type, `X`; to store unsigned -/// values of this size, we take use the fact that these conversions are no-ops. -/// -/// For a closed range, the number of possible numbers we should generate is -/// `range = (high - low + 1)`. To avoid bias, we must ensure that the size of -/// our sample space, `zone`, is a multiple of `range`; other values must be -/// rejected (by replacing with a new random sample). -/// -/// As a special case, we use `range = 0` to represent the full range of the -/// result type (i.e. for `new_inclusive($ty::MIN, $ty::MAX)`). -/// -/// The optimum `zone` is the largest product of `range` which fits in our -/// (unsigned) target type. We calculate this by calculating how many numbers we -/// must reject: `reject = (MAX + 1) % range = (MAX - range + 1) % range`. Any (large) -/// product of `range` will suffice, thus in `sample_single` we multiply by a -/// power of 2 via bit-shifting (faster but may cause more rejections). -/// -/// The smallest integer PRNGs generate is `u32`. For 8- and 16-bit outputs we -/// use `u32` for our `zone` and samples (because it's not slower and because -/// it reduces the chance of having to reject a sample). In this case we cannot -/// store `zone` in the target type since it is too large, however we know -/// `ints_to_reject < range <= $unsigned::MAX`. -/// -/// An alternative to using a modulus is widening multiply: After a widening -/// multiply by `range`, the result is in the high word. Then comparing the low -/// word against `zone` makes sure our distribution is uniform. -#[derive(Clone, Copy, Debug)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub struct UniformInt { - low: X, - range: X, - z: X, // either ints_to_reject or zone depending on implementation -} - -macro_rules! uniform_int_impl { - ($ty:ty, $unsigned:ident, $u_large:ident) => { - impl SampleUniform for $ty { - type Sampler = UniformInt<$ty>; - } - - impl UniformSampler for UniformInt<$ty> { - // We play free and fast with unsigned vs signed here - // (when $ty is signed), but that's fine, since the - // contract of this macro is for $ty and $unsigned to be - // "bit-equal", so casting between them is a no-op. - - type X = $ty; - - #[inline] // if the range is constant, this helps LLVM to do the - // calculations at compile-time. - fn new(low_b: B1, high_b: B2) -> Self - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - let low = *low_b.borrow(); - let high = *high_b.borrow(); - assert!(low < high, "Uniform::new called with `low >= high`"); - UniformSampler::new_inclusive(low, high - 1) - } - - #[inline] // if the range is constant, this helps LLVM to do the - // calculations at compile-time. - fn new_inclusive(low_b: B1, high_b: B2) -> Self - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - let low = *low_b.borrow(); - let high = *high_b.borrow(); - assert!( - low <= high, - "Uniform::new_inclusive called with `low > high`" - ); - let unsigned_max = ::core::$u_large::MAX; - - let range = high.wrapping_sub(low).wrapping_add(1) as $unsigned; - let ints_to_reject = if range > 0 { - let range = $u_large::from(range); - (unsigned_max - range + 1) % range - } else { - 0 - }; - - UniformInt { - low, - // These are really $unsigned values, but store as $ty: - range: range as $ty, - z: ints_to_reject as $unsigned as $ty, - } - } - - #[inline] - fn sample(&self, rng: &mut R) -> Self::X { - let range = self.range as $unsigned as $u_large; - if range > 0 { - let unsigned_max = ::core::$u_large::MAX; - let zone = unsigned_max - (self.z as $unsigned as $u_large); - loop { - let v: $u_large = rng.gen(); - let (hi, lo) = v.wmul(range); - if lo <= zone { - return self.low.wrapping_add(hi as $ty); - } - } - } else { - // Sample from the entire integer range. - rng.gen() - } - } - - #[inline] - fn sample_single(low_b: B1, high_b: B2, rng: &mut R) -> Self::X - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - let low = *low_b.borrow(); - let high = *high_b.borrow(); - assert!(low < high, "UniformSampler::sample_single: low >= high"); - Self::sample_single_inclusive(low, high - 1, rng) - } - - #[inline] - fn sample_single_inclusive(low_b: B1, high_b: B2, rng: &mut R) -> Self::X - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - let low = *low_b.borrow(); - let high = *high_b.borrow(); - assert!(low <= high, "UniformSampler::sample_single_inclusive: low > high"); - let range = high.wrapping_sub(low).wrapping_add(1) as $unsigned as $u_large; - let zone = if ::core::$unsigned::MAX <= ::core::u16::MAX as $unsigned { - // Using a modulus is faster than the approximation for - // i8 and i16. I suppose we trade the cost of one - // modulus for near-perfect branch prediction. - let unsigned_max: $u_large = ::core::$u_large::MAX; - let ints_to_reject = (unsigned_max - range + 1) % range; - unsigned_max - ints_to_reject - } else { - // conservative but fast approximation. `- 1` is necessary to allow the - // same comparison without bias. - (range << range.leading_zeros()).wrapping_sub(1) - }; - - loop { - let v: $u_large = rng.gen(); - let (hi, lo) = v.wmul(range); - if lo <= zone { - return low.wrapping_add(hi as $ty); - } - } - } - } - }; -} - -uniform_int_impl! { i8, u8, u32 } -uniform_int_impl! { i16, u16, u32 } -uniform_int_impl! { i32, u32, u32 } -uniform_int_impl! { i64, u64, u64 } -#[cfg(not(target_os = "emscripten"))] -uniform_int_impl! { i128, u128, u128 } -uniform_int_impl! { isize, usize, usize } -uniform_int_impl! { u8, u8, u32 } -uniform_int_impl! { u16, u16, u32 } -uniform_int_impl! { u32, u32, u32 } -uniform_int_impl! { u64, u64, u64 } -uniform_int_impl! { usize, usize, usize } -#[cfg(not(target_os = "emscripten"))] -uniform_int_impl! { u128, u128, u128 } - -#[cfg(feature = "simd_support")] -macro_rules! uniform_simd_int_impl { - ($ty:ident, $unsigned:ident, $u_scalar:ident) => { - // The "pick the largest zone that can fit in an `u32`" optimization - // is less useful here. Multiple lanes complicate things, we don't - // know the PRNG's minimal output size, and casting to a larger vector - // is generally a bad idea for SIMD performance. The user can still - // implement it manually. - - // TODO: look into `Uniform::::new(0u32, 100)` functionality - // perhaps `impl SampleUniform for $u_scalar`? - impl SampleUniform for $ty { - type Sampler = UniformInt<$ty>; - } - - impl UniformSampler for UniformInt<$ty> { - type X = $ty; - - #[inline] // if the range is constant, this helps LLVM to do the - // calculations at compile-time. - fn new(low_b: B1, high_b: B2) -> Self - where B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized - { - let low = *low_b.borrow(); - let high = *high_b.borrow(); - assert!(low.lt(high).all(), "Uniform::new called with `low >= high`"); - UniformSampler::new_inclusive(low, high - 1) - } - - #[inline] // if the range is constant, this helps LLVM to do the - // calculations at compile-time. - fn new_inclusive(low_b: B1, high_b: B2) -> Self - where B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized - { - let low = *low_b.borrow(); - let high = *high_b.borrow(); - assert!(low.le(high).all(), - "Uniform::new_inclusive called with `low > high`"); - let unsigned_max = ::core::$u_scalar::MAX; - - // NOTE: these may need to be replaced with explicitly - // wrapping operations if `packed_simd` changes - let range: $unsigned = ((high - low) + 1).cast(); - // `% 0` will panic at runtime. - let not_full_range = range.gt($unsigned::splat(0)); - // replacing 0 with `unsigned_max` allows a faster `select` - // with bitwise OR - let modulo = not_full_range.select(range, $unsigned::splat(unsigned_max)); - // wrapping addition - let ints_to_reject = (unsigned_max - range + 1) % modulo; - // When `range` is 0, `lo` of `v.wmul(range)` will always be - // zero which means only one sample is needed. - let zone = unsigned_max - ints_to_reject; - - UniformInt { - low, - // These are really $unsigned values, but store as $ty: - range: range.cast(), - z: zone.cast(), - } - } - - fn sample(&self, rng: &mut R) -> Self::X { - let range: $unsigned = self.range.cast(); - let zone: $unsigned = self.z.cast(); - - // This might seem very slow, generating a whole new - // SIMD vector for every sample rejection. For most uses - // though, the chance of rejection is small and provides good - // general performance. With multiple lanes, that chance is - // multiplied. To mitigate this, we replace only the lanes of - // the vector which fail, iteratively reducing the chance of - // rejection. The replacement method does however add a little - // overhead. Benchmarking or calculating probabilities might - // reveal contexts where this replacement method is slower. - let mut v: $unsigned = rng.gen(); - loop { - let (hi, lo) = v.wmul(range); - let mask = lo.le(zone); - if mask.all() { - let hi: $ty = hi.cast(); - // wrapping addition - let result = self.low + hi; - // `select` here compiles to a blend operation - // When `range.eq(0).none()` the compare and blend - // operations are avoided. - let v: $ty = v.cast(); - return range.gt($unsigned::splat(0)).select(result, v); - } - // Replace only the failing lanes - v = mask.select(v, rng.gen()); - } - } - } - }; - - // bulk implementation - ($(($unsigned:ident, $signed:ident),)+ $u_scalar:ident) => { - $( - uniform_simd_int_impl!($unsigned, $unsigned, $u_scalar); - uniform_simd_int_impl!($signed, $unsigned, $u_scalar); - )+ - }; -} - -#[cfg(feature = "simd_support")] -uniform_simd_int_impl! { - (u64x2, i64x2), - (u64x4, i64x4), - (u64x8, i64x8), - u64 -} - -#[cfg(feature = "simd_support")] -uniform_simd_int_impl! { - (u32x2, i32x2), - (u32x4, i32x4), - (u32x8, i32x8), - (u32x16, i32x16), - u32 -} - -#[cfg(feature = "simd_support")] -uniform_simd_int_impl! { - (u16x2, i16x2), - (u16x4, i16x4), - (u16x8, i16x8), - (u16x16, i16x16), - (u16x32, i16x32), - u16 -} - -#[cfg(feature = "simd_support")] -uniform_simd_int_impl! { - (u8x2, i8x2), - (u8x4, i8x4), - (u8x8, i8x8), - (u8x16, i8x16), - (u8x32, i8x32), - (u8x64, i8x64), - u8 -} - -impl SampleUniform for char { - type Sampler = UniformChar; -} - -/// The back-end implementing [`UniformSampler`] for `char`. -/// -/// Unless you are implementing [`UniformSampler`] for your own type, this type -/// should not be used directly, use [`Uniform`] instead. -/// -/// This differs from integer range sampling since the range `0xD800..=0xDFFF` -/// are used for surrogate pairs in UCS and UTF-16, and consequently are not -/// valid Unicode code points. We must therefore avoid sampling values in this -/// range. -#[derive(Clone, Copy, Debug)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub struct UniformChar { - sampler: UniformInt, -} - -/// UTF-16 surrogate range start -const CHAR_SURROGATE_START: u32 = 0xD800; -/// UTF-16 surrogate range size -const CHAR_SURROGATE_LEN: u32 = 0xE000 - CHAR_SURROGATE_START; - -/// Convert `char` to compressed `u32` -fn char_to_comp_u32(c: char) -> u32 { - match c as u32 { - c if c >= CHAR_SURROGATE_START => c - CHAR_SURROGATE_LEN, - c => c, - } -} - -impl UniformSampler for UniformChar { - type X = char; - - #[inline] // if the range is constant, this helps LLVM to do the - // calculations at compile-time. - fn new(low_b: B1, high_b: B2) -> Self - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - let low = char_to_comp_u32(*low_b.borrow()); - let high = char_to_comp_u32(*high_b.borrow()); - let sampler = UniformInt::::new(low, high); - UniformChar { sampler } - } - - #[inline] // if the range is constant, this helps LLVM to do the - // calculations at compile-time. - fn new_inclusive(low_b: B1, high_b: B2) -> Self - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - let low = char_to_comp_u32(*low_b.borrow()); - let high = char_to_comp_u32(*high_b.borrow()); - let sampler = UniformInt::::new_inclusive(low, high); - UniformChar { sampler } - } - - fn sample(&self, rng: &mut R) -> Self::X { - let mut x = self.sampler.sample(rng); - if x >= CHAR_SURROGATE_START { - x += CHAR_SURROGATE_LEN; - } - // SAFETY: x must not be in surrogate range or greater than char::MAX. - // This relies on range constructors which accept char arguments. - // Validity of input char values is assumed. - unsafe { core::char::from_u32_unchecked(x) } - } -} - -/// The back-end implementing [`UniformSampler`] for floating-point types. -/// -/// Unless you are implementing [`UniformSampler`] for your own type, this type -/// should not be used directly, use [`Uniform`] instead. -/// -/// # Implementation notes -/// -/// Instead of generating a float in the `[0, 1)` range using [`Standard`], the -/// `UniformFloat` implementation converts the output of an PRNG itself. This -/// way one or two steps can be optimized out. -/// -/// The floats are first converted to a value in the `[1, 2)` interval using a -/// transmute-based method, and then mapped to the expected range with a -/// multiply and addition. Values produced this way have what equals 23 bits of -/// random digits for an `f32`, and 52 for an `f64`. -/// -/// [`new`]: UniformSampler::new -/// [`new_inclusive`]: UniformSampler::new_inclusive -/// [`Standard`]: crate::distributions::Standard -#[derive(Clone, Copy, Debug)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub struct UniformFloat { - low: X, - scale: X, -} - -macro_rules! uniform_float_impl { - ($ty:ty, $uty:ident, $f_scalar:ident, $u_scalar:ident, $bits_to_discard:expr) => { - impl SampleUniform for $ty { - type Sampler = UniformFloat<$ty>; - } - - impl UniformSampler for UniformFloat<$ty> { - type X = $ty; - - fn new(low_b: B1, high_b: B2) -> Self - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - let low = *low_b.borrow(); - let high = *high_b.borrow(); - assert!(low.all_lt(high), "Uniform::new called with `low >= high`"); - assert!( - low.all_finite() && high.all_finite(), - "Uniform::new called with non-finite boundaries" - ); - let max_rand = <$ty>::splat( - (::core::$u_scalar::MAX >> $bits_to_discard).into_float_with_exponent(0) - 1.0, - ); - - let mut scale = high - low; - - loop { - let mask = (scale * max_rand + low).ge_mask(high); - if mask.none() { - break; - } - scale = scale.decrease_masked(mask); - } - - debug_assert!(<$ty>::splat(0.0).all_le(scale)); - - UniformFloat { low, scale } - } - - fn new_inclusive(low_b: B1, high_b: B2) -> Self - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - let low = *low_b.borrow(); - let high = *high_b.borrow(); - assert!( - low.all_le(high), - "Uniform::new_inclusive called with `low > high`" - ); - assert!( - low.all_finite() && high.all_finite(), - "Uniform::new_inclusive called with non-finite boundaries" - ); - let max_rand = <$ty>::splat( - (::core::$u_scalar::MAX >> $bits_to_discard).into_float_with_exponent(0) - 1.0, - ); - - let mut scale = (high - low) / max_rand; - - loop { - let mask = (scale * max_rand + low).gt_mask(high); - if mask.none() { - break; - } - scale = scale.decrease_masked(mask); - } - - debug_assert!(<$ty>::splat(0.0).all_le(scale)); - - UniformFloat { low, scale } - } - - fn sample(&self, rng: &mut R) -> Self::X { - // Generate a value in the range [1, 2) - let value1_2 = (rng.gen::<$uty>() >> $bits_to_discard).into_float_with_exponent(0); - - // Get a value in the range [0, 1) in order to avoid - // overflowing into infinity when multiplying with scale - let value0_1 = value1_2 - 1.0; - - // We don't use `f64::mul_add`, because it is not available with - // `no_std`. Furthermore, it is slower for some targets (but - // faster for others). However, the order of multiplication and - // addition is important, because on some platforms (e.g. ARM) - // it will be optimized to a single (non-FMA) instruction. - value0_1 * self.scale + self.low - } - - #[inline] - fn sample_single(low_b: B1, high_b: B2, rng: &mut R) -> Self::X - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - let low = *low_b.borrow(); - let high = *high_b.borrow(); - assert!( - low.all_lt(high), - "UniformSampler::sample_single: low >= high" - ); - let mut scale = high - low; - - loop { - // Generate a value in the range [1, 2) - let value1_2 = - (rng.gen::<$uty>() >> $bits_to_discard).into_float_with_exponent(0); - - // Get a value in the range [0, 1) in order to avoid - // overflowing into infinity when multiplying with scale - let value0_1 = value1_2 - 1.0; - - // Doing multiply before addition allows some architectures - // to use a single instruction. - let res = value0_1 * scale + low; - - debug_assert!(low.all_le(res) || !scale.all_finite()); - if res.all_lt(high) { - return res; - } - - // This handles a number of edge cases. - // * `low` or `high` is NaN. In this case `scale` and - // `res` are going to end up as NaN. - // * `low` is negative infinity and `high` is finite. - // `scale` is going to be infinite and `res` will be - // NaN. - // * `high` is positive infinity and `low` is finite. - // `scale` is going to be infinite and `res` will - // be infinite or NaN (if value0_1 is 0). - // * `low` is negative infinity and `high` is positive - // infinity. `scale` will be infinite and `res` will - // be NaN. - // * `low` and `high` are finite, but `high - low` - // overflows to infinite. `scale` will be infinite - // and `res` will be infinite or NaN (if value0_1 is 0). - // So if `high` or `low` are non-finite, we are guaranteed - // to fail the `res < high` check above and end up here. - // - // While we technically should check for non-finite `low` - // and `high` before entering the loop, by doing the checks - // here instead, we allow the common case to avoid these - // checks. But we are still guaranteed that if `low` or - // `high` are non-finite we'll end up here and can do the - // appropriate checks. - // - // Likewise `high - low` overflowing to infinity is also - // rare, so handle it here after the common case. - let mask = !scale.finite_mask(); - if mask.any() { - assert!( - low.all_finite() && high.all_finite(), - "Uniform::sample_single: low and high must be finite" - ); - scale = scale.decrease_masked(mask); - } - } - } - } - }; -} - -uniform_float_impl! { f32, u32, f32, u32, 32 - 23 } -uniform_float_impl! { f64, u64, f64, u64, 64 - 52 } - -#[cfg(feature = "simd_support")] -uniform_float_impl! { f32x2, u32x2, f32, u32, 32 - 23 } -#[cfg(feature = "simd_support")] -uniform_float_impl! { f32x4, u32x4, f32, u32, 32 - 23 } -#[cfg(feature = "simd_support")] -uniform_float_impl! { f32x8, u32x8, f32, u32, 32 - 23 } -#[cfg(feature = "simd_support")] -uniform_float_impl! { f32x16, u32x16, f32, u32, 32 - 23 } - -#[cfg(feature = "simd_support")] -uniform_float_impl! { f64x2, u64x2, f64, u64, 64 - 52 } -#[cfg(feature = "simd_support")] -uniform_float_impl! { f64x4, u64x4, f64, u64, 64 - 52 } -#[cfg(feature = "simd_support")] -uniform_float_impl! { f64x8, u64x8, f64, u64, 64 - 52 } - - -/// The back-end implementing [`UniformSampler`] for `Duration`. -/// -/// Unless you are implementing [`UniformSampler`] for your own types, this type -/// should not be used directly, use [`Uniform`] instead. -#[derive(Clone, Copy, Debug)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub struct UniformDuration { - mode: UniformDurationMode, - offset: u32, -} - -#[derive(Debug, Copy, Clone)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -enum UniformDurationMode { - Small { - secs: u64, - nanos: Uniform, - }, - Medium { - nanos: Uniform, - }, - Large { - max_secs: u64, - max_nanos: u32, - secs: Uniform, - }, -} - -impl SampleUniform for Duration { - type Sampler = UniformDuration; -} - -impl UniformSampler for UniformDuration { - type X = Duration; - - #[inline] - fn new(low_b: B1, high_b: B2) -> Self - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - let low = *low_b.borrow(); - let high = *high_b.borrow(); - assert!(low < high, "Uniform::new called with `low >= high`"); - UniformDuration::new_inclusive(low, high - Duration::new(0, 1)) - } - - #[inline] - fn new_inclusive(low_b: B1, high_b: B2) -> Self - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - let low = *low_b.borrow(); - let high = *high_b.borrow(); - assert!( - low <= high, - "Uniform::new_inclusive called with `low > high`" - ); - - let low_s = low.as_secs(); - let low_n = low.subsec_nanos(); - let mut high_s = high.as_secs(); - let mut high_n = high.subsec_nanos(); - - if high_n < low_n { - high_s -= 1; - high_n += 1_000_000_000; - } - - let mode = if low_s == high_s { - UniformDurationMode::Small { - secs: low_s, - nanos: Uniform::new_inclusive(low_n, high_n), - } - } else { - let max = high_s - .checked_mul(1_000_000_000) - .and_then(|n| n.checked_add(u64::from(high_n))); - - if let Some(higher_bound) = max { - let lower_bound = low_s * 1_000_000_000 + u64::from(low_n); - UniformDurationMode::Medium { - nanos: Uniform::new_inclusive(lower_bound, higher_bound), - } - } else { - // An offset is applied to simplify generation of nanoseconds - let max_nanos = high_n - low_n; - UniformDurationMode::Large { - max_secs: high_s, - max_nanos, - secs: Uniform::new_inclusive(low_s, high_s), - } - } - }; - UniformDuration { - mode, - offset: low_n, - } - } - - #[inline] - fn sample(&self, rng: &mut R) -> Duration { - match self.mode { - UniformDurationMode::Small { secs, nanos } => { - let n = nanos.sample(rng); - Duration::new(secs, n) - } - UniformDurationMode::Medium { nanos } => { - let nanos = nanos.sample(rng); - Duration::new(nanos / 1_000_000_000, (nanos % 1_000_000_000) as u32) - } - UniformDurationMode::Large { - max_secs, - max_nanos, - secs, - } => { - // constant folding means this is at least as fast as `Rng::sample(Range)` - let nano_range = Uniform::new(0, 1_000_000_000); - loop { - let s = secs.sample(rng); - let n = nano_range.sample(rng); - if !(s == max_secs && n > max_nanos) { - let sum = n + self.offset; - break Duration::new(s, sum); - } - } - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::rngs::mock::StepRng; - - #[test] - #[cfg(feature = "serde1")] - fn test_serialization_uniform_duration() { - let distr = UniformDuration::new(std::time::Duration::from_secs(10), std::time::Duration::from_secs(60)); - let de_distr: UniformDuration = bincode::deserialize(&bincode::serialize(&distr).unwrap()).unwrap(); - assert_eq!( - distr.offset, de_distr.offset - ); - match (distr.mode, de_distr.mode) { - (UniformDurationMode::Small {secs: a_secs, nanos: a_nanos}, UniformDurationMode::Small {secs, nanos}) => { - assert_eq!(a_secs, secs); - - assert_eq!(a_nanos.0.low, nanos.0.low); - assert_eq!(a_nanos.0.range, nanos.0.range); - assert_eq!(a_nanos.0.z, nanos.0.z); - } - (UniformDurationMode::Medium {nanos: a_nanos} , UniformDurationMode::Medium {nanos}) => { - assert_eq!(a_nanos.0.low, nanos.0.low); - assert_eq!(a_nanos.0.range, nanos.0.range); - assert_eq!(a_nanos.0.z, nanos.0.z); - } - (UniformDurationMode::Large {max_secs:a_max_secs, max_nanos:a_max_nanos, secs:a_secs}, UniformDurationMode::Large {max_secs, max_nanos, secs} ) => { - assert_eq!(a_max_secs, max_secs); - assert_eq!(a_max_nanos, max_nanos); - - assert_eq!(a_secs.0.low, secs.0.low); - assert_eq!(a_secs.0.range, secs.0.range); - assert_eq!(a_secs.0.z, secs.0.z); - } - _ => panic!("`UniformDurationMode` was not serialized/deserialized correctly") - } - } - - #[test] - #[cfg(feature = "serde1")] - fn test_uniform_serialization() { - let unit_box: Uniform = Uniform::new(-1, 1); - let de_unit_box: Uniform = bincode::deserialize(&bincode::serialize(&unit_box).unwrap()).unwrap(); - - assert_eq!(unit_box.0.low, de_unit_box.0.low); - assert_eq!(unit_box.0.range, de_unit_box.0.range); - assert_eq!(unit_box.0.z, de_unit_box.0.z); - - let unit_box: Uniform = Uniform::new(-1., 1.); - let de_unit_box: Uniform = bincode::deserialize(&bincode::serialize(&unit_box).unwrap()).unwrap(); - - assert_eq!(unit_box.0.low, de_unit_box.0.low); - assert_eq!(unit_box.0.scale, de_unit_box.0.scale); - } - - #[should_panic] - #[test] - fn test_uniform_bad_limits_equal_int() { - Uniform::new(10, 10); - } - - #[test] - fn test_uniform_good_limits_equal_int() { - let mut rng = crate::test::rng(804); - let dist = Uniform::new_inclusive(10, 10); - for _ in 0..20 { - assert_eq!(rng.sample(dist), 10); - } - } - - #[should_panic] - #[test] - fn test_uniform_bad_limits_flipped_int() { - Uniform::new(10, 5); - } - - #[test] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_integers() { - #[cfg(not(target_os = "emscripten"))] use core::{i128, u128}; - use core::{i16, i32, i64, i8, isize}; - use core::{u16, u32, u64, u8, usize}; - - let mut rng = crate::test::rng(251); - macro_rules! t { - ($ty:ident, $v:expr, $le:expr, $lt:expr) => {{ - for &(low, high) in $v.iter() { - let my_uniform = Uniform::new(low, high); - for _ in 0..1000 { - let v: $ty = rng.sample(my_uniform); - assert!($le(low, v) && $lt(v, high)); - } - - let my_uniform = Uniform::new_inclusive(low, high); - for _ in 0..1000 { - let v: $ty = rng.sample(my_uniform); - assert!($le(low, v) && $le(v, high)); - } - - let my_uniform = Uniform::new(&low, high); - for _ in 0..1000 { - let v: $ty = rng.sample(my_uniform); - assert!($le(low, v) && $lt(v, high)); - } - - let my_uniform = Uniform::new_inclusive(&low, &high); - for _ in 0..1000 { - let v: $ty = rng.sample(my_uniform); - assert!($le(low, v) && $le(v, high)); - } - - for _ in 0..1000 { - let v = <$ty as SampleUniform>::Sampler::sample_single(low, high, &mut rng); - assert!($le(low, v) && $lt(v, high)); - } - } - }}; - - // scalar bulk - ($($ty:ident),*) => {{ - $(t!( - $ty, - [(0, 10), (10, 127), ($ty::MIN, $ty::MAX)], - |x, y| x <= y, - |x, y| x < y - );)* - }}; - - // simd bulk - ($($ty:ident),* => $scalar:ident) => {{ - $(t!( - $ty, - [ - ($ty::splat(0), $ty::splat(10)), - ($ty::splat(10), $ty::splat(127)), - ($ty::splat($scalar::MIN), $ty::splat($scalar::MAX)), - ], - |x: $ty, y| x.le(y).all(), - |x: $ty, y| x.lt(y).all() - );)* - }}; - } - t!(i8, i16, i32, i64, isize, u8, u16, u32, u64, usize); - #[cfg(not(target_os = "emscripten"))] - t!(i128, u128); - - #[cfg(feature = "simd_support")] - { - t!(u8x2, u8x4, u8x8, u8x16, u8x32, u8x64 => u8); - t!(i8x2, i8x4, i8x8, i8x16, i8x32, i8x64 => i8); - t!(u16x2, u16x4, u16x8, u16x16, u16x32 => u16); - t!(i16x2, i16x4, i16x8, i16x16, i16x32 => i16); - t!(u32x2, u32x4, u32x8, u32x16 => u32); - t!(i32x2, i32x4, i32x8, i32x16 => i32); - t!(u64x2, u64x4, u64x8 => u64); - t!(i64x2, i64x4, i64x8 => i64); - } - } - - #[test] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_char() { - let mut rng = crate::test::rng(891); - let mut max = core::char::from_u32(0).unwrap(); - for _ in 0..100 { - let c = rng.gen_range('A'..='Z'); - assert!('A' <= c && c <= 'Z'); - max = max.max(c); - } - assert_eq!(max, 'Z'); - let d = Uniform::new( - core::char::from_u32(0xD7F0).unwrap(), - core::char::from_u32(0xE010).unwrap(), - ); - for _ in 0..100 { - let c = d.sample(&mut rng); - assert!((c as u32) < 0xD800 || (c as u32) > 0xDFFF); - } - } - - #[test] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_floats() { - let mut rng = crate::test::rng(252); - let mut zero_rng = StepRng::new(0, 0); - let mut max_rng = StepRng::new(0xffff_ffff_ffff_ffff, 0); - macro_rules! t { - ($ty:ty, $f_scalar:ident, $bits_shifted:expr) => {{ - let v: &[($f_scalar, $f_scalar)] = &[ - (0.0, 100.0), - (-1e35, -1e25), - (1e-35, 1e-25), - (-1e35, 1e35), - (<$f_scalar>::from_bits(0), <$f_scalar>::from_bits(3)), - (-<$f_scalar>::from_bits(10), -<$f_scalar>::from_bits(1)), - (-<$f_scalar>::from_bits(5), 0.0), - (-<$f_scalar>::from_bits(7), -0.0), - (10.0, ::core::$f_scalar::MAX), - (-100.0, ::core::$f_scalar::MAX), - (-::core::$f_scalar::MAX / 5.0, ::core::$f_scalar::MAX), - (-::core::$f_scalar::MAX, ::core::$f_scalar::MAX / 5.0), - (-::core::$f_scalar::MAX * 0.8, ::core::$f_scalar::MAX * 0.7), - (-::core::$f_scalar::MAX, ::core::$f_scalar::MAX), - ]; - for &(low_scalar, high_scalar) in v.iter() { - for lane in 0..<$ty>::lanes() { - let low = <$ty>::splat(0.0 as $f_scalar).replace(lane, low_scalar); - let high = <$ty>::splat(1.0 as $f_scalar).replace(lane, high_scalar); - let my_uniform = Uniform::new(low, high); - let my_incl_uniform = Uniform::new_inclusive(low, high); - for _ in 0..100 { - let v = rng.sample(my_uniform).extract(lane); - assert!(low_scalar <= v && v < high_scalar); - let v = rng.sample(my_incl_uniform).extract(lane); - assert!(low_scalar <= v && v <= high_scalar); - let v = <$ty as SampleUniform>::Sampler - ::sample_single(low, high, &mut rng).extract(lane); - assert!(low_scalar <= v && v < high_scalar); - } - - assert_eq!( - rng.sample(Uniform::new_inclusive(low, low)).extract(lane), - low_scalar - ); - - assert_eq!(zero_rng.sample(my_uniform).extract(lane), low_scalar); - assert_eq!(zero_rng.sample(my_incl_uniform).extract(lane), low_scalar); - assert_eq!(<$ty as SampleUniform>::Sampler - ::sample_single(low, high, &mut zero_rng) - .extract(lane), low_scalar); - assert!(max_rng.sample(my_uniform).extract(lane) < high_scalar); - assert!(max_rng.sample(my_incl_uniform).extract(lane) <= high_scalar); - - // Don't run this test for really tiny differences between high and low - // since for those rounding might result in selecting high for a very - // long time. - if (high_scalar - low_scalar) > 0.0001 { - let mut lowering_max_rng = StepRng::new( - 0xffff_ffff_ffff_ffff, - (-1i64 << $bits_shifted) as u64, - ); - assert!( - <$ty as SampleUniform>::Sampler - ::sample_single(low, high, &mut lowering_max_rng) - .extract(lane) < high_scalar - ); - } - } - } - - assert_eq!( - rng.sample(Uniform::new_inclusive( - ::core::$f_scalar::MAX, - ::core::$f_scalar::MAX - )), - ::core::$f_scalar::MAX - ); - assert_eq!( - rng.sample(Uniform::new_inclusive( - -::core::$f_scalar::MAX, - -::core::$f_scalar::MAX - )), - -::core::$f_scalar::MAX - ); - }}; - } - - t!(f32, f32, 32 - 23); - t!(f64, f64, 64 - 52); - #[cfg(feature = "simd_support")] - { - t!(f32x2, f32, 32 - 23); - t!(f32x4, f32, 32 - 23); - t!(f32x8, f32, 32 - 23); - t!(f32x16, f32, 32 - 23); - t!(f64x2, f64, 64 - 52); - t!(f64x4, f64, 64 - 52); - t!(f64x8, f64, 64 - 52); - } - } - - #[test] - #[cfg(all( - feature = "std", - not(target_arch = "wasm32"), - not(target_arch = "asmjs") - ))] - fn test_float_assertions() { - use super::SampleUniform; - use std::panic::catch_unwind; - fn range(low: T, high: T) { - let mut rng = crate::test::rng(253); - T::Sampler::sample_single(low, high, &mut rng); - } - - macro_rules! t { - ($ty:ident, $f_scalar:ident) => {{ - let v: &[($f_scalar, $f_scalar)] = &[ - (::std::$f_scalar::NAN, 0.0), - (1.0, ::std::$f_scalar::NAN), - (::std::$f_scalar::NAN, ::std::$f_scalar::NAN), - (1.0, 0.5), - (::std::$f_scalar::MAX, -::std::$f_scalar::MAX), - (::std::$f_scalar::INFINITY, ::std::$f_scalar::INFINITY), - ( - ::std::$f_scalar::NEG_INFINITY, - ::std::$f_scalar::NEG_INFINITY, - ), - (::std::$f_scalar::NEG_INFINITY, 5.0), - (5.0, ::std::$f_scalar::INFINITY), - (::std::$f_scalar::NAN, ::std::$f_scalar::INFINITY), - (::std::$f_scalar::NEG_INFINITY, ::std::$f_scalar::NAN), - (::std::$f_scalar::NEG_INFINITY, ::std::$f_scalar::INFINITY), - ]; - for &(low_scalar, high_scalar) in v.iter() { - for lane in 0..<$ty>::lanes() { - let low = <$ty>::splat(0.0 as $f_scalar).replace(lane, low_scalar); - let high = <$ty>::splat(1.0 as $f_scalar).replace(lane, high_scalar); - assert!(catch_unwind(|| range(low, high)).is_err()); - assert!(catch_unwind(|| Uniform::new(low, high)).is_err()); - assert!(catch_unwind(|| Uniform::new_inclusive(low, high)).is_err()); - assert!(catch_unwind(|| range(low, low)).is_err()); - assert!(catch_unwind(|| Uniform::new(low, low)).is_err()); - } - } - }}; - } - - t!(f32, f32); - t!(f64, f64); - #[cfg(feature = "simd_support")] - { - t!(f32x2, f32); - t!(f32x4, f32); - t!(f32x8, f32); - t!(f32x16, f32); - t!(f64x2, f64); - t!(f64x4, f64); - t!(f64x8, f64); - } - } - - - #[test] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_durations() { - #[cfg(not(feature = "std"))] use core::time::Duration; - #[cfg(feature = "std")] use std::time::Duration; - - let mut rng = crate::test::rng(253); - - let v = &[ - (Duration::new(10, 50000), Duration::new(100, 1234)), - (Duration::new(0, 100), Duration::new(1, 50)), - ( - Duration::new(0, 0), - Duration::new(u64::max_value(), 999_999_999), - ), - ]; - for &(low, high) in v.iter() { - let my_uniform = Uniform::new(low, high); - for _ in 0..1000 { - let v = rng.sample(my_uniform); - assert!(low <= v && v < high); - } - } - } - - #[test] - fn test_custom_uniform() { - use crate::distributions::uniform::{ - SampleBorrow, SampleUniform, UniformFloat, UniformSampler, - }; - #[derive(Clone, Copy, PartialEq, PartialOrd)] - struct MyF32 { - x: f32, - } - #[derive(Clone, Copy, Debug)] - struct UniformMyF32(UniformFloat); - impl UniformSampler for UniformMyF32 { - type X = MyF32; - - fn new(low: B1, high: B2) -> Self - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - UniformMyF32(UniformFloat::::new(low.borrow().x, high.borrow().x)) - } - - fn new_inclusive(low: B1, high: B2) -> Self - where - B1: SampleBorrow + Sized, - B2: SampleBorrow + Sized, - { - UniformSampler::new(low, high) - } - - fn sample(&self, rng: &mut R) -> Self::X { - MyF32 { - x: self.0.sample(rng), - } - } - } - impl SampleUniform for MyF32 { - type Sampler = UniformMyF32; - } - - let (low, high) = (MyF32 { x: 17.0f32 }, MyF32 { x: 22.0f32 }); - let uniform = Uniform::new(low, high); - let mut rng = crate::test::rng(804); - for _ in 0..100 { - let x: MyF32 = rng.sample(uniform); - assert!(low <= x && x < high); - } - } - - #[test] - fn test_uniform_from_std_range() { - let r = Uniform::from(2u32..7); - assert_eq!(r.0.low, 2); - assert_eq!(r.0.range, 5); - let r = Uniform::from(2.0f64..7.0); - assert_eq!(r.0.low, 2.0); - assert_eq!(r.0.scale, 5.0); - } - - #[test] - fn test_uniform_from_std_range_inclusive() { - let r = Uniform::from(2u32..=6); - assert_eq!(r.0.low, 2); - assert_eq!(r.0.range, 5); - let r = Uniform::from(2.0f64..=7.0); - assert_eq!(r.0.low, 2.0); - assert!(r.0.scale > 5.0); - assert!(r.0.scale < 5.0 + 1e-14); - } - - #[test] - fn value_stability() { - fn test_samples( - lb: T, ub: T, expected_single: &[T], expected_multiple: &[T], - ) where Uniform: Distribution { - let mut rng = crate::test::rng(897); - let mut buf = [lb; 3]; - - for x in &mut buf { - *x = T::Sampler::sample_single(lb, ub, &mut rng); - } - assert_eq!(&buf, expected_single); - - let distr = Uniform::new(lb, ub); - for x in &mut buf { - *x = rng.sample(&distr); - } - assert_eq!(&buf, expected_multiple); - } - - // We test on a sub-set of types; possibly we should do more. - // TODO: SIMD types - - test_samples(11u8, 219, &[17, 66, 214], &[181, 93, 165]); - test_samples(11u32, 219, &[17, 66, 214], &[181, 93, 165]); - - test_samples(0f32, 1e-2f32, &[0.0003070104, 0.0026630748, 0.00979833], &[ - 0.008194133, - 0.00398172, - 0.007428536, - ]); - test_samples( - -1e10f64, - 1e10f64, - &[-4673848682.871551, 6388267422.932352, 4857075081.198343], - &[1173375212.1808167, 1917642852.109581, 2365076174.3153973], - ); - - test_samples( - Duration::new(2, 0), - Duration::new(4, 0), - &[ - Duration::new(2, 532615131), - Duration::new(3, 638826742), - Duration::new(3, 485707508), - ], - &[ - Duration::new(3, 117337521), - Duration::new(3, 191764285), - Duration::new(3, 236507617), - ], - ); - } -} diff --git a/src/distributions/weighted.rs b/src/distributions/weighted.rs deleted file mode 100644 index 6dd9273a506..00000000000 --- a/src/distributions/weighted.rs +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! Weighted index sampling -//! -//! This module is deprecated. Use [`crate::distributions::WeightedIndex`] and -//! [`crate::distributions::WeightedError`] instead. - -pub use super::{WeightedIndex, WeightedError}; - -#[allow(missing_docs)] -#[deprecated(since = "0.8.0", note = "moved to rand_distr crate")] -pub mod alias_method { - // This module exists to provide a deprecation warning which minimises - // compile errors, but still fails to compile if ever used. - use core::marker::PhantomData; - use alloc::vec::Vec; - use super::WeightedError; - - #[derive(Debug)] - pub struct WeightedIndex { - _phantom: PhantomData, - } - impl WeightedIndex { - pub fn new(_weights: Vec) -> Result { - Err(WeightedError::NoItem) - } - } - - pub trait Weight {} - macro_rules! impl_weight { - () => {}; - ($T:ident, $($more:ident,)*) => { - impl Weight for $T {} - impl_weight!($($more,)*); - }; - } - impl_weight!(f64, f32,); - impl_weight!(u8, u16, u32, u64, usize,); - impl_weight!(i8, i16, i32, i64, isize,); - #[cfg(not(target_os = "emscripten"))] - impl_weight!(u128, i128,); -} diff --git a/src/distributions/weighted_index.rs b/src/distributions/weighted_index.rs deleted file mode 100644 index 07ba53ec027..00000000000 --- a/src/distributions/weighted_index.rs +++ /dev/null @@ -1,453 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! Weighted index sampling - -use crate::distributions::uniform::{SampleBorrow, SampleUniform, UniformSampler}; -use crate::distributions::Distribution; -use crate::Rng; -use core::cmp::PartialOrd; -use core::fmt; - -// Note that this whole module is only imported if feature="alloc" is enabled. -use alloc::vec::Vec; - -#[cfg(feature = "serde1")] -use serde::{Serialize, Deserialize}; - -/// A distribution using weighted sampling of discrete items -/// -/// Sampling a `WeightedIndex` distribution returns the index of a randomly -/// selected element from the iterator used when the `WeightedIndex` was -/// created. The chance of a given element being picked is proportional to the -/// value of the element. The weights can use any type `X` for which an -/// implementation of [`Uniform`] exists. -/// -/// # Performance -/// -/// Time complexity of sampling from `WeightedIndex` is `O(log N)` where -/// `N` is the number of weights. As an alternative, -/// [`rand_distr::weighted_alias`](https://docs.rs/rand_distr/*/rand_distr/weighted_alias/index.html) -/// supports `O(1)` sampling, but with much higher initialisation cost. -/// -/// A `WeightedIndex` contains a `Vec` and a [`Uniform`] and so its -/// size is the sum of the size of those objects, possibly plus some alignment. -/// -/// Creating a `WeightedIndex` will allocate enough space to hold `N - 1` -/// weights of type `X`, where `N` is the number of weights. However, since -/// `Vec` doesn't guarantee a particular growth strategy, additional memory -/// might be allocated but not used. Since the `WeightedIndex` object also -/// contains, this might cause additional allocations, though for primitive -/// types, [`Uniform`] doesn't allocate any memory. -/// -/// Sampling from `WeightedIndex` will result in a single call to -/// `Uniform::sample` (method of the [`Distribution`] trait), which typically -/// will request a single value from the underlying [`RngCore`], though the -/// exact number depends on the implementation of `Uniform::sample`. -/// -/// # Example -/// -/// ``` -/// use rand::prelude::*; -/// use rand::distributions::WeightedIndex; -/// -/// let choices = ['a', 'b', 'c']; -/// let weights = [2, 1, 1]; -/// let dist = WeightedIndex::new(&weights).unwrap(); -/// let mut rng = thread_rng(); -/// for _ in 0..100 { -/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' -/// println!("{}", choices[dist.sample(&mut rng)]); -/// } -/// -/// let items = [('a', 0), ('b', 3), ('c', 7)]; -/// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1)).unwrap(); -/// for _ in 0..100 { -/// // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c' -/// println!("{}", items[dist2.sample(&mut rng)].0); -/// } -/// ``` -/// -/// [`Uniform`]: crate::distributions::Uniform -/// [`RngCore`]: crate::RngCore -#[derive(Debug, Clone)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] -pub struct WeightedIndex { - cumulative_weights: Vec, - total_weight: X, - weight_distribution: X::Sampler, -} - -impl WeightedIndex { - /// Creates a new a `WeightedIndex` [`Distribution`] using the values - /// in `weights`. The weights can use any type `X` for which an - /// implementation of [`Uniform`] exists. - /// - /// Returns an error if the iterator is empty, if any weight is `< 0`, or - /// if its total value is 0. - /// - /// [`Uniform`]: crate::distributions::uniform::Uniform - pub fn new(weights: I) -> Result, WeightedError> - where - I: IntoIterator, - I::Item: SampleBorrow, - X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default, - { - let mut iter = weights.into_iter(); - let mut total_weight: X = iter.next().ok_or(WeightedError::NoItem)?.borrow().clone(); - - let zero = ::default(); - if !(total_weight >= zero) { - return Err(WeightedError::InvalidWeight); - } - - let mut weights = Vec::::with_capacity(iter.size_hint().0); - for w in iter { - // Note that `!(w >= x)` is not equivalent to `w < x` for partially - // ordered types due to NaNs which are equal to nothing. - if !(w.borrow() >= &zero) { - return Err(WeightedError::InvalidWeight); - } - weights.push(total_weight.clone()); - total_weight += w.borrow(); - } - - if total_weight == zero { - return Err(WeightedError::AllWeightsZero); - } - let distr = X::Sampler::new(zero, total_weight.clone()); - - Ok(WeightedIndex { - cumulative_weights: weights, - total_weight, - weight_distribution: distr, - }) - } - - /// Update a subset of weights, without changing the number of weights. - /// - /// `new_weights` must be sorted by the index. - /// - /// Using this method instead of `new` might be more efficient if only a small number of - /// weights is modified. No allocations are performed, unless the weight type `X` uses - /// allocation internally. - /// - /// In case of error, `self` is not modified. - pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), WeightedError> - where X: for<'a> ::core::ops::AddAssign<&'a X> - + for<'a> ::core::ops::SubAssign<&'a X> - + Clone - + Default { - if new_weights.is_empty() { - return Ok(()); - } - - let zero = ::default(); - - let mut total_weight = self.total_weight.clone(); - - // Check for errors first, so we don't modify `self` in case something - // goes wrong. - let mut prev_i = None; - for &(i, w) in new_weights { - if let Some(old_i) = prev_i { - if old_i >= i { - return Err(WeightedError::InvalidWeight); - } - } - if !(*w >= zero) { - return Err(WeightedError::InvalidWeight); - } - if i > self.cumulative_weights.len() { - return Err(WeightedError::TooMany); - } - - let mut old_w = if i < self.cumulative_weights.len() { - self.cumulative_weights[i].clone() - } else { - self.total_weight.clone() - }; - if i > 0 { - old_w -= &self.cumulative_weights[i - 1]; - } - - total_weight -= &old_w; - total_weight += w; - prev_i = Some(i); - } - if total_weight <= zero { - return Err(WeightedError::AllWeightsZero); - } - - // Update the weights. Because we checked all the preconditions in the - // previous loop, this should never panic. - let mut iter = new_weights.iter(); - - let mut prev_weight = zero.clone(); - let mut next_new_weight = iter.next(); - let &(first_new_index, _) = next_new_weight.unwrap(); - let mut cumulative_weight = if first_new_index > 0 { - self.cumulative_weights[first_new_index - 1].clone() - } else { - zero.clone() - }; - for i in first_new_index..self.cumulative_weights.len() { - match next_new_weight { - Some(&(j, w)) if i == j => { - cumulative_weight += w; - next_new_weight = iter.next(); - } - _ => { - let mut tmp = self.cumulative_weights[i].clone(); - tmp -= &prev_weight; // We know this is positive. - cumulative_weight += &tmp; - } - } - prev_weight = cumulative_weight.clone(); - core::mem::swap(&mut prev_weight, &mut self.cumulative_weights[i]); - } - - self.total_weight = total_weight; - self.weight_distribution = X::Sampler::new(zero, self.total_weight.clone()); - - Ok(()) - } -} - -impl Distribution for WeightedIndex -where X: SampleUniform + PartialOrd -{ - fn sample(&self, rng: &mut R) -> usize { - use ::core::cmp::Ordering; - let chosen_weight = self.weight_distribution.sample(rng); - // Find the first item which has a weight *higher* than the chosen weight. - self.cumulative_weights - .binary_search_by(|w| { - if *w <= chosen_weight { - Ordering::Less - } else { - Ordering::Greater - } - }) - .unwrap_err() - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[cfg(feature = "serde1")] - #[test] - fn test_weightedindex_serde1() { - let weighted_index = WeightedIndex::new(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).unwrap(); - - let ser_weighted_index = bincode::serialize(&weighted_index).unwrap(); - let de_weighted_index: WeightedIndex = - bincode::deserialize(&ser_weighted_index).unwrap(); - - assert_eq!( - de_weighted_index.cumulative_weights, - weighted_index.cumulative_weights - ); - assert_eq!(de_weighted_index.total_weight, weighted_index.total_weight); - } - - #[test] - fn test_accepting_nan(){ - assert_eq!( - WeightedIndex::new(&[core::f32::NAN, 0.5]).unwrap_err(), - WeightedError::InvalidWeight, - ); - assert_eq!( - WeightedIndex::new(&[core::f32::NAN]).unwrap_err(), - WeightedError::InvalidWeight, - ); - assert_eq!( - WeightedIndex::new(&[0.5, core::f32::NAN]).unwrap_err(), - WeightedError::InvalidWeight, - ); - - assert_eq!( - WeightedIndex::new(&[0.5, 7.0]) - .unwrap() - .update_weights(&[(0, &core::f32::NAN)]) - .unwrap_err(), - WeightedError::InvalidWeight, - ) - } - - - #[test] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_weightedindex() { - let mut r = crate::test::rng(700); - const N_REPS: u32 = 5000; - let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; - let total_weight = weights.iter().sum::() as f32; - - let verify = |result: [i32; 14]| { - for (i, count) in result.iter().enumerate() { - let exp = (weights[i] * N_REPS) as f32 / total_weight; - let mut err = (*count as f32 - exp).abs(); - if err != 0.0 { - err /= exp; - } - assert!(err <= 0.25); - } - }; - - // WeightedIndex from vec - let mut chosen = [0i32; 14]; - let distr = WeightedIndex::new(weights.to_vec()).unwrap(); - for _ in 0..N_REPS { - chosen[distr.sample(&mut r)] += 1; - } - verify(chosen); - - // WeightedIndex from slice - chosen = [0i32; 14]; - let distr = WeightedIndex::new(&weights[..]).unwrap(); - for _ in 0..N_REPS { - chosen[distr.sample(&mut r)] += 1; - } - verify(chosen); - - // WeightedIndex from iterator - chosen = [0i32; 14]; - let distr = WeightedIndex::new(weights.iter()).unwrap(); - for _ in 0..N_REPS { - chosen[distr.sample(&mut r)] += 1; - } - verify(chosen); - - for _ in 0..5 { - assert_eq!(WeightedIndex::new(&[0, 1]).unwrap().sample(&mut r), 1); - assert_eq!(WeightedIndex::new(&[1, 0]).unwrap().sample(&mut r), 0); - assert_eq!( - WeightedIndex::new(&[0, 0, 0, 0, 10, 0]) - .unwrap() - .sample(&mut r), - 4 - ); - } - - assert_eq!( - WeightedIndex::new(&[10][0..0]).unwrap_err(), - WeightedError::NoItem - ); - assert_eq!( - WeightedIndex::new(&[0]).unwrap_err(), - WeightedError::AllWeightsZero - ); - assert_eq!( - WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(), - WeightedError::InvalidWeight - ); - assert_eq!( - WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(), - WeightedError::InvalidWeight - ); - assert_eq!( - WeightedIndex::new(&[-10]).unwrap_err(), - WeightedError::InvalidWeight - ); - } - - #[test] - fn test_update_weights() { - let data = [ - ( - &[10u32, 2, 3, 4][..], - &[(1, &100), (2, &4)][..], // positive change - &[10, 100, 4, 4][..], - ), - ( - &[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..], - &[(2, &1), (5, &1), (13, &100)][..], // negative change and last element - &[1u32, 2, 1, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 100][..], - ), - ]; - - for (weights, update, expected_weights) in data.iter() { - let total_weight = weights.iter().sum::(); - let mut distr = WeightedIndex::new(weights.to_vec()).unwrap(); - assert_eq!(distr.total_weight, total_weight); - - distr.update_weights(update).unwrap(); - let expected_total_weight = expected_weights.iter().sum::(); - let expected_distr = WeightedIndex::new(expected_weights.to_vec()).unwrap(); - assert_eq!(distr.total_weight, expected_total_weight); - assert_eq!(distr.total_weight, expected_distr.total_weight); - assert_eq!(distr.cumulative_weights, expected_distr.cumulative_weights); - } - } - - #[test] - fn value_stability() { - fn test_samples( - weights: I, buf: &mut [usize], expected: &[usize], - ) where - I: IntoIterator, - I::Item: SampleBorrow, - X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default, - { - assert_eq!(buf.len(), expected.len()); - let distr = WeightedIndex::new(weights).unwrap(); - let mut rng = crate::test::rng(701); - for r in buf.iter_mut() { - *r = rng.sample(&distr); - } - assert_eq!(buf, expected); - } - - let mut buf = [0; 10]; - test_samples(&[1i32, 1, 1, 1, 1, 1, 1, 1, 1], &mut buf, &[ - 0, 6, 2, 6, 3, 4, 7, 8, 2, 5, - ]); - test_samples(&[0.7f32, 0.1, 0.1, 0.1], &mut buf, &[ - 0, 0, 0, 1, 0, 0, 2, 3, 0, 0, - ]); - test_samples(&[1.0f64, 0.999, 0.998, 0.997], &mut buf, &[ - 2, 2, 1, 3, 2, 1, 3, 3, 2, 1, - ]); - } -} - -/// Error type returned from `WeightedIndex::new`. -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum WeightedError { - /// The provided weight collection contains no items. - NoItem, - - /// A weight is either less than zero, greater than the supported maximum, - /// NaN, or otherwise invalid. - InvalidWeight, - - /// All items in the provided weight collection are zero. - AllWeightsZero, - - /// Too many weights are provided (length greater than `u32::MAX`) - TooMany, -} - -#[cfg(feature = "std")] -impl ::std::error::Error for WeightedError {} - -impl fmt::Display for WeightedError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - WeightedError::NoItem => write!(f, "No weights provided."), - WeightedError::InvalidWeight => write!(f, "A weight is invalid."), - WeightedError::AllWeightsZero => write!(f, "All weights are zero."), - WeightedError::TooMany => write!(f, "Too many weights (hit u32::MAX)"), - } - } -} diff --git a/src/lib.rs b/src/lib.rs index 8bf7a9df126..e1a9ef4ddc1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,25 +14,24 @@ //! //! # Quick Start //! -//! To get you started quickly, the easiest and highest-level way to get -//! a random value is to use [`random()`]; alternatively you can use -//! [`thread_rng()`]. The [`Rng`] trait provides a useful API on all RNGs, while -//! the [`distributions`] and [`seq`] modules provide further -//! functionality on top of RNGs. -//! //! ``` +//! // The prelude import enables methods we use below, specifically +//! // Rng::random, Rng::sample, SliceRandom::shuffle and IndexedRandom::choose. //! use rand::prelude::*; //! -//! if rand::random() { // generates a boolean -//! // Try printing a random unicode code point (probably a bad idea)! -//! println!("char: {}", rand::random::()); -//! } +//! // Get an RNG: +//! let mut rng = rand::rng(); //! -//! let mut rng = rand::thread_rng(); -//! let y: f64 = rng.gen(); // generates a float between 0 and 1 +//! // Try printing a random unicode code point (probably a bad idea)! +//! println!("char: '{}'", rng.random::()); +//! // Try printing a random alphanumeric value instead! +//! println!("alpha: '{}'", rng.sample(rand::distr::Alphanumeric) as char); //! +//! // Generate and shuffle a sequence: //! let mut nums: Vec = (1..100).collect(); //! nums.shuffle(&mut rng); +//! // And take a random pick (yes, we didn't need to shuffle first!): +//! let _ = nums.choose(&mut rng); //! ``` //! //! # The Book @@ -49,16 +48,22 @@ #![deny(missing_debug_implementations)] #![doc(test(attr(allow(unused_variables), deny(warnings))))] #![no_std] -#![cfg_attr(feature = "simd_support", feature(stdsimd))] -#![cfg_attr(feature = "nightly", feature(slice_partition_at_index))] -#![cfg_attr(doc_cfg, feature(doc_cfg))] +#![cfg_attr(feature = "simd_support", feature(portable_simd))] +#![cfg_attr( + all(feature = "simd_support", target_feature = "avx512bw"), + feature(stdarch_x86_avx512) +)] +#![cfg_attr(docsrs, feature(doc_auto_cfg))] #![allow( clippy::float_cmp, clippy::neg_cmp_op_on_partial_ord, + clippy::nonminimal_bool )] -#[cfg(feature = "std")] extern crate std; -#[cfg(feature = "alloc")] extern crate alloc; +#[cfg(feature = "alloc")] +extern crate alloc; +#[cfg(feature = "std")] +extern crate std; #[allow(unused)] macro_rules! trace { ($($x:tt)*) => ( @@ -92,55 +97,40 @@ macro_rules! error { ($($x:tt)*) => ( ) } // Re-exports from rand_core -pub use rand_core::{CryptoRng, Error, RngCore, SeedableRng}; +pub use rand_core::{CryptoRng, RngCore, SeedableRng, TryCryptoRng, TryRngCore}; // Public modules -pub mod distributions; +pub mod distr; pub mod prelude; mod rng; pub mod rngs; pub mod seq; // Public exports -#[cfg(all(feature = "std", feature = "std_rng"))] -pub use crate::rngs::thread::thread_rng; +#[cfg(feature = "thread_rng")] +pub use crate::rngs::thread::rng; + +/// Access the thread-local generator +/// +/// Use [`rand::rng()`](rng()) instead. +#[cfg(feature = "thread_rng")] +#[deprecated(since = "0.9.0", note = "renamed to `rng`")] +#[inline] +pub fn thread_rng() -> crate::rngs::ThreadRng { + rng() +} + pub use rng::{Fill, Rng}; -#[cfg(all(feature = "std", feature = "std_rng"))] -use crate::distributions::{Distribution, Standard}; +#[cfg(feature = "thread_rng")] +use crate::distr::{Distribution, StandardUniform}; -/// Generates a random value using the thread-local random number generator. -/// -/// This is simply a shortcut for `thread_rng().gen()`. See [`thread_rng`] for -/// documentation of the entropy source and [`Standard`] for documentation of -/// distributions and type-specific generation. -/// -/// # Provided implementations +/// Generate a random value using the thread-local random number generator. /// -/// The following types have provided implementations that -/// generate values with the following ranges and distributions: +/// This function is shorthand for [rng()].[random()](Rng::random): /// -/// * Integers (`i32`, `u32`, `isize`, `usize`, etc.): Uniformly distributed -/// over all values of the type. -/// * `char`: Uniformly distributed over all Unicode scalar values, i.e. all -/// code points in the range `0...0x10_FFFF`, except for the range -/// `0xD800...0xDFFF` (the surrogate code points). This includes -/// unassigned/reserved code points. -/// * `bool`: Generates `false` or `true`, each with probability 0.5. -/// * Floating point types (`f32` and `f64`): Uniformly distributed in the -/// half-open range `[0, 1)`. See notes below. -/// * Wrapping integers (`Wrapping`), besides the type identical to their -/// normal integer variants. -/// -/// Also supported is the generation of the following -/// compound types where all component types are supported: -/// -/// * Tuples (up to 12 elements): each element is generated sequentially. -/// * Arrays (up to 32 elements): each element is generated sequentially; -/// see also [`Rng::fill`] which supports arbitrary array length for integer -/// types and tends to be faster for `u32` and smaller types. -/// * `Option` first generates a `bool`, and if true generates and returns -/// `Some(value)` where `value: T`, otherwise returning `None`. +/// - See [`ThreadRng`] for documentation of the generator and security +/// - See [`StandardUniform`] for documentation of supported types and distributions /// /// # Examples /// @@ -156,34 +146,151 @@ use crate::distributions::{Distribution, Standard}; /// } /// ``` /// -/// If you're calling `random()` in a loop, caching the generator as in the -/// following example can increase performance. +/// If you're calling `random()` repeatedly, consider using a local `rng` +/// handle to save an initialization-check on each usage: /// /// ``` -/// use rand::Rng; +/// use rand::Rng; // provides the `random` method +/// +/// let mut rng = rand::rng(); // a local handle to the generator /// /// let mut v = vec![1, 2, 3]; /// /// for x in v.iter_mut() { -/// *x = rand::random() +/// *x = rng.random(); /// } +/// ``` /// -/// // can be made faster by caching thread_rng +/// [`StandardUniform`]: distr::StandardUniform +/// [`ThreadRng`]: rngs::ThreadRng +#[cfg(feature = "thread_rng")] +#[inline] +pub fn random() -> T +where + StandardUniform: Distribution, +{ + rng().random() +} + +/// Return an iterator over [`random()`] variates /// -/// let mut rng = rand::thread_rng(); +/// This function is shorthand for +/// [rng()].[random_iter](Rng::random_iter)(). +/// +/// # Example +/// +/// ``` +/// let v: Vec = rand::random_iter().take(5).collect(); +/// println!("{v:?}"); +/// ``` +#[cfg(feature = "thread_rng")] +#[inline] +pub fn random_iter() -> distr::Iter +where + StandardUniform: Distribution, +{ + rng().random_iter() +} + +/// Generate a random value in the given range using the thread-local random number generator. +/// +/// This function is shorthand for +/// [rng()].[random_range](Rng::random_range)(range). +/// +/// # Example /// -/// for x in v.iter_mut() { -/// *x = rng.gen(); -/// } /// ``` +/// let y: f32 = rand::random_range(0.0..=1e9); +/// println!("{}", y); /// -/// [`Standard`]: distributions::Standard -#[cfg(all(feature = "std", feature = "std_rng"))] -#[cfg_attr(doc_cfg, doc(cfg(all(feature = "std", feature = "std_rng"))))] +/// let words: Vec<&str> = "Mary had a little lamb".split(' ').collect(); +/// println!("{}", words[rand::random_range(..words.len())]); +/// ``` +/// Note that the first example can also be achieved (without `collect`'ing +/// to a `Vec`) using [`seq::IteratorRandom::choose`]. +#[cfg(feature = "thread_rng")] #[inline] -pub fn random() -> T -where Standard: Distribution { - thread_rng().gen() +pub fn random_range(range: R) -> T +where + T: distr::uniform::SampleUniform, + R: distr::uniform::SampleRange, +{ + rng().random_range(range) +} + +/// Return a bool with a probability `p` of being true. +/// +/// This function is shorthand for +/// [rng()].[random_bool](Rng::random_bool)(p). +/// +/// # Example +/// +/// ``` +/// println!("{}", rand::random_bool(1.0 / 3.0)); +/// ``` +/// +/// # Panics +/// +/// If `p < 0` or `p > 1`. +#[cfg(feature = "thread_rng")] +#[inline] +#[track_caller] +pub fn random_bool(p: f64) -> bool { + rng().random_bool(p) +} + +/// Return a bool with a probability of `numerator/denominator` of being +/// true. +/// +/// That is, `random_ratio(2, 3)` has chance of 2 in 3, or about 67%, of +/// returning true. If `numerator == denominator`, then the returned value +/// is guaranteed to be `true`. If `numerator == 0`, then the returned +/// value is guaranteed to be `false`. +/// +/// See also the [`Bernoulli`] distribution, which may be faster if +/// sampling from the same `numerator` and `denominator` repeatedly. +/// +/// This function is shorthand for +/// [rng()].[random_ratio](Rng::random_ratio)(numerator, denominator). +/// +/// # Panics +/// +/// If `denominator == 0` or `numerator > denominator`. +/// +/// # Example +/// +/// ``` +/// println!("{}", rand::random_ratio(2, 3)); +/// ``` +/// +/// [`Bernoulli`]: distr::Bernoulli +#[cfg(feature = "thread_rng")] +#[inline] +#[track_caller] +pub fn random_ratio(numerator: u32, denominator: u32) -> bool { + rng().random_ratio(numerator, denominator) +} + +/// Fill any type implementing [`Fill`] with random data +/// +/// This function is shorthand for +/// [rng()].[fill](Rng::fill)(dest). +/// +/// # Example +/// +/// ``` +/// let mut arr = [0i8; 20]; +/// rand::fill(&mut arr[..]); +/// ``` +/// +/// Note that you can instead use [`random()`] to generate an array of random +/// data, though this is slower for small elements (smaller than the RNG word +/// size). +#[cfg(feature = "thread_rng")] +#[inline] +#[track_caller] +pub fn fill(dest: &mut T) { + dest.fill(&mut rng()) } #[cfg(test)] @@ -199,17 +306,23 @@ mod test { } #[test] - #[cfg(all(feature = "std", feature = "std_rng"))] + #[cfg(feature = "thread_rng")] fn test_random() { - // not sure how to test this aside from just getting some values - let _n: usize = random(); + let _n: u64 = random(); let _f: f32 = random(); - let _o: Option> = random(); + #[allow(clippy::type_complexity)] let _many: ( (), - (usize, isize, Option<(u32, (bool,))>), + [(u32, bool); 3], (u8, i8, u16, i16, u32, i32, u64, i64), (f32, (f64, (f64,))), ) = random(); } + + #[test] + #[cfg(feature = "thread_rng")] + fn test_range() { + let _n: usize = random_range(42..=43); + let _f: f32 = random_range(42.0..43.0); + } } diff --git a/src/prelude.rs b/src/prelude.rs index 51c457e3f9e..b0f563ad5fc 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -14,21 +14,22 @@ //! //! ``` //! use rand::prelude::*; -//! # let mut r = StdRng::from_rng(thread_rng()).unwrap(); -//! # let _: f32 = r.gen(); +//! # let mut r = StdRng::from_rng(&mut rand::rng()); +//! # let _: f32 = r.random(); //! ``` -#[doc(no_inline)] pub use crate::distributions::Distribution; +#[doc(no_inline)] +pub use crate::distr::Distribution; #[cfg(feature = "small_rng")] #[doc(no_inline)] pub use crate::rngs::SmallRng; #[cfg(feature = "std_rng")] -#[doc(no_inline)] pub use crate::rngs::StdRng; #[doc(no_inline)] -#[cfg(all(feature = "std", feature = "std_rng"))] +pub use crate::rngs::StdRng; +#[doc(no_inline)] +#[cfg(feature = "thread_rng")] pub use crate::rngs::ThreadRng; -#[doc(no_inline)] pub use crate::seq::{IteratorRandom, SliceRandom}; #[doc(no_inline)] -#[cfg(all(feature = "std", feature = "std_rng"))] -pub use crate::{random, thread_rng}; -#[doc(no_inline)] pub use crate::{CryptoRng, Rng, RngCore, SeedableRng}; +pub use crate::seq::{IndexedMutRandom, IndexedRandom, IteratorRandom, SliceRandom}; +#[doc(no_inline)] +pub use crate::{CryptoRng, Rng, RngCore, SeedableRng}; diff --git a/src/rng.rs b/src/rng.rs index bb977a54379..258c87de273 100644 --- a/src/rng.rs +++ b/src/rng.rs @@ -9,16 +9,20 @@ //! [`Rng`] trait -use rand_core::{Error, RngCore}; -use crate::distributions::uniform::{SampleRange, SampleUniform}; -use crate::distributions::{self, Distribution, Standard}; +use crate::distr::uniform::{SampleRange, SampleUniform}; +use crate::distr::{self, Distribution, StandardUniform}; use core::num::Wrapping; -use core::{mem, slice}; +use rand_core::RngCore; +use zerocopy::IntoBytes; -/// An automatically-implemented extension trait on [`RngCore`] providing high-level -/// generic methods for sampling values and other convenience methods. +/// User-level interface for RNGs /// -/// This is the primary trait to use when generating random values. +/// [`RngCore`] is the `dyn`-safe implementation-level interface for Random +/// (Number) Generators. This trait, `Rng`, provides a user-level interface on +/// RNGs. It is implemented automatically for any `R: RngCore`. +/// +/// This trait must usually be brought into scope via `use rand::Rng;` or +/// `use rand::prelude::*;`. /// /// # Generic usage /// @@ -28,69 +32,97 @@ use core::{mem, slice}; /// - Since `Rng: RngCore` and every `RngCore` implements `Rng`, it makes no /// difference whether we use `R: Rng` or `R: RngCore`. /// - The `+ ?Sized` un-bounding allows functions to be called directly on -/// type-erased references; i.e. `foo(r)` where `r: &mut RngCore`. Without +/// type-erased references; i.e. `foo(r)` where `r: &mut dyn RngCore`. Without /// this it would be necessary to write `foo(&mut r)`. /// /// An alternative pattern is possible: `fn foo(rng: R)`. This has some /// trade-offs. It allows the argument to be consumed directly without a `&mut` -/// (which is how `from_rng(thread_rng())` works); also it still works directly +/// (which is how `from_rng(rand::rng())` works); also it still works directly /// on references (including type-erased references). Unfortunately within the /// function `foo` it is not known whether `rng` is a reference type or not, /// hence many uses of `rng` require an extra reference, either explicitly -/// (`distr.sample(&mut rng)`) or implicitly (`rng.gen()`); one may hope the +/// (`distr.sample(&mut rng)`) or implicitly (`rng.random()`); one may hope the /// optimiser can remove redundant references later. /// /// Example: /// /// ``` -/// # use rand::thread_rng; /// use rand::Rng; /// /// fn foo(rng: &mut R) -> f32 { -/// rng.gen() +/// rng.random() /// } /// -/// # let v = foo(&mut thread_rng()); +/// # let v = foo(&mut rand::rng()); /// ``` pub trait Rng: RngCore { - /// Return a random value supporting the [`Standard`] distribution. + /// Return a random value via the [`StandardUniform`] distribution. /// /// # Example /// /// ``` - /// use rand::{thread_rng, Rng}; + /// use rand::Rng; /// - /// let mut rng = thread_rng(); - /// let x: u32 = rng.gen(); + /// let mut rng = rand::rng(); + /// let x: u32 = rng.random(); /// println!("{}", x); - /// println!("{:?}", rng.gen::<(f64, bool)>()); + /// println!("{:?}", rng.random::<(f64, bool)>()); /// ``` /// /// # Arrays and tuples /// - /// The `rng.gen()` method is able to generate arrays (up to 32 elements) + /// The `rng.random()` method is able to generate arrays /// and tuples (up to 12 elements), so long as all element types can be /// generated. /// /// For arrays of integers, especially for those with small element types - /// (< 64 bit), it will likely be faster to instead use [`Rng::fill`]. + /// (< 64 bit), it will likely be faster to instead use [`Rng::fill`], + /// though note that generated values will differ. /// /// ``` - /// use rand::{thread_rng, Rng}; + /// use rand::Rng; /// - /// let mut rng = thread_rng(); - /// let tuple: (u8, i32, char) = rng.gen(); // arbitrary tuple support + /// let mut rng = rand::rng(); + /// let tuple: (u8, i32, char) = rng.random(); // arbitrary tuple support /// - /// let arr1: [f32; 32] = rng.gen(); // array construction + /// let arr1: [f32; 32] = rng.random(); // array construction /// let mut arr2 = [0u8; 128]; /// rng.fill(&mut arr2); // array fill /// ``` /// - /// [`Standard`]: distributions::Standard + /// [`StandardUniform`]: distr::StandardUniform #[inline] - fn gen(&mut self) -> T - where Standard: Distribution { - Standard.sample(self) + fn random(&mut self) -> T + where + StandardUniform: Distribution, + { + StandardUniform.sample(self) + } + + /// Return an iterator over [`random`](Self::random) variates + /// + /// This is a just a wrapper over [`Rng::sample_iter`] using + /// [`distr::StandardUniform`]. + /// + /// Note: this method consumes its argument. Use + /// `(&mut rng).random_iter()` to avoid consuming the RNG. + /// + /// # Example + /// + /// ``` + /// use rand::{rngs::mock::StepRng, Rng}; + /// + /// let rng = StepRng::new(1, 1); + /// let v: Vec = rng.random_iter().take(5).collect(); + /// assert_eq!(&v, &[1, 2, 3, 4, 5]); + /// ``` + #[inline] + fn random_iter(self) -> distr::Iter + where + Self: Sized, + StandardUniform: Distribution, + { + StandardUniform.sample_iter(self) } /// Generate a random value in the given range. @@ -99,38 +131,105 @@ pub trait Rng: RngCore { /// made from the given range. See also the [`Uniform`] distribution /// type which may be faster if sampling from the same range repeatedly. /// - /// Only `gen_range(low..high)` and `gen_range(low..=high)` are supported. + /// All types support `low..high_exclusive` and `low..=high` range syntax. + /// Unsigned integer types also support `..high_exclusive` and `..=high` syntax. /// /// # Panics /// - /// Panics if the range is empty. + /// Panics if the range is empty, or if `high - low` overflows for floats. /// /// # Example /// /// ``` - /// use rand::{thread_rng, Rng}; + /// use rand::Rng; /// - /// let mut rng = thread_rng(); + /// let mut rng = rand::rng(); /// /// // Exclusive range - /// let n: u32 = rng.gen_range(0..10); + /// let n: u32 = rng.random_range(..10); /// println!("{}", n); - /// let m: f64 = rng.gen_range(-40.0..1.3e5); + /// let m: f64 = rng.random_range(-40.0..1.3e5); /// println!("{}", m); /// /// // Inclusive range - /// let n: u32 = rng.gen_range(0..=10); + /// let n: u32 = rng.random_range(..=10); /// println!("{}", n); /// ``` /// - /// [`Uniform`]: distributions::uniform::Uniform - fn gen_range(&mut self, range: R) -> T + /// [`Uniform`]: distr::uniform::Uniform + #[track_caller] + fn random_range(&mut self, range: R) -> T where T: SampleUniform, - R: SampleRange + R: SampleRange, { assert!(!range.is_empty(), "cannot sample empty range"); - range.sample_single(self) + range.sample_single(self).unwrap() + } + + /// Return a bool with a probability `p` of being true. + /// + /// See also the [`Bernoulli`] distribution, which may be faster if + /// sampling from the same probability repeatedly. + /// + /// # Example + /// + /// ``` + /// use rand::Rng; + /// + /// let mut rng = rand::rng(); + /// println!("{}", rng.random_bool(1.0 / 3.0)); + /// ``` + /// + /// # Panics + /// + /// If `p < 0` or `p > 1`. + /// + /// [`Bernoulli`]: distr::Bernoulli + #[inline] + #[track_caller] + fn random_bool(&mut self, p: f64) -> bool { + match distr::Bernoulli::new(p) { + Ok(d) => self.sample(d), + Err(_) => panic!("p={:?} is outside range [0.0, 1.0]", p), + } + } + + /// Return a bool with a probability of `numerator/denominator` of being + /// true. + /// + /// That is, `random_ratio(2, 3)` has chance of 2 in 3, or about 67%, of + /// returning true. If `numerator == denominator`, then the returned value + /// is guaranteed to be `true`. If `numerator == 0`, then the returned + /// value is guaranteed to be `false`. + /// + /// See also the [`Bernoulli`] distribution, which may be faster if + /// sampling from the same `numerator` and `denominator` repeatedly. + /// + /// # Panics + /// + /// If `denominator == 0` or `numerator > denominator`. + /// + /// # Example + /// + /// ``` + /// use rand::Rng; + /// + /// let mut rng = rand::rng(); + /// println!("{}", rng.random_ratio(2, 3)); + /// ``` + /// + /// [`Bernoulli`]: distr::Bernoulli + #[inline] + #[track_caller] + fn random_ratio(&mut self, numerator: u32, denominator: u32) -> bool { + match distr::Bernoulli::from_ratio(numerator, denominator) { + Ok(d) => self.sample(d), + Err(_) => panic!( + "p={}/{} is outside range [0.0, 1.0]", + numerator, denominator + ), + } } /// Sample a new value, using the given distribution. @@ -138,14 +237,14 @@ pub trait Rng: RngCore { /// ### Example /// /// ``` - /// use rand::{thread_rng, Rng}; - /// use rand::distributions::Uniform; + /// use rand::Rng; + /// use rand::distr::Uniform; /// - /// let mut rng = thread_rng(); - /// let x = rng.sample(Uniform::new(10u32, 15)); + /// let mut rng = rand::rng(); + /// let x = rng.sample(Uniform::new(10u32, 15).unwrap()); /// // Type annotation requires two types, the type and distribution; the /// // distribution can be inferred. - /// let y = rng.sample::(Uniform::new(10, 15)); + /// let y = rng.sample::(Uniform::new(10, 15).unwrap()); /// ``` fn sample>(&mut self, distr: D) -> T { distr.sample(self) @@ -153,22 +252,19 @@ pub trait Rng: RngCore { /// Create an iterator that generates values using the given distribution. /// - /// Note that this function takes its arguments by value. This works since - /// `(&mut R): Rng where R: Rng` and - /// `(&D): Distribution where D: Distribution`, - /// however borrowing is not automatic hence `rng.sample_iter(...)` may - /// need to be replaced with `(&mut rng).sample_iter(...)`. + /// Note: this method consumes its arguments. Use + /// `(&mut rng).sample_iter(..)` to avoid consuming the RNG. /// /// # Example /// /// ``` - /// use rand::{thread_rng, Rng}; - /// use rand::distributions::{Alphanumeric, Uniform, Standard}; + /// use rand::Rng; + /// use rand::distr::{Alphanumeric, Uniform, StandardUniform}; /// - /// let mut rng = thread_rng(); + /// let mut rng = rand::rng(); /// /// // Vec of 16 x f32: - /// let v: Vec = (&mut rng).sample_iter(Standard).take(16).collect(); + /// let v: Vec = (&mut rng).sample_iter(StandardUniform).take(16).collect(); /// /// // String: /// let s: String = (&mut rng).sample_iter(Alphanumeric) @@ -177,17 +273,17 @@ pub trait Rng: RngCore { /// .collect(); /// /// // Combined values - /// println!("{:?}", (&mut rng).sample_iter(Standard).take(5) + /// println!("{:?}", (&mut rng).sample_iter(StandardUniform).take(5) /// .collect::>()); /// /// // Dice-rolling: - /// let die_range = Uniform::new_inclusive(1, 6); + /// let die_range = Uniform::new_inclusive(1, 6).unwrap(); /// let mut roll_die = (&mut rng).sample_iter(die_range); /// while roll_die.next().unwrap() != 6 { /// println!("Not a 6; rolling again!"); /// } /// ``` - fn sample_iter(self, distr: D) -> distributions::DistIter + fn sample_iter(self, distr: D) -> distr::Iter where D: Distribution, Self: Sized, @@ -197,106 +293,64 @@ pub trait Rng: RngCore { /// Fill any type implementing [`Fill`] with random data /// + /// This method is implemented for types which may be safely reinterpreted + /// as an (aligned) `[u8]` slice then filled with random data. It is often + /// faster than using [`Rng::random`] but not value-equivalent. + /// /// The distribution is expected to be uniform with portable results, but /// this cannot be guaranteed for third-party implementations. /// - /// This is identical to [`try_fill`] except that it panics on error. - /// /// # Example /// /// ``` - /// use rand::{thread_rng, Rng}; + /// use rand::Rng; /// /// let mut arr = [0i8; 20]; - /// thread_rng().fill(&mut arr[..]); + /// rand::rng().fill(&mut arr[..]); /// ``` /// /// [`fill_bytes`]: RngCore::fill_bytes - /// [`try_fill`]: Rng::try_fill + #[track_caller] fn fill(&mut self, dest: &mut T) { - dest.try_fill(self).unwrap_or_else(|_| panic!("Rng::fill failed")) + dest.fill(self) } - /// Fill any type implementing [`Fill`] with random data - /// - /// The distribution is expected to be uniform with portable results, but - /// this cannot be guaranteed for third-party implementations. - /// - /// This is identical to [`fill`] except that it forwards errors. - /// - /// # Example - /// - /// ``` - /// # use rand::Error; - /// use rand::{thread_rng, Rng}; - /// - /// # fn try_inner() -> Result<(), Error> { - /// let mut arr = [0u64; 4]; - /// thread_rng().try_fill(&mut arr[..])?; - /// # Ok(()) - /// # } - /// - /// # try_inner().unwrap() - /// ``` - /// - /// [`try_fill_bytes`]: RngCore::try_fill_bytes - /// [`fill`]: Rng::fill - fn try_fill(&mut self, dest: &mut T) -> Result<(), Error> { - dest.try_fill(self) + /// Alias for [`Rng::random`]. + #[inline] + #[deprecated( + since = "0.9.0", + note = "Renamed to `random` to avoid conflict with the new `gen` keyword in Rust 2024." + )] + fn r#gen(&mut self) -> T + where + StandardUniform: Distribution, + { + self.random() } - /// Return a bool with a probability `p` of being true. - /// - /// See also the [`Bernoulli`] distribution, which may be faster if - /// sampling from the same probability repeatedly. - /// - /// # Example - /// - /// ``` - /// use rand::{thread_rng, Rng}; - /// - /// let mut rng = thread_rng(); - /// println!("{}", rng.gen_bool(1.0 / 3.0)); - /// ``` - /// - /// # Panics - /// - /// If `p < 0` or `p > 1`. - /// - /// [`Bernoulli`]: distributions::Bernoulli + /// Alias for [`Rng::random_range`]. #[inline] + #[deprecated(since = "0.9.0", note = "Renamed to `random_range`")] + fn gen_range(&mut self, range: R) -> T + where + T: SampleUniform, + R: SampleRange, + { + self.random_range(range) + } + + /// Alias for [`Rng::random_bool`]. + #[inline] + #[deprecated(since = "0.9.0", note = "Renamed to `random_bool`")] fn gen_bool(&mut self, p: f64) -> bool { - let d = distributions::Bernoulli::new(p).unwrap(); - self.sample(d) + self.random_bool(p) } - /// Return a bool with a probability of `numerator/denominator` of being - /// true. I.e. `gen_ratio(2, 3)` has chance of 2 in 3, or about 67%, of - /// returning true. If `numerator == denominator`, then the returned value - /// is guaranteed to be `true`. If `numerator == 0`, then the returned - /// value is guaranteed to be `false`. - /// - /// See also the [`Bernoulli`] distribution, which may be faster if - /// sampling from the same `numerator` and `denominator` repeatedly. - /// - /// # Panics - /// - /// If `denominator == 0` or `numerator > denominator`. - /// - /// # Example - /// - /// ``` - /// use rand::{thread_rng, Rng}; - /// - /// let mut rng = thread_rng(); - /// println!("{}", rng.gen_ratio(2, 3)); - /// ``` - /// - /// [`Bernoulli`]: distributions::Bernoulli + /// Alias for [`Rng::random_ratio`]. #[inline] + #[deprecated(since = "0.9.0", note = "Renamed to `random_ratio`")] fn gen_ratio(&mut self, numerator: u32, denominator: u32) -> bool { - let d = distributions::Bernoulli::from_ratio(numerator, denominator).unwrap(); - self.sample(d) + self.random_ratio(numerator, denominator) } } @@ -311,18 +365,17 @@ impl Rng for R {} /// [Chapter on Portability](https://rust-random.github.io/book/portability.html)). pub trait Fill { /// Fill self with random data - fn try_fill(&mut self, rng: &mut R) -> Result<(), Error>; + fn fill(&mut self, rng: &mut R); } macro_rules! impl_fill_each { () => {}; ($t:ty) => { impl Fill for [$t] { - fn try_fill(&mut self, rng: &mut R) -> Result<(), Error> { + fn fill(&mut self, rng: &mut R) { for elt in self.iter_mut() { - *elt = rng.gen(); + *elt = rng.random(); } - Ok(()) } } }; @@ -335,8 +388,8 @@ macro_rules! impl_fill_each { impl_fill_each!(bool, char, f32, f64,); impl Fill for [u8] { - fn try_fill(&mut self, rng: &mut R) -> Result<(), Error> { - rng.try_fill_bytes(self) + fn fill(&mut self, rng: &mut R) { + rng.fill_bytes(self) } } @@ -345,37 +398,25 @@ macro_rules! impl_fill { ($t:ty) => { impl Fill for [$t] { #[inline(never)] // in micro benchmarks, this improves performance - fn try_fill(&mut self, rng: &mut R) -> Result<(), Error> { + fn fill(&mut self, rng: &mut R) { if self.len() > 0 { - rng.try_fill_bytes(unsafe { - slice::from_raw_parts_mut(self.as_mut_ptr() - as *mut u8, - self.len() * mem::size_of::<$t>() - ) - })?; + rng.fill_bytes(self.as_mut_bytes()); for x in self { *x = x.to_le(); } } - Ok(()) } } impl Fill for [Wrapping<$t>] { #[inline(never)] - fn try_fill(&mut self, rng: &mut R) -> Result<(), Error> { + fn fill(&mut self, rng: &mut R) { if self.len() > 0 { - rng.try_fill_bytes(unsafe { - slice::from_raw_parts_mut(self.as_mut_ptr() - as *mut u8, - self.len() * mem::size_of::<$t>() - ) - })?; + rng.fill_bytes(self.as_mut_bytes()); for x in self { *x = Wrapping(x.0.to_le()); } } - Ok(()) } } }; @@ -387,42 +428,25 @@ macro_rules! impl_fill { } } -impl_fill!(u16, u32, u64, usize,); -#[cfg(not(target_os = "emscripten"))] -impl_fill!(u128); -impl_fill!(i8, i16, i32, i64, isize,); -#[cfg(not(target_os = "emscripten"))] -impl_fill!(i128); - -macro_rules! impl_fill_arrays { - ($n:expr,) => {}; - ($n:expr, $N:ident) => { - impl Fill for [T; $n] where [T]: Fill { - fn try_fill(&mut self, rng: &mut R) -> Result<(), Error> { - self[..].try_fill(rng) - } - } - }; - ($n:expr, $N:ident, $($NN:ident,)*) => { - impl_fill_arrays!($n, $N); - impl_fill_arrays!($n - 1, $($NN,)*); - }; - (!div $n:expr,) => {}; - (!div $n:expr, $N:ident, $($NN:ident,)*) => { - impl_fill_arrays!($n, $N); - impl_fill_arrays!(!div $n / 2, $($NN,)*); - }; +impl_fill!(u16, u32, u64, u128,); +impl_fill!(i8, i16, i32, i64, i128,); + +impl Fill for [T; N] +where + [T]: Fill, +{ + fn fill(&mut self, rng: &mut R) { + <[T] as Fill>::fill(self, rng) + } } -#[rustfmt::skip] -impl_fill_arrays!(32, N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,N,); -impl_fill_arrays!(!div 4096, N,N,N,N,N,N,N,); #[cfg(test)] mod test { use super::*; - use crate::test::rng; use crate::rngs::mock::StepRng; - #[cfg(feature = "alloc")] use alloc::boxed::Box; + use crate::test::rng; + #[cfg(feature = "alloc")] + use alloc::boxed::Box; #[test] fn test_fill_bytes_default() { @@ -470,8 +494,8 @@ mod test { // Check equivalence for generated floats let mut array = [0f32; 2]; rng.fill(&mut array); - let gen: [f32; 2] = rng.gen(); - assert_eq!(array, gen); + let arr2: [f32; 2] = rng.random(); + assert_eq!(array, arr2); } #[test] @@ -483,85 +507,98 @@ mod test { } #[test] - fn test_gen_range_int() { + fn test_random_range_int() { let mut r = rng(101); for _ in 0..1000 { - let a = r.gen_range(-4711..17); - assert!(a >= -4711 && a < 17); - let a = r.gen_range(-3i8..42); - assert!(a >= -3i8 && a < 42i8); - let a: u16 = r.gen_range(10..99); - assert!(a >= 10u16 && a < 99u16); - let a = r.gen_range(-100i32..2000); - assert!(a >= -100i32 && a < 2000i32); - let a: u32 = r.gen_range(12..=24); - assert!(a >= 12u32 && a <= 24u32); - - assert_eq!(r.gen_range(0u32..1), 0u32); - assert_eq!(r.gen_range(-12i64..-11), -12i64); - assert_eq!(r.gen_range(3_000_000..3_000_001), 3_000_000); + let a = r.random_range(-4711..17); + assert!((-4711..17).contains(&a)); + let a: i8 = r.random_range(-3..42); + assert!((-3..42).contains(&a)); + let a: u16 = r.random_range(10..99); + assert!((10..99).contains(&a)); + let a: i32 = r.random_range(-100..2000); + assert!((-100..2000).contains(&a)); + let a: u32 = r.random_range(12..=24); + assert!((12..=24).contains(&a)); + + assert_eq!(r.random_range(..1u32), 0u32); + assert_eq!(r.random_range(-12i64..-11), -12i64); + assert_eq!(r.random_range(3_000_000..3_000_001), 3_000_000); } } #[test] - fn test_gen_range_float() { + fn test_random_range_float() { let mut r = rng(101); for _ in 0..1000 { - let a = r.gen_range(-4.5..1.7); - assert!(a >= -4.5 && a < 1.7); - let a = r.gen_range(-1.1..=-0.3); - assert!(a >= -1.1 && a <= -0.3); - - assert_eq!(r.gen_range(0.0f32..=0.0), 0.); - assert_eq!(r.gen_range(-11.0..=-11.0), -11.); - assert_eq!(r.gen_range(3_000_000.0..=3_000_000.0), 3_000_000.); + let a = r.random_range(-4.5..1.7); + assert!((-4.5..1.7).contains(&a)); + let a = r.random_range(-1.1..=-0.3); + assert!((-1.1..=-0.3).contains(&a)); + + assert_eq!(r.random_range(0.0f32..=0.0), 0.); + assert_eq!(r.random_range(-11.0..=-11.0), -11.); + assert_eq!(r.random_range(3_000_000.0..=3_000_000.0), 3_000_000.); } } #[test] #[should_panic] - fn test_gen_range_panic_int() { + #[allow(clippy::reversed_empty_ranges)] + fn test_random_range_panic_int() { let mut r = rng(102); - r.gen_range(5..-2); + r.random_range(5..-2); } #[test] #[should_panic] - fn test_gen_range_panic_usize() { + #[allow(clippy::reversed_empty_ranges)] + fn test_random_range_panic_usize() { let mut r = rng(103); - r.gen_range(5..2); + r.random_range(5..2); } #[test] - fn test_gen_bool() { + #[allow(clippy::bool_assert_comparison)] + fn test_random_bool() { let mut r = rng(105); for _ in 0..5 { - assert_eq!(r.gen_bool(0.0), false); - assert_eq!(r.gen_bool(1.0), true); + assert_eq!(r.random_bool(0.0), false); + assert_eq!(r.random_bool(1.0), true); + } + } + + #[test] + fn test_rng_mut_ref() { + fn use_rng(mut r: impl Rng) { + let _ = r.next_u32(); } + + let mut rng = rng(109); + use_rng(&mut rng); } #[test] fn test_rng_trait_object() { - use crate::distributions::{Distribution, Standard}; + use crate::distr::{Distribution, StandardUniform}; let mut rng = rng(109); let mut r = &mut rng as &mut dyn RngCore; r.next_u32(); - r.gen::(); - assert_eq!(r.gen_range(0..1), 0); - let _c: u8 = Standard.sample(&mut r); + r.random::(); + assert_eq!(r.random_range(0..1), 0); + let _c: u8 = StandardUniform.sample(&mut r); } #[test] #[cfg(feature = "alloc")] fn test_rng_boxed_trait() { - use crate::distributions::{Distribution, Standard}; + use crate::distr::{Distribution, StandardUniform}; let rng = rng(110); let mut r = Box::new(rng) as Box; r.next_u32(); - r.gen::(); - assert_eq!(r.gen_range(0..1), 0); - let _c: u8 = Standard.sample(&mut r); + r.random::(); + assert_eq!(r.random_range(0..1), 0); + let _c: u8 = StandardUniform.sample(&mut r); } #[test] @@ -574,7 +611,7 @@ mod test { let mut sum: u32 = 0; let mut rng = rng(111); for _ in 0..N { - if rng.gen_ratio(NUM, DENOM) { + if rng.random_ratio(NUM, DENOM) { sum += 1; } } diff --git a/src/rngs/adapter/mod.rs b/src/rngs/adapter/mod.rs deleted file mode 100644 index 22b7158d40f..00000000000 --- a/src/rngs/adapter/mod.rs +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! Wrappers / adapters forming RNGs - -mod read; -mod reseeding; - -pub use self::read::{ReadError, ReadRng}; -pub use self::reseeding::ReseedingRng; diff --git a/src/rngs/adapter/read.rs b/src/rngs/adapter/read.rs deleted file mode 100644 index 63b0dd0c0f0..00000000000 --- a/src/rngs/adapter/read.rs +++ /dev/null @@ -1,157 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// Copyright 2013 The Rust Project Developers. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! A wrapper around any Read to treat it as an RNG. - -use std::fmt; -use std::io::Read; - -use rand_core::{impls, Error, RngCore}; - - -/// An RNG that reads random bytes straight from any type supporting -/// [`std::io::Read`], for example files. -/// -/// This will work best with an infinite reader, but that is not required. -/// -/// This can be used with `/dev/urandom` on Unix but it is recommended to use -/// [`OsRng`] instead. -/// -/// # Panics -/// -/// `ReadRng` uses [`std::io::Read::read_exact`], which retries on interrupts. -/// All other errors from the underlying reader, including when it does not -/// have enough data, will only be reported through [`try_fill_bytes`]. -/// The other [`RngCore`] methods will panic in case of an error. -/// -/// # Example -/// -/// ``` -/// use rand::Rng; -/// use rand::rngs::adapter::ReadRng; -/// -/// let data = vec![1, 2, 3, 4, 5, 6, 7, 8]; -/// let mut rng = ReadRng::new(&data[..]); -/// println!("{:x}", rng.gen::()); -/// ``` -/// -/// [`OsRng`]: crate::rngs::OsRng -/// [`try_fill_bytes`]: RngCore::try_fill_bytes -#[derive(Debug)] -pub struct ReadRng { - reader: R, -} - -impl ReadRng { - /// Create a new `ReadRng` from a `Read`. - pub fn new(r: R) -> ReadRng { - ReadRng { reader: r } - } -} - -impl RngCore for ReadRng { - fn next_u32(&mut self) -> u32 { - impls::next_u32_via_fill(self) - } - - fn next_u64(&mut self) -> u64 { - impls::next_u64_via_fill(self) - } - - fn fill_bytes(&mut self, dest: &mut [u8]) { - self.try_fill_bytes(dest).unwrap_or_else(|err| { - panic!( - "reading random bytes from Read implementation failed; error: {}", - err - ) - }); - } - - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - if dest.is_empty() { - return Ok(()); - } - // Use `std::io::read_exact`, which retries on `ErrorKind::Interrupted`. - self.reader - .read_exact(dest) - .map_err(|e| Error::new(ReadError(e))) - } -} - -/// `ReadRng` error type -#[derive(Debug)] -pub struct ReadError(std::io::Error); - -impl fmt::Display for ReadError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "ReadError: {}", self.0) - } -} - -impl std::error::Error for ReadError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - Some(&self.0) - } -} - - -#[cfg(test)] -mod test { - use std::println; - - use super::ReadRng; - use crate::RngCore; - - #[test] - fn test_reader_rng_u64() { - // transmute from the target to avoid endianness concerns. - #[rustfmt::skip] - let v = [0u8, 0, 0, 0, 0, 0, 0, 1, - 0, 4, 0, 0, 3, 0, 0, 2, - 5, 0, 0, 0, 0, 0, 0, 0]; - let mut rng = ReadRng::new(&v[..]); - - assert_eq!(rng.next_u64(), 1 << 56); - assert_eq!(rng.next_u64(), (2 << 56) + (3 << 32) + (4 << 8)); - assert_eq!(rng.next_u64(), 5); - } - - #[test] - fn test_reader_rng_u32() { - let v = [0u8, 0, 0, 1, 0, 0, 2, 0, 3, 0, 0, 0]; - let mut rng = ReadRng::new(&v[..]); - - assert_eq!(rng.next_u32(), 1 << 24); - assert_eq!(rng.next_u32(), 2 << 16); - assert_eq!(rng.next_u32(), 3); - } - - #[test] - fn test_reader_rng_fill_bytes() { - let v = [1u8, 2, 3, 4, 5, 6, 7, 8]; - let mut w = [0u8; 8]; - - let mut rng = ReadRng::new(&v[..]); - rng.fill_bytes(&mut w); - - assert!(v == w); - } - - #[test] - fn test_reader_rng_insufficient_bytes() { - let v = [1u8, 2, 3, 4, 5, 6, 7, 8]; - let mut w = [0u8; 9]; - - let mut rng = ReadRng::new(&v[..]); - - let result = rng.try_fill_bytes(&mut w); - assert!(result.is_err()); - println!("Error: {}", result.unwrap_err()); - } -} diff --git a/src/rngs/mock.rs b/src/rngs/mock.rs index a1745a490dd..b6da66a8565 100644 --- a/src/rngs/mock.rs +++ b/src/rngs/mock.rs @@ -8,27 +8,38 @@ //! Mock random number generator -use rand_core::{impls, Error, RngCore}; +use rand_core::{impls, RngCore}; -#[cfg(feature = "serde1")] -use serde::{Serialize, Deserialize}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; -/// A simple implementation of `RngCore` for testing purposes. +/// A mock generator yielding very predictable output /// /// This generates an arithmetic sequence (i.e. adds a constant each step) /// over a `u64` number, using wrapping arithmetic. If the increment is 0 /// the generator yields a constant. /// +/// Other integer types (64-bit and smaller) are produced via cast from `u64`. +/// +/// Other types are produced via their implementation of [`Rng`](crate::Rng) or +/// [`Distribution`](crate::distr::Distribution). +/// Output values may not be intuitive and may change in future releases but +/// are considered +/// [portable](https://rust-random.github.io/book/portability.html). +/// (`bool` output is true when bit `1u64 << 31` is set.) +/// +/// # Example +/// /// ``` /// use rand::Rng; /// use rand::rngs::mock::StepRng; /// /// let mut my_rng = StepRng::new(2, 1); -/// let sample: [u64; 3] = my_rng.gen(); +/// let sample: [u64; 3] = my_rng.random(); /// assert_eq!(sample, [2, 3, 4]); /// ``` #[derive(Debug, Clone, PartialEq, Eq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct StepRng { v: u64, a: u64, @@ -53,35 +64,40 @@ impl RngCore for StepRng { #[inline] fn next_u64(&mut self) -> u64 { - let result = self.v; + let res = self.v; self.v = self.v.wrapping_add(self.a); - result - } - - #[inline] - fn fill_bytes(&mut self, dest: &mut [u8]) { - impls::fill_bytes_via_next(self, dest); + res } #[inline] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.fill_bytes(dest); - Ok(()) + fn fill_bytes(&mut self, dst: &mut [u8]) { + impls::fill_bytes_via_next(self, dst) } } #[cfg(test)] mod tests { + #[cfg(any(feature = "alloc", feature = "serde"))] + use super::StepRng; + #[test] - #[cfg(feature = "serde1")] + #[cfg(feature = "serde")] fn test_serialization_step_rng() { - use super::StepRng; - let some_rng = StepRng::new(42, 7); let de_some_rng: StepRng = bincode::deserialize(&bincode::serialize(&some_rng).unwrap()).unwrap(); assert_eq!(some_rng.v, de_some_rng.v); assert_eq!(some_rng.a, de_some_rng.a); + } + + #[test] + #[cfg(feature = "alloc")] + fn test_bool() { + use crate::{distr::StandardUniform, Rng}; + // If this result ever changes, update doc on StepRng! + let rng = StepRng::new(0, 1 << 31); + let result: alloc::vec::Vec = rng.sample_iter(StandardUniform).take(6).collect(); + assert_eq!(&result, &[false, true, false, true, false, true]); } } diff --git a/src/rngs/mod.rs b/src/rngs/mod.rs index ac3c2c595da..cb7ed57f33e 100644 --- a/src/rngs/mod.rs +++ b/src/rngs/mod.rs @@ -8,112 +8,102 @@ //! Random number generators and adapters //! -//! ## Background: Random number generators (RNGs) +//! This crate provides a small selection of non-[portable] generators. +//! See also [Types of generators] and [Our RNGs] in the book. //! -//! Computers cannot produce random numbers from nowhere. We classify -//! random number generators as follows: +//! ## Generators //! -//! - "True" random number generators (TRNGs) use hard-to-predict data sources -//! (e.g. the high-resolution parts of event timings and sensor jitter) to -//! harvest random bit-sequences, apply algorithms to remove bias and -//! estimate available entropy, then combine these bits into a byte-sequence -//! or an entropy pool. This job is usually done by the operating system or -//! a hardware generator (HRNG). -//! - "Pseudo"-random number generators (PRNGs) use algorithms to transform a -//! seed into a sequence of pseudo-random numbers. These generators can be -//! fast and produce well-distributed unpredictable random numbers (or not). -//! They are usually deterministic: given algorithm and seed, the output -//! sequence can be reproduced. They have finite period and eventually loop; -//! with many algorithms this period is fixed and can be proven sufficiently -//! long, while others are chaotic and the period depends on the seed. -//! - "Cryptographically secure" pseudo-random number generators (CSPRNGs) -//! are the sub-set of PRNGs which are secure. Security of the generator -//! relies both on hiding the internal state and using a strong algorithm. +//! This crate provides a small selection of non-[portable] random number generators: //! -//! ## Traits and functionality -//! -//! All RNGs implement the [`RngCore`] trait, as a consequence of which the -//! [`Rng`] extension trait is automatically implemented. Secure RNGs may -//! additionally implement the [`CryptoRng`] trait. -//! -//! All PRNGs require a seed to produce their random number sequence. The -//! [`SeedableRng`] trait provides three ways of constructing PRNGs: -//! -//! - `from_seed` accepts a type specific to the PRNG -//! - `from_rng` allows a PRNG to be seeded from any other RNG -//! - `seed_from_u64` allows any PRNG to be seeded from a `u64` insecurely -//! - `from_entropy` securely seeds a PRNG from fresh entropy -//! -//! Use the [`rand_core`] crate when implementing your own RNGs. -//! -//! ## Our generators -//! -//! This crate provides several random number generators: -//! -//! - [`OsRng`] is an interface to the operating system's random number -//! source. Typically the operating system uses a CSPRNG with entropy -//! provided by a TRNG and some type of on-going re-seeding. -//! - [`ThreadRng`], provided by the [`thread_rng`] function, is a handle to a -//! thread-local CSPRNG with periodic seeding from [`OsRng`]. Because this +//! - [`OsRng`] is a stateless interface over the operating system's random number +//! source. This is typically secure with some form of periodic re-seeding. +//! - [`ThreadRng`], provided by [`crate::rng()`], is a handle to a +//! thread-local generator with periodic seeding from [`OsRng`]. Because this //! is local, it is typically much faster than [`OsRng`]. It should be -//! secure, though the paranoid may prefer [`OsRng`]. +//! secure, but see documentation on [`ThreadRng`]. //! - [`StdRng`] is a CSPRNG chosen for good performance and trust of security //! (based on reviews, maturity and usage). The current algorithm is ChaCha12, //! which is well established and rigorously analysed. -//! [`StdRng`] provides the algorithm used by [`ThreadRng`] but without -//! periodic reseeding. -//! - [`SmallRng`] is an **insecure** PRNG designed to be fast, simple, require -//! little memory, and have good output quality. +//! [`StdRng`] is the deterministic generator used by [`ThreadRng`] but +//! without the periodic reseeding or thread-local management. +//! - [`SmallRng`] is a relatively simple, insecure generator designed to be +//! fast, use little memory, and pass various statistical tests of +//! randomness quality. //! //! The algorithms selected for [`StdRng`] and [`SmallRng`] may change in any -//! release and may be platform-dependent, therefore they should be considered -//! **not reproducible**. +//! release and may be platform-dependent, therefore they are not +//! [reproducible][portable]. //! -//! ## Additional generators +//! ### Additional generators //! -//! **TRNGs**: The [`rdrand`] crate provides an interface to the RDRAND and -//! RDSEED instructions available in modern Intel and AMD CPUs. -//! The [`rand_jitter`] crate provides a user-space implementation of -//! entropy harvesting from CPU timer jitter, but is very slow and has -//! [security issues](https://github.com/rust-random/rand/issues/699). +//! - The [`rdrand`] crate provides an interface to the RDRAND and RDSEED +//! instructions available in modern Intel and AMD CPUs. +//! - The [`rand_jitter`] crate provides a user-space implementation of +//! entropy harvesting from CPU timer jitter, but is very slow and has +//! [security issues](https://github.com/rust-random/rand/issues/699). +//! - The [`rand_chacha`] crate provides [portable] implementations of +//! generators derived from the [ChaCha] family of stream ciphers +//! - The [`rand_pcg`] crate provides [portable] implementations of a subset +//! of the [PCG] family of small, insecure generators +//! - The [`rand_xoshiro`] crate provides [portable] implementations of the +//! [xoshiro] family of small, insecure generators //! -//! **PRNGs**: Several companion crates are available, providing individual or -//! families of PRNG algorithms. These provide the implementations behind -//! [`StdRng`] and [`SmallRng`] but can also be used directly, indeed *should* -//! be used directly when **reproducibility** matters. -//! Some suggestions are: [`rand_chacha`], [`rand_pcg`], [`rand_xoshiro`]. -//! A full list can be found by searching for crates with the [`rng` tag]. +//! For more, search [crates with the `rng` tag]. //! +//! ## Traits and functionality +//! +//! All generators implement [`RngCore`] and thus also [`Rng`][crate::Rng]. +//! See also the [Random Values] chapter in the book. +//! +//! Secure RNGs may additionally implement the [`CryptoRng`] trait. +//! +//! Use the [`rand_core`] crate when implementing your own RNGs. +//! +//! [portable]: https://rust-random.github.io/book/crate-reprod.html +//! [Types of generators]: https://rust-random.github.io/book/guide-gen.html +//! [Our RNGs]: https://rust-random.github.io/book/guide-rngs.html +//! [Random Values]: https://rust-random.github.io/book/guide-values.html //! [`Rng`]: crate::Rng //! [`RngCore`]: crate::RngCore //! [`CryptoRng`]: crate::CryptoRng //! [`SeedableRng`]: crate::SeedableRng -//! [`thread_rng`]: crate::thread_rng //! [`rdrand`]: https://crates.io/crates/rdrand //! [`rand_jitter`]: https://crates.io/crates/rand_jitter //! [`rand_chacha`]: https://crates.io/crates/rand_chacha //! [`rand_pcg`]: https://crates.io/crates/rand_pcg //! [`rand_xoshiro`]: https://crates.io/crates/rand_xoshiro -//! [`rng` tag]: https://crates.io/keywords/rng +//! [crates with the `rng` tag]: https://crates.io/keywords/rng +//! [chacha]: https://cr.yp.to/chacha.html +//! [PCG]: https://www.pcg-random.org/ +//! [xoshiro]: https://prng.di.unimi.it/ -#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] -#[cfg(feature = "std")] pub mod adapter; +mod reseeding; +pub use reseeding::ReseedingRng; pub mod mock; // Public so we don't export `StepRng` directly, making it a bit // more clear it is intended for testing. +#[cfg(feature = "small_rng")] +mod small; +#[cfg(all( + feature = "small_rng", + any(target_pointer_width = "32", target_pointer_width = "16") +))] +mod xoshiro128plusplus; #[cfg(all(feature = "small_rng", target_pointer_width = "64"))] mod xoshiro256plusplus; -#[cfg(all(feature = "small_rng", not(target_pointer_width = "64")))] -mod xoshiro128plusplus; -#[cfg(feature = "small_rng")] mod small; -#[cfg(feature = "std_rng")] mod std; -#[cfg(all(feature = "std", feature = "std_rng"))] pub(crate) mod thread; +#[cfg(feature = "std_rng")] +mod std; +#[cfg(feature = "thread_rng")] +pub(crate) mod thread; -#[cfg(feature = "small_rng")] pub use self::small::SmallRng; -#[cfg(feature = "std_rng")] pub use self::std::StdRng; -#[cfg(all(feature = "std", feature = "std_rng"))] pub use self::thread::ThreadRng; +#[cfg(feature = "small_rng")] +pub use self::small::SmallRng; +#[cfg(feature = "std_rng")] +pub use self::std::StdRng; +#[cfg(feature = "thread_rng")] +pub use self::thread::ThreadRng; -#[cfg_attr(doc_cfg, doc(cfg(feature = "getrandom")))] -#[cfg(feature = "getrandom")] pub use rand_core::OsRng; +#[cfg(feature = "os_rng")] +pub use rand_core::OsRng; diff --git a/src/rngs/adapter/reseeding.rs b/src/rngs/reseeding.rs similarity index 55% rename from src/rngs/adapter/reseeding.rs rename to src/rngs/reseeding.rs index 1977cb31906..570d04eeba4 100644 --- a/src/rngs/adapter/reseeding.rs +++ b/src/rngs/reseeding.rs @@ -10,10 +10,10 @@ //! A wrapper around another PRNG that reseeds it after it //! generates a certain number of random bytes. -use core::mem::size_of; +use core::mem::size_of_val; -use rand_core::block::{BlockRng, BlockRngCore}; -use rand_core::{CryptoRng, Error, RngCore, SeedableRng}; +use rand_core::block::{BlockRng, BlockRngCore, CryptoBlockRng}; +use rand_core::{CryptoRng, RngCore, SeedableRng, TryCryptoRng, TryRngCore}; /// A wrapper around any PRNG that implements [`BlockRngCore`], that adds the /// ability to reseed it. @@ -22,10 +22,6 @@ use rand_core::{CryptoRng, Error, RngCore, SeedableRng}; /// /// - On a manual call to [`reseed()`]. /// - After `clone()`, the clone will be reseeded on first use. -/// - After a process is forked, the RNG in the child process is reseeded within -/// the next few generated values, depending on the block size of the -/// underlying PRNG. For ChaCha and Hc128 this is a maximum of -/// 15 `u32` values before reseeding. /// - After the PRNG has generated a configurable number of random bytes. /// /// # When should reseeding after a fixed number of generated bytes be used? @@ -61,15 +57,14 @@ use rand_core::{CryptoRng, Error, RngCore, SeedableRng}; /// use rand_chacha::ChaCha20Core; // Internal part of ChaChaRng that /// // implements BlockRngCore /// use rand::rngs::OsRng; -/// use rand::rngs::adapter::ReseedingRng; +/// use rand::rngs::ReseedingRng; /// -/// let prng = ChaCha20Core::from_entropy(); -/// let mut reseeding_rng = ReseedingRng::new(prng, 0, OsRng); +/// let mut reseeding_rng = ReseedingRng::::new(0, OsRng).unwrap(); /// -/// println!("{}", reseeding_rng.gen::()); +/// println!("{}", reseeding_rng.random::()); /// /// let mut cloned_rng = reseeding_rng.clone(); -/// assert!(reseeding_rng.gen::() != cloned_rng.gen::()); +/// assert!(reseeding_rng.random::() != cloned_rng.random::()); /// ``` /// /// [`BlockRngCore`]: rand_core::block::BlockRngCore @@ -79,12 +74,12 @@ use rand_core::{CryptoRng, Error, RngCore, SeedableRng}; pub struct ReseedingRng(BlockRng>) where R: BlockRngCore + SeedableRng, - Rsdr: RngCore; + Rsdr: TryRngCore; impl ReseedingRng where R: BlockRngCore + SeedableRng, - Rsdr: RngCore, + Rsdr: TryRngCore, { /// Create a new `ReseedingRng` from an existing PRNG, combined with a RNG /// to use as reseeder. @@ -92,22 +87,27 @@ where /// `threshold` sets the number of generated bytes after which to reseed the /// PRNG. Set it to zero to never reseed based on the number of generated /// values. - pub fn new(rng: R, threshold: u64, reseeder: Rsdr) -> Self { - ReseedingRng(BlockRng::new(ReseedingCore::new(rng, threshold, reseeder))) + pub fn new(threshold: u64, reseeder: Rsdr) -> Result { + Ok(ReseedingRng(BlockRng::new(ReseedingCore::new( + threshold, reseeder, + )?))) } - /// Reseed the internal PRNG. - pub fn reseed(&mut self) -> Result<(), Error> { + /// Immediately reseed the generator + /// + /// This discards any remaining random data in the cache. + pub fn reseed(&mut self) -> Result<(), Rsdr::Error> { + self.0.reset(); self.0.core.reseed() } } // TODO: this should be implemented for any type where the inner type // implements RngCore, but we can't specify that because ReseedingCore is private -impl RngCore for ReseedingRng +impl RngCore for ReseedingRng where R: BlockRngCore + SeedableRng, - ::Results: AsRef<[u32]> + AsMut<[u32]>, + Rsdr: TryRngCore, { #[inline(always)] fn next_u32(&mut self) -> u32 { @@ -122,16 +122,12 @@ where fn fill_bytes(&mut self, dest: &mut [u8]) { self.0.fill_bytes(dest) } - - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.0.try_fill_bytes(dest) - } } impl Clone for ReseedingRng where R: BlockRngCore + SeedableRng + Clone, - Rsdr: RngCore + Clone, + Rsdr: TryRngCore + Clone, { fn clone(&self) -> ReseedingRng { // Recreating `BlockRng` seems easier than cloning it and resetting @@ -142,8 +138,8 @@ where impl CryptoRng for ReseedingRng where - R: BlockRngCore + SeedableRng + CryptoRng, - Rsdr: RngCore + CryptoRng, + R: BlockRngCore + SeedableRng + CryptoBlockRng, + Rsdr: TryCryptoRng, { } @@ -153,26 +149,24 @@ struct ReseedingCore { reseeder: Rsdr, threshold: i64, bytes_until_reseed: i64, - fork_counter: usize, } impl BlockRngCore for ReseedingCore where R: BlockRngCore + SeedableRng, - Rsdr: RngCore, + Rsdr: TryRngCore, { type Item = ::Item; type Results = ::Results; fn generate(&mut self, results: &mut Self::Results) { - let global_fork_counter = fork::get_fork_counter(); - if self.bytes_until_reseed <= 0 || self.is_forked(global_fork_counter) { + if self.bytes_until_reseed <= 0 { // We get better performance by not calling only `reseed` here // and continuing with the rest of the function, but by directly // returning from a non-inlined function. - return self.reseed_and_generate(results, global_fork_counter); + return self.reseed_and_generate(results); } - let num_bytes = results.as_ref().len() * size_of::(); + let num_bytes = size_of_val(results.as_ref()); self.bytes_until_reseed -= num_bytes as i64; self.inner.generate(results); } @@ -181,74 +175,53 @@ where impl ReseedingCore where R: BlockRngCore + SeedableRng, - Rsdr: RngCore, + Rsdr: TryRngCore, { /// Create a new `ReseedingCore`. - fn new(rng: R, threshold: u64, reseeder: Rsdr) -> Self { - use ::core::i64::MAX; - fork::register_fork_handler(); - + /// + /// `threshold` is the maximum number of bytes produced by + /// [`BlockRngCore::generate`] before attempting reseeding. + fn new(threshold: u64, mut reseeder: Rsdr) -> Result { // Because generating more values than `i64::MAX` takes centuries on // current hardware, we just clamp to that value. // Also we set a threshold of 0, which indicates no limit, to that // value. let threshold = if threshold == 0 { - MAX - } else if threshold <= MAX as u64 { + i64::MAX + } else if threshold <= i64::MAX as u64 { threshold as i64 } else { - MAX + i64::MAX }; - ReseedingCore { - inner: rng, + let inner = R::try_from_rng(&mut reseeder)?; + + Ok(ReseedingCore { + inner, reseeder, - threshold: threshold as i64, - bytes_until_reseed: threshold as i64, - fork_counter: 0, - } + threshold, + bytes_until_reseed: threshold, + }) } /// Reseed the internal PRNG. - fn reseed(&mut self) -> Result<(), Error> { - R::from_rng(&mut self.reseeder).map(|result| { + fn reseed(&mut self) -> Result<(), Rsdr::Error> { + R::try_from_rng(&mut self.reseeder).map(|result| { self.bytes_until_reseed = self.threshold; self.inner = result }) } - fn is_forked(&self, global_fork_counter: usize) -> bool { - // In theory, on 32-bit platforms, it is possible for - // `global_fork_counter` to wrap around after ~4e9 forks. - // - // This check will detect a fork in the normal case where - // `fork_counter < global_fork_counter`, and also when the difference - // between both is greater than `isize::MAX` (wrapped around). - // - // It will still fail to detect a fork if there have been more than - // `isize::MAX` forks, without any reseed in between. Seems unlikely - // enough. - (self.fork_counter.wrapping_sub(global_fork_counter) as isize) < 0 - } - #[inline(never)] - fn reseed_and_generate( - &mut self, results: &mut ::Results, global_fork_counter: usize, - ) { - #![allow(clippy::if_same_then_else)] // false positive - if self.is_forked(global_fork_counter) { - info!("Fork detected, reseeding RNG"); - } else { - trace!("Reseeding RNG (periodic reseed)"); - } + fn reseed_and_generate(&mut self, results: &mut ::Results) { + trace!("Reseeding RNG (periodic reseed)"); - let num_bytes = results.as_ref().len() * size_of::<::Item>(); + let num_bytes = size_of_val(results.as_ref()); if let Err(e) = self.reseed() { warn!("Reseeding RNG failed: {}", e); let _ = e; } - self.fork_counter = global_fork_counter; self.bytes_until_reseed = self.threshold - num_bytes as i64; self.inner.generate(results); @@ -258,7 +231,7 @@ where impl Clone for ReseedingCore where R: BlockRngCore + SeedableRng + Clone, - Rsdr: RngCore + Clone, + Rsdr: TryRngCore + Clone, { fn clone(&self) -> ReseedingCore { ReseedingCore { @@ -266,79 +239,31 @@ where reseeder: self.reseeder.clone(), threshold: self.threshold, bytes_until_reseed: 0, // reseed clone on first use - fork_counter: self.fork_counter, } } } -impl CryptoRng for ReseedingCore +impl CryptoBlockRng for ReseedingCore where - R: BlockRngCore + SeedableRng + CryptoRng, - Rsdr: RngCore + CryptoRng, + R: BlockRngCore + SeedableRng + CryptoBlockRng, + Rsdr: TryCryptoRng, { } - -#[cfg(all(unix, not(target_os = "emscripten")))] -mod fork { - use core::sync::atomic::{AtomicUsize, Ordering}; - use std::sync::Once; - - // Fork protection - // - // We implement fork protection on Unix using `pthread_atfork`. - // When the process is forked, we increment `RESEEDING_RNG_FORK_COUNTER`. - // Every `ReseedingRng` stores the last known value of the static in - // `fork_counter`. If the cached `fork_counter` is less than - // `RESEEDING_RNG_FORK_COUNTER`, it is time to reseed this RNG. - // - // If reseeding fails, we don't deal with this by setting a delay, but just - // don't update `fork_counter`, so a reseed is attempted as soon as - // possible. - - static RESEEDING_RNG_FORK_COUNTER: AtomicUsize = AtomicUsize::new(0); - - pub fn get_fork_counter() -> usize { - RESEEDING_RNG_FORK_COUNTER.load(Ordering::Relaxed) - } - - extern "C" fn fork_handler() { - // Note: fetch_add is defined to wrap on overflow - // (which is what we want). - RESEEDING_RNG_FORK_COUNTER.fetch_add(1, Ordering::Relaxed); - } - - pub fn register_fork_handler() { - static REGISTER: Once = Once::new(); - REGISTER.call_once(|| unsafe { - libc::pthread_atfork(None, None, Some(fork_handler)); - }); - } -} - -#[cfg(not(all(unix, not(target_os = "emscripten"))))] -mod fork { - pub fn get_fork_counter() -> usize { - 0 - } - pub fn register_fork_handler() {} -} - - #[cfg(feature = "std_rng")] #[cfg(test)] mod test { - use super::ReseedingRng; use crate::rngs::mock::StepRng; use crate::rngs::std::Core; - use crate::{Rng, SeedableRng}; + use crate::Rng; + + use super::ReseedingRng; #[test] fn test_reseeding() { - let mut zero = StepRng::new(0, 0); - let rng = Core::from_rng(&mut zero).unwrap(); + let zero = StepRng::new(0, 0); let thresh = 1; // reseed every time the buffer is exhausted - let mut reseeding = ReseedingRng::new(rng, thresh, zero); + let mut reseeding = ReseedingRng::::new(thresh, zero).unwrap(); // RNG buffer size is [u32; 64] // Debug is only implemented up to length 32 so use two arrays @@ -354,17 +279,17 @@ mod test { } #[test] + #[allow(clippy::redundant_clone)] fn test_clone_reseeding() { - let mut zero = StepRng::new(0, 0); - let rng = Core::from_rng(&mut zero).unwrap(); - let mut rng1 = ReseedingRng::new(rng, 32 * 4, zero); + let zero = StepRng::new(0, 0); + let mut rng1 = ReseedingRng::::new(32 * 4, zero).unwrap(); - let first: u32 = rng1.gen(); + let first: u32 = rng1.random(); for _ in 0..10 { - let _ = rng1.gen::(); + let _ = rng1.random::(); } let mut rng2 = rng1.clone(); - assert_eq!(first, rng2.gen::()); + assert_eq!(first, rng2.random::()); } } diff --git a/src/rngs/small.rs b/src/rngs/small.rs index fb0e0d119b6..67e0d0544f4 100644 --- a/src/rngs/small.rs +++ b/src/rngs/small.rs @@ -8,110 +8,113 @@ //! A small fast RNG -use rand_core::{Error, RngCore, SeedableRng}; +use rand_core::{RngCore, SeedableRng}; +#[cfg(any(target_pointer_width = "32", target_pointer_width = "16"))] +type Rng = super::xoshiro128plusplus::Xoshiro128PlusPlus; #[cfg(target_pointer_width = "64")] type Rng = super::xoshiro256plusplus::Xoshiro256PlusPlus; -#[cfg(not(target_pointer_width = "64"))] -type Rng = super::xoshiro128plusplus::Xoshiro128PlusPlus; -/// A small-state, fast non-crypto PRNG +/// A small-state, fast, non-crypto, non-portable PRNG /// -/// `SmallRng` may be a good choice when a PRNG with small state, cheap -/// initialization, good statistical quality and good performance are required. -/// Note that depending on the application, [`StdRng`] may be faster on many -/// modern platforms while providing higher-quality randomness. Furthermore, -/// `SmallRng` is **not** a good choice when: -/// - Security against prediction is important. Use [`StdRng`] instead. -/// - Seeds with many zeros are provided. In such cases, it takes `SmallRng` -/// about 10 samples to produce 0 and 1 bits with equal probability. Either -/// provide seeds with an approximately equal number of 0 and 1 (for example -/// by using [`SeedableRng::from_entropy`] or [`SeedableRng::seed_from_u64`]), -/// or use [`StdRng`] instead. +/// This is the "standard small" RNG, a generator with the following properties: /// -/// The algorithm is deterministic but should not be considered reproducible -/// due to dependence on platform and possible replacement in future -/// library versions. For a reproducible generator, use a named PRNG from an -/// external crate, e.g. [rand_xoshiro] or [rand_chacha]. -/// Refer also to [The Book](https://rust-random.github.io/book/guide-rngs.html). +/// - Non-[portable]: any future library version may replace the algorithm +/// and results may be platform-dependent. +/// (For a small portable generator, use the [rand_pcg] or [rand_xoshiro] crate.) +/// - Non-cryptographic: output is easy to predict (insecure) +/// - [Quality]: statistically good quality +/// - Fast: the RNG is fast for both bulk generation and single values, with +/// consistent cost of method calls +/// - Fast initialization +/// - Small state: little memory usage (current state size is 16-32 bytes +/// depending on platform) /// -/// The PRNG algorithm in `SmallRng` is chosen to be efficient on the current -/// platform, without consideration for cryptography or security. The size of -/// its state is much smaller than [`StdRng`]. The current algorithm is +/// The current algorithm is /// `Xoshiro256PlusPlus` on 64-bit platforms and `Xoshiro128PlusPlus` on 32-bit /// platforms. Both are also implemented by the [rand_xoshiro] crate. /// -/// # Examples +/// ## Seeding (construction) /// -/// Initializing `SmallRng` with a random seed can be done using [`SeedableRng::from_entropy`]: +/// This generator implements the [`SeedableRng`] trait. All methods are +/// suitable for seeding, but note that, even with a fixed seed, output is not +/// [portable]. Some suggestions: /// -/// ``` -/// use rand::{Rng, SeedableRng}; -/// use rand::rngs::SmallRng; +/// 1. To automatically seed with a unique seed, use [`SeedableRng::from_rng`]: +/// ``` +/// use rand::SeedableRng; +/// use rand::rngs::SmallRng; +/// let rng = SmallRng::from_rng(&mut rand::rng()); +/// # let _: SmallRng = rng; +/// ``` +/// or [`SeedableRng::from_os_rng`]: +/// ``` +/// # use rand::SeedableRng; +/// # use rand::rngs::SmallRng; +/// let rng = SmallRng::from_os_rng(); +/// # let _: SmallRng = rng; +/// ``` +/// 2. To use a deterministic integral seed, use `seed_from_u64`. This uses a +/// hash function internally to yield a (typically) good seed from any +/// input. +/// ``` +/// # use rand::{SeedableRng, rngs::SmallRng}; +/// let rng = SmallRng::seed_from_u64(1); +/// # let _: SmallRng = rng; +/// ``` +/// 3. To seed deterministically from text or other input, use [`rand_seeder`]. /// -/// // Create small, cheap to initialize and fast RNG with a random seed. -/// // The randomness is supplied by the operating system. -/// let mut small_rng = SmallRng::from_entropy(); -/// # let v: u32 = small_rng.gen(); -/// ``` +/// See also [Seeding RNGs] in the book. /// -/// When initializing a lot of `SmallRng`'s, using [`thread_rng`] can be more -/// efficient: +/// ## Generation /// -/// ``` -/// use rand::{SeedableRng, thread_rng}; -/// use rand::rngs::SmallRng; -/// -/// // Create a big, expensive to initialize and slower, but unpredictable RNG. -/// // This is cached and done only once per thread. -/// let mut thread_rng = thread_rng(); -/// // Create small, cheap to initialize and fast RNGs with random seeds. -/// // One can generally assume this won't fail. -/// let rngs: Vec = (0..10) -/// .map(|_| SmallRng::from_rng(&mut thread_rng).unwrap()) -/// .collect(); -/// ``` +/// The generators implements [`RngCore`] and thus also [`Rng`][crate::Rng]. +/// See also the [Random Values] chapter in the book. /// +/// [portable]: https://rust-random.github.io/book/crate-reprod.html +/// [Seeding RNGs]: https://rust-random.github.io/book/guide-seeding.html +/// [Random Values]: https://rust-random.github.io/book/guide-values.html +/// [Quality]: https://rust-random.github.io/book/guide-rngs.html#quality /// [`StdRng`]: crate::rngs::StdRng -/// [`thread_rng`]: crate::thread_rng -/// [rand_chacha]: https://crates.io/crates/rand_chacha +/// [rand_pcg]: https://crates.io/crates/rand_pcg /// [rand_xoshiro]: https://crates.io/crates/rand_xoshiro -#[cfg_attr(doc_cfg, doc(cfg(feature = "small_rng")))] +/// [`rand_chacha::ChaCha8Rng`]: https://docs.rs/rand_chacha/latest/rand_chacha/struct.ChaCha8Rng.html +/// [`rand_seeder`]: https://docs.rs/rand_seeder/latest/rand_seeder/ #[derive(Clone, Debug, PartialEq, Eq)] pub struct SmallRng(Rng); -impl RngCore for SmallRng { - #[inline(always)] - fn next_u32(&mut self) -> u32 { - self.0.next_u32() - } +impl SeedableRng for SmallRng { + // Fix to 256 bits. Changing this is a breaking change! + type Seed = [u8; 32]; #[inline(always)] - fn next_u64(&mut self) -> u64 { - self.0.next_u64() + fn from_seed(seed: Self::Seed) -> Self { + // This is for compatibility with 32-bit platforms where Rng::Seed has a different seed size + // With MSRV >= 1.77: let seed = *seed.first_chunk().unwrap() + const LEN: usize = core::mem::size_of::<::Seed>(); + let seed = (&seed[..LEN]).try_into().unwrap(); + SmallRng(Rng::from_seed(seed)) } #[inline(always)] - fn fill_bytes(&mut self, dest: &mut [u8]) { - self.0.fill_bytes(dest); + fn seed_from_u64(state: u64) -> Self { + SmallRng(Rng::seed_from_u64(state)) } +} +impl RngCore for SmallRng { #[inline(always)] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.0.try_fill_bytes(dest) + fn next_u32(&mut self) -> u32 { + self.0.next_u32() } -} - -impl SeedableRng for SmallRng { - type Seed = ::Seed; #[inline(always)] - fn from_seed(seed: Self::Seed) -> Self { - SmallRng(Rng::from_seed(seed)) + fn next_u64(&mut self) -> u64 { + self.0.next_u64() } #[inline(always)] - fn from_rng(rng: R) -> Result { - Rng::from_rng(rng).map(SmallRng) + fn fill_bytes(&mut self, dest: &mut [u8]) { + self.0.fill_bytes(dest) } } diff --git a/src/rngs/std.rs b/src/rngs/std.rs index 80f84336980..6e1658e7453 100644 --- a/src/rngs/std.rs +++ b/src/rngs/std.rs @@ -8,32 +8,64 @@ //! The standard RNG -use crate::{CryptoRng, Error, RngCore, SeedableRng}; +use rand_core::{CryptoRng, RngCore, SeedableRng}; -#[cfg(all(any(test, feature = "std"), not(target_os = "emscripten")))] +#[cfg(any(test, feature = "os_rng"))] pub(crate) use rand_chacha::ChaCha12Core as Core; -#[cfg(all(any(test, feature = "std"), target_os = "emscripten"))] -pub(crate) use rand_hc::Hc128Core as Core; -#[cfg(not(target_os = "emscripten"))] use rand_chacha::ChaCha12Rng as Rng; -#[cfg(target_os = "emscripten")] use rand_hc::Hc128Rng as Rng; +use rand_chacha::ChaCha12Rng as Rng; -/// The standard RNG. The PRNG algorithm in `StdRng` is chosen to be efficient -/// on the current platform, to be statistically strong and unpredictable -/// (meaning a cryptographically secure PRNG). +/// A strong, fast (amortized), non-portable RNG +/// +/// This is the "standard" RNG, a generator with the following properties: +/// +/// - Non-[portable]: any future library version may replace the algorithm +/// and results may be platform-dependent. +/// (For a portable version, use the [rand_chacha] crate directly.) +/// - [CSPRNG]: statistically good quality of randomness and [unpredictable] +/// - Fast ([amortized](https://en.wikipedia.org/wiki/Amortized_analysis)): +/// the RNG is fast for bulk generation, but the cost of method calls is not +/// consistent due to usage of an output buffer. /// /// The current algorithm used is the ChaCha block cipher with 12 rounds. Please -/// see this relevant [rand issue] for the discussion. This may change as new +/// see this relevant [rand issue] for the discussion. This may change as new /// evidence of cipher security and performance becomes available. /// -/// The algorithm is deterministic but should not be considered reproducible -/// due to dependence on configuration and possible replacement in future -/// library versions. For a secure reproducible generator, we recommend use of -/// the [rand_chacha] crate directly. +/// ## Seeding (construction) +/// +/// This generator implements the [`SeedableRng`] trait. Any method may be used, +/// but note that `seed_from_u64` is not suitable for usage where security is +/// important. Also note that, even with a fixed seed, output is not [portable]. +/// +/// Using a fresh seed **direct from the OS** is the most secure option: +/// ``` +/// # use rand::{SeedableRng, rngs::StdRng}; +/// let rng = StdRng::from_os_rng(); +/// # let _: StdRng = rng; +/// ``` +/// +/// Seeding via [`rand::rng()`](crate::rng()) may be faster: +/// ``` +/// # use rand::{SeedableRng, rngs::StdRng}; +/// let rng = StdRng::from_rng(&mut rand::rng()); +/// # let _: StdRng = rng; +/// ``` +/// +/// Any [`SeedableRng`] method may be used, but note that `seed_from_u64` is not +/// suitable where security is required. See also [Seeding RNGs] in the book. +/// +/// ## Generation +/// +/// The generators implements [`RngCore`] and thus also [`Rng`][crate::Rng]. +/// See also the [Random Values] chapter in the book. /// +/// [portable]: https://rust-random.github.io/book/crate-reprod.html +/// [Seeding RNGs]: https://rust-random.github.io/book/guide-seeding.html +/// [unpredictable]: https://rust-random.github.io/book/guide-rngs.html#security +/// [Random Values]: https://rust-random.github.io/book/guide-values.html +/// [CSPRNG]: https://rust-random.github.io/book/guide-gen.html#cryptographically-secure-pseudo-random-number-generator /// [rand_chacha]: https://crates.io/crates/rand_chacha /// [rand issue]: https://github.com/rust-random/rand/issues/932 -#[cfg_attr(doc_cfg, doc(cfg(feature = "std_rng")))] #[derive(Clone, Debug, PartialEq, Eq)] pub struct StdRng(Rng); @@ -49,33 +81,23 @@ impl RngCore for StdRng { } #[inline(always)] - fn fill_bytes(&mut self, dest: &mut [u8]) { - self.0.fill_bytes(dest); - } - - #[inline(always)] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.0.try_fill_bytes(dest) + fn fill_bytes(&mut self, dst: &mut [u8]) { + self.0.fill_bytes(dst) } } impl SeedableRng for StdRng { - type Seed = ::Seed; + // Fix to 256 bits. Changing this is a breaking change! + type Seed = [u8; 32]; #[inline(always)] fn from_seed(seed: Self::Seed) -> Self { StdRng(Rng::from_seed(seed)) } - - #[inline(always)] - fn from_rng(rng: R) -> Result { - Rng::from_rng(rng).map(StdRng) - } } impl CryptoRng for StdRng {} - #[cfg(test)] mod test { use crate::rngs::StdRng; @@ -94,7 +116,7 @@ mod test { let mut rng0 = StdRng::from_seed(seed); let x0 = rng0.next_u64(); - let mut rng1 = StdRng::from_rng(rng0).unwrap(); + let mut rng1 = StdRng::from_rng(&mut rng0); let x1 = rng1.next_u64(); assert_eq!([x0, x1], target); diff --git a/src/rngs/thread.rs b/src/rngs/thread.rs index 552851f1ec3..7e5203214a4 100644 --- a/src/rngs/thread.rs +++ b/src/rngs/thread.rs @@ -9,13 +9,15 @@ //! Thread-local random number generator use core::cell::UnsafeCell; +use std::fmt; use std::rc::Rc; use std::thread_local; +use rand_core::{CryptoRng, RngCore}; + use super::std::Core; -use crate::rngs::adapter::ReseedingRng; use crate::rngs::OsRng; -use crate::{CryptoRng, Error, RngCore, SeedableRng}; +use crate::rngs::ReseedingRng; // Rationale for using `UnsafeCell` in `ThreadRng`: // @@ -31,67 +33,135 @@ use crate::{CryptoRng, Error, RngCore, SeedableRng}; // `ThreadRng` internally, which is nonsensical anyway. We should also never run // `ThreadRng` in destructors of its implementation, which is also nonsensical. - // Number of generated bytes after which to reseed `ThreadRng`. -// According to benchmarks, reseeding has a noticable impact with thresholds +// According to benchmarks, reseeding has a noticeable impact with thresholds // of 32 kB and less. We choose 64 kB to avoid significant overhead. const THREAD_RNG_RESEED_THRESHOLD: u64 = 1024 * 64; /// A reference to the thread-local generator /// -/// An instance can be obtained via [`thread_rng`] or via `ThreadRng::default()`. -/// This handle is safe to use everywhere (including thread-local destructors) -/// but cannot be passed between threads (is not `Send` or `Sync`). +/// This type is a reference to a lazily-initialized thread-local generator. +/// An instance can be obtained via [`rand::rng()`][crate::rng()] or via +/// [`ThreadRng::default()`]. +/// The handle cannot be passed between threads (is not `Send` or `Sync`). +/// +/// # Security +/// +/// Security must be considered relative to a threat model and validation +/// requirements. The Rand project can provide no guarantee of fitness for +/// purpose. The design criteria for `ThreadRng` are as follows: +/// +/// - Automatic seeding via [`OsRng`] and periodically thereafter (see +/// ([`ReseedingRng`] documentation). Limitation: there is no automatic +/// reseeding on process fork (see [below](#fork)). +/// - A rigorusly analyzed, unpredictable (cryptographic) pseudo-random generator +/// (see [the book on security](https://rust-random.github.io/book/guide-rngs.html#security)). +/// The currently selected algorithm is ChaCha (12-rounds). +/// See also [`StdRng`] documentation. +/// - Not to leak internal state through [`Debug`] or serialization +/// implementations. +/// - No further protections exist to in-memory state. In particular, the +/// implementation is not required to zero memory on exit (of the process or +/// thread). (This may change in the future.) +/// - Be fast enough for general-purpose usage. Note in particular that +/// `ThreadRng` is designed to be a "fast, reasonably secure generator" +/// (where "reasonably secure" implies the above criteria). /// -/// `ThreadRng` uses the same PRNG as [`StdRng`] for security and performance -/// and is automatically seeded from [`OsRng`]. +/// We leave it to the user to determine whether this generator meets their +/// security requirements. For an alternative, see [`OsRng`]. /// -/// Unlike `StdRng`, `ThreadRng` uses the [`ReseedingRng`] wrapper to reseed -/// the PRNG from fresh entropy every 64 kiB of random data as well as after a -/// fork on Unix (though not quite immediately; see documentation of -/// [`ReseedingRng`]). -/// Note that the reseeding is done as an extra precaution against side-channel -/// attacks and mis-use (e.g. if somehow weak entropy were supplied initially). -/// The PRNG algorithms used are assumed to be secure. +/// # Fork /// -/// [`ReseedingRng`]: crate::rngs::adapter::ReseedingRng +/// `ThreadRng` is not automatically reseeded on fork. It is recommended to +/// explicitly call [`ThreadRng::reseed`] immediately after a fork, for example: +/// ```ignore +/// fn do_fork() { +/// let pid = unsafe { libc::fork() }; +/// if pid == 0 { +/// // Reseed ThreadRng in child processes: +/// rand::rng().reseed(); +/// } +/// } +/// ``` +/// +/// Methods on `ThreadRng` are not reentrant-safe and thus should not be called +/// from an interrupt (e.g. a fork handler) unless it can be guaranteed that no +/// other method on the same `ThreadRng` is currently executing. +/// +/// [`ReseedingRng`]: crate::rngs::ReseedingRng /// [`StdRng`]: crate::rngs::StdRng -#[cfg_attr(doc_cfg, doc(cfg(all(feature = "std", feature = "std_rng"))))] -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct ThreadRng { - // Rc is explictly !Send and !Sync + // Rc is explicitly !Send and !Sync rng: Rc>>, } +impl ThreadRng { + /// Immediately reseed the generator + /// + /// This discards any remaining random data in the cache. + pub fn reseed(&mut self) -> Result<(), rand_core::OsError> { + // SAFETY: We must make sure to stop using `rng` before anyone else + // creates another mutable reference + let rng = unsafe { &mut *self.rng.get() }; + rng.reseed() + } +} + +/// Debug implementation does not leak internal state +impl fmt::Debug for ThreadRng { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + write!(fmt, "ThreadRng {{ .. }}") + } +} + thread_local!( - // We require Rc<..> to avoid premature freeing when thread_rng is used + // We require Rc<..> to avoid premature freeing when ThreadRng is used // within thread-local destructors. See #968. static THREAD_RNG_KEY: Rc>> = { - let r = Core::from_rng(OsRng).unwrap_or_else(|err| - panic!("could not initialize thread_rng: {}", err)); - let rng = ReseedingRng::new(r, - THREAD_RNG_RESEED_THRESHOLD, - OsRng); + let rng = ReseedingRng::new(THREAD_RNG_RESEED_THRESHOLD, + OsRng).unwrap_or_else(|err| + panic!("could not initialize ThreadRng: {}", err)); Rc::new(UnsafeCell::new(rng)) } ); -/// Retrieve the lazily-initialized thread-local random number generator, -/// seeded by the system. Intended to be used in method chaining style, -/// e.g. `thread_rng().gen::()`, or cached locally, e.g. -/// `let mut rng = thread_rng();`. Invoked by the `Default` trait, making -/// `ThreadRng::default()` equivalent. +/// Access a fast, pre-initialized generator +/// +/// This is a handle to the local [`ThreadRng`]. +/// +/// See also [`crate::rngs`] for alternatives. /// -/// For more information see [`ThreadRng`]. -#[cfg_attr(doc_cfg, doc(cfg(all(feature = "std", feature = "std_rng"))))] -pub fn thread_rng() -> ThreadRng { +/// # Example +/// +/// ``` +/// use rand::prelude::*; +/// +/// # fn main() { +/// +/// let mut numbers = [1, 2, 3, 4, 5]; +/// numbers.shuffle(&mut rand::rng()); +/// println!("Numbers: {numbers:?}"); +/// +/// // Using a local binding avoids an initialization-check on each usage: +/// let mut rng = rand::rng(); +/// +/// println!("True or false: {}", rng.random::()); +/// println!("A simulated die roll: {}", rng.random_range(1..=6)); +/// # } +/// ``` +/// +/// # Security +/// +/// Refer to [`ThreadRng#Security`]. +pub fn rng() -> ThreadRng { let rng = THREAD_RNG_KEY.with(|t| t.clone()); ThreadRng { rng } } impl Default for ThreadRng { fn default() -> ThreadRng { - crate::prelude::thread_rng() + rng() } } @@ -112,31 +182,31 @@ impl RngCore for ThreadRng { rng.next_u64() } + #[inline(always)] fn fill_bytes(&mut self, dest: &mut [u8]) { // SAFETY: We must make sure to stop using `rng` before anyone else // creates another mutable reference let rng = unsafe { &mut *self.rng.get() }; rng.fill_bytes(dest) } - - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - // SAFETY: We must make sure to stop using `rng` before anyone else - // creates another mutable reference - let rng = unsafe { &mut *self.rng.get() }; - rng.try_fill_bytes(dest) - } } impl CryptoRng for ThreadRng {} - #[cfg(test)] mod test { #[test] fn test_thread_rng() { use crate::Rng; - let mut r = crate::thread_rng(); - r.gen::(); - assert_eq!(r.gen_range(0..1), 0); + let mut r = crate::rng(); + r.random::(); + assert_eq!(r.random_range(0..1), 0); + } + + #[test] + fn test_debug_output() { + // We don't care about the exact output here, but it must not include + // private CSPRNG state or the cache stored by BlockRng! + assert_eq!(std::format!("{:?}", crate::rng()), "ThreadRng { .. }"); } } diff --git a/src/rngs/xoshiro128plusplus.rs b/src/rngs/xoshiro128plusplus.rs index ece98fafd6a..69fe7ca9202 100644 --- a/src/rngs/xoshiro128plusplus.rs +++ b/src/rngs/xoshiro128plusplus.rs @@ -6,10 +6,11 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -#[cfg(feature="serde1")] use serde::{Serialize, Deserialize}; -use rand_core::impls::{next_u64_via_u32, fill_bytes_via_next}; +use rand_core::impls::{fill_bytes_via_next, next_u64_via_u32}; use rand_core::le::read_u32_into; -use rand_core::{SeedableRng, RngCore, Error}; +use rand_core::{RngCore, SeedableRng}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; /// A xoshiro128++ random number generator. /// @@ -20,7 +21,7 @@ use rand_core::{SeedableRng, RngCore, Error}; /// reference source code](http://xoshiro.di.unimi.it/xoshiro128plusplus.c) by /// David Blackman and Sebastiano Vigna. #[derive(Debug, Clone, PartialEq, Eq)] -#[cfg_attr(feature="serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Xoshiro128PlusPlus { s: [u32; 4], } @@ -32,36 +33,43 @@ impl SeedableRng for Xoshiro128PlusPlus { /// mapped to a different seed. #[inline] fn from_seed(seed: [u8; 16]) -> Xoshiro128PlusPlus { - if seed.iter().all(|&x| x == 0) { - return Self::seed_from_u64(0); - } let mut state = [0; 4]; read_u32_into(&seed, &mut state); + // Check for zero on aligned integers for better code generation. + // Furtermore, seed_from_u64(0) will expand to a constant when optimized. + if state.iter().all(|&x| x == 0) { + return Self::seed_from_u64(0); + } Xoshiro128PlusPlus { s: state } } /// Create a new `Xoshiro128PlusPlus` from a `u64` seed. /// /// This uses the SplitMix64 generator internally. + #[inline] fn seed_from_u64(mut state: u64) -> Self { const PHI: u64 = 0x9e3779b97f4a7c15; - let mut seed = Self::Seed::default(); - for chunk in seed.as_mut().chunks_mut(8) { + let mut s = [0; 4]; + for i in s.chunks_exact_mut(2) { state = state.wrapping_add(PHI); let mut z = state; z = (z ^ (z >> 30)).wrapping_mul(0xbf58476d1ce4e5b9); z = (z ^ (z >> 27)).wrapping_mul(0x94d049bb133111eb); z = z ^ (z >> 31); - chunk.copy_from_slice(&z.to_le_bytes()); + i[0] = z as u32; + i[1] = (z >> 32) as u32; } - Self::from_seed(seed) + // By using a non-zero PHI we are guaranteed to generate a non-zero state + // Thus preventing a recursion between from_seed and seed_from_u64. + debug_assert_ne!(s, [0; 4]); + Xoshiro128PlusPlus { s } } } impl RngCore for Xoshiro128PlusPlus { #[inline] fn next_u32(&mut self) -> u32 { - let result_starstar = self.s[0] + let res = self.s[0] .wrapping_add(self.s[3]) .rotate_left(7) .wrapping_add(self.s[0]); @@ -77,7 +85,7 @@ impl RngCore for Xoshiro128PlusPlus { self.s[3] = self.s[3].rotate_left(11); - result_starstar + res } #[inline] @@ -86,30 +94,39 @@ impl RngCore for Xoshiro128PlusPlus { } #[inline] - fn fill_bytes(&mut self, dest: &mut [u8]) { - fill_bytes_via_next(self, dest); - } - - #[inline] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.fill_bytes(dest); - Ok(()) + fn fill_bytes(&mut self, dst: &mut [u8]) { + fill_bytes_via_next(self, dst) } } #[cfg(test)] mod tests { - use super::*; + use super::Xoshiro128PlusPlus; + use rand_core::{RngCore, SeedableRng}; #[test] fn reference() { - let mut rng = Xoshiro128PlusPlus::from_seed( - [1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0]); + let mut rng = + Xoshiro128PlusPlus::from_seed([1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0]); // These values were produced with the reference implementation: // http://xoshiro.di.unimi.it/xoshiro128plusplus.c let expected = [ - 641, 1573767, 3222811527, 3517856514, 836907274, 4247214768, - 3867114732, 1355841295, 495546011, 621204420, + 641, 1573767, 3222811527, 3517856514, 836907274, 4247214768, 3867114732, 1355841295, + 495546011, 621204420, + ]; + for &e in &expected { + assert_eq!(rng.next_u32(), e); + } + } + + #[test] + fn stable_seed_from_u64() { + // We don't guarantee value-stability for SmallRng but this + // could influence keeping stability whenever possible (e.g. after optimizations). + let mut rng = Xoshiro128PlusPlus::seed_from_u64(0); + let expected = [ + 1179900579, 1938959192, 3089844957, 3657088315, 1015453891, 479942911, 3433842246, + 669252886, 3985671746, 2737205563, ]; for &e in &expected { assert_eq!(rng.next_u32(), e); diff --git a/src/rngs/xoshiro256plusplus.rs b/src/rngs/xoshiro256plusplus.rs index cd373c30669..7b39c6109a7 100644 --- a/src/rngs/xoshiro256plusplus.rs +++ b/src/rngs/xoshiro256plusplus.rs @@ -6,21 +6,22 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -#[cfg(feature="serde1")] use serde::{Serialize, Deserialize}; use rand_core::impls::fill_bytes_via_next; use rand_core::le::read_u64_into; -use rand_core::{SeedableRng, RngCore, Error}; +use rand_core::{RngCore, SeedableRng}; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; -/// A xoshiro256** random number generator. +/// A xoshiro256++ random number generator. /// -/// The xoshiro256** algorithm is not suitable for cryptographic purposes, but +/// The xoshiro256++ algorithm is not suitable for cryptographic purposes, but /// is very fast and has excellent statistical properties. /// /// The algorithm used here is translated from [the `xoshiro256plusplus.c` /// reference source code](http://xoshiro.di.unimi.it/xoshiro256plusplus.c) by /// David Blackman and Sebastiano Vigna. #[derive(Debug, Clone, PartialEq, Eq)] -#[cfg_attr(feature="serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Xoshiro256PlusPlus { s: [u64; 4], } @@ -32,29 +33,35 @@ impl SeedableRng for Xoshiro256PlusPlus { /// mapped to a different seed. #[inline] fn from_seed(seed: [u8; 32]) -> Xoshiro256PlusPlus { - if seed.iter().all(|&x| x == 0) { - return Self::seed_from_u64(0); - } let mut state = [0; 4]; read_u64_into(&seed, &mut state); + // Check for zero on aligned integers for better code generation. + // Furtermore, seed_from_u64(0) will expand to a constant when optimized. + if state.iter().all(|&x| x == 0) { + return Self::seed_from_u64(0); + } Xoshiro256PlusPlus { s: state } } /// Create a new `Xoshiro256PlusPlus` from a `u64` seed. /// /// This uses the SplitMix64 generator internally. + #[inline] fn seed_from_u64(mut state: u64) -> Self { const PHI: u64 = 0x9e3779b97f4a7c15; - let mut seed = Self::Seed::default(); - for chunk in seed.as_mut().chunks_mut(8) { + let mut s = [0; 4]; + for i in s.iter_mut() { state = state.wrapping_add(PHI); let mut z = state; z = (z ^ (z >> 30)).wrapping_mul(0xbf58476d1ce4e5b9); z = (z ^ (z >> 27)).wrapping_mul(0x94d049bb133111eb); z = z ^ (z >> 31); - chunk.copy_from_slice(&z.to_le_bytes()); + *i = z; } - Self::from_seed(seed) + // By using a non-zero PHI we are guaranteed to generate a non-zero state + // Thus preventing a recursion between from_seed and seed_from_u64. + debug_assert_ne!(s, [0; 4]); + Xoshiro256PlusPlus { s } } } @@ -63,12 +70,13 @@ impl RngCore for Xoshiro256PlusPlus { fn next_u32(&mut self) -> u32 { // The lowest bits have some linear dependencies, so we use the // upper bits instead. - (self.next_u64() >> 32) as u32 + let val = self.next_u64(); + (val >> 32) as u32 } #[inline] fn next_u64(&mut self) -> u64 { - let result_plusplus = self.s[0] + let res = self.s[0] .wrapping_add(self.s[3]) .rotate_left(23) .wrapping_add(self.s[0]); @@ -84,36 +92,61 @@ impl RngCore for Xoshiro256PlusPlus { self.s[3] = self.s[3].rotate_left(45); - result_plusplus - } - - #[inline] - fn fill_bytes(&mut self, dest: &mut [u8]) { - fill_bytes_via_next(self, dest); + res } #[inline] - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> { - self.fill_bytes(dest); - Ok(()) + fn fill_bytes(&mut self, dst: &mut [u8]) { + fill_bytes_via_next(self, dst) } } #[cfg(test)] mod tests { - use super::*; + use super::Xoshiro256PlusPlus; + use rand_core::{RngCore, SeedableRng}; #[test] fn reference() { - let mut rng = Xoshiro256PlusPlus::from_seed( - [1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, - 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0]); + let mut rng = Xoshiro256PlusPlus::from_seed([ + 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, + 0, 0, 0, + ]); // These values were produced with the reference implementation: // http://xoshiro.di.unimi.it/xoshiro256plusplus.c let expected = [ - 41943041, 58720359, 3588806011781223, 3591011842654386, - 9228616714210784205, 9973669472204895162, 14011001112246962877, - 12406186145184390807, 15849039046786891736, 10450023813501588000, + 41943041, + 58720359, + 3588806011781223, + 3591011842654386, + 9228616714210784205, + 9973669472204895162, + 14011001112246962877, + 12406186145184390807, + 15849039046786891736, + 10450023813501588000, + ]; + for &e in &expected { + assert_eq!(rng.next_u64(), e); + } + } + + #[test] + fn stable_seed_from_u64() { + // We don't guarantee value-stability for SmallRng but this + // could influence keeping stability whenever possible (e.g. after optimizations). + let mut rng = Xoshiro256PlusPlus::seed_from_u64(0); + let expected = [ + 5987356902031041503, + 7051070477665621255, + 6633766593972829180, + 211316841551650330, + 9136120204379184874, + 379361710973160858, + 15813423377499357806, + 15596884590815070553, + 5439680534584881407, + 1369371744833522710, ]; for &e in &expected { assert_eq!(rng.next_u64(), e); diff --git a/src/seq/coin_flipper.rs b/src/seq/coin_flipper.rs new file mode 100644 index 00000000000..7e8f53116ce --- /dev/null +++ b/src/seq/coin_flipper.rs @@ -0,0 +1,160 @@ +// Copyright 2018-2023 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use crate::RngCore; + +pub(crate) struct CoinFlipper { + pub rng: R, + chunk: u32, // TODO(opt): this should depend on RNG word size + chunk_remaining: u32, +} + +impl CoinFlipper { + pub fn new(rng: R) -> Self { + Self { + rng, + chunk: 0, + chunk_remaining: 0, + } + } + + #[inline] + /// Returns true with a probability of 1 / d + /// Uses an expected two bits of randomness + /// Panics if d == 0 + pub fn random_ratio_one_over(&mut self, d: usize) -> bool { + debug_assert_ne!(d, 0); + // This uses the same logic as `random_ratio` but is optimized for the case that + // the starting numerator is one (which it always is for `Sequence::Choose()`) + + // In this case (but not `random_ratio`), this way of calculating c is always accurate + let c = (usize::BITS - 1 - d.leading_zeros()).min(32); + + if self.flip_c_heads(c) { + let numerator = 1 << c; + self.random_ratio(numerator, d) + } else { + false + } + } + + #[inline] + /// Returns true with a probability of n / d + /// Uses an expected two bits of randomness + fn random_ratio(&mut self, mut n: usize, d: usize) -> bool { + // Explanation: + // We are trying to return true with a probability of n / d + // If n >= d, we can just return true + // Otherwise there are two possibilities 2n < d and 2n >= d + // In either case we flip a coin. + // If 2n < d + // If it comes up tails, return false + // If it comes up heads, double n and start again + // This is fair because (0.5 * 0) + (0.5 * 2n / d) = n / d and 2n is less than d + // (if 2n was greater than d we would effectively round it down to 1 + // by returning true) + // If 2n >= d + // If it comes up tails, set n to 2n - d and start again + // If it comes up heads, return true + // This is fair because (0.5 * 1) + (0.5 * (2n - d) / d) = n / d + // Note that if 2n = d and the coin comes up tails, n will be set to 0 + // before restarting which is equivalent to returning false. + + // As a performance optimization we can flip multiple coins at once + // This is efficient because we can use the `lzcnt` intrinsic + // We can check up to 32 flips at once but we only receive one bit of information + // - all heads or at least one tail. + + // Let c be the number of coins to flip. 1 <= c <= 32 + // If 2n < d, n * 2^c < d + // If the result is all heads, then set n to n * 2^c + // If there was at least one tail, return false + // If 2n >= d, the order of results matters so we flip one coin at a time so c = 1 + // Ideally, c will be as high as possible within these constraints + + while n < d { + // Find a good value for c by counting leading zeros + // This will either give the highest possible c, or 1 less than that + let c = n + .leading_zeros() + .saturating_sub(d.leading_zeros() + 1) + .clamp(1, 32); + + if self.flip_c_heads(c) { + // All heads + // Set n to n * 2^c + // If 2n >= d, the while loop will exit and we will return `true` + // If n * 2^c > `usize::MAX` we always return `true` anyway + n = n.saturating_mul(2_usize.pow(c)); + } else { + // At least one tail + if c == 1 { + // Calculate 2n - d. + // We need to use wrapping as 2n might be greater than `usize::MAX` + let next_n = n.wrapping_add(n).wrapping_sub(d); + if next_n == 0 || next_n > n { + // This will happen if 2n < d + return false; + } + n = next_n; + } else { + // c > 1 so 2n < d so we can return false + return false; + } + } + } + true + } + + /// If the next `c` bits of randomness all represent heads, consume them, return true + /// Otherwise return false and consume the number of heads plus one. + /// Generates new bits of randomness when necessary (in 32 bit chunks) + /// Has a 1 in 2 to the `c` chance of returning true + /// `c` must be less than or equal to 32 + fn flip_c_heads(&mut self, mut c: u32) -> bool { + debug_assert!(c <= 32); + // Note that zeros on the left of the chunk represent heads. + // It needs to be this way round because zeros are filled in when left shifting + loop { + let zeros = self.chunk.leading_zeros(); + + if zeros < c { + // The happy path - we found a 1 and can return false + // Note that because a 1 bit was detected, + // We cannot have run out of random bits so we don't need to check + + // First consume all of the bits read + // Using shl seems to give worse performance for size-hinted iterators + self.chunk = self.chunk.wrapping_shl(zeros + 1); + + self.chunk_remaining = self.chunk_remaining.saturating_sub(zeros + 1); + return false; + } else { + // The number of zeros is larger than `c` + // There are two possibilities + if let Some(new_remaining) = self.chunk_remaining.checked_sub(c) { + // Those zeroes were all part of our random chunk, + // throw away `c` bits of randomness and return true + self.chunk_remaining = new_remaining; + self.chunk <<= c; + return true; + } else { + // Some of those zeroes were part of the random chunk + // and some were part of the space behind it + // We need to take into account only the zeroes that were random + c -= self.chunk_remaining; + + // Generate a new chunk + self.chunk = self.rng.next_u32(); + self.chunk_remaining = 32; + // Go back to start of loop + } + } + } + } +} diff --git a/src/seq/increasing_uniform.rs b/src/seq/increasing_uniform.rs new file mode 100644 index 00000000000..10dd48a652a --- /dev/null +++ b/src/seq/increasing_uniform.rs @@ -0,0 +1,108 @@ +// Copyright 2018-2023 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use crate::{Rng, RngCore}; + +/// Similar to a Uniform distribution, +/// but after returning a number in the range [0,n], n is increased by 1. +pub(crate) struct IncreasingUniform { + pub rng: R, + n: u32, + // Chunk is a random number in [0, (n + 1) * (n + 2) *..* (n + chunk_remaining) ) + chunk: u32, + chunk_remaining: u8, +} + +impl IncreasingUniform { + /// Create a dice roller. + /// The next item returned will be a random number in the range [0,n] + pub fn new(rng: R, n: u32) -> Self { + // If n = 0, the first number returned will always be 0 + // so we don't need to generate a random number + let chunk_remaining = if n == 0 { 1 } else { 0 }; + Self { + rng, + n, + chunk: 0, + chunk_remaining, + } + } + + /// Returns a number in [0,n] and increments n by 1. + /// Generates new random bits as needed + /// Panics if `n >= u32::MAX` + #[inline] + pub fn next_index(&mut self) -> usize { + let next_n = self.n + 1; + + // There's room for further optimisation here: + // random_range uses rejection sampling (or other method; see #1196) to avoid bias. + // When the initial sample is biased for range 0..bound + // it may still be viable to use for a smaller bound + // (especially if small biases are considered acceptable). + + let next_chunk_remaining = self.chunk_remaining.checked_sub(1).unwrap_or_else(|| { + // If the chunk is empty, generate a new chunk + let (bound, remaining) = calculate_bound_u32(next_n); + // bound = (n + 1) * (n + 2) *..* (n + remaining) + self.chunk = self.rng.random_range(..bound); + // Chunk is a random number in + // [0, (n + 1) * (n + 2) *..* (n + remaining) ) + + remaining - 1 + }); + + let result = if next_chunk_remaining == 0 { + // `chunk` is a random number in the range [0..n+1) + // Because `chunk_remaining` is about to be set to zero + // we do not need to clear the chunk here + self.chunk as usize + } else { + // `chunk` is a random number in a range that is a multiple of n+1 + // so r will be a random number in [0..n+1) + let r = self.chunk % next_n; + self.chunk /= next_n; + r as usize + }; + + self.chunk_remaining = next_chunk_remaining; + self.n = next_n; + result + } +} + +#[inline] +/// Calculates `bound`, `count` such that bound (m)*(m+1)*..*(m + remaining - 1) +fn calculate_bound_u32(m: u32) -> (u32, u8) { + debug_assert!(m > 0); + #[inline] + const fn inner(m: u32) -> (u32, u8) { + let mut product = m; + let mut current = m + 1; + + loop { + if let Some(p) = u32::checked_mul(product, current) { + product = p; + current += 1; + } else { + // Count has a maximum value of 13 for when min is 1 or 2 + let count = (current - m) as u8; + return (product, count); + } + } + } + + const RESULT2: (u32, u8) = inner(2); + if m == 2 { + // Making this value a constant instead of recalculating it + // gives a significant (~50%) performance boost for small shuffles + return RESULT2; + } + + inner(m) +} diff --git a/src/seq/index.rs b/src/seq/index.rs index c09e5804229..852bdac76c4 100644 --- a/src/seq/index.rs +++ b/src/seq/index.rs @@ -7,50 +7,56 @@ // except according to those terms. //! Low-level API for sampling indices - -#[cfg(feature = "alloc")] use core::slice; - -#[cfg(feature = "alloc")] use alloc::vec::{self, Vec}; +use alloc::vec::{self, Vec}; +use core::slice; +use core::{hash::Hash, ops::AddAssign}; // BTreeMap is not as fast in tests, but better than nothing. -#[cfg(all(feature = "alloc", not(feature = "std")))] -use alloc::collections::BTreeSet; -#[cfg(feature = "std")] use std::collections::HashSet; - -#[cfg(feature = "alloc")] -use crate::distributions::{uniform::SampleUniform, Distribution, Uniform, WeightedError}; +#[cfg(feature = "std")] +use super::WeightError; +use crate::distr::uniform::SampleUniform; +use crate::distr::{Distribution, Uniform}; use crate::Rng; +#[cfg(not(feature = "std"))] +use alloc::collections::BTreeSet; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; +#[cfg(feature = "std")] +use std::collections::HashSet; -#[cfg(feature = "serde1")] -use serde::{Serialize, Deserialize}; +#[cfg(not(any(target_pointer_width = "32", target_pointer_width = "64")))] +compile_error!("unsupported pointer width"); /// A vector of indices. /// /// Multiple internal representations are possible. #[derive(Clone, Debug)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum IndexVec { #[doc(hidden)] U32(Vec), + #[cfg(target_pointer_width = "64")] #[doc(hidden)] - USize(Vec), + U64(Vec), } impl IndexVec { /// Returns the number of indices #[inline] pub fn len(&self) -> usize { - match *self { - IndexVec::U32(ref v) => v.len(), - IndexVec::USize(ref v) => v.len(), + match self { + IndexVec::U32(v) => v.len(), + #[cfg(target_pointer_width = "64")] + IndexVec::U64(v) => v.len(), } } /// Returns `true` if the length is 0. #[inline] pub fn is_empty(&self) -> bool { - match *self { - IndexVec::U32(ref v) => v.is_empty(), - IndexVec::USize(ref v) => v.is_empty(), + match self { + IndexVec::U32(v) => v.is_empty(), + #[cfg(target_pointer_width = "64")] + IndexVec::U64(v) => v.is_empty(), } } @@ -60,9 +66,10 @@ impl IndexVec { /// restrictions.) #[inline] pub fn index(&self, index: usize) -> usize { - match *self { - IndexVec::U32(ref v) => v[index] as usize, - IndexVec::USize(ref v) => v[index], + match self { + IndexVec::U32(v) => v[index] as usize, + #[cfg(target_pointer_width = "64")] + IndexVec::U64(v) => v[index] as usize, } } @@ -71,30 +78,33 @@ impl IndexVec { pub fn into_vec(self) -> Vec { match self { IndexVec::U32(v) => v.into_iter().map(|i| i as usize).collect(), - IndexVec::USize(v) => v, + #[cfg(target_pointer_width = "64")] + IndexVec::U64(v) => v.into_iter().map(|i| i as usize).collect(), } } /// Iterate over the indices as a sequence of `usize` values #[inline] pub fn iter(&self) -> IndexVecIter<'_> { - match *self { - IndexVec::U32(ref v) => IndexVecIter::U32(v.iter()), - IndexVec::USize(ref v) => IndexVecIter::USize(v.iter()), + match self { + IndexVec::U32(v) => IndexVecIter::U32(v.iter()), + #[cfg(target_pointer_width = "64")] + IndexVec::U64(v) => IndexVecIter::U64(v.iter()), } } } impl IntoIterator for IndexVec { - type Item = usize; type IntoIter = IndexVecIntoIter; + type Item = usize; /// Convert into an iterator over the indices as a sequence of `usize` values #[inline] fn into_iter(self) -> IndexVecIntoIter { match self { IndexVec::U32(v) => IndexVecIntoIter::U32(v.into_iter()), - IndexVec::USize(v) => IndexVecIntoIter::USize(v.into_iter()), + #[cfg(target_pointer_width = "64")] + IndexVec::U64(v) => IndexVecIntoIter::U64(v.into_iter()), } } } @@ -103,13 +113,16 @@ impl PartialEq for IndexVec { fn eq(&self, other: &IndexVec) -> bool { use self::IndexVec::*; match (self, other) { - (&U32(ref v1), &U32(ref v2)) => v1 == v2, - (&USize(ref v1), &USize(ref v2)) => v1 == v2, - (&U32(ref v1), &USize(ref v2)) => { - (v1.len() == v2.len()) && (v1.iter().zip(v2.iter()).all(|(x, y)| *x as usize == *y)) + (U32(v1), U32(v2)) => v1 == v2, + #[cfg(target_pointer_width = "64")] + (U64(v1), U64(v2)) => v1 == v2, + #[cfg(target_pointer_width = "64")] + (U32(v1), U64(v2)) => { + (v1.len() == v2.len()) && (v1.iter().zip(v2.iter()).all(|(x, y)| *x as u64 == *y)) } - (&USize(ref v1), &U32(ref v2)) => { - (v1.len() == v2.len()) && (v1.iter().zip(v2.iter()).all(|(x, y)| *x == *y as usize)) + #[cfg(target_pointer_width = "64")] + (U64(v1), U32(v2)) => { + (v1.len() == v2.len()) && (v1.iter().zip(v2.iter()).all(|(x, y)| *x == *y as u64)) } } } @@ -122,10 +135,11 @@ impl From> for IndexVec { } } -impl From> for IndexVec { +#[cfg(target_pointer_width = "64")] +impl From> for IndexVec { #[inline] - fn from(v: Vec) -> Self { - IndexVec::USize(v) + fn from(v: Vec) -> Self { + IndexVec::U64(v) } } @@ -134,40 +148,44 @@ impl From> for IndexVec { pub enum IndexVecIter<'a> { #[doc(hidden)] U32(slice::Iter<'a, u32>), + #[cfg(target_pointer_width = "64")] #[doc(hidden)] - USize(slice::Iter<'a, usize>), + U64(slice::Iter<'a, u64>), } -impl<'a> Iterator for IndexVecIter<'a> { +impl Iterator for IndexVecIter<'_> { type Item = usize; #[inline] fn next(&mut self) -> Option { use self::IndexVecIter::*; - match *self { - U32(ref mut iter) => iter.next().map(|i| *i as usize), - USize(ref mut iter) => iter.next().cloned(), + match self { + U32(iter) => iter.next().map(|i| *i as usize), + #[cfg(target_pointer_width = "64")] + U64(iter) => iter.next().map(|i| *i as usize), } } #[inline] fn size_hint(&self) -> (usize, Option) { - match *self { - IndexVecIter::U32(ref v) => v.size_hint(), - IndexVecIter::USize(ref v) => v.size_hint(), + match self { + IndexVecIter::U32(v) => v.size_hint(), + #[cfg(target_pointer_width = "64")] + IndexVecIter::U64(v) => v.size_hint(), } } } -impl<'a> ExactSizeIterator for IndexVecIter<'a> {} +impl ExactSizeIterator for IndexVecIter<'_> {} /// Return type of `IndexVec::into_iter`. #[derive(Clone, Debug)] pub enum IndexVecIntoIter { #[doc(hidden)] U32(vec::IntoIter), + #[cfg(target_pointer_width = "64")] #[doc(hidden)] - USize(vec::IntoIter), + U64(vec::IntoIter), } impl Iterator for IndexVecIntoIter { @@ -176,25 +194,26 @@ impl Iterator for IndexVecIntoIter { #[inline] fn next(&mut self) -> Option { use self::IndexVecIntoIter::*; - match *self { - U32(ref mut v) => v.next().map(|i| i as usize), - USize(ref mut v) => v.next(), + match self { + U32(v) => v.next().map(|i| i as usize), + #[cfg(target_pointer_width = "64")] + U64(v) => v.next().map(|i| i as usize), } } #[inline] fn size_hint(&self) -> (usize, Option) { use self::IndexVecIntoIter::*; - match *self { - U32(ref v) => v.size_hint(), - USize(ref v) => v.size_hint(), + match self { + U32(v) => v.size_hint(), + #[cfg(target_pointer_width = "64")] + U64(v) => v.size_hint(), } } } impl ExactSizeIterator for IndexVecIntoIter {} - /// Randomly sample exactly `amount` distinct indices from `0..length`, and /// return them in random order (fully shuffled). /// @@ -217,15 +236,22 @@ impl ExactSizeIterator for IndexVecIntoIter {} /// to adapt the internal `sample_floyd` implementation. /// /// Panics if `amount > length`. +#[track_caller] pub fn sample(rng: &mut R, length: usize, amount: usize) -> IndexVec -where R: Rng + ?Sized { +where + R: Rng + ?Sized, +{ if amount > length { panic!("`amount` of samples must be less than or equal to `length`"); } - if length > (::core::u32::MAX as usize) { + if length > (u32::MAX as usize) { + #[cfg(target_pointer_width = "32")] + unreachable!(); + // We never want to use inplace here, but could use floyd's alg // Lazy version: always use the cache alg. - return sample_rejection(rng, length, amount); + #[cfg(target_pointer_width = "64")] + return sample_rejection(rng, length as u64, amount as u64); } let amount = amount as u32; let length = length as u32; @@ -236,7 +262,7 @@ where R: Rng + ?Sized { if amount < 163 { const C: [[f32; 2]; 2] = [[1.6, 8.0 / 45.0], [10.0, 70.0 / 9.0]]; - let j = if length < 500_000 { 0 } else { 1 }; + let j = usize::from(length >= 500_000); let amount_fp = amount as f32; let m4 = C[0][j] * amount_fp; // Short-cut: when amount < 12, floyd's is always faster @@ -247,7 +273,7 @@ where R: Rng + ?Sized { } } else { const C: [f32; 2] = [270.0, 330.0 / 9.0]; - let j = if length < 500_000 { 0 } else { 1 }; + let j = usize::from(length >= 500_000); if (length as f32) < C[j] * (amount as f32) { sample_inplace(rng, length, amount) } else { @@ -256,54 +282,71 @@ where R: Rng + ?Sized { } } -/// Randomly sample exactly `amount` distinct indices from `0..length`, and -/// return them in an arbitrary order (there is no guarantee of shuffling or -/// ordering). The weights are to be provided by the input function `weights`, -/// which will be called once for each index. +/// Randomly sample exactly `amount` distinct indices from `0..length` +/// +/// Results are in arbitrary order (there is no guarantee of shuffling or +/// ordering). +/// +/// Function `weight` is called once for each index to provide weights. /// /// This method is used internally by the slice sampling methods, but it can /// sometimes be useful to have the indices themselves so this is provided as /// an alternative. /// -/// This implementation uses `O(length + amount)` space and `O(length)` time -/// if the "nightly" feature is enabled, or `O(length)` space and -/// `O(length + amount * log length)` time otherwise. +/// Error cases: +/// - [`WeightError::InvalidWeight`] when a weight is not-a-number or negative. +/// - [`WeightError::InsufficientNonZero`] when fewer than `amount` weights are positive. /// -/// Panics if `amount > length`. +/// This implementation uses `O(length + amount)` space and `O(length)` time. +#[cfg(feature = "std")] pub fn sample_weighted( - rng: &mut R, length: usize, weight: F, amount: usize, -) -> Result + rng: &mut R, + length: usize, + weight: F, + amount: usize, +) -> Result where R: Rng + ?Sized, F: Fn(usize) -> X, X: Into, { - if length > (core::u32::MAX as usize) { - sample_efraimidis_spirakis(rng, length, weight, amount) + if length > (u32::MAX as usize) { + #[cfg(target_pointer_width = "32")] + unreachable!(); + + #[cfg(target_pointer_width = "64")] + { + let amount = amount as u64; + let length = length as u64; + sample_efraimidis_spirakis(rng, length, weight, amount) + } } else { - assert!(amount <= core::u32::MAX as usize); + assert!(amount <= u32::MAX as usize); let amount = amount as u32; let length = length as u32; sample_efraimidis_spirakis(rng, length, weight, amount) } } - /// Randomly sample exactly `amount` distinct indices from `0..length`, and /// return them in an arbitrary order (there is no guarantee of shuffling or /// ordering). The weights are to be provided by the input function `weights`, /// which will be called once for each index. /// -/// This implementation uses the algorithm described by Efraimidis and Spirakis -/// in this paper: https://doi.org/10.1016/j.ipl.2005.11.003 -/// It uses `O(length + amount)` space and `O(length)` time if the -/// "nightly" feature is enabled, or `O(length)` space and `O(length -/// + amount * log length)` time otherwise. +/// This implementation is based on the algorithm A-ExpJ as found in +/// [Efraimidis and Spirakis, 2005](https://doi.org/10.1016/j.ipl.2005.11.003). +/// It uses `O(length + amount)` space and `O(length)` time. /// -/// Panics if `amount > length`. +/// Error cases: +/// - [`WeightError::InvalidWeight`] when a weight is not-a-number or negative. +/// - [`WeightError::InsufficientNonZero`] when fewer than `amount` weights are positive. +#[cfg(feature = "std")] fn sample_efraimidis_spirakis( - rng: &mut R, length: N, weight: F, amount: N, -) -> Result + rng: &mut R, + length: N, + weight: F, + amount: N, +) -> Result where R: Rng + ?Sized, F: Fn(usize) -> X, @@ -311,97 +354,82 @@ where N: UInt, IndexVec: From>, { + use std::{cmp::Ordering, collections::BinaryHeap}; + if amount == N::zero() { return Ok(IndexVec::U32(Vec::new())); } - if amount > length { - panic!("`amount` of samples must be less than or equal to `length`"); - } - struct Element { index: N, key: f64, } + impl PartialOrd for Element { - fn partial_cmp(&self, other: &Self) -> Option { - self.key.partial_cmp(&other.key) + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) } } + impl Ord for Element { - fn cmp(&self, other: &Self) -> core::cmp::Ordering { - // partial_cmp will always produce a value, - // because we check that the weights are not nan - self.partial_cmp(other).unwrap() + fn cmp(&self, other: &Self) -> Ordering { + // unwrap() should not panic since weights should not be NaN + // We reverse so that BinaryHeap::peek shows the smallest item + self.key.partial_cmp(&other.key).unwrap().reverse() } } + impl PartialEq for Element { fn eq(&self, other: &Self) -> bool { self.key == other.key } } - impl Eq for Element {} - #[cfg(feature = "nightly")] - { - let mut candidates = Vec::with_capacity(length.as_usize()); - let mut index = N::zero(); - while index < length { - let weight = weight(index.as_usize()).into(); - if !(weight >= 0.) { - return Err(WeightedError::InvalidWeight); - } + impl Eq for Element {} - let key = rng.gen::().powf(1.0 / weight); + let mut candidates = BinaryHeap::with_capacity(amount.as_usize()); + let mut index = N::zero(); + while index < length && candidates.len() < amount.as_usize() { + let weight = weight(index.as_usize()).into(); + if weight > 0.0 { + // We use the log of the key used in A-ExpJ to improve precision + // for small weights: + let key = rng.random::().ln() / weight; candidates.push(Element { index, key }); - - index += N::one(); + } else if !(weight >= 0.0) { + return Err(WeightError::InvalidWeight); } - // Partially sort the array to find the `amount` elements with the greatest - // keys. Do this by using `select_nth_unstable` to put the elements with - // the *smallest* keys at the beginning of the list in `O(n)` time, which - // provides equivalent information about the elements with the *greatest* keys. - let (_, mid, greater) - = candidates.select_nth_unstable(length.as_usize() - amount.as_usize()); - - let mut result: Vec = Vec::with_capacity(amount.as_usize()); - result.push(mid.index); - for element in greater { - result.push(element.index); - } - Ok(IndexVec::from(result)) - } - - #[cfg(not(feature = "nightly"))] - { - #[cfg(all(feature = "alloc", not(feature = "std")))] - use crate::alloc::collections::BinaryHeap; - #[cfg(feature = "std")] - use std::collections::BinaryHeap; - - // Partially sort the array such that the `amount` elements with the largest - // keys are first using a binary max heap. - let mut candidates = BinaryHeap::with_capacity(length.as_usize()); - let mut index = N::zero(); - while index < length { - let weight = weight(index.as_usize()).into(); - if !(weight >= 0.) { - return Err(WeightedError::InvalidWeight); - } + index += N::one(); + } - let key = rng.gen::().powf(1.0 / weight); - candidates.push(Element { index, key }); + if candidates.len() < amount.as_usize() { + return Err(WeightError::InsufficientNonZero); + } - index += N::one(); - } + let mut x = rng.random::().ln() / candidates.peek().unwrap().key; + while index < length { + let weight = weight(index.as_usize()).into(); + if weight > 0.0 { + x -= weight; + if x <= 0.0 { + let min_candidate = candidates.pop().unwrap(); + let t = (min_candidate.key * weight).exp(); + let key = rng.random_range(t..1.0).ln() / weight; + candidates.push(Element { index, key }); - let mut result: Vec = Vec::with_capacity(amount.as_usize()); - while result.len() < amount.as_usize() { - result.push(candidates.pop().unwrap().index); + x = rng.random::().ln() / candidates.peek().unwrap().key; + } + } else if !(weight >= 0.0) { + return Err(WeightError::InvalidWeight); } - Ok(IndexVec::from(result)) + + index += N::one(); } + + Ok(IndexVec::from( + candidates.iter().map(|elt| elt.index).collect(), + )) } /// Randomly sample exactly `amount` indices from `0..length`, using Floyd's @@ -411,34 +439,21 @@ where /// /// This implementation uses `O(amount)` memory and `O(amount^2)` time. fn sample_floyd(rng: &mut R, length: u32, amount: u32) -> IndexVec -where R: Rng + ?Sized { - // For small amount we use Floyd's fully-shuffled variant. For larger - // amounts this is slow due to Vec::insert performance, so we shuffle - // afterwards. Benchmarks show little overhead from extra logic. - let floyd_shuffle = amount < 50; - +where + R: Rng + ?Sized, +{ + // Note that the values returned by `rng.random_range()` can be + // inferred from the returned vector by working backwards from + // the last entry. This bijection proves the algorithm fair. debug_assert!(amount <= length); let mut indices = Vec::with_capacity(amount as usize); for j in length - amount..length { - let t = rng.gen_range(0..=j); - if floyd_shuffle { - if let Some(pos) = indices.iter().position(|&x| x == t) { - indices.insert(pos, j); - continue; - } - } else if indices.contains(&t) { - indices.push(j); - continue; + let t = rng.random_range(..=j); + if let Some(pos) = indices.iter().position(|&x| x == t) { + indices[pos] = j; } indices.push(t); } - if !floyd_shuffle { - // Reimplement SliceRandom::shuffle with smaller indices - for i in (1..amount).rev() { - // invariant: elements with index > i have been locked in place. - indices.swap(i as usize, rng.gen_range(0..=i) as usize); - } - } IndexVec::from(indices) } @@ -455,12 +470,14 @@ where R: Rng + ?Sized { /// /// Set-up is `O(length)` time and memory and shuffling is `O(amount)` time. fn sample_inplace(rng: &mut R, length: u32, amount: u32) -> IndexVec -where R: Rng + ?Sized { +where + R: Rng + ?Sized, +{ debug_assert!(amount <= length); let mut indices: Vec = Vec::with_capacity(length as usize); indices.extend(0..length); for i in 0..amount { - let j: u32 = rng.gen_range(i..length); + let j: u32 = rng.random_range(i..length); indices.swap(i as usize, j as usize); } indices.truncate(amount as usize); @@ -468,12 +485,13 @@ where R: Rng + ?Sized { IndexVec::from(indices) } -trait UInt: Copy + PartialOrd + Ord + PartialEq + Eq + SampleUniform - + core::hash::Hash + core::ops::AddAssign { +trait UInt: Copy + PartialOrd + Ord + PartialEq + Eq + SampleUniform + Hash + AddAssign { fn zero() -> Self; + #[cfg_attr(feature = "alloc", allow(dead_code))] fn one() -> Self; fn as_usize(self) -> usize; } + impl UInt for u32 { #[inline] fn zero() -> Self { @@ -490,7 +508,9 @@ impl UInt for u32 { self as usize } } -impl UInt for usize { + +#[cfg(target_pointer_width = "64")] +impl UInt for u64 { #[inline] fn zero() -> Self { 0 @@ -503,7 +523,7 @@ impl UInt for usize { #[inline] fn as_usize(self) -> usize { - self + self as usize } } @@ -526,7 +546,7 @@ where let mut cache = HashSet::with_capacity(amount.as_usize()); #[cfg(not(feature = "std"))] let mut cache = BTreeSet::new(); - let distr = Uniform::new(X::zero(), length); + let distr = Uniform::new(X::zero(), length).unwrap(); let mut indices = Vec::with_capacity(amount.as_usize()); for _ in 0..amount.as_usize() { let mut pos = distr.sample(rng); @@ -543,25 +563,17 @@ where #[cfg(test)] mod test { use super::*; + use alloc::vec; #[test] - #[cfg(feature = "serde1")] + #[cfg(feature = "serde")] fn test_serialization_index_vec() { - let some_index_vec = IndexVec::from(vec![254_usize, 234, 2, 1]); - let de_some_index_vec: IndexVec = bincode::deserialize(&bincode::serialize(&some_index_vec).unwrap()).unwrap(); - match (some_index_vec, de_some_index_vec) { - (IndexVec::U32(a), IndexVec::U32(b)) => { - assert_eq!(a, b); - }, - (IndexVec::USize(a), IndexVec::USize(b)) => { - assert_eq!(a, b); - }, - _ => {panic!("failed to seralize/deserialize `IndexVec`")} - } + let some_index_vec = IndexVec::from(vec![254_u32, 234, 2, 1]); + let de_some_index_vec: IndexVec = + bincode::deserialize(&bincode::serialize(&some_index_vec).unwrap()).unwrap(); + assert_eq!(some_index_vec, de_some_index_vec); } - #[cfg(feature = "alloc")] use alloc::vec; - #[test] fn test_sample_boundaries() { let mut r = crate::test::rng(404); @@ -619,24 +631,29 @@ mod test { assert_eq!(v1, v2); } + #[cfg(feature = "std")] #[test] fn test_sample_weighted() { let seed_rng = crate::test::rng; - for &(amount, len) in &[(0, 10), (5, 10), (10, 10)] { + for &(amount, len) in &[(0, 10), (5, 10), (9, 10)] { let v = sample_weighted(&mut seed_rng(423), len, |i| i as f64, amount).unwrap(); match v { IndexVec::U32(mut indices) => { assert_eq!(indices.len(), amount); - indices.sort(); + indices.sort_unstable(); indices.dedup(); assert_eq!(indices.len(), amount); for &i in &indices { assert!((i as usize) < len); } - }, - IndexVec::USize(_) => panic!("expected `IndexVec::U32`"), + } + #[cfg(target_pointer_width = "64")] + _ => panic!("expected `IndexVec::U32`"), } } + + let r = sample_weighted(&mut seed_rng(423), 10, |i| i as f64, 10); + assert_eq!(r.unwrap_err(), WeightError::InsufficientNonZero); } #[test] @@ -659,17 +676,21 @@ mod test { ); }; - do_test(10, 6, &[8, 0, 3, 5, 9, 6]); // floyd - do_test(25, 10, &[18, 15, 14, 9, 0, 13, 5, 24]); // floyd - do_test(300, 8, &[30, 283, 150, 1, 73, 13, 285, 35]); // floyd - do_test(300, 80, &[31, 289, 248, 154, 5, 78, 19, 286]); // inplace - do_test(300, 180, &[31, 289, 248, 154, 5, 78, 19, 286]); // inplace - - do_test(1000_000, 8, &[ - 103717, 963485, 826422, 509101, 736394, 807035, 5327, 632573, - ]); // floyd - do_test(1000_000, 180, &[ - 103718, 963490, 826426, 509103, 736396, 807036, 5327, 632573, - ]); // rejection + do_test(10, 6, &[0, 9, 5, 4, 6, 8]); // floyd + do_test(25, 10, &[24, 20, 19, 9, 22, 16, 0, 14]); // floyd + do_test(300, 8, &[30, 283, 243, 150, 218, 240, 1, 189]); // floyd + do_test(300, 80, &[31, 289, 248, 154, 221, 243, 7, 192]); // inplace + do_test(300, 180, &[31, 289, 248, 154, 221, 243, 7, 192]); // inplace + + do_test( + 1_000_000, + 8, + &[103717, 963485, 826422, 509101, 736394, 807035, 5327, 632573], + ); // floyd + do_test( + 1_000_000, + 180, + &[103718, 963490, 826426, 509103, 736396, 807036, 5327, 632573], + ); // rejection } } diff --git a/src/seq/iterator.rs b/src/seq/iterator.rs new file mode 100644 index 00000000000..b10d205676a --- /dev/null +++ b/src/seq/iterator.rs @@ -0,0 +1,664 @@ +// Copyright 2018-2024 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! `IteratorRandom` + +use super::coin_flipper::CoinFlipper; +#[allow(unused)] +use super::IndexedRandom; +use crate::Rng; +#[cfg(feature = "alloc")] +use alloc::vec::Vec; + +/// Extension trait on iterators, providing random sampling methods. +/// +/// This trait is implemented on all iterators `I` where `I: Iterator + Sized` +/// and provides methods for +/// choosing one or more elements. You must `use` this trait: +/// +/// ``` +/// use rand::seq::IteratorRandom; +/// +/// let faces = "😀😎😐😕😠😢"; +/// println!("I am {}!", faces.chars().choose(&mut rand::rng()).unwrap()); +/// ``` +/// Example output (non-deterministic): +/// ```none +/// I am 😀! +/// ``` +pub trait IteratorRandom: Iterator + Sized { + /// Uniformly sample one element + /// + /// Assuming that the [`Iterator::size_hint`] is correct, this method + /// returns one uniformly-sampled random element of the slice, or `None` + /// only if the slice is empty. Incorrect bounds on the `size_hint` may + /// cause this method to incorrectly return `None` if fewer elements than + /// the advertised `lower` bound are present and may prevent sampling of + /// elements beyond an advertised `upper` bound (i.e. incorrect `size_hint` + /// is memory-safe, but may result in unexpected `None` result and + /// non-uniform distribution). + /// + /// With an accurate [`Iterator::size_hint`] and where [`Iterator::nth`] is + /// a constant-time operation, this method can offer `O(1)` performance. + /// Where no size hint is + /// available, complexity is `O(n)` where `n` is the iterator length. + /// Partial hints (where `lower > 0`) also improve performance. + /// + /// Note further that [`Iterator::size_hint`] may affect the number of RNG + /// samples used as well as the result (while remaining uniform sampling). + /// Consider instead using [`IteratorRandom::choose_stable`] to avoid + /// [`Iterator`] combinators which only change size hints from affecting the + /// results. + /// + /// # Example + /// + /// ``` + /// use rand::seq::IteratorRandom; + /// + /// let words = "Mary had a little lamb".split(' '); + /// println!("{}", words.choose(&mut rand::rng()).unwrap()); + /// ``` + fn choose(mut self, rng: &mut R) -> Option + where + R: Rng + ?Sized, + { + let (mut lower, mut upper) = self.size_hint(); + let mut result = None; + + // Handling for this condition outside the loop allows the optimizer to eliminate the loop + // when the Iterator is an ExactSizeIterator. This has a large performance impact on e.g. + // seq_iter_choose_from_1000. + if upper == Some(lower) { + return match lower { + 0 => None, + 1 => self.next(), + _ => self.nth(rng.random_range(..lower)), + }; + } + + let mut coin_flipper = CoinFlipper::new(rng); + let mut consumed = 0; + + // Continue until the iterator is exhausted + loop { + if lower > 1 { + let ix = coin_flipper.rng.random_range(..lower + consumed); + let skip = if ix < lower { + result = self.nth(ix); + lower - (ix + 1) + } else { + lower + }; + if upper == Some(lower) { + return result; + } + consumed += lower; + if skip > 0 { + self.nth(skip - 1); + } + } else { + let elem = self.next(); + if elem.is_none() { + return result; + } + consumed += 1; + if coin_flipper.random_ratio_one_over(consumed) { + result = elem; + } + } + + let hint = self.size_hint(); + lower = hint.0; + upper = hint.1; + } + } + + /// Uniformly sample one element (stable) + /// + /// This method is very similar to [`choose`] except that the result + /// only depends on the length of the iterator and the values produced by + /// `rng`. Notably for any iterator of a given length this will make the + /// same requests to `rng` and if the same sequence of values are produced + /// the same index will be selected from `self`. This may be useful if you + /// need consistent results no matter what type of iterator you are working + /// with. If you do not need this stability prefer [`choose`]. + /// + /// Note that this method still uses [`Iterator::size_hint`] to skip + /// constructing elements where possible, however the selection and `rng` + /// calls are the same in the face of this optimization. If you want to + /// force every element to be created regardless call `.inspect(|e| ())`. + /// + /// [`choose`]: IteratorRandom::choose + fn choose_stable(mut self, rng: &mut R) -> Option + where + R: Rng + ?Sized, + { + let mut consumed = 0; + let mut result = None; + let mut coin_flipper = CoinFlipper::new(rng); + + loop { + // Currently the only way to skip elements is `nth()`. So we need to + // store what index to access next here. + // This should be replaced by `advance_by()` once it is stable: + // https://github.com/rust-lang/rust/issues/77404 + let mut next = 0; + + let (lower, _) = self.size_hint(); + if lower >= 2 { + let highest_selected = (0..lower) + .filter(|ix| coin_flipper.random_ratio_one_over(consumed + ix + 1)) + .last(); + + consumed += lower; + next = lower; + + if let Some(ix) = highest_selected { + result = self.nth(ix); + next -= ix + 1; + debug_assert!(result.is_some(), "iterator shorter than size_hint().0"); + } + } + + let elem = self.nth(next); + if elem.is_none() { + return result; + } + + if coin_flipper.random_ratio_one_over(consumed + 1) { + result = elem; + } + consumed += 1; + } + } + + /// Uniformly sample `amount` distinct elements into a buffer + /// + /// Collects values at random from the iterator into a supplied buffer + /// until that buffer is filled. + /// + /// Although the elements are selected randomly, the order of elements in + /// the buffer is neither stable nor fully random. If random ordering is + /// desired, shuffle the result. + /// + /// Returns the number of elements added to the buffer. This equals the length + /// of the buffer unless the iterator contains insufficient elements, in which + /// case this equals the number of elements available. + /// + /// Complexity is `O(n)` where `n` is the length of the iterator. + /// For slices, prefer [`IndexedRandom::choose_multiple`]. + fn choose_multiple_fill(mut self, rng: &mut R, buf: &mut [Self::Item]) -> usize + where + R: Rng + ?Sized, + { + let amount = buf.len(); + let mut len = 0; + while len < amount { + if let Some(elem) = self.next() { + buf[len] = elem; + len += 1; + } else { + // Iterator exhausted; stop early + return len; + } + } + + // Continue, since the iterator was not exhausted + for (i, elem) in self.enumerate() { + let k = rng.random_range(..i + 1 + amount); + if let Some(slot) = buf.get_mut(k) { + *slot = elem; + } + } + len + } + + /// Uniformly sample `amount` distinct elements into a [`Vec`] + /// + /// This is equivalent to `choose_multiple_fill` except for the result type. + /// + /// Although the elements are selected randomly, the order of elements in + /// the buffer is neither stable nor fully random. If random ordering is + /// desired, shuffle the result. + /// + /// The length of the returned vector equals `amount` unless the iterator + /// contains insufficient elements, in which case it equals the number of + /// elements available. + /// + /// Complexity is `O(n)` where `n` is the length of the iterator. + /// For slices, prefer [`IndexedRandom::choose_multiple`]. + #[cfg(feature = "alloc")] + fn choose_multiple(mut self, rng: &mut R, amount: usize) -> Vec + where + R: Rng + ?Sized, + { + let mut reservoir = Vec::with_capacity(amount); + reservoir.extend(self.by_ref().take(amount)); + + // Continue unless the iterator was exhausted + // + // note: this prevents iterators that "restart" from causing problems. + // If the iterator stops once, then so do we. + if reservoir.len() == amount { + for (i, elem) in self.enumerate() { + let k = rng.random_range(..i + 1 + amount); + if let Some(slot) = reservoir.get_mut(k) { + *slot = elem; + } + } + } else { + // Don't hang onto extra memory. There is a corner case where + // `amount` was much less than `self.len()`. + reservoir.shrink_to_fit(); + } + reservoir + } +} + +impl IteratorRandom for I where I: Iterator + Sized {} + +#[cfg(test)] +mod test { + use super::*; + #[cfg(all(feature = "alloc", not(feature = "std")))] + use alloc::vec::Vec; + + #[derive(Clone)] + struct UnhintedIterator { + iter: I, + } + impl Iterator for UnhintedIterator { + type Item = I::Item; + + fn next(&mut self) -> Option { + self.iter.next() + } + } + + #[derive(Clone)] + struct ChunkHintedIterator { + iter: I, + chunk_remaining: usize, + chunk_size: usize, + hint_total_size: bool, + } + impl Iterator for ChunkHintedIterator { + type Item = I::Item; + + fn next(&mut self) -> Option { + if self.chunk_remaining == 0 { + self.chunk_remaining = core::cmp::min(self.chunk_size, self.iter.len()); + } + self.chunk_remaining = self.chunk_remaining.saturating_sub(1); + + self.iter.next() + } + + fn size_hint(&self) -> (usize, Option) { + ( + self.chunk_remaining, + if self.hint_total_size { + Some(self.iter.len()) + } else { + None + }, + ) + } + } + + #[derive(Clone)] + struct WindowHintedIterator { + iter: I, + window_size: usize, + hint_total_size: bool, + } + impl Iterator for WindowHintedIterator { + type Item = I::Item; + + fn next(&mut self) -> Option { + self.iter.next() + } + + fn size_hint(&self) -> (usize, Option) { + ( + core::cmp::min(self.iter.len(), self.window_size), + if self.hint_total_size { + Some(self.iter.len()) + } else { + None + }, + ) + } + } + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_iterator_choose() { + let r = &mut crate::test::rng(109); + fn test_iter + Clone>(r: &mut R, iter: Iter) { + let mut chosen = [0i32; 9]; + for _ in 0..1000 { + let picked = iter.clone().choose(r).unwrap(); + chosen[picked] += 1; + } + for count in chosen.iter() { + // Samples should follow Binomial(1000, 1/9) + // Octave: binopdf(x, 1000, 1/9) gives the prob of *count == x + // Note: have seen 153, which is unlikely but not impossible. + assert!( + 72 < *count && *count < 154, + "count not close to 1000/9: {}", + count + ); + } + } + + test_iter(r, 0..9); + test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()); + #[cfg(feature = "alloc")] + test_iter(r, (0..9).collect::>().into_iter()); + test_iter(r, UnhintedIterator { iter: 0..9 }); + test_iter( + r, + ChunkHintedIterator { + iter: 0..9, + chunk_size: 4, + chunk_remaining: 4, + hint_total_size: false, + }, + ); + test_iter( + r, + ChunkHintedIterator { + iter: 0..9, + chunk_size: 4, + chunk_remaining: 4, + hint_total_size: true, + }, + ); + test_iter( + r, + WindowHintedIterator { + iter: 0..9, + window_size: 2, + hint_total_size: false, + }, + ); + test_iter( + r, + WindowHintedIterator { + iter: 0..9, + window_size: 2, + hint_total_size: true, + }, + ); + + assert_eq!((0..0).choose(r), None); + assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None); + } + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_iterator_choose_stable() { + let r = &mut crate::test::rng(109); + fn test_iter + Clone>(r: &mut R, iter: Iter) { + let mut chosen = [0i32; 9]; + for _ in 0..1000 { + let picked = iter.clone().choose_stable(r).unwrap(); + chosen[picked] += 1; + } + for count in chosen.iter() { + // Samples should follow Binomial(1000, 1/9) + // Octave: binopdf(x, 1000, 1/9) gives the prob of *count == x + // Note: have seen 153, which is unlikely but not impossible. + assert!( + 72 < *count && *count < 154, + "count not close to 1000/9: {}", + count + ); + } + } + + test_iter(r, 0..9); + test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()); + #[cfg(feature = "alloc")] + test_iter(r, (0..9).collect::>().into_iter()); + test_iter(r, UnhintedIterator { iter: 0..9 }); + test_iter( + r, + ChunkHintedIterator { + iter: 0..9, + chunk_size: 4, + chunk_remaining: 4, + hint_total_size: false, + }, + ); + test_iter( + r, + ChunkHintedIterator { + iter: 0..9, + chunk_size: 4, + chunk_remaining: 4, + hint_total_size: true, + }, + ); + test_iter( + r, + WindowHintedIterator { + iter: 0..9, + window_size: 2, + hint_total_size: false, + }, + ); + test_iter( + r, + WindowHintedIterator { + iter: 0..9, + window_size: 2, + hint_total_size: true, + }, + ); + + assert_eq!((0..0).choose(r), None); + assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None); + } + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_iterator_choose_stable_stability() { + fn test_iter(iter: impl Iterator + Clone) -> [i32; 9] { + let r = &mut crate::test::rng(109); + let mut chosen = [0i32; 9]; + for _ in 0..1000 { + let picked = iter.clone().choose_stable(r).unwrap(); + chosen[picked] += 1; + } + chosen + } + + let reference = test_iter(0..9); + assert_eq!( + test_iter([0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()), + reference + ); + + #[cfg(feature = "alloc")] + assert_eq!(test_iter((0..9).collect::>().into_iter()), reference); + assert_eq!(test_iter(UnhintedIterator { iter: 0..9 }), reference); + assert_eq!( + test_iter(ChunkHintedIterator { + iter: 0..9, + chunk_size: 4, + chunk_remaining: 4, + hint_total_size: false, + }), + reference + ); + assert_eq!( + test_iter(ChunkHintedIterator { + iter: 0..9, + chunk_size: 4, + chunk_remaining: 4, + hint_total_size: true, + }), + reference + ); + assert_eq!( + test_iter(WindowHintedIterator { + iter: 0..9, + window_size: 2, + hint_total_size: false, + }), + reference + ); + assert_eq!( + test_iter(WindowHintedIterator { + iter: 0..9, + window_size: 2, + hint_total_size: true, + }), + reference + ); + } + + #[test] + #[cfg(feature = "alloc")] + fn test_sample_iter() { + let min_val = 1; + let max_val = 100; + + let mut r = crate::test::rng(401); + let vals = (min_val..max_val).collect::>(); + let small_sample = vals.iter().choose_multiple(&mut r, 5); + let large_sample = vals.iter().choose_multiple(&mut r, vals.len() + 5); + + assert_eq!(small_sample.len(), 5); + assert_eq!(large_sample.len(), vals.len()); + // no randomization happens when amount >= len + assert_eq!(large_sample, vals.iter().collect::>()); + + assert!(small_sample + .iter() + .all(|e| { **e >= min_val && **e <= max_val })); + } + + #[test] + fn value_stability_choose() { + fn choose>(iter: I) -> Option { + let mut rng = crate::test::rng(411); + iter.choose(&mut rng) + } + + assert_eq!(choose([].iter().cloned()), None); + assert_eq!(choose(0..100), Some(33)); + assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(27)); + assert_eq!( + choose(ChunkHintedIterator { + iter: 0..100, + chunk_size: 32, + chunk_remaining: 32, + hint_total_size: false, + }), + Some(91) + ); + assert_eq!( + choose(ChunkHintedIterator { + iter: 0..100, + chunk_size: 32, + chunk_remaining: 32, + hint_total_size: true, + }), + Some(91) + ); + assert_eq!( + choose(WindowHintedIterator { + iter: 0..100, + window_size: 32, + hint_total_size: false, + }), + Some(34) + ); + assert_eq!( + choose(WindowHintedIterator { + iter: 0..100, + window_size: 32, + hint_total_size: true, + }), + Some(34) + ); + } + + #[test] + fn value_stability_choose_stable() { + fn choose>(iter: I) -> Option { + let mut rng = crate::test::rng(411); + iter.choose_stable(&mut rng) + } + + assert_eq!(choose([].iter().cloned()), None); + assert_eq!(choose(0..100), Some(27)); + assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(27)); + assert_eq!( + choose(ChunkHintedIterator { + iter: 0..100, + chunk_size: 32, + chunk_remaining: 32, + hint_total_size: false, + }), + Some(27) + ); + assert_eq!( + choose(ChunkHintedIterator { + iter: 0..100, + chunk_size: 32, + chunk_remaining: 32, + hint_total_size: true, + }), + Some(27) + ); + assert_eq!( + choose(WindowHintedIterator { + iter: 0..100, + window_size: 32, + hint_total_size: false, + }), + Some(27) + ); + assert_eq!( + choose(WindowHintedIterator { + iter: 0..100, + window_size: 32, + hint_total_size: true, + }), + Some(27) + ); + } + + #[test] + fn value_stability_choose_multiple() { + fn do_test>(iter: I, v: &[u32]) { + let mut rng = crate::test::rng(412); + let mut buf = [0u32; 8]; + assert_eq!( + iter.clone().choose_multiple_fill(&mut rng, &mut buf), + v.len() + ); + assert_eq!(&buf[0..v.len()], v); + + #[cfg(feature = "alloc")] + { + let mut rng = crate::test::rng(412); + assert_eq!(iter.choose_multiple(&mut rng, v.len()), v); + } + } + + do_test(0..4, &[0, 1, 2, 3]); + do_test(0..8, &[0, 1, 2, 3, 4, 5, 6, 7]); + do_test(0..100, &[77, 95, 38, 23, 25, 8, 58, 40]); + } +} diff --git a/src/seq/mod.rs b/src/seq/mod.rs index 9e6ffaf1242..91d634d865e 100644 --- a/src/seq/mod.rs +++ b/src/seq/mod.rs @@ -1,4 +1,4 @@ -// Copyright 2018 Developers of the Rand project. +// Copyright 2018-2023 Developers of the Rand project. // // Licensed under the Apache License, Version 2.0 or the MIT license @@ -10,1346 +10,71 @@ //! //! This module provides: //! -//! * [`SliceRandom`] slice sampling and mutation -//! * [`IteratorRandom`] iterator sampling +//! * [`IndexedRandom`] for sampling slices and other indexable lists +//! * [`IndexedMutRandom`] for sampling slices and other mutably indexable lists +//! * [`SliceRandom`] for mutating slices +//! * [`IteratorRandom`] for sampling iterators //! * [`index::sample`] low-level API to choose multiple indices from //! `0..length` //! //! Also see: //! -//! * [`crate::distributions::WeightedIndex`] distribution which provides +//! * [`crate::distr::weighted::WeightedIndex`] distribution which provides //! weighted index sampling. //! //! In order to make results reproducible across 32-64 bit architectures, all //! `usize` indices are sampled as a `u32` where possible (also providing a //! small performance boost in some cases). +mod coin_flipper; +mod increasing_uniform; +mod iterator; +mod slice; #[cfg(feature = "alloc")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] -pub mod index; - -#[cfg(feature = "alloc")] use core::ops::Index; - -#[cfg(feature = "alloc")] use alloc::vec::Vec; +#[path = "index.rs"] +mod index_; #[cfg(feature = "alloc")] -use crate::distributions::uniform::{SampleBorrow, SampleUniform}; -#[cfg(feature = "alloc")] use crate::distributions::WeightedError; -use crate::Rng; - -/// Extension trait on slices, providing random mutation and sampling methods. -/// -/// This trait is implemented on all `[T]` slice types, providing several -/// methods for choosing and shuffling elements. You must `use` this trait: -/// -/// ``` -/// use rand::seq::SliceRandom; -/// -/// let mut rng = rand::thread_rng(); -/// let mut bytes = "Hello, random!".to_string().into_bytes(); -/// bytes.shuffle(&mut rng); -/// let str = String::from_utf8(bytes).unwrap(); -/// println!("{}", str); -/// ``` -/// Example output (non-deterministic): -/// ```none -/// l,nmroHado !le -/// ``` -pub trait SliceRandom { - /// The element type. - type Item; - - /// Returns a reference to one random element of the slice, or `None` if the - /// slice is empty. - /// - /// For slices, complexity is `O(1)`. - /// - /// # Example - /// - /// ``` - /// use rand::thread_rng; - /// use rand::seq::SliceRandom; - /// - /// let choices = [1, 2, 4, 8, 16, 32]; - /// let mut rng = thread_rng(); - /// println!("{:?}", choices.choose(&mut rng)); - /// assert_eq!(choices[..0].choose(&mut rng), None); - /// ``` - fn choose(&self, rng: &mut R) -> Option<&Self::Item> - where R: Rng + ?Sized; - - /// Returns a mutable reference to one random element of the slice, or - /// `None` if the slice is empty. - /// - /// For slices, complexity is `O(1)`. - fn choose_mut(&mut self, rng: &mut R) -> Option<&mut Self::Item> - where R: Rng + ?Sized; +#[doc(no_inline)] +pub use crate::distr::weighted::Error as WeightError; +pub use iterator::IteratorRandom; +#[cfg(feature = "alloc")] +pub use slice::SliceChooseIter; +pub use slice::{IndexedMutRandom, IndexedRandom, SliceRandom}; - /// Chooses `amount` elements from the slice at random, without repetition, - /// and in random order. The returned iterator is appropriate both for - /// collection into a `Vec` and filling an existing buffer (see example). - /// - /// In case this API is not sufficiently flexible, use [`index::sample`]. - /// - /// For slices, complexity is the same as [`index::sample`]. - /// - /// # Example - /// ``` - /// use rand::seq::SliceRandom; - /// - /// let mut rng = &mut rand::thread_rng(); - /// let sample = "Hello, audience!".as_bytes(); - /// - /// // collect the results into a vector: - /// let v: Vec = sample.choose_multiple(&mut rng, 3).cloned().collect(); - /// - /// // store in a buffer: - /// let mut buf = [0u8; 5]; - /// for (b, slot) in sample.choose_multiple(&mut rng, buf.len()).zip(buf.iter_mut()) { - /// *slot = *b; - /// } - /// ``` - #[cfg(feature = "alloc")] - #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] - fn choose_multiple(&self, rng: &mut R, amount: usize) -> SliceChooseIter - where R: Rng + ?Sized; +/// Low-level API for sampling indices +pub mod index { + use crate::Rng; - /// Similar to [`choose`], but where the likelihood of each outcome may be - /// specified. - /// - /// The specified function `weight` maps each item `x` to a relative - /// likelihood `weight(x)`. The probability of each item being selected is - /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`. - /// - /// For slices of length `n`, complexity is `O(n)`. - /// See also [`choose_weighted_mut`], [`distributions::weighted`]. - /// - /// # Example - /// - /// ``` - /// use rand::prelude::*; - /// - /// let choices = [('a', 2), ('b', 1), ('c', 1)]; - /// let mut rng = thread_rng(); - /// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' - /// println!("{:?}", choices.choose_weighted(&mut rng, |item| item.1).unwrap().0); - /// ``` - /// [`choose`]: SliceRandom::choose - /// [`choose_weighted_mut`]: SliceRandom::choose_weighted_mut - /// [`distributions::weighted`]: crate::distributions::weighted #[cfg(feature = "alloc")] - #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] - fn choose_weighted( - &self, rng: &mut R, weight: F, - ) -> Result<&Self::Item, WeightedError> - where - R: Rng + ?Sized, - F: Fn(&Self::Item) -> B, - B: SampleBorrow, - X: SampleUniform - + for<'a> ::core::ops::AddAssign<&'a X> - + ::core::cmp::PartialOrd - + Clone - + Default; + #[doc(inline)] + pub use super::index_::*; - /// Similar to [`choose_mut`], but where the likelihood of each outcome may - /// be specified. + /// Randomly sample exactly `N` distinct indices from `0..len`, and + /// return them in random order (fully shuffled). /// - /// The specified function `weight` maps each item `x` to a relative - /// likelihood `weight(x)`. The probability of each item being selected is - /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`. + /// This is implemented via Floyd's algorithm. Time complexity is `O(N^2)` + /// and memory complexity is `O(N)`. /// - /// For slices of length `n`, complexity is `O(n)`. - /// See also [`choose_weighted`], [`distributions::weighted`]. - /// - /// [`choose_mut`]: SliceRandom::choose_mut - /// [`choose_weighted`]: SliceRandom::choose_weighted - /// [`distributions::weighted`]: crate::distributions::weighted - #[cfg(feature = "alloc")] - #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] - fn choose_weighted_mut( - &mut self, rng: &mut R, weight: F, - ) -> Result<&mut Self::Item, WeightedError> + /// Returns `None` if (and only if) `N > len`. + pub fn sample_array(rng: &mut R, len: usize) -> Option<[usize; N]> where R: Rng + ?Sized, - F: Fn(&Self::Item) -> B, - B: SampleBorrow, - X: SampleUniform - + for<'a> ::core::ops::AddAssign<&'a X> - + ::core::cmp::PartialOrd - + Clone - + Default; - - /// Similar to [`choose_multiple`], but where the likelihood of each element's - /// inclusion in the output may be specified. The elements are returned in an - /// arbitrary, unspecified order. - /// - /// The specified function `weight` maps each item `x` to a relative - /// likelihood `weight(x)`. The probability of each item being selected is - /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`. - /// - /// If all of the weights are equal, even if they are all zero, each element has - /// an equal likelihood of being selected. - /// - /// The complexity of this method depends on the feature `partition_at_index`. - /// If the feature is enabled, then for slices of length `n`, the complexity - /// is `O(n)` space and `O(n)` time. Otherwise, the complexity is `O(n)` space and - /// `O(n * log amount)` time. - /// - /// # Example - /// - /// ``` - /// use rand::prelude::*; - /// - /// let choices = [('a', 2), ('b', 1), ('c', 1)]; - /// let mut rng = thread_rng(); - /// // First Draw * Second Draw = total odds - /// // ----------------------- - /// // (50% * 50%) + (25% * 67%) = 41.7% chance that the output is `['a', 'b']` in some order. - /// // (50% * 50%) + (25% * 67%) = 41.7% chance that the output is `['a', 'c']` in some order. - /// // (25% * 33%) + (25% * 33%) = 16.6% chance that the output is `['b', 'c']` in some order. - /// println!("{:?}", choices.choose_multiple_weighted(&mut rng, 2, |item| item.1).unwrap().collect::>()); - /// ``` - /// [`choose_multiple`]: SliceRandom::choose_multiple - #[cfg(feature = "alloc")] - fn choose_multiple_weighted( - &self, rng: &mut R, amount: usize, weight: F, - ) -> Result, WeightedError> - where - R: Rng + ?Sized, - F: Fn(&Self::Item) -> X, - X: Into; - - /// Shuffle a mutable slice in place. - /// - /// For slices of length `n`, complexity is `O(n)`. - /// - /// # Example - /// - /// ``` - /// use rand::seq::SliceRandom; - /// use rand::thread_rng; - /// - /// let mut rng = thread_rng(); - /// let mut y = [1, 2, 3, 4, 5]; - /// println!("Unshuffled: {:?}", y); - /// y.shuffle(&mut rng); - /// println!("Shuffled: {:?}", y); - /// ``` - fn shuffle(&mut self, rng: &mut R) - where R: Rng + ?Sized; - - /// Shuffle a slice in place, but exit early. - /// - /// Returns two mutable slices from the source slice. The first contains - /// `amount` elements randomly permuted. The second has the remaining - /// elements that are not fully shuffled. - /// - /// This is an efficient method to select `amount` elements at random from - /// the slice, provided the slice may be mutated. - /// - /// If you only need to choose elements randomly and `amount > self.len()/2` - /// then you may improve performance by taking - /// `amount = values.len() - amount` and using only the second slice. - /// - /// If `amount` is greater than the number of elements in the slice, this - /// will perform a full shuffle. - /// - /// For slices, complexity is `O(m)` where `m = amount`. - fn partial_shuffle( - &mut self, rng: &mut R, amount: usize, - ) -> (&mut [Self::Item], &mut [Self::Item]) - where R: Rng + ?Sized; -} - -/// Extension trait on iterators, providing random sampling methods. -/// -/// This trait is implemented on all iterators `I` where `I: Iterator + Sized` -/// and provides methods for -/// choosing one or more elements. You must `use` this trait: -/// -/// ``` -/// use rand::seq::IteratorRandom; -/// -/// let mut rng = rand::thread_rng(); -/// -/// let faces = "😀😎😐😕😠😢"; -/// println!("I am {}!", faces.chars().choose(&mut rng).unwrap()); -/// ``` -/// Example output (non-deterministic): -/// ```none -/// I am 😀! -/// ``` -pub trait IteratorRandom: Iterator + Sized { - /// Choose one element at random from the iterator. - /// - /// Returns `None` if and only if the iterator is empty. - /// - /// This method uses [`Iterator::size_hint`] for optimisation. With an - /// accurate hint and where [`Iterator::nth`] is a constant-time operation - /// this method can offer `O(1)` performance. Where no size hint is - /// available, complexity is `O(n)` where `n` is the iterator length. - /// Partial hints (where `lower > 0`) also improve performance. - /// - /// Note that the output values and the number of RNG samples used - /// depends on size hints. In particular, `Iterator` combinators that don't - /// change the values yielded but change the size hints may result in - /// `choose` returning different elements. If you want consistent results - /// and RNG usage consider using [`IteratorRandom::choose_stable`]. - fn choose(mut self, rng: &mut R) -> Option - where R: Rng + ?Sized { - let (mut lower, mut upper) = self.size_hint(); - let mut consumed = 0; - let mut result = None; - - // Handling for this condition outside the loop allows the optimizer to eliminate the loop - // when the Iterator is an ExactSizeIterator. This has a large performance impact on e.g. - // seq_iter_choose_from_1000. - if upper == Some(lower) { - return if lower == 0 { - None - } else { - self.nth(gen_index(rng, lower)) - }; - } - - // Continue until the iterator is exhausted - loop { - if lower > 1 { - let ix = gen_index(rng, lower + consumed); - let skip = if ix < lower { - result = self.nth(ix); - lower - (ix + 1) - } else { - lower - }; - if upper == Some(lower) { - return result; - } - consumed += lower; - if skip > 0 { - self.nth(skip - 1); - } - } else { - let elem = self.next(); - if elem.is_none() { - return result; - } - consumed += 1; - if gen_index(rng, consumed) == 0 { - result = elem; - } - } - - let hint = self.size_hint(); - lower = hint.0; - upper = hint.1; - } - } - - /// Choose one element at random from the iterator. - /// - /// Returns `None` if and only if the iterator is empty. - /// - /// This method is very similar to [`choose`] except that the result - /// only depends on the length of the iterator and the values produced by - /// `rng`. Notably for any iterator of a given length this will make the - /// same requests to `rng` and if the same sequence of values are produced - /// the same index will be selected from `self`. This may be useful if you - /// need consistent results no matter what type of iterator you are working - /// with. If you do not need this stability prefer [`choose`]. - /// - /// Note that this method still uses [`Iterator::size_hint`] to skip - /// constructing elements where possible, however the selection and `rng` - /// calls are the same in the face of this optimization. If you want to - /// force every element to be created regardless call `.inspect(|e| ())`. - /// - /// [`choose`]: IteratorRandom::choose - fn choose_stable(mut self, rng: &mut R) -> Option - where R: Rng + ?Sized { - let mut consumed = 0; - let mut result = None; - - loop { - // Currently the only way to skip elements is `nth()`. So we need to - // store what index to access next here. - // This should be replaced by `advance_by()` once it is stable: - // https://github.com/rust-lang/rust/issues/77404 - let mut next = 0; - - let (lower, _) = self.size_hint(); - if lower >= 2 { - let highest_selected = (0..lower) - .filter(|ix| gen_index(rng, consumed+ix+1) == 0) - .last(); - - consumed += lower; - next = lower; - - if let Some(ix) = highest_selected { - result = self.nth(ix); - next -= ix + 1; - debug_assert!(result.is_some(), "iterator shorter than size_hint().0"); - } - } - - let elem = self.nth(next); - if elem.is_none() { - return result - } - - if gen_index(rng, consumed+1) == 0 { - result = elem; - } - consumed += 1; - } - } - - /// Collects values at random from the iterator into a supplied buffer - /// until that buffer is filled. - /// - /// Although the elements are selected randomly, the order of elements in - /// the buffer is neither stable nor fully random. If random ordering is - /// desired, shuffle the result. - /// - /// Returns the number of elements added to the buffer. This equals the length - /// of the buffer unless the iterator contains insufficient elements, in which - /// case this equals the number of elements available. - /// - /// Complexity is `O(n)` where `n` is the length of the iterator. - /// For slices, prefer [`SliceRandom::choose_multiple`]. - fn choose_multiple_fill(mut self, rng: &mut R, buf: &mut [Self::Item]) -> usize - where R: Rng + ?Sized { - let amount = buf.len(); - let mut len = 0; - while len < amount { - if let Some(elem) = self.next() { - buf[len] = elem; - len += 1; - } else { - // Iterator exhausted; stop early - return len; - } - } - - // Continue, since the iterator was not exhausted - for (i, elem) in self.enumerate() { - let k = gen_index(rng, i + 1 + amount); - if let Some(slot) = buf.get_mut(k) { - *slot = elem; - } - } - len - } - - /// Collects `amount` values at random from the iterator into a vector. - /// - /// This is equivalent to `choose_multiple_fill` except for the result type. - /// - /// Although the elements are selected randomly, the order of elements in - /// the buffer is neither stable nor fully random. If random ordering is - /// desired, shuffle the result. - /// - /// The length of the returned vector equals `amount` unless the iterator - /// contains insufficient elements, in which case it equals the number of - /// elements available. - /// - /// Complexity is `O(n)` where `n` is the length of the iterator. - /// For slices, prefer [`SliceRandom::choose_multiple`]. - #[cfg(feature = "alloc")] - #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] - fn choose_multiple(mut self, rng: &mut R, amount: usize) -> Vec - where R: Rng + ?Sized { - let mut reservoir = Vec::with_capacity(amount); - reservoir.extend(self.by_ref().take(amount)); - - // Continue unless the iterator was exhausted - // - // note: this prevents iterators that "restart" from causing problems. - // If the iterator stops once, then so do we. - if reservoir.len() == amount { - for (i, elem) in self.enumerate() { - let k = gen_index(rng, i + 1 + amount); - if let Some(slot) = reservoir.get_mut(k) { - *slot = elem; - } - } - } else { - // Don't hang onto extra memory. There is a corner case where - // `amount` was much less than `self.len()`. - reservoir.shrink_to_fit(); - } - reservoir - } -} - - -impl SliceRandom for [T] { - type Item = T; - - fn choose(&self, rng: &mut R) -> Option<&Self::Item> - where R: Rng + ?Sized { - if self.is_empty() { - None - } else { - Some(&self[gen_index(rng, self.len())]) - } - } - - fn choose_mut(&mut self, rng: &mut R) -> Option<&mut Self::Item> - where R: Rng + ?Sized { - if self.is_empty() { - None - } else { - let len = self.len(); - Some(&mut self[gen_index(rng, len)]) - } - } - - #[cfg(feature = "alloc")] - fn choose_multiple(&self, rng: &mut R, amount: usize) -> SliceChooseIter - where R: Rng + ?Sized { - let amount = ::core::cmp::min(amount, self.len()); - SliceChooseIter { - slice: self, - _phantom: Default::default(), - indices: index::sample(rng, self.len(), amount).into_iter(), - } - } - - #[cfg(feature = "alloc")] - fn choose_weighted( - &self, rng: &mut R, weight: F, - ) -> Result<&Self::Item, WeightedError> - where - R: Rng + ?Sized, - F: Fn(&Self::Item) -> B, - B: SampleBorrow, - X: SampleUniform - + for<'a> ::core::ops::AddAssign<&'a X> - + ::core::cmp::PartialOrd - + Clone - + Default, - { - use crate::distributions::{Distribution, WeightedIndex}; - let distr = WeightedIndex::new(self.iter().map(weight))?; - Ok(&self[distr.sample(rng)]) - } - - #[cfg(feature = "alloc")] - fn choose_weighted_mut( - &mut self, rng: &mut R, weight: F, - ) -> Result<&mut Self::Item, WeightedError> - where - R: Rng + ?Sized, - F: Fn(&Self::Item) -> B, - B: SampleBorrow, - X: SampleUniform - + for<'a> ::core::ops::AddAssign<&'a X> - + ::core::cmp::PartialOrd - + Clone - + Default, { - use crate::distributions::{Distribution, WeightedIndex}; - let distr = WeightedIndex::new(self.iter().map(weight))?; - Ok(&mut self[distr.sample(rng)]) - } - - #[cfg(feature = "alloc")] - fn choose_multiple_weighted( - &self, rng: &mut R, amount: usize, weight: F, - ) -> Result, WeightedError> - where - R: Rng + ?Sized, - F: Fn(&Self::Item) -> X, - X: Into, - { - let amount = ::core::cmp::min(amount, self.len()); - Ok(SliceChooseIter { - slice: self, - _phantom: Default::default(), - indices: index::sample_weighted( - rng, - self.len(), - |idx| weight(&self[idx]).into(), - amount, - )? - .into_iter(), - }) - } - - fn shuffle(&mut self, rng: &mut R) - where R: Rng + ?Sized { - for i in (1..self.len()).rev() { - // invariant: elements with index > i have been locked in place. - self.swap(i, gen_index(rng, i + 1)); - } - } - - fn partial_shuffle( - &mut self, rng: &mut R, amount: usize, - ) -> (&mut [Self::Item], &mut [Self::Item]) - where R: Rng + ?Sized { - // This applies Durstenfeld's algorithm for the - // [Fisher–Yates shuffle](https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm) - // for an unbiased permutation, but exits early after choosing `amount` - // elements. - - let len = self.len(); - let end = if amount >= len { 0 } else { len - amount }; - - for i in (end..len).rev() { - // invariant: elements with index > i have been locked in place. - self.swap(i, gen_index(rng, i + 1)); - } - let r = self.split_at_mut(end); - (r.1, r.0) - } -} - -impl IteratorRandom for I where I: Iterator + Sized {} - - -/// An iterator over multiple slice elements. -/// -/// This struct is created by -/// [`SliceRandom::choose_multiple`](trait.SliceRandom.html#tymethod.choose_multiple). -#[cfg(feature = "alloc")] -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] -#[derive(Debug)] -pub struct SliceChooseIter<'a, S: ?Sized + 'a, T: 'a> { - slice: &'a S, - _phantom: ::core::marker::PhantomData, - indices: index::IndexVecIntoIter, -} - -#[cfg(feature = "alloc")] -impl<'a, S: Index + ?Sized + 'a, T: 'a> Iterator for SliceChooseIter<'a, S, T> { - type Item = &'a T; - - fn next(&mut self) -> Option { - // TODO: investigate using SliceIndex::get_unchecked when stable - self.indices.next().map(|i| &self.slice[i as usize]) - } - - fn size_hint(&self) -> (usize, Option) { - (self.indices.len(), Some(self.indices.len())) - } -} - -#[cfg(feature = "alloc")] -impl<'a, S: Index + ?Sized + 'a, T: 'a> ExactSizeIterator - for SliceChooseIter<'a, S, T> -{ - fn len(&self) -> usize { - self.indices.len() - } -} - - -// Sample a number uniformly between 0 and `ubound`. Uses 32-bit sampling where -// possible, primarily in order to produce the same output on 32-bit and 64-bit -// platforms. -#[inline] -fn gen_index(rng: &mut R, ubound: usize) -> usize { - if ubound <= (core::u32::MAX as usize) { - rng.gen_range(0..ubound as u32) as usize - } else { - rng.gen_range(0..ubound) - } -} - - -#[cfg(test)] -mod test { - use super::*; - #[cfg(feature = "alloc")] use crate::Rng; - #[cfg(all(feature = "alloc", not(feature = "std")))] use alloc::vec::Vec; - - #[test] - fn test_slice_choose() { - let mut r = crate::test::rng(107); - let chars = [ - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', - ]; - let mut chosen = [0i32; 14]; - // The below all use a binomial distribution with n=1000, p=1/14. - // binocdf(40, 1000, 1/14) ~= 2e-5; 1-binocdf(106, ..) ~= 2e-5 - for _ in 0..1000 { - let picked = *chars.choose(&mut r).unwrap(); - chosen[(picked as usize) - ('a' as usize)] += 1; - } - for count in chosen.iter() { - assert!(40 < *count && *count < 106); + if N > len { + return None; } - chosen.iter_mut().for_each(|x| *x = 0); - for _ in 0..1000 { - *chosen.choose_mut(&mut r).unwrap() += 1; - } - for count in chosen.iter() { - assert!(40 < *count && *count < 106); - } - - let mut v: [isize; 0] = []; - assert_eq!(v.choose(&mut r), None); - assert_eq!(v.choose_mut(&mut r), None); - } - - #[test] - fn value_stability_slice() { - let mut r = crate::test::rng(413); - let chars = [ - 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', - ]; - let mut nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; - - assert_eq!(chars.choose(&mut r), Some(&'l')); - assert_eq!(nums.choose_mut(&mut r), Some(&mut 10)); - - #[cfg(feature = "alloc")] - assert_eq!( - &chars - .choose_multiple(&mut r, 8) - .cloned() - .collect::>(), - &['d', 'm', 'b', 'n', 'c', 'k', 'h', 'e'] - ); - - #[cfg(feature = "alloc")] - assert_eq!(chars.choose_weighted(&mut r, |_| 1), Ok(&'f')); - #[cfg(feature = "alloc")] - assert_eq!(nums.choose_weighted_mut(&mut r, |_| 1), Ok(&mut 5)); - - let mut r = crate::test::rng(414); - nums.shuffle(&mut r); - assert_eq!(nums, [9, 5, 3, 10, 7, 12, 8, 11, 6, 4, 0, 2, 1]); - nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; - let res = nums.partial_shuffle(&mut r, 6); - assert_eq!(res.0, &mut [7, 4, 8, 6, 9, 3]); - assert_eq!(res.1, &mut [0, 1, 2, 12, 11, 5, 10]); - } - - #[derive(Clone)] - struct UnhintedIterator { - iter: I, - } - impl Iterator for UnhintedIterator { - type Item = I::Item; - - fn next(&mut self) -> Option { - self.iter.next() - } - } - - #[derive(Clone)] - struct ChunkHintedIterator { - iter: I, - chunk_remaining: usize, - chunk_size: usize, - hint_total_size: bool, - } - impl Iterator for ChunkHintedIterator { - type Item = I::Item; - - fn next(&mut self) -> Option { - if self.chunk_remaining == 0 { - self.chunk_remaining = ::core::cmp::min(self.chunk_size, self.iter.len()); + // Floyd's algorithm + let mut indices = [0; N]; + for (i, j) in (len - N..len).enumerate() { + let t = rng.random_range(..j + 1); + if let Some(pos) = indices[0..i].iter().position(|&x| x == t) { + indices[pos] = j; } - self.chunk_remaining = self.chunk_remaining.saturating_sub(1); - - self.iter.next() - } - - fn size_hint(&self) -> (usize, Option) { - ( - self.chunk_remaining, - if self.hint_total_size { - Some(self.iter.len()) - } else { - None - }, - ) - } - } - - #[derive(Clone)] - struct WindowHintedIterator { - iter: I, - window_size: usize, - hint_total_size: bool, - } - impl Iterator for WindowHintedIterator { - type Item = I::Item; - - fn next(&mut self) -> Option { - self.iter.next() - } - - fn size_hint(&self) -> (usize, Option) { - ( - ::core::cmp::min(self.iter.len(), self.window_size), - if self.hint_total_size { - Some(self.iter.len()) - } else { - None - }, - ) + indices[i] = t; } - } - - #[test] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_iterator_choose() { - let r = &mut crate::test::rng(109); - fn test_iter + Clone>(r: &mut R, iter: Iter) { - let mut chosen = [0i32; 9]; - for _ in 0..1000 { - let picked = iter.clone().choose(r).unwrap(); - chosen[picked] += 1; - } - for count in chosen.iter() { - // Samples should follow Binomial(1000, 1/9) - // Octave: binopdf(x, 1000, 1/9) gives the prob of *count == x - // Note: have seen 153, which is unlikely but not impossible. - assert!( - 72 < *count && *count < 154, - "count not close to 1000/9: {}", - count - ); - } - } - - test_iter(r, 0..9); - test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()); - #[cfg(feature = "alloc")] - test_iter(r, (0..9).collect::>().into_iter()); - test_iter(r, UnhintedIterator { iter: 0..9 }); - test_iter(r, ChunkHintedIterator { - iter: 0..9, - chunk_size: 4, - chunk_remaining: 4, - hint_total_size: false, - }); - test_iter(r, ChunkHintedIterator { - iter: 0..9, - chunk_size: 4, - chunk_remaining: 4, - hint_total_size: true, - }); - test_iter(r, WindowHintedIterator { - iter: 0..9, - window_size: 2, - hint_total_size: false, - }); - test_iter(r, WindowHintedIterator { - iter: 0..9, - window_size: 2, - hint_total_size: true, - }); - - assert_eq!((0..0).choose(r), None); - assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None); - } - - #[test] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_iterator_choose_stable() { - let r = &mut crate::test::rng(109); - fn test_iter + Clone>(r: &mut R, iter: Iter) { - let mut chosen = [0i32; 9]; - for _ in 0..1000 { - let picked = iter.clone().choose_stable(r).unwrap(); - chosen[picked] += 1; - } - for count in chosen.iter() { - // Samples should follow Binomial(1000, 1/9) - // Octave: binopdf(x, 1000, 1/9) gives the prob of *count == x - // Note: have seen 153, which is unlikely but not impossible. - assert!( - 72 < *count && *count < 154, - "count not close to 1000/9: {}", - count - ); - } - } - - test_iter(r, 0..9); - test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()); - #[cfg(feature = "alloc")] - test_iter(r, (0..9).collect::>().into_iter()); - test_iter(r, UnhintedIterator { iter: 0..9 }); - test_iter(r, ChunkHintedIterator { - iter: 0..9, - chunk_size: 4, - chunk_remaining: 4, - hint_total_size: false, - }); - test_iter(r, ChunkHintedIterator { - iter: 0..9, - chunk_size: 4, - chunk_remaining: 4, - hint_total_size: true, - }); - test_iter(r, WindowHintedIterator { - iter: 0..9, - window_size: 2, - hint_total_size: false, - }); - test_iter(r, WindowHintedIterator { - iter: 0..9, - window_size: 2, - hint_total_size: true, - }); - - assert_eq!((0..0).choose(r), None); - assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None); - } - - #[test] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_iterator_choose_stable_stability() { - fn test_iter(iter: impl Iterator + Clone) -> [i32; 9] { - let r = &mut crate::test::rng(109); - let mut chosen = [0i32; 9]; - for _ in 0..1000 { - let picked = iter.clone().choose_stable(r).unwrap(); - chosen[picked] += 1; - } - chosen - } - - let reference = test_iter(0..9); - assert_eq!(test_iter([0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()), reference); - - #[cfg(feature = "alloc")] - assert_eq!(test_iter((0..9).collect::>().into_iter()), reference); - assert_eq!(test_iter(UnhintedIterator { iter: 0..9 }), reference); - assert_eq!(test_iter(ChunkHintedIterator { - iter: 0..9, - chunk_size: 4, - chunk_remaining: 4, - hint_total_size: false, - }), reference); - assert_eq!(test_iter(ChunkHintedIterator { - iter: 0..9, - chunk_size: 4, - chunk_remaining: 4, - hint_total_size: true, - }), reference); - assert_eq!(test_iter(WindowHintedIterator { - iter: 0..9, - window_size: 2, - hint_total_size: false, - }), reference); - assert_eq!(test_iter(WindowHintedIterator { - iter: 0..9, - window_size: 2, - hint_total_size: true, - }), reference); - } - - #[test] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_shuffle() { - let mut r = crate::test::rng(108); - let empty: &mut [isize] = &mut []; - empty.shuffle(&mut r); - let mut one = [1]; - one.shuffle(&mut r); - let b: &[_] = &[1]; - assert_eq!(one, b); - - let mut two = [1, 2]; - two.shuffle(&mut r); - assert!(two == [1, 2] || two == [2, 1]); - - fn move_last(slice: &mut [usize], pos: usize) { - // use slice[pos..].rotate_left(1); once we can use that - let last_val = slice[pos]; - for i in pos..slice.len() - 1 { - slice[i] = slice[i + 1]; - } - *slice.last_mut().unwrap() = last_val; - } - let mut counts = [0i32; 24]; - for _ in 0..10000 { - let mut arr: [usize; 4] = [0, 1, 2, 3]; - arr.shuffle(&mut r); - let mut permutation = 0usize; - let mut pos_value = counts.len(); - for i in 0..4 { - pos_value /= 4 - i; - let pos = arr.iter().position(|&x| x == i).unwrap(); - assert!(pos < (4 - i)); - permutation += pos * pos_value; - move_last(&mut arr, pos); - assert_eq!(arr[3], i); - } - for i in 0..4 { - assert_eq!(arr[i], i); - } - counts[permutation] += 1; - } - for count in counts.iter() { - // Binomial(10000, 1/24) with average 416.667 - // Octave: binocdf(n, 10000, 1/24) - // 99.9% chance samples lie within this range: - assert!(352 <= *count && *count <= 483, "count: {}", count); - } - } - - #[test] - fn test_partial_shuffle() { - let mut r = crate::test::rng(118); - - let mut empty: [u32; 0] = []; - let res = empty.partial_shuffle(&mut r, 10); - assert_eq!((res.0.len(), res.1.len()), (0, 0)); - - let mut v = [1, 2, 3, 4, 5]; - let res = v.partial_shuffle(&mut r, 2); - assert_eq!((res.0.len(), res.1.len()), (2, 3)); - assert!(res.0[0] != res.0[1]); - // First elements are only modified if selected, so at least one isn't modified: - assert!(res.1[0] == 1 || res.1[1] == 2 || res.1[2] == 3); - } - - #[test] - #[cfg(feature = "alloc")] - fn test_sample_iter() { - let min_val = 1; - let max_val = 100; - - let mut r = crate::test::rng(401); - let vals = (min_val..max_val).collect::>(); - let small_sample = vals.iter().choose_multiple(&mut r, 5); - let large_sample = vals.iter().choose_multiple(&mut r, vals.len() + 5); - - assert_eq!(small_sample.len(), 5); - assert_eq!(large_sample.len(), vals.len()); - // no randomization happens when amount >= len - assert_eq!(large_sample, vals.iter().collect::>()); - - assert!(small_sample - .iter() - .all(|e| { **e >= min_val && **e <= max_val })); - } - - #[test] - #[cfg(feature = "alloc")] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_weighted() { - let mut r = crate::test::rng(406); - const N_REPS: u32 = 3000; - let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; - let total_weight = weights.iter().sum::() as f32; - - let verify = |result: [i32; 14]| { - for (i, count) in result.iter().enumerate() { - let exp = (weights[i] * N_REPS) as f32 / total_weight; - let mut err = (*count as f32 - exp).abs(); - if err != 0.0 { - err /= exp; - } - assert!(err <= 0.25); - } - }; - - // choose_weighted - fn get_weight(item: &(u32, T)) -> u32 { - item.0 - } - let mut chosen = [0i32; 14]; - let mut items = [(0u32, 0usize); 14]; // (weight, index) - for (i, item) in items.iter_mut().enumerate() { - *item = (weights[i], i); - } - for _ in 0..N_REPS { - let item = items.choose_weighted(&mut r, get_weight).unwrap(); - chosen[item.1] += 1; - } - verify(chosen); - - // choose_weighted_mut - let mut items = [(0u32, 0i32); 14]; // (weight, count) - for (i, item) in items.iter_mut().enumerate() { - *item = (weights[i], 0); - } - for _ in 0..N_REPS { - items.choose_weighted_mut(&mut r, get_weight).unwrap().1 += 1; - } - for (ch, item) in chosen.iter_mut().zip(items.iter()) { - *ch = item.1; - } - verify(chosen); - - // Check error cases - let empty_slice = &mut [10][0..0]; - assert_eq!( - empty_slice.choose_weighted(&mut r, |_| 1), - Err(WeightedError::NoItem) - ); - assert_eq!( - empty_slice.choose_weighted_mut(&mut r, |_| 1), - Err(WeightedError::NoItem) - ); - assert_eq!( - ['x'].choose_weighted_mut(&mut r, |_| 0), - Err(WeightedError::AllWeightsZero) - ); - assert_eq!( - [0, -1].choose_weighted_mut(&mut r, |x| *x), - Err(WeightedError::InvalidWeight) - ); - assert_eq!( - [-1, 0].choose_weighted_mut(&mut r, |x| *x), - Err(WeightedError::InvalidWeight) - ); - } - - #[test] - fn value_stability_choose() { - fn choose>(iter: I) -> Option { - let mut rng = crate::test::rng(411); - iter.choose(&mut rng) - } - - assert_eq!(choose([].iter().cloned()), None); - assert_eq!(choose(0..100), Some(33)); - assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(40)); - assert_eq!( - choose(ChunkHintedIterator { - iter: 0..100, - chunk_size: 32, - chunk_remaining: 32, - hint_total_size: false, - }), - Some(39) - ); - assert_eq!( - choose(ChunkHintedIterator { - iter: 0..100, - chunk_size: 32, - chunk_remaining: 32, - hint_total_size: true, - }), - Some(39) - ); - assert_eq!( - choose(WindowHintedIterator { - iter: 0..100, - window_size: 32, - hint_total_size: false, - }), - Some(90) - ); - assert_eq!( - choose(WindowHintedIterator { - iter: 0..100, - window_size: 32, - hint_total_size: true, - }), - Some(90) - ); - } - - #[test] - fn value_stability_choose_stable() { - fn choose>(iter: I) -> Option { - let mut rng = crate::test::rng(411); - iter.choose_stable(&mut rng) - } - - assert_eq!(choose([].iter().cloned()), None); - assert_eq!(choose(0..100), Some(40)); - assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(40)); - assert_eq!( - choose(ChunkHintedIterator { - iter: 0..100, - chunk_size: 32, - chunk_remaining: 32, - hint_total_size: false, - }), - Some(40) - ); - assert_eq!( - choose(ChunkHintedIterator { - iter: 0..100, - chunk_size: 32, - chunk_remaining: 32, - hint_total_size: true, - }), - Some(40) - ); - assert_eq!( - choose(WindowHintedIterator { - iter: 0..100, - window_size: 32, - hint_total_size: false, - }), - Some(40) - ); - assert_eq!( - choose(WindowHintedIterator { - iter: 0..100, - window_size: 32, - hint_total_size: true, - }), - Some(40) - ); - } - - #[test] - fn value_stability_choose_multiple() { - fn do_test>(iter: I, v: &[u32]) { - let mut rng = crate::test::rng(412); - let mut buf = [0u32; 8]; - assert_eq!(iter.choose_multiple_fill(&mut rng, &mut buf), v.len()); - assert_eq!(&buf[0..v.len()], v); - } - - do_test(0..4, &[0, 1, 2, 3]); - do_test(0..8, &[0, 1, 2, 3, 4, 5, 6, 7]); - do_test(0..100, &[58, 78, 80, 92, 43, 8, 96, 7]); - - #[cfg(feature = "alloc")] - { - fn do_test>(iter: I, v: &[u32]) { - let mut rng = crate::test::rng(412); - assert_eq!(iter.choose_multiple(&mut rng, v.len()), v); - } - - do_test(0..4, &[0, 1, 2, 3]); - do_test(0..8, &[0, 1, 2, 3, 4, 5, 6, 7]); - do_test(0..100, &[58, 78, 80, 92, 43, 8, 96, 7]); - } - } - - #[test] - #[cfg(feature = "alloc")] - fn test_multiple_weighted_edge_cases() { - use super::*; - - let mut rng = crate::test::rng(413); - - // Case 1: One of the weights is 0 - let choices = [('a', 2), ('b', 1), ('c', 0)]; - for _ in 0..100 { - let result = choices - .choose_multiple_weighted(&mut rng, 2, |item| item.1) - .unwrap() - .collect::>(); - - assert_eq!(result.len(), 2); - assert!(!result.iter().any(|val| val.0 == 'c')); - } - - // Case 2: All of the weights are 0 - let choices = [('a', 0), ('b', 0), ('c', 0)]; - let result = choices - .choose_multiple_weighted(&mut rng, 2, |item| item.1) - .unwrap() - .collect::>(); - assert_eq!(result.len(), 2); - - // Case 3: Negative weights - let choices = [('a', -1), ('b', 1), ('c', 1)]; - assert_eq!( - choices - .choose_multiple_weighted(&mut rng, 2, |item| item.1) - .unwrap_err(), - WeightedError::InvalidWeight - ); - - // Case 4: Empty list - let choices = []; - let result = choices - .choose_multiple_weighted(&mut rng, 0, |_: &()| 0) - .unwrap() - .collect::>(); - assert_eq!(result.len(), 0); - - // Case 5: NaN weights - let choices = [('a', core::f64::NAN), ('b', 1.0), ('c', 1.0)]; - assert_eq!( - choices - .choose_multiple_weighted(&mut rng, 2, |item| item.1) - .unwrap_err(), - WeightedError::InvalidWeight - ); - - // Case 6: +infinity weights - let choices = [('a', core::f64::INFINITY), ('b', 1.0), ('c', 1.0)]; - for _ in 0..100 { - let result = choices - .choose_multiple_weighted(&mut rng, 2, |item| item.1) - .unwrap() - .collect::>(); - assert_eq!(result.len(), 2); - assert!(result.iter().any(|val| val.0 == 'a')); - } - - // Case 7: -infinity weights - let choices = [('a', core::f64::NEG_INFINITY), ('b', 1.0), ('c', 1.0)]; - assert_eq!( - choices - .choose_multiple_weighted(&mut rng, 2, |item| item.1) - .unwrap_err(), - WeightedError::InvalidWeight - ); - - // Case 8: -0 weights - let choices = [('a', -0.0), ('b', 1.0), ('c', 1.0)]; - assert!(choices - .choose_multiple_weighted(&mut rng, 2, |item| item.1) - .is_ok()); - } - - #[test] - #[cfg(feature = "alloc")] - fn test_multiple_weighted_distributions() { - use super::*; - - // The theoretical probabilities of the different outcomes are: - // AB: 0.5 * 0.5 = 0.250 - // AC: 0.5 * 0.5 = 0.250 - // BA: 0.25 * 0.67 = 0.167 - // BC: 0.25 * 0.33 = 0.082 - // CA: 0.25 * 0.67 = 0.167 - // CB: 0.25 * 0.33 = 0.082 - let choices = [('a', 2), ('b', 1), ('c', 1)]; - let mut rng = crate::test::rng(414); - - let mut results = [0i32; 3]; - let expected_results = [4167, 4167, 1666]; - for _ in 0..10000 { - let result = choices - .choose_multiple_weighted(&mut rng, 2, |item| item.1) - .unwrap() - .collect::>(); - - assert_eq!(result.len(), 2); - - match (result[0].0, result[1].0) { - ('a', 'b') | ('b', 'a') => { - results[0] += 1; - } - ('a', 'c') | ('c', 'a') => { - results[1] += 1; - } - ('b', 'c') | ('c', 'b') => { - results[2] += 1; - } - (_, _) => panic!("unexpected result"), - } - } - - let mut diffs = results - .iter() - .zip(&expected_results) - .map(|(a, b)| (a - b).abs()); - assert!(!diffs.any(|deviation| deviation > 100)); + Some(indices) } } diff --git a/src/seq/slice.rs b/src/seq/slice.rs new file mode 100644 index 00000000000..d48d9d2e9f3 --- /dev/null +++ b/src/seq/slice.rs @@ -0,0 +1,774 @@ +// Copyright 2018-2023 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! `IndexedRandom`, `IndexedMutRandom`, `SliceRandom` + +use super::increasing_uniform::IncreasingUniform; +use super::index; +#[cfg(feature = "alloc")] +use crate::distr::uniform::{SampleBorrow, SampleUniform}; +#[cfg(feature = "alloc")] +use crate::distr::weighted::{Error as WeightError, Weight}; +use crate::Rng; +use core::ops::{Index, IndexMut}; + +/// Extension trait on indexable lists, providing random sampling methods. +/// +/// This trait is implemented on `[T]` slice types. Other types supporting +/// [`std::ops::Index`] may implement this (only [`Self::len`] must be +/// specified). +pub trait IndexedRandom: Index { + /// The length + fn len(&self) -> usize; + + /// True when the length is zero + #[inline] + fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Uniformly sample one element + /// + /// Returns a reference to one uniformly-sampled random element of + /// the slice, or `None` if the slice is empty. + /// + /// For slices, complexity is `O(1)`. + /// + /// # Example + /// + /// ``` + /// use rand::seq::IndexedRandom; + /// + /// let choices = [1, 2, 4, 8, 16, 32]; + /// let mut rng = rand::rng(); + /// println!("{:?}", choices.choose(&mut rng)); + /// assert_eq!(choices[..0].choose(&mut rng), None); + /// ``` + fn choose(&self, rng: &mut R) -> Option<&Self::Output> + where + R: Rng + ?Sized, + { + if self.is_empty() { + None + } else { + Some(&self[rng.random_range(..self.len())]) + } + } + + /// Uniformly sample `amount` distinct elements from self + /// + /// Chooses `amount` elements from the slice at random, without repetition, + /// and in random order. The returned iterator is appropriate both for + /// collection into a `Vec` and filling an existing buffer (see example). + /// + /// In case this API is not sufficiently flexible, use [`index::sample`]. + /// + /// For slices, complexity is the same as [`index::sample`]. + /// + /// # Example + /// ``` + /// use rand::seq::IndexedRandom; + /// + /// let mut rng = &mut rand::rng(); + /// let sample = "Hello, audience!".as_bytes(); + /// + /// // collect the results into a vector: + /// let v: Vec = sample.choose_multiple(&mut rng, 3).cloned().collect(); + /// + /// // store in a buffer: + /// let mut buf = [0u8; 5]; + /// for (b, slot) in sample.choose_multiple(&mut rng, buf.len()).zip(buf.iter_mut()) { + /// *slot = *b; + /// } + /// ``` + #[cfg(feature = "alloc")] + fn choose_multiple(&self, rng: &mut R, amount: usize) -> SliceChooseIter + where + Self::Output: Sized, + R: Rng + ?Sized, + { + let amount = core::cmp::min(amount, self.len()); + SliceChooseIter { + slice: self, + _phantom: Default::default(), + indices: index::sample(rng, self.len(), amount).into_iter(), + } + } + + /// Uniformly sample a fixed-size array of distinct elements from self + /// + /// Chooses `N` elements from the slice at random, without repetition, + /// and in random order. + /// + /// For slices, complexity is the same as [`index::sample_array`]. + /// + /// # Example + /// ``` + /// use rand::seq::IndexedRandom; + /// + /// let mut rng = &mut rand::rng(); + /// let sample = "Hello, audience!".as_bytes(); + /// + /// let a: [u8; 3] = sample.choose_multiple_array(&mut rng).unwrap(); + /// ``` + fn choose_multiple_array(&self, rng: &mut R) -> Option<[Self::Output; N]> + where + Self::Output: Clone + Sized, + R: Rng + ?Sized, + { + let indices = index::sample_array(rng, self.len())?; + Some(indices.map(|index| self[index].clone())) + } + + /// Biased sampling for one element + /// + /// Returns a reference to one element of the slice, sampled according + /// to the provided weights. Returns `None` only if the slice is empty. + /// + /// The specified function `weight` maps each item `x` to a relative + /// likelihood `weight(x)`. The probability of each item being selected is + /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`. + /// + /// For slices of length `n`, complexity is `O(n)`. + /// For more information about the underlying algorithm, + /// see the [`WeightedIndex`] distribution. + /// + /// See also [`choose_weighted_mut`]. + /// + /// # Example + /// + /// ``` + /// use rand::prelude::*; + /// + /// let choices = [('a', 2), ('b', 1), ('c', 1), ('d', 0)]; + /// let mut rng = rand::rng(); + /// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c', + /// // and 'd' will never be printed + /// println!("{:?}", choices.choose_weighted(&mut rng, |item| item.1).unwrap().0); + /// ``` + /// [`choose`]: IndexedRandom::choose + /// [`choose_weighted_mut`]: IndexedMutRandom::choose_weighted_mut + /// [`WeightedIndex`]: crate::distr::weighted::WeightedIndex + #[cfg(feature = "alloc")] + fn choose_weighted( + &self, + rng: &mut R, + weight: F, + ) -> Result<&Self::Output, WeightError> + where + R: Rng + ?Sized, + F: Fn(&Self::Output) -> B, + B: SampleBorrow, + X: SampleUniform + Weight + PartialOrd, + { + use crate::distr::{weighted::WeightedIndex, Distribution}; + let distr = WeightedIndex::new((0..self.len()).map(|idx| weight(&self[idx])))?; + Ok(&self[distr.sample(rng)]) + } + + /// Biased sampling of `amount` distinct elements + /// + /// Similar to [`choose_multiple`], but where the likelihood of each element's + /// inclusion in the output may be specified. The elements are returned in an + /// arbitrary, unspecified order. + /// + /// The specified function `weight` maps each item `x` to a relative + /// likelihood `weight(x)`. The probability of each item being selected is + /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`. + /// + /// If all of the weights are equal, even if they are all zero, each element has + /// an equal likelihood of being selected. + /// + /// This implementation uses `O(length + amount)` space and `O(length)` time + /// if the "nightly" feature is enabled, or `O(length)` space and + /// `O(length + amount * log length)` time otherwise. + /// + /// # Known issues + /// + /// The algorithm currently used to implement this method loses accuracy + /// when small values are used for weights. + /// See [#1476](https://github.com/rust-random/rand/issues/1476). + /// + /// # Example + /// + /// ``` + /// use rand::prelude::*; + /// + /// let choices = [('a', 2), ('b', 1), ('c', 1)]; + /// let mut rng = rand::rng(); + /// // First Draw * Second Draw = total odds + /// // ----------------------- + /// // (50% * 50%) + (25% * 67%) = 41.7% chance that the output is `['a', 'b']` in some order. + /// // (50% * 50%) + (25% * 67%) = 41.7% chance that the output is `['a', 'c']` in some order. + /// // (25% * 33%) + (25% * 33%) = 16.6% chance that the output is `['b', 'c']` in some order. + /// println!("{:?}", choices.choose_multiple_weighted(&mut rng, 2, |item| item.1).unwrap().collect::>()); + /// ``` + /// [`choose_multiple`]: IndexedRandom::choose_multiple + // Note: this is feature-gated on std due to usage of f64::powf. + // If necessary, we may use alloc+libm as an alternative (see PR #1089). + #[cfg(feature = "std")] + fn choose_multiple_weighted( + &self, + rng: &mut R, + amount: usize, + weight: F, + ) -> Result, WeightError> + where + Self::Output: Sized, + R: Rng + ?Sized, + F: Fn(&Self::Output) -> X, + X: Into, + { + let amount = core::cmp::min(amount, self.len()); + Ok(SliceChooseIter { + slice: self, + _phantom: Default::default(), + indices: index::sample_weighted( + rng, + self.len(), + |idx| weight(&self[idx]).into(), + amount, + )? + .into_iter(), + }) + } +} + +/// Extension trait on indexable lists, providing random sampling methods. +/// +/// This trait is implemented automatically for every type implementing +/// [`IndexedRandom`] and [`std::ops::IndexMut`]. +pub trait IndexedMutRandom: IndexedRandom + IndexMut { + /// Uniformly sample one element (mut) + /// + /// Returns a mutable reference to one uniformly-sampled random element of + /// the slice, or `None` if the slice is empty. + /// + /// For slices, complexity is `O(1)`. + fn choose_mut(&mut self, rng: &mut R) -> Option<&mut Self::Output> + where + R: Rng + ?Sized, + { + if self.is_empty() { + None + } else { + let len = self.len(); + Some(&mut self[rng.random_range(..len)]) + } + } + + /// Biased sampling for one element (mut) + /// + /// Returns a mutable reference to one element of the slice, sampled according + /// to the provided weights. Returns `None` only if the slice is empty. + /// + /// The specified function `weight` maps each item `x` to a relative + /// likelihood `weight(x)`. The probability of each item being selected is + /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`. + /// + /// For slices of length `n`, complexity is `O(n)`. + /// For more information about the underlying algorithm, + /// see the [`WeightedIndex`] distribution. + /// + /// See also [`choose_weighted`]. + /// + /// [`choose_mut`]: IndexedMutRandom::choose_mut + /// [`choose_weighted`]: IndexedRandom::choose_weighted + /// [`WeightedIndex`]: crate::distr::weighted::WeightedIndex + #[cfg(feature = "alloc")] + fn choose_weighted_mut( + &mut self, + rng: &mut R, + weight: F, + ) -> Result<&mut Self::Output, WeightError> + where + R: Rng + ?Sized, + F: Fn(&Self::Output) -> B, + B: SampleBorrow, + X: SampleUniform + Weight + PartialOrd, + { + use crate::distr::{weighted::WeightedIndex, Distribution}; + let distr = WeightedIndex::new((0..self.len()).map(|idx| weight(&self[idx])))?; + let index = distr.sample(rng); + Ok(&mut self[index]) + } +} + +/// Extension trait on slices, providing shuffling methods. +/// +/// This trait is implemented on all `[T]` slice types, providing several +/// methods for choosing and shuffling elements. You must `use` this trait: +/// +/// ``` +/// use rand::seq::SliceRandom; +/// +/// let mut rng = rand::rng(); +/// let mut bytes = "Hello, random!".to_string().into_bytes(); +/// bytes.shuffle(&mut rng); +/// let str = String::from_utf8(bytes).unwrap(); +/// println!("{}", str); +/// ``` +/// Example output (non-deterministic): +/// ```none +/// l,nmroHado !le +/// ``` +pub trait SliceRandom: IndexedMutRandom { + /// Shuffle a mutable slice in place. + /// + /// For slices of length `n`, complexity is `O(n)`. + /// The resulting permutation is picked uniformly from the set of all possible permutations. + /// + /// # Example + /// + /// ``` + /// use rand::seq::SliceRandom; + /// + /// let mut rng = rand::rng(); + /// let mut y = [1, 2, 3, 4, 5]; + /// println!("Unshuffled: {:?}", y); + /// y.shuffle(&mut rng); + /// println!("Shuffled: {:?}", y); + /// ``` + fn shuffle(&mut self, rng: &mut R) + where + R: Rng + ?Sized; + + /// Shuffle a slice in place, but exit early. + /// + /// Returns two mutable slices from the source slice. The first contains + /// `amount` elements randomly permuted. The second has the remaining + /// elements that are not fully shuffled. + /// + /// This is an efficient method to select `amount` elements at random from + /// the slice, provided the slice may be mutated. + /// + /// If you only need to choose elements randomly and `amount > self.len()/2` + /// then you may improve performance by taking + /// `amount = self.len() - amount` and using only the second slice. + /// + /// If `amount` is greater than the number of elements in the slice, this + /// will perform a full shuffle. + /// + /// For slices, complexity is `O(m)` where `m = amount`. + fn partial_shuffle( + &mut self, + rng: &mut R, + amount: usize, + ) -> (&mut [Self::Output], &mut [Self::Output]) + where + Self::Output: Sized, + R: Rng + ?Sized; +} + +impl IndexedRandom for [T] { + fn len(&self) -> usize { + self.len() + } +} + +impl + ?Sized> IndexedMutRandom for IR {} + +impl SliceRandom for [T] { + fn shuffle(&mut self, rng: &mut R) + where + R: Rng + ?Sized, + { + if self.len() <= 1 { + // There is no need to shuffle an empty or single element slice + return; + } + self.partial_shuffle(rng, self.len()); + } + + fn partial_shuffle(&mut self, rng: &mut R, amount: usize) -> (&mut [T], &mut [T]) + where + R: Rng + ?Sized, + { + let m = self.len().saturating_sub(amount); + + // The algorithm below is based on Durstenfeld's algorithm for the + // [Fisher–Yates shuffle](https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm) + // for an unbiased permutation. + // It ensures that the last `amount` elements of the slice + // are randomly selected from the whole slice. + + // `IncreasingUniform::next_index()` is faster than `Rng::random_range` + // but only works for 32 bit integers + // So we must use the slow method if the slice is longer than that. + if self.len() < (u32::MAX as usize) { + let mut chooser = IncreasingUniform::new(rng, m as u32); + for i in m..self.len() { + let index = chooser.next_index(); + self.swap(i, index); + } + } else { + for i in m..self.len() { + let index = rng.random_range(..i + 1); + self.swap(i, index); + } + } + let r = self.split_at_mut(m); + (r.1, r.0) + } +} + +/// An iterator over multiple slice elements. +/// +/// This struct is created by +/// [`IndexedRandom::choose_multiple`](trait.IndexedRandom.html#tymethod.choose_multiple). +#[cfg(feature = "alloc")] +#[derive(Debug)] +pub struct SliceChooseIter<'a, S: ?Sized + 'a, T: 'a> { + slice: &'a S, + _phantom: core::marker::PhantomData, + indices: index::IndexVecIntoIter, +} + +#[cfg(feature = "alloc")] +impl<'a, S: Index + ?Sized + 'a, T: 'a> Iterator for SliceChooseIter<'a, S, T> { + type Item = &'a T; + + fn next(&mut self) -> Option { + // TODO: investigate using SliceIndex::get_unchecked when stable + self.indices.next().map(|i| &self.slice[i]) + } + + fn size_hint(&self) -> (usize, Option) { + (self.indices.len(), Some(self.indices.len())) + } +} + +#[cfg(feature = "alloc")] +impl<'a, S: Index + ?Sized + 'a, T: 'a> ExactSizeIterator + for SliceChooseIter<'a, S, T> +{ + fn len(&self) -> usize { + self.indices.len() + } +} + +#[cfg(test)] +mod test { + use super::*; + #[cfg(feature = "alloc")] + use alloc::vec::Vec; + + #[test] + fn test_slice_choose() { + let mut r = crate::test::rng(107); + let chars = [ + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', + ]; + let mut chosen = [0i32; 14]; + // The below all use a binomial distribution with n=1000, p=1/14. + // binocdf(40, 1000, 1/14) ~= 2e-5; 1-binocdf(106, ..) ~= 2e-5 + for _ in 0..1000 { + let picked = *chars.choose(&mut r).unwrap(); + chosen[(picked as usize) - ('a' as usize)] += 1; + } + for count in chosen.iter() { + assert!(40 < *count && *count < 106); + } + + chosen.iter_mut().for_each(|x| *x = 0); + for _ in 0..1000 { + *chosen.choose_mut(&mut r).unwrap() += 1; + } + for count in chosen.iter() { + assert!(40 < *count && *count < 106); + } + + let mut v: [isize; 0] = []; + assert_eq!(v.choose(&mut r), None); + assert_eq!(v.choose_mut(&mut r), None); + } + + #[test] + fn value_stability_slice() { + let mut r = crate::test::rng(413); + let chars = [ + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', + ]; + let mut nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; + + assert_eq!(chars.choose(&mut r), Some(&'l')); + assert_eq!(nums.choose_mut(&mut r), Some(&mut 3)); + + assert_eq!( + &chars.choose_multiple_array(&mut r), + &Some(['f', 'i', 'd', 'b', 'c', 'm', 'j', 'k']) + ); + + #[cfg(feature = "alloc")] + assert_eq!( + &chars + .choose_multiple(&mut r, 8) + .cloned() + .collect::>(), + &['h', 'm', 'd', 'b', 'c', 'e', 'n', 'f'] + ); + + #[cfg(feature = "alloc")] + assert_eq!(chars.choose_weighted(&mut r, |_| 1), Ok(&'i')); + #[cfg(feature = "alloc")] + assert_eq!(nums.choose_weighted_mut(&mut r, |_| 1), Ok(&mut 2)); + + let mut r = crate::test::rng(414); + nums.shuffle(&mut r); + assert_eq!(nums, [5, 11, 0, 8, 7, 12, 6, 4, 9, 3, 1, 2, 10]); + nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; + let res = nums.partial_shuffle(&mut r, 6); + assert_eq!(res.0, &mut [7, 12, 6, 8, 1, 9]); + assert_eq!(res.1, &mut [0, 11, 2, 3, 4, 5, 10]); + } + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_shuffle() { + let mut r = crate::test::rng(108); + let empty: &mut [isize] = &mut []; + empty.shuffle(&mut r); + let mut one = [1]; + one.shuffle(&mut r); + let b: &[_] = &[1]; + assert_eq!(one, b); + + let mut two = [1, 2]; + two.shuffle(&mut r); + assert!(two == [1, 2] || two == [2, 1]); + + fn move_last(slice: &mut [usize], pos: usize) { + // use slice[pos..].rotate_left(1); once we can use that + let last_val = slice[pos]; + for i in pos..slice.len() - 1 { + slice[i] = slice[i + 1]; + } + *slice.last_mut().unwrap() = last_val; + } + let mut counts = [0i32; 24]; + for _ in 0..10000 { + let mut arr: [usize; 4] = [0, 1, 2, 3]; + arr.shuffle(&mut r); + let mut permutation = 0usize; + let mut pos_value = counts.len(); + for i in 0..4 { + pos_value /= 4 - i; + let pos = arr.iter().position(|&x| x == i).unwrap(); + assert!(pos < (4 - i)); + permutation += pos * pos_value; + move_last(&mut arr, pos); + assert_eq!(arr[3], i); + } + for (i, &a) in arr.iter().enumerate() { + assert_eq!(a, i); + } + counts[permutation] += 1; + } + for count in counts.iter() { + // Binomial(10000, 1/24) with average 416.667 + // Octave: binocdf(n, 10000, 1/24) + // 99.9% chance samples lie within this range: + assert!(352 <= *count && *count <= 483, "count: {}", count); + } + } + + #[test] + fn test_partial_shuffle() { + let mut r = crate::test::rng(118); + + let mut empty: [u32; 0] = []; + let res = empty.partial_shuffle(&mut r, 10); + assert_eq!((res.0.len(), res.1.len()), (0, 0)); + + let mut v = [1, 2, 3, 4, 5]; + let res = v.partial_shuffle(&mut r, 2); + assert_eq!((res.0.len(), res.1.len()), (2, 3)); + assert!(res.0[0] != res.0[1]); + // First elements are only modified if selected, so at least one isn't modified: + assert!(res.1[0] == 1 || res.1[1] == 2 || res.1[2] == 3); + } + + #[test] + #[cfg(feature = "alloc")] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_weighted() { + let mut r = crate::test::rng(406); + const N_REPS: u32 = 3000; + let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; + let total_weight = weights.iter().sum::() as f32; + + let verify = |result: [i32; 14]| { + for (i, count) in result.iter().enumerate() { + let exp = (weights[i] * N_REPS) as f32 / total_weight; + let mut err = (*count as f32 - exp).abs(); + if err != 0.0 { + err /= exp; + } + assert!(err <= 0.25); + } + }; + + // choose_weighted + fn get_weight(item: &(u32, T)) -> u32 { + item.0 + } + let mut chosen = [0i32; 14]; + let mut items = [(0u32, 0usize); 14]; // (weight, index) + for (i, item) in items.iter_mut().enumerate() { + *item = (weights[i], i); + } + for _ in 0..N_REPS { + let item = items.choose_weighted(&mut r, get_weight).unwrap(); + chosen[item.1] += 1; + } + verify(chosen); + + // choose_weighted_mut + let mut items = [(0u32, 0i32); 14]; // (weight, count) + for (i, item) in items.iter_mut().enumerate() { + *item = (weights[i], 0); + } + for _ in 0..N_REPS { + items.choose_weighted_mut(&mut r, get_weight).unwrap().1 += 1; + } + for (ch, item) in chosen.iter_mut().zip(items.iter()) { + *ch = item.1; + } + verify(chosen); + + // Check error cases + let empty_slice = &mut [10][0..0]; + assert_eq!( + empty_slice.choose_weighted(&mut r, |_| 1), + Err(WeightError::InvalidInput) + ); + assert_eq!( + empty_slice.choose_weighted_mut(&mut r, |_| 1), + Err(WeightError::InvalidInput) + ); + assert_eq!( + ['x'].choose_weighted_mut(&mut r, |_| 0), + Err(WeightError::InsufficientNonZero) + ); + assert_eq!( + [0, -1].choose_weighted_mut(&mut r, |x| *x), + Err(WeightError::InvalidWeight) + ); + assert_eq!( + [-1, 0].choose_weighted_mut(&mut r, |x| *x), + Err(WeightError::InvalidWeight) + ); + } + + #[test] + #[cfg(feature = "std")] + fn test_multiple_weighted_edge_cases() { + use super::*; + + let mut rng = crate::test::rng(413); + + // Case 1: One of the weights is 0 + let choices = [('a', 2), ('b', 1), ('c', 0)]; + for _ in 0..100 { + let result = choices + .choose_multiple_weighted(&mut rng, 2, |item| item.1) + .unwrap() + .collect::>(); + + assert_eq!(result.len(), 2); + assert!(!result.iter().any(|val| val.0 == 'c')); + } + + // Case 2: All of the weights are 0 + let choices = [('a', 0), ('b', 0), ('c', 0)]; + let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1); + assert_eq!(r.unwrap_err(), WeightError::InsufficientNonZero); + + // Case 3: Negative weights + let choices = [('a', -1), ('b', 1), ('c', 1)]; + let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1); + assert_eq!(r.unwrap_err(), WeightError::InvalidWeight); + + // Case 4: Empty list + let choices = []; + let r = choices.choose_multiple_weighted(&mut rng, 0, |_: &()| 0); + assert_eq!(r.unwrap().count(), 0); + + // Case 5: NaN weights + let choices = [('a', f64::NAN), ('b', 1.0), ('c', 1.0)]; + let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1); + assert_eq!(r.unwrap_err(), WeightError::InvalidWeight); + + // Case 6: +infinity weights + let choices = [('a', f64::INFINITY), ('b', 1.0), ('c', 1.0)]; + for _ in 0..100 { + let result = choices + .choose_multiple_weighted(&mut rng, 2, |item| item.1) + .unwrap() + .collect::>(); + assert_eq!(result.len(), 2); + assert!(result.iter().any(|val| val.0 == 'a')); + } + + // Case 7: -infinity weights + let choices = [('a', f64::NEG_INFINITY), ('b', 1.0), ('c', 1.0)]; + let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1); + assert_eq!(r.unwrap_err(), WeightError::InvalidWeight); + + // Case 8: -0 weights + let choices = [('a', -0.0), ('b', 1.0), ('c', 1.0)]; + let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1); + assert!(r.is_ok()); + } + + #[test] + #[cfg(feature = "std")] + fn test_multiple_weighted_distributions() { + use super::*; + + // The theoretical probabilities of the different outcomes are: + // AB: 0.5 * 0.667 = 0.3333 + // AC: 0.5 * 0.333 = 0.1667 + // BA: 0.333 * 0.75 = 0.25 + // BC: 0.333 * 0.25 = 0.0833 + // CA: 0.167 * 0.6 = 0.1 + // CB: 0.167 * 0.4 = 0.0667 + let choices = [('a', 3), ('b', 2), ('c', 1)]; + let mut rng = crate::test::rng(414); + + let mut results = [0i32; 3]; + let expected_results = [5833, 2667, 1500]; + for _ in 0..10000 { + let result = choices + .choose_multiple_weighted(&mut rng, 2, |item| item.1) + .unwrap() + .collect::>(); + + assert_eq!(result.len(), 2); + + match (result[0].0, result[1].0) { + ('a', 'b') | ('b', 'a') => { + results[0] += 1; + } + ('a', 'c') | ('c', 'a') => { + results[1] += 1; + } + ('b', 'c') | ('c', 'b') => { + results[2] += 1; + } + (_, _) => panic!("unexpected result"), + } + } + + let mut diffs = results + .iter() + .zip(&expected_results) + .map(|(a, b)| (a - b).abs()); + assert!(!diffs.any(|deviation| deviation > 100)); + } +} diff --git a/utils/ziggurat_tables.py b/utils/ziggurat_tables.py index 88cfdab6ba2..87a766ccc36 100755 --- a/utils/ziggurat_tables.py +++ b/utils/ziggurat_tables.py @@ -10,7 +10,7 @@ # except according to those terms. # This creates the tables used for distributions implemented using the -# ziggurat algorithm in `rand::distributions;`. They are +# ziggurat algorithm in `rand::distr;`. They are # (basically) the tables as used in the ZIGNOR variant (Doornik 2005). # They are changed rarely, so the generated file should be checked in # to git.