diff --git a/.github/workflows/benches.yml b/.github/workflows/benches.yml index 22b4baa8dce..368023aba9d 100644 --- a/.github/workflows/benches.yml +++ b/.github/workflows/benches.yml @@ -5,13 +5,11 @@ on: branches: [ master ] paths-ignore: - "**.md" - - "distr_test/**" - "examples/**" pull_request: branches: [ master ] paths-ignore: - "**.md" - - "distr_test/**" - "examples/**" defaults: diff --git a/.github/workflows/distr_test.yml b/.github/workflows/distr_test.yml deleted file mode 100644 index f2b7f814c98..00000000000 --- a/.github/workflows/distr_test.yml +++ /dev/null @@ -1,43 +0,0 @@ -name: distr_test - -on: - push: - branches: [ master ] - paths-ignore: - - "**.md" - - "benches/**" - - "examples/**" - pull_request: - branches: [ master ] - paths-ignore: - - "**.md" - - "benches/**" - - "examples/**" - -defaults: - run: - working-directory: ./distr_test - -jobs: - clippy-fmt: - name: "distr_test: Check Clippy and rustfmt" - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@master - with: - toolchain: stable - components: clippy, rustfmt - - name: Rustfmt - run: cargo fmt -- --check - - name: Clippy - run: cargo clippy --all-targets -- -D warnings - ks-tests: - name: "distr_test: Run Komogorov Smirnov tests" - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@master - with: - toolchain: nightly - - run: cargo test --release diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 293d5f4942d..ad0cf1425cc 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -6,13 +6,11 @@ on: paths-ignore: - "**.md" - "benches/**" - - "distr_test/**" pull_request: branches: [ master, '0.[0-9]+' ] paths-ignore: - "**.md" - "benches/**" - - "distr_test/**" permissions: contents: read # to fetch code (actions/checkout) @@ -47,8 +45,6 @@ jobs: run: cargo doc --all-features --no-deps - name: rand_core run: cargo doc --all-features --package rand_core --no-deps - - name: rand_distr - run: cargo doc --all-features --package rand_distr --no-deps - name: rand_chacha run: cargo doc --all-features --package rand_chacha --no-deps - name: rand_pcg @@ -122,11 +118,6 @@ jobs: cargo test --target ${{ matrix.target }} --manifest-path rand_core/Cargo.toml cargo test --target ${{ matrix.target }} --manifest-path rand_core/Cargo.toml --no-default-features cargo test --target ${{ matrix.target }} --manifest-path rand_core/Cargo.toml --no-default-features --features=os_rng - - name: Test rand_distr - run: | - cargo test --target ${{ matrix.target }} --manifest-path rand_distr/Cargo.toml --features=serde - cargo test --target ${{ matrix.target }} --manifest-path rand_distr/Cargo.toml --no-default-features - cargo test --target ${{ matrix.target }} --manifest-path rand_distr/Cargo.toml --no-default-features --features=std,std_math - name: Test rand_pcg run: cargo test --target ${{ matrix.target }} --manifest-path rand_pcg/Cargo.toml --features=serde - name: Test rand_chacha @@ -162,7 +153,6 @@ jobs: cross test --no-fail-fast --target ${{ matrix.target }} --features=serde,log,small_rng cross test --no-fail-fast --target ${{ matrix.target }} --examples cross test --no-fail-fast --target ${{ matrix.target }} --manifest-path rand_core/Cargo.toml - cross test --no-fail-fast --target ${{ matrix.target }} --manifest-path rand_distr/Cargo.toml --features=serde cross test --no-fail-fast --target ${{ matrix.target }} --manifest-path rand_pcg/Cargo.toml --features=serde cross test --no-fail-fast --target ${{ matrix.target }} --manifest-path rand_chacha/Cargo.toml @@ -182,7 +172,6 @@ jobs: cargo miri test --manifest-path rand_core/Cargo.toml cargo miri test --manifest-path rand_core/Cargo.toml --features=serde cargo miri test --manifest-path rand_core/Cargo.toml --no-default-features - #cargo miri test --manifest-path rand_distr/Cargo.toml # no unsafe and lots of slow tests cargo miri test --manifest-path rand_pcg/Cargo.toml --features=serde cargo miri test --manifest-path rand_chacha/Cargo.toml --no-default-features diff --git a/CHANGELOG.md b/CHANGELOG.md index fded9d79aca..891db26a9f9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,25 @@ A [separate changelog is kept for rand_core](rand_core/CHANGELOG.md). You may also find the [Upgrade Guide](https://rust-random.github.io/book/update.html) useful. +## [Unreleased] +### Deprecated +- Deprecate `rand::rngs::mock` module and `StepRng` generator (#1634) + +## [0.9.1] - 2025-04-17 +### Security and unsafe +- Revise "not a crypto library" policy again (#1565) +- Remove `zerocopy` dependency from `rand` (#1579) + +### Fixes +- Fix feature `simd_support` for recent nightly rust (#1586) + +### Changes +- Allow `fn rand::seq::index::sample_weighted` and `fn IndexedRandom::choose_multiple_weighted` to return fewer than `amount` results (#1623), reverting an undocumented change (#1382) to the previous release. + +### Additions +- Add `rand::distr::Alphabetic` distribution. (#1587) +- Re-export `rand_core` (#1604) + ## [0.9.0] - 2025-01-27 ### Security and unsafe - Policy: "rand is not a crypto library" (#1514) diff --git a/Cargo.toml b/Cargo.toml index 956f12741fc..523c8d3c867 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rand" -version = "0.9.0" +version = "0.9.1" authors = ["The Rand Project Developers", "The Rust Project Developers"] license = "MIT OR Apache-2.0" readme = "README.md" @@ -43,7 +43,7 @@ alloc = [] os_rng = ["rand_core/os_rng"] # Option (requires nightly Rust): experimental SIMD support -simd_support = ["zerocopy/simd-nightly"] +simd_support = [] # Option (enabled by default): enable StdRng std_rng = ["dep:rand_chacha"] @@ -65,7 +65,6 @@ log = ["dep:log"] [workspace] members = [ "rand_core", - "rand_distr", "rand_chacha", "rand_pcg", ] @@ -76,7 +75,6 @@ rand_core = { path = "rand_core", version = "0.9.0", default-features = false } log = { version = "0.4.4", optional = true } serde = { version = "1.0.103", features = ["derive"], optional = true } rand_chacha = { path = "rand_chacha", version = "0.9.0", default-features = false, optional = true } -zerocopy = { version = "0.8.0", default-features = false, features = ["simd"] } [dev-dependencies] rand_pcg = { path = "rand_pcg", version = "0.9.0" } diff --git a/README.md b/README.md index 740807a9669..e8b6fe3d337 100644 --- a/README.md +++ b/README.md @@ -16,12 +16,12 @@ Rand is a set of crates supporting (pseudo-)random generators: With broad support for random value generation and random processes: -- [`StandardUniform`](https://docs.rs/rand/latest/rand/distributions/struct.StandardUniform.html) random value sampling, - [`Uniform`](https://docs.rs/rand/latest/rand/distributions/struct.Uniform.html)-ranged value sampling +- [`StandardUniform`](https://docs.rs/rand/latest/rand/distr/struct.StandardUniform.html) random value sampling, + [`Uniform`](https://docs.rs/rand/latest/rand/distr/struct.Uniform.html)-ranged value sampling and [more](https://docs.rs/rand/latest/rand/distr/index.html) - Samplers for a large number of non-uniform random number distributions via our own [`rand_distr`](https://docs.rs/rand_distr) and via - the [`statrs`](https://docs.rs/statrs/0.13.0/statrs/) + the [`statrs`](https://docs.rs/statrs) - Random processes (mostly choose and shuffle) via [`rand::seq`](https://docs.rs/rand/latest/rand/seq/index.html) traits All with: @@ -39,12 +39,11 @@ Rand **is not**: not simplicity. If you prefer a small-and-simple library, there are alternatives including [fastrand](https://crates.io/crates/fastrand) and [oorandom](https://crates.io/crates/oorandom). -- A cryptography library. Rand provides functionality for generating - unpredictable random data (potentially applicable depending on requirements) - but does not provide high-level cryptography functionality. - -Rand is a community project and cannot provide legally-binding guarantees of -security. +- Primarily a cryptographic library. `rand` does provide some generators which + aim to support unpredictable value generation under certain constraints; + see [SECURITY.md](https://github.com/rust-random/rand/blob/master/SECURITY.md) for details. + Users are expected to determine for themselves + whether `rand`'s functionality meets their own security requirements. Documentation: @@ -56,11 +55,11 @@ Documentation: ## Versions Rand is *mature* (suitable for general usage, with infrequent breaking releases -which minimise breakage) but not yet at 1.0. Current versions are: +which minimise breakage) but not yet at 1.0. Current `MAJOR.MINOR` versions are: - Version 0.9 was released in January 2025. -See the [CHANGELOG](CHANGELOG.md) or [Upgrade Guide](https://rust-random.github.io/book/update.html) for more details. +See the [CHANGELOG](https://github.com/rust-random/rand/blob/master/CHANGELOG.md) or [Upgrade Guide](https://rust-random.github.io/book/update.html) for more details. ## Crate Features @@ -70,6 +69,7 @@ Rand is built with these features enabled by default: - `alloc` (implied by `std`) enables functionality requiring an allocator - `os_rng` (implied by `std`) enables `rngs::OsRng`, using the [getrandom] crate - `std_rng` enables inclusion of `StdRng`, `ThreadRng` +- `small_rng` enables inclusion of the `SmallRng` PRNG Optionally, the following dependencies can be enabled: @@ -77,10 +77,14 @@ Optionally, the following dependencies can be enabled: Additionally, these features configure Rand: -- `small_rng` enables inclusion of the `SmallRng` PRNG - `nightly` includes some additions requiring nightly Rust - `simd_support` (experimental) enables sampling of SIMD values (uniformly random SIMD integers and floats), requiring nightly Rust +- `unbiased` use unbiased sampling for algorithms supporting this option: Uniform distribution. + + (By default, bias affecting no more than one in 2^48 samples is accepted.) + + Note: enabling this option is expected to affect reproducibility of results. Note that nightly features are not stable and therefore not all library and compiler versions will be compatible. This is especially true of Rand's @@ -97,23 +101,20 @@ Many (but not all) algorithms are intended to have reproducible output. Read mor The Rand library supports a variety of CPU architectures. Platform integration is outsourced to [getrandom]. -### WASM support +### WebAssembly support -Seeding entropy from OS on WASM target `wasm32-unknown-unknown` is not -*automatically* supported by `rand` or `getrandom`. If you are fine with -seeding the generator manually, you can disable the `os_rng` feature -and use the methods on the `SeedableRng` trait. To enable seeding from OS, -either use a different target such as `wasm32-wasi` or add a direct -dependency on [getrandom] with the `js` feature (if the target supports -JavaScript). See -[getrandom#WebAssembly support](https://docs.rs/getrandom/latest/getrandom/#webassembly-support). +The [WASI](https://github.com/WebAssembly/WASI/tree/main) and Emscripten +targets are directly supported. The `wasm32-unknown-unknown` target is not +*automatically* supported. To enable support for this target, refer to the +[`getrandom` documentation for WebAssembly](https://docs.rs/getrandom/latest/getrandom/#webassembly-support). +Alternatively, the `os_rng` feature may be disabled. # License Rand is distributed under the terms of both the MIT license and the Apache License (Version 2.0). -See [LICENSE-APACHE](LICENSE-APACHE) and [LICENSE-MIT](LICENSE-MIT), and -[COPYRIGHT](COPYRIGHT) for details. +See [LICENSE-APACHE](https://github.com/rust-random/rand/blob/master/LICENSE-APACHE) and [LICENSE-MIT](https://github.com/rust-random/rand/blob/master/LICENSE-MIT), and +[COPYRIGHT](https://github.com/rust-random/rand/blob/master/COPYRIGHT) for details. [getrandom]: https://crates.io/crates/getrandom diff --git a/SECURITY.md b/SECURITY.md index 26cf7c12fc5..f1a61b0d208 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -10,12 +10,24 @@ security. ### Marker traits Rand provides the marker traits `CryptoRng`, `TryCryptoRng` and -`CryptoBlockRng`. Generators implementing one of these traits and used in a way -which meets the following additional constraints: - -- Instances of seedable RNGs (those implementing `SeedableRng`) are - constructed with cryptographically secure seed values -- The state (memory) of the RNG and its seed value are not exposed +`CryptoBlockRng`. Generators (RNGs) implementing one of these traits which are +used according to these additional constraints: + +- The generator may be constructed using `std::default::Default` where the + generator supports this trait. Note that generators should *only* support + `Default` where the `default()` instance is appropriately seeded: for + example `OsRng` has no state and thus has a trivial `default()` instance + while `ThreadRng::default()` returns a handle to a thread-local instance + seeded using `OsRng`. +- The generator may be constructed using `rand_core::SeedableRng` in any of + the following ways where the generator supports this trait: + + - Via `SeedableRng::from_seed` using a cryptographically secure seed value + - Via `SeedableRng::from_rng` or `try_from_rng` using a cryptographically + secure source `rng` + - Via `SeedableRng::from_os_rng` or `try_from_os_rng` +- The state (memory) of the generator and its seed value (or source `rng`) are + not exposed are expected to provide the following: @@ -34,48 +46,44 @@ are expected to provide the following: `OsRng` is a stateless "generator" implemented via [getrandom]. As such, it has no possible state to leak and cannot be improperly seeded. -`ThreadRng` will periodically reseed itself, thus placing an upper bound on the -number of bits of output from an instance before any advantage an attacker may -have gained through state-compromising side-channel attacks is lost. +`StdRng` is a `CryptoRng` and `SeedableRng` using a pseudo-random algorithm +selected for good security and performance qualities. Since it does not offer +reproducibility of output, its algorithm may be changed in any release version. + +`ChaCha12Rng` and `ChaCha20Rng` are selected pseudo-random generators +distributed by the `rand` project which meet the requirements of the `CryptoRng` +trait and implement `SeedableRng` with a commitment to reproducibility of +results. + +`ThreadRng` is a conveniently-packaged generator over `StdRng` offering +automatic seeding from `OsRng`, periodic reseeding and thread locality. +This random source is intended to offer a good compromise between cryptographic +security, fast generation with reasonably low memory and initialization cost +overheads, and robustness against misuse. [getrandom]: https://crates.io/crates/getrandom ### Distributions -Additionally, derivations from such an RNG (including the `Rng` trait, -implementations of the `Distribution` trait, and `seq` algorithms) should not -introduce significant bias other than that expected from the operation in -question (e.g. bias from a weighted distribution). +Methods of the `Rng` trait, functionality of the `rand::seq` module and +implementators of the `Distribution` trait are expected, while using a +cryptographically secure `CryptoRng` instance meeting the above constraints, +to not introduce significant bias to their operation beyond what would be +expected of the operation. Note that the usage of 'significant' here permits +some bias, as noted for example in the documentation of the `Uniform` +distribution. ## Supported Versions -We will attempt to uphold these premises in the following crate versions, -provided that only the latest patch version is used, and with potential -exceptions for theoretical issues without a known exploit: - -| Crate | Versions | Exceptions | -| ----- | -------- | ---------- | -| `rand` | 0.8 | | -| `rand` | 0.7 | | -| `rand` | 0.5, 0.6 | Jitter | -| `rand` | 0.4 | Jitter, ISAAC | -| `rand_core` | 0.2 - 0.6 | | -| `rand_chacha` | 0.1 - 0.3 | | +We aim to provide security fixes in the form of a new patch version for the +latest release version of `rand` and its dependencies `rand_core` and +`rand_chacha`, as well as for prior major and minor releases which were, at some +time during the previous 12 months, the latest release version. -Explanation of exceptions: - -- Jitter: `JitterRng` is used as an entropy source when the primary source - fails; this source may not be secure against side-channel attacks, see #699. -- ISAAC: the [ISAAC](https://burtleburtle.net/bob/rand/isaacafa.html) RNG used - to implement `ThreadRng` is difficult to analyse and thus cannot provide - strong assertions of security. - -## Known issues +## Reporting a Vulnerability -In `rand` version 0.3 (0.3.18 and later), if `OsRng` fails, `ThreadRng` is -seeded from the system time in an insecure manner. +If you have discovered a security vulnerability in this project, please report it privately. **Do not disclose it as a public issue.** This gives us time to work with you to fix the issue before public exposure, reducing the chance that the exploit will be used before a patch is released. -## Reporting a Vulnerability +Please disclose it at [security advisory](https://github.com/rust-random/rand/security/advisories/new). -To report a vulnerability, [open a new issue](https://github.com/rust-random/rand/issues/new). -Once the issue is resolved, the vulnerability should be [reported to RustSec](https://github.com/RustSec/advisory-db/blob/master/CONTRIBUTING.md). +This project is maintained by a team of volunteers on a reasonable-effort basis. As such, please give us at least 90 days to work on a fix before public exposure. diff --git a/benches/Cargo.toml b/benches/Cargo.toml index a143bff3c02..adb9aadd84b 100644 --- a/benches/Cargo.toml +++ b/benches/Cargo.toml @@ -4,13 +4,16 @@ version = "0.1.0" edition = "2021" publish = false +[features] +# Option (requires nightly Rust): experimental SIMD support +simd_support = ["rand/simd_support"] + [dependencies] [dev-dependencies] rand = { path = "..", features = ["small_rng", "nightly"] } rand_pcg = { path = "../rand_pcg" } rand_chacha = { path = "../rand_chacha" } -rand_distr = { path = "../rand_distr" } criterion = "0.5" criterion-cycles-per-byte = "0.6" @@ -22,10 +25,6 @@ harness = false name = "bool" harness = false -[[bench]] -name = "distr" -harness = false - [[bench]] name = "generators" harness = false @@ -38,6 +37,10 @@ harness = false name = "shuffle" harness = false +[[bench]] +name = "simd" +harness = false + [[bench]] name = "standard" harness = false diff --git a/benches/benches/distr.rs b/benches/benches/distr.rs deleted file mode 100644 index 3a76211972d..00000000000 --- a/benches/benches/distr.rs +++ /dev/null @@ -1,194 +0,0 @@ -// Copyright 2018-2021 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -use criterion::{criterion_group, criterion_main, Criterion, Throughput}; -use criterion_cycles_per_byte::CyclesPerByte; - -use rand::prelude::*; -use rand_distr::weighted::*; -use rand_distr::*; - -// At this time, distributions are optimised for 64-bit platforms. -use rand_pcg::Pcg64Mcg; - -const ITER_ELTS: u64 = 100; - -macro_rules! distr_int { - ($group:ident, $fnn:expr, $ty:ty, $distr:expr) => { - $group.bench_function($fnn, |c| { - let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); - let distr = $distr; - - c.iter(|| distr.sample(&mut rng)); - }); - }; -} - -macro_rules! distr_float { - ($group:ident, $fnn:expr, $ty:ty, $distr:expr) => { - $group.bench_function($fnn, |c| { - let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); - let distr = $distr; - - c.iter(|| Distribution::<$ty>::sample(&distr, &mut rng)); - }); - }; -} - -macro_rules! distr_arr { - ($group:ident, $fnn:expr, $ty:ty, $distr:expr) => { - $group.bench_function($fnn, |c| { - let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); - let distr = $distr; - - c.iter(|| Distribution::<$ty>::sample(&distr, &mut rng)); - }); - }; -} - -macro_rules! sample_binomial { - ($group:ident, $name:expr, $n:expr, $p:expr) => { - distr_int!($group, $name, u64, Binomial::new($n, $p).unwrap()) - }; -} - -fn bench(c: &mut Criterion) { - let mut g = c.benchmark_group("exp"); - distr_float!(g, "exp", f64, Exp::new(1.23 * 4.56).unwrap()); - distr_float!(g, "exp1_specialized", f64, Exp1); - distr_float!(g, "exp1_general", f64, Exp::new(1.).unwrap()); - g.finish(); - - let mut g = c.benchmark_group("normal"); - distr_float!(g, "normal", f64, Normal::new(-1.23, 4.56).unwrap()); - distr_float!(g, "standardnormal_specialized", f64, StandardNormal); - distr_float!(g, "standardnormal_general", f64, Normal::new(0., 1.).unwrap()); - distr_float!(g, "log_normal", f64, LogNormal::new(-1.23, 4.56).unwrap()); - g.throughput(Throughput::Elements(ITER_ELTS)); - g.bench_function("iter", |c| { - use core::f64::consts::{E, PI}; - let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); - let distr = Normal::new(-E, PI).unwrap(); - - c.iter(|| { - distr - .sample_iter(&mut rng) - .take(ITER_ELTS as usize) - .fold(0.0, |a, r| a + r) - }); - }); - g.finish(); - - let mut g = c.benchmark_group("skew_normal"); - distr_float!(g, "shape_zero", f64, SkewNormal::new(0.0, 1.0, 0.0).unwrap()); - distr_float!(g, "shape_positive", f64, SkewNormal::new(0.0, 1.0, 100.0).unwrap()); - distr_float!(g, "shape_negative", f64, SkewNormal::new(0.0, 1.0, -100.0).unwrap()); - g.finish(); - - let mut g = c.benchmark_group("gamma"); - distr_float!(g, "large_shape", f64, Gamma::new(10., 1.0).unwrap()); - distr_float!(g, "small_shape", f64, Gamma::new(0.1, 1.0).unwrap()); - g.finish(); - - let mut g = c.benchmark_group("beta"); - distr_float!(g, "small_param", f64, Beta::new(0.1, 0.1).unwrap()); - distr_float!(g, "large_param_similar", f64, Beta::new(101., 95.).unwrap()); - distr_float!(g, "large_param_different", f64, Beta::new(10., 1000.).unwrap()); - distr_float!(g, "mixed_param", f64, Beta::new(0.5, 100.).unwrap()); - g.finish(); - - let mut g = c.benchmark_group("cauchy"); - distr_float!(g, "cauchy", f64, Cauchy::new(4.2, 6.9).unwrap()); - g.finish(); - - let mut g = c.benchmark_group("triangular"); - distr_float!(g, "triangular", f64, Triangular::new(0., 1., 0.9).unwrap()); - g.finish(); - - let mut g = c.benchmark_group("geometric"); - distr_int!(g, "geometric", u64, Geometric::new(0.5).unwrap()); - distr_int!(g, "standard_geometric", u64, StandardGeometric); - g.finish(); - - let mut g = c.benchmark_group("weighted"); - distr_int!(g, "i8", usize, WeightedIndex::new([1i8, 2, 3, 4, 12, 0, 2, 1]).unwrap()); - distr_int!(g, "u32", usize, WeightedIndex::new([1u32, 2, 3, 4, 12, 0, 2, 1]).unwrap()); - distr_int!(g, "f64", usize, WeightedIndex::new([1.0f64, 0.001, 1.0 / 3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap()); - distr_int!(g, "large_set", usize, WeightedIndex::new((0..10000).rev().chain(1..10001)).unwrap()); - distr_int!(g, "alias_method_i8", usize, WeightedAliasIndex::new(vec![1i8, 2, 3, 4, 12, 0, 2, 1]).unwrap()); - distr_int!(g, "alias_method_u32", usize, WeightedAliasIndex::new(vec![1u32, 2, 3, 4, 12, 0, 2, 1]).unwrap()); - distr_int!( - g, - "alias_method_f64", - usize, - WeightedAliasIndex::new(vec![1.0f64, 0.001, 1.0 / 3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap() - ); - distr_int!( - g, - "alias_method_large_set", - usize, - WeightedAliasIndex::new((0..10000).rev().chain(1..10001).collect()).unwrap() - ); - g.finish(); - - let mut g = c.benchmark_group("binomial"); - sample_binomial!(g, "small", 1_000_000, 1e-30); - sample_binomial!(g, "1", 1, 0.9); - sample_binomial!(g, "10", 10, 0.9); - sample_binomial!(g, "100", 100, 0.99); - sample_binomial!(g, "1000", 1000, 0.01); - sample_binomial!(g, "1e12", 1_000_000_000_000, 0.2); - g.finish(); - - let mut g = c.benchmark_group("poisson"); - for lambda in [1f64, 4.0, 10.0, 100.0].into_iter() { - let name = format!("{lambda}"); - distr_float!(g, name, f64, Poisson::new(lambda).unwrap()); - } - g.throughput(Throughput::Elements(ITER_ELTS)); - g.bench_function("variable", |c| { - let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); - let ldistr = Uniform::new(0.1, 10.0).unwrap(); - - c.iter(|| { - let l = rng.sample(ldistr); - let distr = Poisson::new(l * l).unwrap(); - Distribution::::sample_iter(&distr, &mut rng) - .take(ITER_ELTS as usize) - .fold(0.0, |a, r| a + r) - }) - }); - g.finish(); - - let mut g = c.benchmark_group("zipf"); - distr_float!(g, "zipf", f64, Zipf::new(10.0, 1.5).unwrap()); - distr_float!(g, "zeta", f64, Zeta::new(1.5).unwrap()); - g.finish(); - - let mut g = c.benchmark_group("bernoulli"); - g.bench_function("bernoulli", |c| { - let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); - let distr = Bernoulli::new(0.18).unwrap(); - c.iter(|| distr.sample(&mut rng)) - }); - g.finish(); - - let mut g = c.benchmark_group("unit"); - distr_arr!(g, "circle", [f64; 2], UnitCircle); - distr_arr!(g, "sphere", [f64; 3], UnitSphere); - g.finish(); -} - -criterion_group!( - name = benches; - config = Criterion::default().with_measurement(CyclesPerByte) - .warm_up_time(core::time::Duration::from_secs(1)) - .measurement_time(core::time::Duration::from_secs(2)); - targets = bench -); -criterion_main!(benches); diff --git a/benches/benches/generators.rs b/benches/benches/generators.rs index 64325ceb9ee..31f08a02408 100644 --- a/benches/benches/generators.rs +++ b/benches/benches/generators.rs @@ -10,8 +10,8 @@ use core::time::Duration; use criterion::measurement::WallTime; use criterion::{black_box, criterion_group, criterion_main, BenchmarkGroup, Criterion}; use rand::prelude::*; +use rand::rngs::OsRng; use rand::rngs::ReseedingRng; -use rand::rngs::{mock::StepRng, OsRng}; use rand_chacha::rand_core::UnwrapErr; use rand_chacha::{ChaCha12Rng, ChaCha20Core, ChaCha20Rng, ChaCha8Rng}; use rand_pcg::{Pcg32, Pcg64, Pcg64Dxsm, Pcg64Mcg}; @@ -39,7 +39,6 @@ pub fn random_bytes(c: &mut Criterion) { }); } - bench(&mut g, "step", StepRng::new(0, 1)); bench(&mut g, "pcg32", Pcg32::from_rng(&mut rand::rng())); bench(&mut g, "pcg64", Pcg64::from_rng(&mut rand::rng())); bench(&mut g, "pcg64mcg", Pcg64Mcg::from_rng(&mut rand::rng())); @@ -68,7 +67,6 @@ pub fn random_u32(c: &mut Criterion) { }); } - bench(&mut g, "step", StepRng::new(0, 1)); bench(&mut g, "pcg32", Pcg32::from_rng(&mut rand::rng())); bench(&mut g, "pcg64", Pcg64::from_rng(&mut rand::rng())); bench(&mut g, "pcg64mcg", Pcg64Mcg::from_rng(&mut rand::rng())); @@ -97,7 +95,6 @@ pub fn random_u64(c: &mut Criterion) { }); } - bench(&mut g, "step", StepRng::new(0, 1)); bench(&mut g, "pcg32", Pcg32::from_rng(&mut rand::rng())); bench(&mut g, "pcg64", Pcg64::from_rng(&mut rand::rng())); bench(&mut g, "pcg64mcg", Pcg64Mcg::from_rng(&mut rand::rng())); diff --git a/benches/benches/simd.rs b/benches/benches/simd.rs new file mode 100644 index 00000000000..f1723245977 --- /dev/null +++ b/benches/benches/simd.rs @@ -0,0 +1,76 @@ +// Copyright 2018-2023 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! Generating SIMD / wide types + +#![cfg_attr(feature = "simd_support", feature(portable_simd))] + +use criterion::{criterion_group, criterion_main, Criterion}; + +criterion_group!( + name = benches; + config = Criterion::default(); + targets = simd +); +criterion_main!(benches); + +#[cfg(not(feature = "simd_support"))] +pub fn simd(_: &mut Criterion) {} + +#[cfg(feature = "simd_support")] +pub fn simd(c: &mut Criterion) { + use rand::prelude::*; + use rand_pcg::Pcg64Mcg; + + let mut g = c.benchmark_group("random_simd"); + + g.bench_function("u128", |b| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + b.iter(|| rng.random::()); + }); + + g.bench_function("m128i", |b| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + b.iter(|| rng.random::()); + }); + + g.bench_function("m256i", |b| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + b.iter(|| rng.random::()); + }); + + g.bench_function("m512i", |b| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + b.iter(|| rng.random::()); + }); + + g.bench_function("u64x2", |b| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + b.iter(|| rng.random::()); + }); + + g.bench_function("u32x4", |b| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + b.iter(|| rng.random::()); + }); + + g.bench_function("u32x8", |b| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + b.iter(|| rng.random::()); + }); + + g.bench_function("u16x8", |b| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + b.iter(|| rng.random::()); + }); + + g.bench_function("u8x16", |b| { + let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); + b.iter(|| rng.random::()); + }); +} diff --git a/benches/benches/standard.rs b/benches/benches/standard.rs index ac38f0225f8..de95fb5ba69 100644 --- a/benches/benches/standard.rs +++ b/benches/benches/standard.rs @@ -9,9 +9,8 @@ use core::time::Duration; use criterion::measurement::WallTime; use criterion::{criterion_group, criterion_main, BenchmarkGroup, Criterion}; -use rand::distr::{Alphanumeric, StandardUniform}; +use rand::distr::{Alphabetic, Alphanumeric, Open01, OpenClosed01, StandardUniform}; use rand::prelude::*; -use rand_distr::{Open01, OpenClosed01}; use rand_pcg::Pcg64Mcg; criterion_group!( @@ -25,7 +24,7 @@ fn bench_ty(g: &mut BenchmarkGroup, name: &str) where D: Distribution + Default, { - g.throughput(criterion::Throughput::Bytes(size_of::() as u64)); + g.throughput(criterion::Throughput::Bytes(core::mem::size_of::() as u64)); g.bench_function(name, |b| { let mut rng = Pcg64Mcg::from_rng(&mut rand::rng()); @@ -53,6 +52,7 @@ pub fn bench(c: &mut Criterion) { do_ty!(f32, f64); do_ty!(char); + bench_ty::(&mut g, "Alphabetic"); bench_ty::(&mut g, "Alphanumeric"); bench_ty::(&mut g, "Open01/f32"); diff --git a/benches/benches/uniform.rs b/benches/benches/uniform.rs index ab1b0ed4149..1f1ed49681d 100644 --- a/benches/benches/uniform.rs +++ b/benches/benches/uniform.rs @@ -8,12 +8,16 @@ //! Implement benchmarks for uniform distributions over integer types +#![cfg_attr(feature = "simd_support", feature(portable_simd))] + use core::time::Duration; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use rand::distr::uniform::{SampleRange, Uniform}; use rand::prelude::*; use rand_chacha::ChaCha8Rng; use rand_pcg::{Pcg32, Pcg64}; +#[cfg(feature = "simd_support")] +use std::simd::{num::SimdUint, Simd}; const WARM_UP_TIME: Duration = Duration::from_millis(1000); const MEASUREMENT_TIME: Duration = Duration::from_secs(3); @@ -21,53 +25,97 @@ const SAMPLE_SIZE: usize = 100_000; const N_RESAMPLES: usize = 10_000; macro_rules! sample { - ($R:ty, $T:ty, $U:ty, $g:expr) => { + (@range $T:ty, $U:ty, 1, $rng:ident) => {{ + assert_eq!(<$T>::BITS, <$U>::BITS); + let bits = (<$T>::BITS / 2); + let mask = (1 as $U).wrapping_neg() >> bits; + let x = $rng.random::<$U>(); + ((x >> bits) * (x & mask)) as $T + }}; + + (@range $T:ty, $U:ty, $len:tt, $rng:ident) => {{ + let bits = (<$T>::BITS / 2); + let mask = Simd::splat((1 as $U).wrapping_neg() >> bits); + let bits = Simd::splat(bits as $U); + let x = $rng.random::>(); + ((x >> bits) * (x & mask)).cast() + }}; + + (@MIN $T:ty, 1) => { + <$T>::MIN + }; + + (@MIN $T:ty, $len:tt) => { + Simd::<$T, $len>::splat(<$T>::MIN) + }; + + (@wrapping_add $lhs:expr, $rhs:expr, 1) => { + $lhs.wrapping_add($rhs) + }; + + (@wrapping_add $lhs:expr, $rhs:expr, $len:tt) => { + ($lhs + $rhs) + }; + + ($R:ty, $T:ty, $U:ty, $len:tt, $g:expr) => { $g.bench_function(BenchmarkId::new(stringify!($R), "single"), |b| { let mut rng = <$R>::from_rng(&mut rand::rng()); - let x = rng.random::<$U>(); - let bits = (<$T>::BITS / 2); - let mask = (1 as $U).wrapping_neg() >> bits; - let range = (x >> bits) * (x & mask); - let low = <$T>::MIN; - let high = low.wrapping_add(range as $T); + let range = sample!(@range $T, $U, $len, rng); + let low = sample!(@MIN $T, $len); + let high = sample!(@wrapping_add low, range, $len); b.iter(|| (low..=high).sample_single(&mut rng)); }); $g.bench_function(BenchmarkId::new(stringify!($R), "distr"), |b| { let mut rng = <$R>::from_rng(&mut rand::rng()); - let x = rng.random::<$U>(); - let bits = (<$T>::BITS / 2); - let mask = (1 as $U).wrapping_neg() >> bits; - let range = (x >> bits) * (x & mask); - let low = <$T>::MIN; - let high = low.wrapping_add(range as $T); - let dist = Uniform::<$T>::new_inclusive(<$T>::MIN, high).unwrap(); + let range = sample!(@range $T, $U, $len, rng); + let low = sample!(@MIN $T, $len); + let high = sample!(@wrapping_add low, range, $len); + let dist = Uniform::new_inclusive(low, high).unwrap(); b.iter(|| dist.sample(&mut rng)); }); }; - ($c:expr, $T:ty, $U:ty) => {{ - let mut g = $c.benchmark_group(concat!("sample", stringify!($T))); + // Entrypoint: + // $T is the output type (integer) + // $U is the unsigned version of the output type + // $len is the width for SIMD or 1 for non-SIMD + ($c:expr, $T:ty, $U:ty, $len:tt) => {{ + let mut g = $c.benchmark_group(concat!("sample_", stringify!($T), "x", stringify!($len))); g.sample_size(SAMPLE_SIZE); g.warm_up_time(WARM_UP_TIME); g.measurement_time(MEASUREMENT_TIME); g.nresamples(N_RESAMPLES); - sample!(SmallRng, $T, $U, g); - sample!(ChaCha8Rng, $T, $U, g); - sample!(Pcg32, $T, $U, g); - sample!(Pcg64, $T, $U, g); + sample!(SmallRng, $T, $U, $len, g); + sample!(ChaCha8Rng, $T, $U, $len, g); + sample!(Pcg32, $T, $U, $len, g); + sample!(Pcg64, $T, $U, $len, g); g.finish(); }}; } fn sample(c: &mut Criterion) { - sample!(c, i8, u8); - sample!(c, i16, u16); - sample!(c, i32, u32); - sample!(c, i64, u64); - sample!(c, i128, u128); + sample!(c, i8, u8, 1); + sample!(c, i16, u16, 1); + sample!(c, i32, u32, 1); + sample!(c, i64, u64, 1); + sample!(c, i128, u128, 1); + #[cfg(feature = "simd_support")] + sample!(c, u8, u8, 8); + #[cfg(feature = "simd_support")] + sample!(c, u8, u8, 16); + #[cfg(feature = "simd_support")] + sample!(c, u8, u8, 32); + #[cfg(feature = "simd_support")] + sample!(c, u8, u8, 64); + #[cfg(feature = "simd_support")] + sample!(c, i16, u16, 8); + #[cfg(feature = "simd_support")] + sample!(c, i16, u16, 16); + #[cfg(feature = "simd_support")] + sample!(c, i16, u16, 32); } criterion_group! { diff --git a/distr_test/Cargo.toml b/distr_test/Cargo.toml deleted file mode 100644 index d9d7fe2c274..00000000000 --- a/distr_test/Cargo.toml +++ /dev/null @@ -1,15 +0,0 @@ -[package] -name = "distr_test" -version = "0.1.0" -edition = "2021" -publish = false - -[dev-dependencies] -rand_distr = { path = "../rand_distr", version = "0.5.0", default-features = false, features = ["alloc"] } -rand = { path = "..", version = "0.9.0", features = ["small_rng"] } -num-traits = "0.2.19" -# Special functions for testing distributions -special = "0.11.0" -spfunc = "0.1.0" -# Cdf implementation -statrs = "0.17.1" diff --git a/distr_test/tests/cdf.rs b/distr_test/tests/cdf.rs deleted file mode 100644 index f417c630ae2..00000000000 --- a/distr_test/tests/cdf.rs +++ /dev/null @@ -1,454 +0,0 @@ -// Copyright 2021 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -use core::f64; - -use special::{Beta, Gamma, Primitive}; -use statrs::distribution::ContinuousCDF; -use statrs::distribution::DiscreteCDF; - -mod ks; -use ks::test_continuous; -use ks::test_discrete; - -#[test] -fn normal() { - let parameters = [ - (0.0, 1.0), - (0.0, 0.1), - (1.0, 10.0), - (1.0, 100.0), - (-1.0, 0.00001), - (-1.0, 0.0000001), - ]; - - for (seed, (mean, std_dev)) in parameters.into_iter().enumerate() { - test_continuous( - seed as u64, - rand_distr::Normal::new(mean, std_dev).unwrap(), - |x| { - statrs::distribution::Normal::new(mean, std_dev) - .unwrap() - .cdf(x) - }, - ); - } -} - -#[test] -fn cauchy() { - let parameters = [ - (0.0, 1.0), - (0.0, 0.1), - (1.0, 10.0), - (1.0, 100.0), - (-1.0, 0.00001), - (-1.0, 0.0000001), - ]; - - for (seed, (median, scale)) in parameters.into_iter().enumerate() { - let dist = rand_distr::Cauchy::new(median, scale).unwrap(); - test_continuous(seed as u64, dist, |x| { - statrs::distribution::Cauchy::new(median, scale) - .unwrap() - .cdf(x) - }); - } -} - -#[test] -fn uniform() { - fn cdf(x: f64, a: f64, b: f64) -> f64 { - if x < a { - 0.0 - } else if x < b { - (x - a) / (b - a) - } else { - 1.0 - } - } - - let parameters = [(0.0, 1.0), (-1.0, 1.0), (0.0, 100.0), (-100.0, 100.0)]; - - for (seed, (a, b)) in parameters.into_iter().enumerate() { - let dist = rand_distr::Uniform::new(a, b).unwrap(); - test_continuous(seed as u64, dist, |x| cdf(x, a, b)); - } -} - -#[test] -fn log_normal() { - let parameters = [ - (0.0, 1.0), - (0.0, 0.1), - (0.5, 0.7), - (1.0, 10.0), - (1.0, 100.0), - ]; - - for (seed, (mean, std_dev)) in parameters.into_iter().enumerate() { - let dist = rand_distr::LogNormal::new(mean, std_dev).unwrap(); - test_continuous(seed as u64, dist, |x| { - statrs::distribution::LogNormal::new(mean, std_dev) - .unwrap() - .cdf(x) - }); - } -} - -#[test] -fn pareto() { - let parameters = [ - (1.0, 1.0), - (1.0, 0.1), - (1.0, 10.0), - (1.0, 100.0), - (0.1, 1.0), - (10.0, 1.0), - (100.0, 1.0), - ]; - - for (seed, (scale, alpha)) in parameters.into_iter().enumerate() { - let dist = rand_distr::Pareto::new(scale, alpha).unwrap(); - test_continuous(seed as u64, dist, |x| { - statrs::distribution::Pareto::new(scale, alpha) - .unwrap() - .cdf(x) - }); - } -} - -#[test] -fn exp() { - fn cdf(x: f64, lambda: f64) -> f64 { - 1.0 - (-lambda * x).exp() - } - - let parameters = [0.5, 1.0, 7.5, 32.0, 100.0]; - - for (seed, lambda) in parameters.into_iter().enumerate() { - let dist = rand_distr::Exp::new(lambda).unwrap(); - test_continuous(seed as u64, dist, |x| cdf(x, lambda)); - } -} - -#[test] -fn weibull() { - fn cdf(x: f64, lambda: f64, k: f64) -> f64 { - if x < 0.0 { - return 0.0; - } - - 1.0 - (-(x / lambda).powf(k)).exp() - } - - let parameters = [ - (0.5, 1.0), - (1.0, 1.0), - (10.0, 0.1), - (0.1, 10.0), - (15.0, 20.0), - (1000.0, 0.01), - ]; - - for (seed, (lambda, k)) in parameters.into_iter().enumerate() { - let dist = rand_distr::Weibull::new(lambda, k).unwrap(); - test_continuous(seed as u64, dist, |x| cdf(x, lambda, k)); - } -} - -#[test] -fn gumbel() { - fn cdf(x: f64, mu: f64, beta: f64) -> f64 { - (-(-(x - mu) / beta).exp()).exp() - } - - let parameters = [ - (0.0, 1.0), - (1.0, 2.0), - (-1.0, 0.5), - (10.0, 0.1), - (100.0, 0.0001), - ]; - - for (seed, (mu, beta)) in parameters.into_iter().enumerate() { - let dist = rand_distr::Gumbel::new(mu, beta).unwrap(); - test_continuous(seed as u64, dist, |x| cdf(x, mu, beta)); - } -} - -#[test] -fn frechet() { - fn cdf(x: f64, alpha: f64, s: f64, m: f64) -> f64 { - if x < m { - return 0.0; - } - - (-((x - m) / s).powf(-alpha)).exp() - } - - let parameters = [ - (0.5, 2.0, 1.0), - (1.0, 1.0, 1.0), - (10.0, 0.1, 1.0), - (100.0, 0.0001, 1.0), - (0.9999, 2.0, 1.0), - ]; - - for (seed, (alpha, s, m)) in parameters.into_iter().enumerate() { - let dist = rand_distr::Frechet::new(m, s, alpha).unwrap(); - test_continuous(seed as u64, dist, |x| cdf(x, alpha, s, m)); - } -} - -#[test] -fn gamma() { - fn cdf(x: f64, shape: f64, scale: f64) -> f64 { - if x < 0.0 { - return 0.0; - } - - (x / scale).inc_gamma(shape) - } - - let parameters = [ - (0.5, 2.0), - (1.0, 1.0), - (10.0, 0.1), - (100.0, 0.0001), - (0.9999, 2.0), - ]; - - for (seed, (shape, scale)) in parameters.into_iter().enumerate() { - let dist = rand_distr::Gamma::new(shape, scale).unwrap(); - test_continuous(seed as u64, dist, |x| cdf(x, shape, scale)); - } -} - -#[test] -fn chi_squared() { - fn cdf(x: f64, k: f64) -> f64 { - if x < 0.0 { - return 0.0; - } - - (x / 2.0).inc_gamma(k / 2.0) - } - - let parameters = [0.1, 1.0, 2.0, 10.0, 100.0, 1000.0]; - - for (seed, k) in parameters.into_iter().enumerate() { - let dist = rand_distr::ChiSquared::new(k).unwrap(); - test_continuous(seed as u64, dist, |x| cdf(x, k)); - } -} -#[test] -fn studend_t() { - fn cdf(x: f64, df: f64) -> f64 { - let h = df / (df + x.powi(2)); - let ib = 0.5 * h.inc_beta(df / 2.0, 0.5, 0.5.ln_beta(df / 2.0)); - if x < 0.0 { - ib - } else { - 1.0 - ib - } - } - - let parameters = [1.0, 10.0, 50.0]; - - for (seed, df) in parameters.into_iter().enumerate() { - let dist = rand_distr::StudentT::new(df).unwrap(); - test_continuous(seed as u64, dist, |x| cdf(x, df)); - } -} - -#[test] -fn fisher_f() { - fn cdf(x: f64, m: f64, n: f64) -> f64 { - if (m == 1.0 && x <= 0.0) || x < 0.0 { - 0.0 - } else { - let k = m * x / (m * x + n); - let d1 = m / 2.0; - let d2 = n / 2.0; - k.inc_beta(d1, d2, d1.ln_beta(d2)) - } - } - - let parameters = [(1.0, 1.0), (1.0, 2.0), (2.0, 1.0), (50.0, 1.0)]; - - for (seed, (m, n)) in parameters.into_iter().enumerate() { - let dist = rand_distr::FisherF::new(m, n).unwrap(); - test_continuous(seed as u64, dist, |x| cdf(x, m, n)); - } -} - -#[test] -fn beta() { - fn cdf(x: f64, alpha: f64, beta: f64) -> f64 { - if x < 0.0 { - return 0.0; - } - if x > 1.0 { - return 1.0; - } - let ln_beta_ab = alpha.ln_beta(beta); - x.inc_beta(alpha, beta, ln_beta_ab) - } - - let parameters = [(0.5, 0.5), (2.0, 3.5), (10.0, 1.0), (100.0, 50.0)]; - - for (seed, (alpha, beta)) in parameters.into_iter().enumerate() { - let dist = rand_distr::Beta::new(alpha, beta).unwrap(); - test_continuous(seed as u64, dist, |x| cdf(x, alpha, beta)); - } -} - -#[test] -fn triangular() { - fn cdf(x: f64, a: f64, b: f64, c: f64) -> f64 { - if x <= a { - 0.0 - } else if a < x && x <= c { - (x - a).powi(2) / ((b - a) * (c - a)) - } else if c < x && x < b { - 1.0 - (b - x).powi(2) / ((b - a) * (b - c)) - } else { - 1.0 - } - } - - let parameters = [ - (0.0, 1.0, 0.0001), - (0.0, 1.0, 0.9999), - (0.0, 1.0, 0.5), - (0.0, 100.0, 50.0), - (-100.0, 100.0, 0.0), - ]; - - for (seed, (a, b, c)) in parameters.into_iter().enumerate() { - let dist = rand_distr::Triangular::new(a, b, c).unwrap(); - test_continuous(seed as u64, dist, |x| cdf(x, a, b, c)); - } -} - -fn binomial_cdf(k: i64, p: f64, n: u64) -> f64 { - if k < 0 { - return 0.0; - } - let k = k as u64; - if k >= n { - return 1.0; - } - - let a = (n - k) as f64; - let b = k as f64 + 1.0; - - let q = 1.0 - p; - - let ln_beta_ab = a.ln_beta(b); - - q.inc_beta(a, b, ln_beta_ab) -} - -#[test] -fn binomial() { - let parameters = [ - (0.5, 10), - (0.5, 100), - (0.1, 10), - (0.0000001, 1000000), - (0.0000001, 10), - (0.9999, 2), - ]; - - for (seed, (p, n)) in parameters.into_iter().enumerate() { - test_discrete(seed as u64, rand_distr::Binomial::new(n, p).unwrap(), |k| { - binomial_cdf(k, p, n) - }); - } -} - -#[test] -fn geometric() { - fn cdf(k: i64, p: f64) -> f64 { - if k < 0 { - 0.0 - } else { - 1.0 - (1.0 - p).powi(1 + k as i32) - } - } - - let parameters = [0.3, 0.5, 0.7, 0.0000001, 0.9999]; - - for (seed, p) in parameters.into_iter().enumerate() { - let dist = rand_distr::Geometric::new(p).unwrap(); - test_discrete(seed as u64, dist, |k| cdf(k, p)); - } -} - -#[test] -fn hypergeometric() { - fn cdf(x: i64, n: u64, k: u64, n_: u64) -> f64 { - let min = if n_ + k > n { n_ + k - n } else { 0 }; - let max = k.min(n_); - if x < min as i64 { - return 0.0; - } else if x >= max as i64 { - return 1.0; - } - - (min..x as u64 + 1).fold(0.0, |acc, k_| { - acc + (ln_binomial(k, k_) + ln_binomial(n - k, n_ - k_) - ln_binomial(n, n_)).exp() - }) - } - - let parameters = [ - (15, 13, 10), - (25, 15, 5), - (60, 10, 7), - (70, 20, 50), - (100, 50, 10), - (100, 50, 49), - ]; - - for (seed, (n, k, n_)) in parameters.into_iter().enumerate() { - let dist = rand_distr::Hypergeometric::new(n, k, n_).unwrap(); - test_discrete(seed as u64, dist, |x| cdf(x, n, k, n_)); - } -} - -#[test] -fn poisson() { - use rand_distr::Poisson; - let parameters = [ - 0.1, 1.0, 7.5, - 45.0, // 1e9, passed case but too slow - // 1.844E+19, // fail case - ]; - - for (seed, lambda) in parameters.into_iter().enumerate() { - let dist = Poisson::new(lambda).unwrap(); - let analytic = statrs::distribution::Poisson::new(lambda).unwrap(); - test_discrete::, _>(seed as u64, dist, |k| { - if k < 0 { - 0.0 - } else { - analytic.cdf(k as u64) - } - }); - } -} - -fn ln_factorial(n: u64) -> f64 { - (n as f64 + 1.0).lgamma().0 -} - -fn ln_binomial(n: u64, k: u64) -> f64 { - ln_factorial(n) - ln_factorial(k) - ln_factorial(n - k) -} diff --git a/distr_test/tests/ks/mod.rs b/distr_test/tests/ks/mod.rs deleted file mode 100644 index ab94db6e1f4..00000000000 --- a/distr_test/tests/ks/mod.rs +++ /dev/null @@ -1,137 +0,0 @@ -// Copyright 2021 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -// [1] Nonparametric Goodness-of-Fit Tests for Discrete Null Distributions -// by Taylor B. Arnold and John W. Emerson -// http://www.stat.yale.edu/~jay/EmersonMaterials/DiscreteGOF.pdf - -#![allow(dead_code)] - -use num_traits::AsPrimitive; -use rand::SeedableRng; -use rand_distr::Distribution; - -/// Empirical Cumulative Distribution Function (ECDF) -struct Ecdf { - sorted_samples: Vec, -} - -impl Ecdf { - fn new(mut samples: Vec) -> Self { - samples.sort_by(|a, b| a.partial_cmp(b).unwrap()); - Self { - sorted_samples: samples, - } - } - - /// Returns the step points of the ECDF - /// The ECDF is a step function that increases by 1/n at each sample point - /// The function is continuous from the right, so we give the bigger value at the step points - /// First point is (-inf, 0.0), last point is (max(samples), 1.0) - fn step_points(&self) -> Vec<(f64, f64)> { - let mut points = Vec::with_capacity(self.sorted_samples.len() + 1); - let mut last = f64::NEG_INFINITY; - let mut count = 0; - let n = self.sorted_samples.len() as f64; - for &x in &self.sorted_samples { - if x != last { - points.push((last, count as f64 / n)); - last = x; - } - count += 1; - } - points.push((last, count as f64 / n)); - points - } -} - -fn kolmogorov_smirnov_statistic_continuous(ecdf: Ecdf, cdf: impl Fn(f64) -> f64) -> f64 { - // We implement equation (3) from [1] - - let mut max_diff: f64 = 0.; - - let step_points = ecdf.step_points(); // x_i in the paper - for i in 1..step_points.len() { - let (x_i, f_i) = step_points[i]; - let (_, f_i_1) = step_points[i - 1]; - let cdf_i = cdf(x_i); - let max_1 = (cdf_i - f_i).abs(); - let max_2 = (cdf_i - f_i_1).abs(); - - max_diff = max_diff.max(max_1).max(max_2); - } - max_diff -} - -fn kolmogorov_smirnov_statistic_discrete(ecdf: Ecdf, cdf: impl Fn(i64) -> f64) -> f64 { - // We implement equation (4) from [1] - - let mut max_diff: f64 = 0.; - - let step_points = ecdf.step_points(); // x_i in the paper - for i in 1..step_points.len() { - let (x_i, f_i) = step_points[i]; - let (_, f_i_1) = step_points[i - 1]; - let max_1 = (cdf(x_i as i64) - f_i).abs(); - let max_2 = (cdf(x_i as i64 - 1) - f_i_1).abs(); // -1 is the same as -epsilon, because we have integer support - - max_diff = max_diff.max(max_1).max(max_2); - } - max_diff -} - -const SAMPLE_SIZE: u64 = 1_000_000; - -fn critical_value() -> f64 { - // If the sampler is correct, we expect less than 0.001 false positives (alpha = 0.001). - // Passing this does not prove that the sampler is correct but is a good indication. - 1.95 / (SAMPLE_SIZE as f64).sqrt() -} - -fn sample_ecdf(seed: u64, dist: impl Distribution) -> Ecdf -where - T: AsPrimitive, -{ - let mut rng = rand::rngs::SmallRng::seed_from_u64(seed); - let samples = (0..SAMPLE_SIZE) - .map(|_| dist.sample(&mut rng).as_()) - .collect(); - Ecdf::new(samples) -} - -/// Tests a distribution against an analytical CDF. -/// The CDF has to be continuous. -pub fn test_continuous(seed: u64, dist: impl Distribution, cdf: impl Fn(f64) -> f64) { - let ecdf = sample_ecdf(seed, dist); - let ks_statistic = kolmogorov_smirnov_statistic_continuous(ecdf, cdf); - - let critical_value = critical_value(); - - println!("KS statistic: {}", ks_statistic); - println!("Critical value: {}", critical_value); - assert!(ks_statistic < critical_value); -} - -/// Tests a distribution over integers against an analytical CDF. -/// The analytical CDF must not have jump points which are not integers. -pub fn test_discrete(seed: u64, dist: D, cdf: F) -where - I: AsPrimitive, - D: Distribution, - F: Fn(i64) -> f64, -{ - let ecdf = sample_ecdf(seed, dist); - let ks_statistic = kolmogorov_smirnov_statistic_discrete(ecdf, cdf); - - // This critical value is bigger than it could be for discrete distributions, but because of large sample sizes this should not matter too much - let critical_value = critical_value(); - - println!("KS statistic: {}", ks_statistic); - println!("Critical value: {}", critical_value); - assert!(ks_statistic < critical_value); -} diff --git a/distr_test/tests/skew_normal.rs b/distr_test/tests/skew_normal.rs deleted file mode 100644 index 0e6b7b3a028..00000000000 --- a/distr_test/tests/skew_normal.rs +++ /dev/null @@ -1,266 +0,0 @@ -// Copyright 2021 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -mod ks; -use ks::test_continuous; -use special::Primitive; - -#[test] -fn skew_normal() { - fn cdf(x: f64, location: f64, scale: f64, shape: f64) -> f64 { - let norm = (x - location) / scale; - phi(norm) - 2.0 * owen_t(norm, shape) - } - - let parameters = [(0.0, 1.0, 5.0), (1.0, 10.0, -5.0), (-1.0, 0.00001, 0.0)]; - - for (seed, (location, scale, shape)) in parameters.into_iter().enumerate() { - let dist = rand_distr::SkewNormal::new(location, scale, shape).unwrap(); - test_continuous(seed as u64, dist, |x| cdf(x, location, scale, shape)); - } -} - -/// [1] Patefield, M. (2000). Fast and Accurate Calculation of Owen’s T Function. -/// Journal of Statistical Software, 5(5), 1–25. -/// https://doi.org/10.18637/jss.v005.i05 -/// -/// This function is ported to Rust from the Fortran code provided in the paper -fn owen_t(h: f64, a: f64) -> f64 { - let absh = h.abs(); - let absa = a.abs(); - let ah = absa * absh; - - let mut t; - if absa <= 1.0 { - t = tf(absh, absa, ah); - } else if absh <= 0.67 { - t = 0.25 - znorm1(absh) * znorm1(ah) - tf(ah, 1.0 / absa, absh); - } else { - let normh = znorm2(absh); - let normah = znorm2(ah); - t = 0.5 * (normh + normah) - normh * normah - tf(ah, 1.0 / absa, absh); - } - - if a < 0.0 { - t = -t; - } - - fn tf(h: f64, a: f64, ah: f64) -> f64 { - let rtwopi = 0.159_154_943_091_895_35; - let rrtpi = 0.398_942_280_401_432_7; - - let c2 = [ - 0.999_999_999_999_999_9, - -0.999_999_999_999_888, - 0.999_999_999_982_907_5, - -0.999_999_998_962_825, - 0.999_999_966_604_593_7, - -0.999_999_339_862_724_7, - 0.999_991_256_111_369_6, - -0.999_917_776_244_633_8, - 0.999_428_355_558_701_4, - -0.996_973_117_207_23, - 0.987_514_480_372_753, - -0.959_158_579_805_728_8, - 0.892_463_055_110_067_1, - -0.768_934_259_904_64, - 0.588_935_284_684_846_9, - -0.383_803_451_604_402_55, - 0.203_176_017_010_453, - -8.281_363_160_700_499e-2, - 2.416_798_473_575_957_8e-2, - -4.467_656_666_397_183e-3, - 3.914_116_940_237_383_6e-4, - ]; - - let pts = [ - 3.508_203_967_645_171_6e-3, - 3.127_904_233_803_075_6e-2, - 8.526_682_628_321_945e-2, - 0.162_450_717_308_122_77, - 0.258_511_960_491_254_36, - 0.368_075_538_406_975_3, - 0.485_010_929_056_047, - 0.602_775_141_526_185_7, - 0.714_778_842_177_532_3, - 0.814_755_109_887_601, - 0.897_110_297_559_489_7, - 0.957_238_080_859_442_6, - 0.991_788_329_746_297, - ]; - - let wts = [ - 1.883_143_811_532_350_3e-2, - 1.856_708_624_397_765e-2, - 1.804_209_346_122_338_5e-2, - 1.726_382_960_639_875_2e-2, - 1.624_321_997_598_985_8e-2, - 1.499_459_203_411_670_5e-2, - 1.353_547_446_966_209e-2, - 1.188_635_160_582_016_5e-2, - 1.007_037_724_277_743_2e-2, - 8.113_054_574_229_958e-3, - 6.041_900_952_847_024e-3, - 3.886_221_701_074_205_7e-3, - 1.679_303_108_454_609e-3, - ]; - - let hrange = [ - 0.02, 0.06, 0.09, 0.125, 0.26, 0.4, 0.6, 1.6, 1.7, 2.33, 2.4, 3.36, 3.4, 4.8, - ]; - let arange = [0.025, 0.09, 0.15, 0.36, 0.5, 0.9, 0.99999]; - - let select = [ - [1, 1, 2, 13, 13, 13, 13, 13, 13, 13, 13, 16, 16, 16, 9], - [1, 2, 2, 3, 3, 5, 5, 14, 14, 15, 15, 16, 16, 16, 9], - [2, 2, 3, 3, 3, 5, 5, 15, 15, 15, 15, 16, 16, 16, 10], - [2, 2, 3, 5, 5, 5, 5, 7, 7, 16, 16, 16, 16, 16, 10], - [2, 3, 3, 5, 5, 6, 6, 8, 8, 17, 17, 17, 12, 12, 11], - [2, 3, 5, 5, 5, 6, 6, 8, 8, 17, 17, 17, 12, 12, 12], - [2, 3, 4, 4, 6, 6, 8, 8, 17, 17, 17, 17, 17, 12, 12], - [2, 3, 4, 4, 6, 6, 18, 18, 18, 18, 17, 17, 17, 12, 12], - ]; - - let ihint = hrange.iter().position(|&r| h < r).unwrap_or(14); - - let iaint = arange.iter().position(|&r| a < r).unwrap_or(7); - - let icode = select[iaint][ihint]; - let m = [ - 2, 3, 4, 5, 7, 10, 12, 18, 10, 20, 30, 20, 4, 7, 8, 20, 13, 0, - ][icode - 1]; - let method = [1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 3, 4, 4, 4, 4, 5, 6][icode - 1]; - - match method { - 1 => { - let hs = -0.5 * h * h; - let dhs = hs.exp(); - let as_ = a * a; - let mut j = 1; - let mut jj = 1; - let mut aj = rtwopi * a; - let mut tf = rtwopi * a.atan(); - let mut dj = dhs - 1.0; - let mut gj = hs * dhs; - loop { - tf += dj * aj / (jj as f64); - if j >= m { - return tf; - } - j += 1; - jj += 2; - aj *= as_; - dj = gj - dj; - gj *= hs / (j as f64); - } - } - 2 => { - let maxii = m + m + 1; - let mut ii = 1; - let mut tf = 0.0; - let hs = h * h; - let as_ = -a * a; - let mut vi = rrtpi * a * (-0.5 * ah * ah).exp(); - let mut z = znorm1(ah) / h; - let y = 1.0 / hs; - loop { - tf += z; - if ii >= maxii { - tf *= rrtpi * (-0.5 * hs).exp(); - return tf; - } - z = y * (vi - (ii as f64) * z); - vi *= as_; - ii += 2; - } - } - 3 => { - let mut i = 1; - let mut ii = 1; - let mut tf = 0.0; - let hs = h * h; - let as_ = a * a; - let mut vi = rrtpi * a * (-0.5 * ah * ah).exp(); - let mut zi = znorm1(ah) / h; - let y = 1.0 / hs; - loop { - tf += zi * c2[i - 1]; - if i > m { - tf *= rrtpi * (-0.5 * hs).exp(); - return tf; - } - zi = y * ((ii as f64) * zi - vi); - vi *= as_; - i += 1; - ii += 2; - } - } - 4 => { - let maxii = m + m + 1; - let mut ii = 1; - let mut tf = 0.0; - let hs = h * h; - let as_ = -a * a; - let mut ai = rtwopi * a * (-0.5 * hs * (1.0 - as_)).exp(); - let mut yi = 1.0; - loop { - tf += ai * yi; - if ii >= maxii { - return tf; - } - ii += 2; - yi = (1.0 - hs * yi) / (ii as f64); - ai *= as_; - } - } - 5 => { - let mut tf = 0.0; - let as_ = a * a; - let hs = -0.5 * h * h; - for i in 0..m { - let r = 1.0 + as_ * pts[i]; - tf += wts[i] * (hs * r).exp() / r; - } - tf *= a; - tf - } - 6 => { - let normh = znorm2(h); - let mut tf = 0.5 * normh * (1.0 - normh); - let y = 1.0 - a; - let r = (y / (1.0 + a)).atan(); - if r != 0.0 { - tf -= rtwopi * r * (-0.5 * y * h * h / r).exp(); - } - tf - } - _ => 0.0, - } - } - - // P(0 ≤ Z ≤ x) - fn znorm1(x: f64) -> f64 { - phi(x) - 0.5 - } - - // P(x ≤ Z < ∞) - fn znorm2(x: f64) -> f64 { - 1.0 - phi(x) - } - - t -} - -fn normal_cdf(x: f64, mean: f64, std_dev: f64) -> f64 { - 0.5 * ((mean - x) / (std_dev * core::f64::consts::SQRT_2)).erfc() -} - -/// standard normal cdf -fn phi(x: f64) -> f64 { - normal_cdf(x, 0.0, 1.0) -} diff --git a/distr_test/tests/weighted.rs b/distr_test/tests/weighted.rs deleted file mode 100644 index 73df7beb9bc..00000000000 --- a/distr_test/tests/weighted.rs +++ /dev/null @@ -1,235 +0,0 @@ -// Copyright 2024 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -mod ks; -use ks::test_discrete; -use rand::distr::Distribution; -use rand::seq::{IndexedRandom, IteratorRandom}; -use rand_distr::weighted::*; - -/// Takes the unnormalized pdf and creates the cdf of a discrete distribution -fn make_cdf(num: usize, f: impl Fn(i64) -> f64) -> impl Fn(i64) -> f64 { - let mut cdf = Vec::with_capacity(num); - let mut ac = 0.0; - for i in 0..num { - ac += f(i as i64); - cdf.push(ac); - } - - let frac = 1.0 / ac; - for x in &mut cdf { - *x *= frac; - } - - move |i| { - if i < 0 { - 0.0 - } else { - cdf[i as usize] - } - } -} - -#[test] -fn weighted_index() { - fn test_weights(num: usize, weight: impl Fn(i64) -> f64) { - let distr = WeightedIndex::new((0..num).map(|i| weight(i as i64))).unwrap(); - test_discrete(0, distr, make_cdf(num, weight)); - } - - test_weights(100, |_| 1.0); - test_weights(100, |i| ((i + 1) as f64).ln()); - test_weights(100, |i| i as f64); - test_weights(100, |i| (i as f64).powi(3)); - test_weights(100, |i| 1.0 / ((i + 1) as f64)); -} - -#[test] -fn weighted_alias_index() { - fn test_weights(num: usize, weight: impl Fn(i64) -> f64) { - let weights = (0..num).map(|i| weight(i as i64)).collect(); - let distr = WeightedAliasIndex::new(weights).unwrap(); - test_discrete(0, distr, make_cdf(num, weight)); - } - - test_weights(100, |_| 1.0); - test_weights(100, |i| ((i + 1) as f64).ln()); - test_weights(100, |i| i as f64); - test_weights(100, |i| (i as f64).powi(3)); - test_weights(100, |i| 1.0 / ((i + 1) as f64)); -} - -#[test] -fn weighted_tree_index() { - fn test_weights(num: usize, weight: impl Fn(i64) -> f64) { - let distr = WeightedTreeIndex::new((0..num).map(|i| weight(i as i64))).unwrap(); - test_discrete(0, distr, make_cdf(num, weight)); - } - - test_weights(100, |_| 1.0); - test_weights(100, |i| ((i + 1) as f64).ln()); - test_weights(100, |i| i as f64); - test_weights(100, |i| (i as f64).powi(3)); - test_weights(100, |i| 1.0 / ((i + 1) as f64)); -} - -#[test] -fn choose_weighted_indexed() { - struct Adapter f64>(Vec, F); - impl f64> Distribution for Adapter { - fn sample(&self, rng: &mut R) -> i64 { - *IndexedRandom::choose_weighted(&self.0[..], rng, |i| (self.1)(*i)).unwrap() - } - } - - fn test_weights(num: usize, weight: impl Fn(i64) -> f64) { - let distr = Adapter((0..num).map(|i| i as i64).collect(), &weight); - test_discrete(0, distr, make_cdf(num, &weight)); - } - - test_weights(100, |_| 1.0); - test_weights(100, |i| ((i + 1) as f64).ln()); - test_weights(100, |i| i as f64); - test_weights(100, |i| (i as f64).powi(3)); - test_weights(100, |i| 1.0 / ((i + 1) as f64)); -} - -#[test] -fn choose_one_weighted_indexed() { - struct Adapter f64>(Vec, F); - impl f64> Distribution for Adapter { - fn sample(&self, rng: &mut R) -> i64 { - *IndexedRandom::choose_multiple_weighted(&self.0[..], rng, 1, |i| (self.1)(*i)) - .unwrap() - .next() - .unwrap() - } - } - - fn test_weights(num: usize, weight: impl Fn(i64) -> f64) { - let distr = Adapter((0..num).map(|i| i as i64).collect(), &weight); - test_discrete(0, distr, make_cdf(num, &weight)); - } - - test_weights(100, |_| 1.0); - test_weights(100, |i| ((i + 1) as f64).ln()); - test_weights(100, |i| i as f64); - test_weights(100, |i| (i as f64).powi(3)); - test_weights(100, |i| 1.0 / ((i + 1) as f64)); -} - -#[test] -fn choose_two_weighted_indexed() { - struct Adapter f64>(Vec, F); - impl f64> Distribution for Adapter { - fn sample(&self, rng: &mut R) -> i64 { - let mut iter = - IndexedRandom::choose_multiple_weighted(&self.0[..], rng, 2, |i| (self.1)(*i)) - .unwrap(); - let mut a = *iter.next().unwrap(); - let mut b = *iter.next().unwrap(); - assert!(iter.next().is_none()); - if b < a { - std::mem::swap(&mut a, &mut b); - } - a * self.0.len() as i64 + b - } - } - - fn test_weights(num: usize, weight: impl Fn(i64) -> f64) { - let distr = Adapter((0..num).map(|i| i as i64).collect(), &weight); - - let pmf1 = (0..num).map(|i| weight(i as i64)).collect::>(); - let sum: f64 = pmf1.iter().sum(); - let frac = 1.0 / sum; - - let mut ac = 0.0; - let mut cdf = Vec::with_capacity(num * num); - for a in 0..num { - for b in 0..num { - if a < b { - let pa = pmf1[a] * frac; - let pab = pa * pmf1[b] / (sum - pmf1[a]); - - let pb = pmf1[b] * frac; - let pba = pb * pmf1[a] / (sum - pmf1[b]); - - ac += pab + pba; - } - cdf.push(ac); - } - } - assert!((cdf.last().unwrap() - 1.0).abs() < 1e-9); - - let cdf = |i| { - if i < 0 { - 0.0 - } else { - cdf[i as usize] - } - }; - - test_discrete(0, distr, cdf); - } - - test_weights(100, |_| 1.0); - test_weights(100, |i| ((i + 1) as f64).ln()); - test_weights(100, |i| i as f64); - test_weights(100, |i| (i as f64).powi(3)); - test_weights(100, |i| 1.0 / ((i + 1) as f64)); - test_weights(10, |i| ((i + 1) as f64).powi(-8)); -} - -#[test] -fn choose_iterator() { - struct Adapter(I); - impl> Distribution for Adapter { - fn sample(&self, rng: &mut R) -> i64 { - IteratorRandom::choose(self.0.clone(), rng).unwrap() - } - } - - let distr = Adapter((0..100).map(|i| i as i64)); - test_discrete(0, distr, make_cdf(100, |_| 1.0)); -} - -#[test] -fn choose_stable_iterator() { - struct Adapter(I); - impl> Distribution for Adapter { - fn sample(&self, rng: &mut R) -> i64 { - IteratorRandom::choose_stable(self.0.clone(), rng).unwrap() - } - } - - let distr = Adapter((0..100).map(|i| i as i64)); - test_discrete(0, distr, make_cdf(100, |_| 1.0)); -} - -#[test] -fn choose_two_iterator() { - struct Adapter(I); - impl> Distribution for Adapter { - fn sample(&self, rng: &mut R) -> i64 { - let mut buf = [0; 2]; - IteratorRandom::choose_multiple_fill(self.0.clone(), rng, &mut buf); - buf.sort_unstable(); - assert!(buf[0] < 99 && buf[1] >= 1); - let a = buf[0]; - 4950 - (99 - a) * (100 - a) / 2 + buf[1] - a - 1 - } - } - - let distr = Adapter((0..100).map(|i| i as i64)); - - test_discrete( - 0, - distr, - |i| if i < 0 { 0.0 } else { (i + 1) as f64 / 4950.0 }, - ); -} diff --git a/distr_test/tests/zeta.rs b/distr_test/tests/zeta.rs deleted file mode 100644 index 6e5ab1f594e..00000000000 --- a/distr_test/tests/zeta.rs +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright 2021 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -mod ks; -use ks::test_discrete; - -#[test] -fn zeta() { - fn cdf(k: i64, s: f64) -> f64 { - use spfunc::zeta::zeta as zeta_func; - if k < 1 { - return 0.0; - } - - gen_harmonic(k as u64, s) / zeta_func(s) - } - - let parameters = [2.0, 3.7, 5.0, 100.0]; - - for (seed, s) in parameters.into_iter().enumerate() { - let dist = rand_distr::Zeta::new(s).unwrap(); - test_discrete(seed as u64, dist, |k| cdf(k, s)); - } -} - -#[test] -fn zipf() { - fn cdf(k: i64, n: u64, s: f64) -> f64 { - if k < 1 { - return 0.0; - } - if k > n as i64 { - return 1.0; - } - gen_harmonic(k as u64, s) / gen_harmonic(n, s) - } - - let parameters = [(1000, 1.0), (500, 2.0), (1000, 0.5)]; - - for (seed, (n, x)) in parameters.into_iter().enumerate() { - let dist = rand_distr::Zipf::new(n as f64, x).unwrap(); - test_discrete(seed as u64, dist, |k| cdf(k, n, x)); - } -} - -fn gen_harmonic(n: u64, m: f64) -> f64 { - match n { - 0 => 1.0, - _ => (0..n).fold(0.0, |acc, x| acc + (x as f64 + 1.0).powf(-m)), - } -} diff --git a/rand_chacha/Cargo.toml b/rand_chacha/Cargo.toml index 7052dd48e4b..e2f313d2e8e 100644 --- a/rand_chacha/Cargo.toml +++ b/rand_chacha/Cargo.toml @@ -26,7 +26,7 @@ serde = { version = "1.0", features = ["derive"], optional = true } [dev-dependencies] # Only to test serde -serde_json = "1.0" +serde_json = "1.0.120" rand_core = { path = "../rand_core", version = "0.9.0", features = ["os_rng"] } [features] diff --git a/rand_core/CHANGELOG.md b/rand_core/CHANGELOG.md index 3b3064db71b..7318dffa878 100644 --- a/rand_core/CHANGELOG.md +++ b/rand_core/CHANGELOG.md @@ -4,6 +4,20 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.9.3] — 2025-02-29 +### Other +- Remove `zerocopy` dependency (#1607) +- Deprecate `rand_core::impls::fill_via_u32_chunks`, `fill_via_u64_chunks` (#1607) + +## [0.9.2] - 2025-02-22 +### API changes +- Relax `Sized` bound on impls of `TryRngCore`, `TryCryptoRng` and `UnwrapMut` (#1593) +- Add `UnwrapMut::re` to reborrow the inner rng with a tighter lifetime (#1595) + +## [0.9.1] - 2025-02-16 +### API changes +- Add `TryRngCore::unwrap_mut`, providing an impl of `RngCore` over `&mut rng` (#1589) + ## [0.9.0] - 2025-01-27 ### Dependencies and features - Bump the MSRV to 1.63.0 (#1207, #1246, #1269, #1341, #1416, #1536); note that 1.60.0 may work for dependents when using `--ignore-rust-version` @@ -15,7 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### API changes - Allow `rand_core::impls::fill_via_u*_chunks` to mutate source (#1182) - Add fn `RngCore::read_adapter` implementing `std::io::Read` (#1267) -- Add trait `CryptoBlockRng: BlockRngCore`; make `trait CryptoRng: RngCore` (#1273) +- Add trait `CryptoBlockRng: BlockRngCore`; make `trait CryptoRng: RngCore` replacing `CryptoRngCore` (#1273) - Add traits `TryRngCore`, `TryCryptoRng` (#1424, #1499) - Rename `fn SeedableRng::from_rng` -> `try_from_rng` and add infallible variant `fn from_rng` (#1424) - Rename `fn SeedableRng::from_entropy` -> `from_os_rng` and add fallible variant `fn try_from_os_rng` (#1424) diff --git a/rand_core/Cargo.toml b/rand_core/Cargo.toml index d1d9e66d7fa..899c359554c 100644 --- a/rand_core/Cargo.toml +++ b/rand_core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rand_core" -version = "0.9.0" +version = "0.9.3" authors = ["The Rand Project Developers", "The Rust Project Developers"] license = "MIT OR Apache-2.0" readme = "README.md" @@ -32,4 +32,3 @@ serde = ["dep:serde"] # enables serde for BlockRng wrapper [dependencies] serde = { version = "1", features = ["derive"], optional = true } getrandom = { version = "0.3.0", optional = true } -zerocopy = { version = "0.8.0", default-features = false } diff --git a/rand_core/README.md b/rand_core/README.md index b95287c4e70..05d9fbf6cb0 100644 --- a/rand_core/README.md +++ b/rand_core/README.md @@ -41,8 +41,9 @@ The traits and error types are also available via `rand`. ## Versions The current version is: -``` -rand_core = "=0.9.0-beta.1" + +```toml +rand_core = "0.9.3" ``` diff --git a/rand_core/src/.le.rs.kate-swp b/rand_core/src/.le.rs.kate-swp new file mode 100644 index 00000000000..0debd30bbe9 Binary files /dev/null and b/rand_core/src/.le.rs.kate-swp differ diff --git a/rand_core/src/block.rs b/rand_core/src/block.rs index aa2252e6da2..667cc0bca6a 100644 --- a/rand_core/src/block.rs +++ b/rand_core/src/block.rs @@ -53,7 +53,7 @@ //! [`BlockRngCore`]: crate::block::BlockRngCore //! [`fill_bytes`]: RngCore::fill_bytes -use crate::impls::{fill_via_u32_chunks, fill_via_u64_chunks}; +use crate::impls::fill_via_chunks; use crate::{CryptoRng, RngCore, SeedableRng, TryRngCore}; use core::fmt; #[cfg(feature = "serde")] @@ -197,7 +197,7 @@ impl> RngCore for BlockRng { fn next_u64(&mut self) -> u64 { let read_u64 = |results: &[u32], index| { let data = &results[index..=index + 1]; - u64::from(data[1]) << 32 | u64::from(data[0]) + (u64::from(data[1]) << 32) | u64::from(data[0]) }; let len = self.results.as_ref().len(); @@ -225,10 +225,8 @@ impl> RngCore for BlockRng { if self.index >= self.results.as_ref().len() { self.generate_and_set(0); } - let (consumed_u32, filled_u8) = fill_via_u32_chunks( - &mut self.results.as_mut()[self.index..], - &mut dest[read_len..], - ); + let (consumed_u32, filled_u8) = + fill_via_chunks(&self.results.as_mut()[self.index..], &mut dest[read_len..]); self.index += consumed_u32; read_len += filled_u8; @@ -390,10 +388,8 @@ impl> RngCore for BlockRng64 { self.index = 0; } - let (consumed_u64, filled_u8) = fill_via_u64_chunks( - &mut self.results.as_mut()[self.index..], - &mut dest[read_len..], - ); + let (consumed_u64, filled_u8) = + fill_via_chunks(&self.results.as_mut()[self.index..], &mut dest[read_len..]); self.index += consumed_u64; read_len += filled_u8; diff --git a/rand_core/src/impls.rs b/rand_core/src/impls.rs index 584a4c16f10..4cced0d6f0b 100644 --- a/rand_core/src/impls.rs +++ b/rand_core/src/impls.rs @@ -18,8 +18,6 @@ //! non-reproducible sources (e.g. `OsRng`) need not bother with it. use crate::RngCore; -use core::cmp::min; -use zerocopy::{Immutable, IntoBytes}; /// Implement `next_u64` via `next_u32`, little-endian order. pub fn next_u64_via_u32(rng: &mut R) -> u64 { @@ -53,41 +51,52 @@ pub fn fill_bytes_via_next(rng: &mut R, dest: &mut [u8]) { } } -trait Observable: IntoBytes + Immutable + Copy { - fn to_le(self) -> Self; +pub(crate) trait Observable: Copy { + type Bytes: Sized + AsRef<[u8]>; + fn to_le_bytes(self) -> Self::Bytes; } impl Observable for u32 { - fn to_le(self) -> Self { - self.to_le() + type Bytes = [u8; 4]; + + fn to_le_bytes(self) -> Self::Bytes { + Self::to_le_bytes(self) } } impl Observable for u64 { - fn to_le(self) -> Self { - self.to_le() + type Bytes = [u8; 8]; + + fn to_le_bytes(self) -> Self::Bytes { + Self::to_le_bytes(self) } } /// Fill dest from src /// -/// Returns `(n, byte_len)`. `src[..n]` is consumed (and possibly mutated), +/// Returns `(n, byte_len)`. `src[..n]` is consumed, /// `dest[..byte_len]` is filled. `src[n..]` and `dest[byte_len..]` are left /// unaltered. -fn fill_via_chunks(src: &mut [T], dest: &mut [u8]) -> (usize, usize) { +pub(crate) fn fill_via_chunks(src: &[T], dest: &mut [u8]) -> (usize, usize) { let size = core::mem::size_of::(); - let byte_len = min(core::mem::size_of_val(src), dest.len()); - let num_chunks = (byte_len + size - 1) / size; - - // Byte-swap for portability of results. This must happen before copying - // since the size of dest is not guaranteed to be a multiple of T or to be - // sufficiently aligned. - if cfg!(target_endian = "big") { - for x in &mut src[..num_chunks] { - *x = x.to_le(); - } - } - dest[..byte_len].copy_from_slice(&<[T]>::as_bytes(&src[..num_chunks])[..byte_len]); + // Always use little endian for portability of results. + let mut dest = dest.chunks_exact_mut(size); + let mut src = src.iter(); + + let zipped = dest.by_ref().zip(src.by_ref()); + let num_chunks = zipped.len(); + zipped.for_each(|(dest, src)| dest.copy_from_slice(src.to_le_bytes().as_ref())); + + let byte_len = num_chunks * size; + if let Some(src) = src.next() { + // We have consumed all full chunks of dest, but not src. + let dest = dest.into_remainder(); + let n = dest.len(); + if n > 0 { + dest.copy_from_slice(&src.to_le_bytes().as_ref()[..n]); + return (num_chunks + 1, byte_len + n); + } + } (num_chunks, byte_len) } @@ -96,8 +105,8 @@ fn fill_via_chunks(src: &mut [T], dest: &mut [u8]) -> (usize, usi /// /// The return values are `(consumed_u32, filled_u8)`. /// -/// On big-endian systems, endianness of `src[..consumed_u32]` values is -/// swapped. No other adjustments to `src` are made. +/// `src` is not modified; it is taken as a `&mut` reference for backward +/// compatibility with previous versions that did change it. /// /// `filled_u8` is the number of filled bytes in `dest`, which may be less than /// the length of `dest`. @@ -124,6 +133,7 @@ fn fill_via_chunks(src: &mut [T], dest: &mut [u8]) -> (usize, usi /// } /// } /// ``` +#[deprecated(since = "0.9.3", note = "use BlockRng instead")] pub fn fill_via_u32_chunks(src: &mut [u32], dest: &mut [u8]) -> (usize, usize) { fill_via_chunks(src, dest) } @@ -133,8 +143,8 @@ pub fn fill_via_u32_chunks(src: &mut [u32], dest: &mut [u8]) -> (usize, usize) { /// /// The return values are `(consumed_u64, filled_u8)`. /// -/// On big-endian systems, endianness of `src[..consumed_u64]` values is -/// swapped. No other adjustments to `src` are made. +/// `src` is not modified; it is taken as a `&mut` reference for backward +/// compatibility with previous versions that did change it. /// /// `filled_u8` is the number of filled bytes in `dest`, which may be less than /// the length of `dest`. @@ -142,6 +152,7 @@ pub fn fill_via_u32_chunks(src: &mut [u32], dest: &mut [u8]) -> (usize, usize) { /// as `filled_u8 / 8` rounded up. /// /// See `fill_via_u32_chunks` for an example. +#[deprecated(since = "0.9.3", note = "use BlockRng64 instead")] pub fn fill_via_u64_chunks(src: &mut [u64], dest: &mut [u8]) -> (usize, usize) { fill_via_chunks(src, dest) } @@ -166,41 +177,41 @@ mod test { #[test] fn test_fill_via_u32_chunks() { - let src_orig = [1, 2, 3]; + let src_orig = [1u32, 2, 3]; let mut src = src_orig; let mut dst = [0u8; 11]; - assert_eq!(fill_via_u32_chunks(&mut src, &mut dst), (3, 11)); + assert_eq!(fill_via_chunks(&mut src, &mut dst), (3, 11)); assert_eq!(dst, [1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0]); let mut src = src_orig; let mut dst = [0u8; 13]; - assert_eq!(fill_via_u32_chunks(&mut src, &mut dst), (3, 12)); + assert_eq!(fill_via_chunks(&mut src, &mut dst), (3, 12)); assert_eq!(dst, [1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 0]); let mut src = src_orig; let mut dst = [0u8; 5]; - assert_eq!(fill_via_u32_chunks(&mut src, &mut dst), (2, 5)); + assert_eq!(fill_via_chunks(&mut src, &mut dst), (2, 5)); assert_eq!(dst, [1, 0, 0, 0, 2]); } #[test] fn test_fill_via_u64_chunks() { - let src_orig = [1, 2]; + let src_orig = [1u64, 2]; let mut src = src_orig; let mut dst = [0u8; 11]; - assert_eq!(fill_via_u64_chunks(&mut src, &mut dst), (2, 11)); + assert_eq!(fill_via_chunks(&mut src, &mut dst), (2, 11)); assert_eq!(dst, [1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0]); let mut src = src_orig; let mut dst = [0u8; 17]; - assert_eq!(fill_via_u64_chunks(&mut src, &mut dst), (2, 16)); + assert_eq!(fill_via_chunks(&mut src, &mut dst), (2, 16)); assert_eq!(dst, [1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0]); let mut src = src_orig; let mut dst = [0u8; 5]; - assert_eq!(fill_via_u64_chunks(&mut src, &mut dst), (1, 5)); + assert_eq!(fill_via_chunks(&mut src, &mut dst), (1, 5)); assert_eq!(dst, [1, 0, 0, 0, 0]); } } diff --git a/rand_core/src/le.rs b/rand_core/src/le.rs index cee84c2f327..6c4d7c82ad0 100644 --- a/rand_core/src/le.rs +++ b/rand_core/src/le.rs @@ -11,11 +11,14 @@ //! Little-Endian order has been chosen for internal usage; this makes some //! useful functions available. -/// Reads unsigned 32 bit integers from `src` into `dst`. +/// Fills `dst: &mut [u32]` from `src` +/// +/// Reads use Little-Endian byte order, allowing portable reproduction of `dst` +/// from a byte slice. /// /// # Panics /// -/// If `dst` has insufficient space (`4*dst.len() < src.len()`). +/// If `src` has insufficient length (if `src.len() < 4*dst.len()`). #[inline] #[track_caller] pub fn read_u32_into(src: &[u8], dst: &mut [u32]) { @@ -25,11 +28,11 @@ pub fn read_u32_into(src: &[u8], dst: &mut [u32]) { } } -/// Reads unsigned 64 bit integers from `src` into `dst`. +/// Fills `dst: &mut [u64]` from `src` /// /// # Panics /// -/// If `dst` has insufficient space (`8*dst.len() < src.len()`). +/// If `src` has insufficient length (if `src.len() < 8*dst.len()`). #[inline] #[track_caller] pub fn read_u64_into(src: &[u8], dst: &mut [u64]) { diff --git a/rand_core/src/lib.rs b/rand_core/src/lib.rs index 9faff9c752f..6c007797806 100644 --- a/rand_core/src/lib.rs +++ b/rand_core/src/lib.rs @@ -31,6 +31,7 @@ )] #![deny(missing_docs)] #![deny(missing_debug_implementations)] +#![deny(clippy::undocumented_unsafe_blocks)] #![doc(test(attr(allow(unused_variables), deny(warnings))))] #![cfg_attr(docsrs, feature(doc_auto_cfg))] #![no_std] @@ -175,32 +176,32 @@ where } } -/// A marker trait used to indicate that an [`RngCore`] implementation is -/// supposed to be cryptographically secure. +/// A marker trait over [`RngCore`] for securely unpredictable RNGs /// -/// *Cryptographically secure generators*, also known as *CSPRNGs*, should -/// satisfy an additional properties over other generators: given the first -/// *k* bits of an algorithm's output -/// sequence, it should not be possible using polynomial-time algorithms to -/// predict the next bit with probability significantly greater than 50%. -/// -/// Some generators may satisfy an additional property, however this is not -/// required by this trait: if the CSPRNG's state is revealed, it should not be -/// computationally-feasible to reconstruct output prior to this. Some other -/// generators allow backwards-computation and are considered *reversible*. +/// This marker trait indicates that the implementing generator is intended, +/// when correctly seeded and protected from side-channel attacks such as a +/// leaking of state, to be a cryptographically secure generator. This trait is +/// provided as a tool to aid review of cryptographic code, but does not by +/// itself guarantee suitability for cryptographic applications. /// -/// Note that this trait is provided for guidance only and cannot guarantee -/// suitability for cryptographic applications. In general it should only be -/// implemented for well-reviewed code implementing well-regarded algorithms. +/// Implementors of `CryptoRng` automatically implement the [`TryCryptoRng`] +/// trait. /// -/// Note also that use of a `CryptoRng` does not protect against other -/// weaknesses such as seeding from a weak entropy source or leaking state. +/// Implementors of `CryptoRng` should only implement [`Default`] if the +/// `default()` instances are themselves secure generators: for example if the +/// implementing type is a stateless interface over a secure external generator +/// (like [`OsRng`]) or if the `default()` instance uses a strong, fresh seed. /// -/// Note that implementors of [`CryptoRng`] also automatically implement -/// the [`TryCryptoRng`] trait. +/// Formally, a CSPRNG (Cryptographically Secure Pseudo-Random Number Generator) +/// should satisfy an additional property over other generators: assuming that +/// the generator has been appropriately seeded and has unknown state, then +/// given the first *k* bits of an algorithm's output +/// sequence, it should not be possible using polynomial-time algorithms to +/// predict the next bit with probability significantly greater than 50%. /// -/// [`BlockRngCore`]: block::BlockRngCore -/// [`Infallible`]: core::convert::Infallible +/// An optional property of CSPRNGs is backtracking resistance: if the CSPRNG's +/// state is revealed, it will not be computationally-feasible to reconstruct +/// prior output values. This property is not required by `CryptoRng`. pub trait CryptoRng: RngCore {} impl CryptoRng for T where T::Target: CryptoRng {} @@ -236,6 +237,11 @@ pub trait TryRngCore { UnwrapErr(self) } + /// Wrap RNG with the [`UnwrapMut`] wrapper. + fn unwrap_mut(&mut self) -> UnwrapMut<'_, Self> { + UnwrapMut(self) + } + /// Convert an [`RngCore`] to a [`RngReadAdapter`]. #[cfg(feature = "std")] fn read_adapter(&mut self) -> RngReadAdapter<'_, Self> @@ -249,7 +255,7 @@ pub trait TryRngCore { // Note that, unfortunately, this blanket impl prevents us from implementing // `TryRngCore` for types which can be dereferenced to `TryRngCore`, i.e. `TryRngCore` // will not be automatically implemented for `&mut R`, `Box`, etc. -impl TryRngCore for R { +impl TryRngCore for R { type Error = core::convert::Infallible; #[inline] @@ -269,13 +275,23 @@ impl TryRngCore for R { } } -/// A marker trait used to indicate that a [`TryRngCore`] implementation is -/// supposed to be cryptographically secure. +/// A marker trait over [`TryRngCore`] for securely unpredictable RNGs +/// +/// This trait is like [`CryptoRng`] but for the trait [`TryRngCore`]. /// -/// See [`CryptoRng`] docs for more information about cryptographically secure generators. +/// This marker trait indicates that the implementing generator is intended, +/// when correctly seeded and protected from side-channel attacks such as a +/// leaking of state, to be a cryptographically secure generator. This trait is +/// provided as a tool to aid review of cryptographic code, but does not by +/// itself guarantee suitability for cryptographic applications. +/// +/// Implementors of `TryCryptoRng` should only implement [`Default`] if the +/// `default()` instances are themselves secure generators: for example if the +/// implementing type is a stateless interface over a secure external generator +/// (like [`OsRng`]) or if the `default()` instance uses a strong, fresh seed. pub trait TryCryptoRng: TryRngCore {} -impl TryCryptoRng for R {} +impl TryCryptoRng for R {} /// Wrapper around [`TryRngCore`] implementation which implements [`RngCore`] /// by panicking on potential errors. @@ -301,11 +317,57 @@ impl RngCore for UnwrapErr { impl CryptoRng for UnwrapErr {} +/// Wrapper around [`TryRngCore`] implementation which implements [`RngCore`] +/// by panicking on potential errors. +#[derive(Debug, Eq, PartialEq, Hash)] +pub struct UnwrapMut<'r, R: TryRngCore + ?Sized>(pub &'r mut R); + +impl<'r, R: TryRngCore + ?Sized> UnwrapMut<'r, R> { + /// Reborrow with a new lifetime + /// + /// Rust allows references like `&T` or `&mut T` to be "reborrowed" through + /// coercion: essentially, the pointer is copied under a new, shorter, lifetime. + /// Until rfcs#1403 lands, reborrows on user types require a method call. + #[inline(always)] + pub fn re<'b>(&'b mut self) -> UnwrapMut<'b, R> + where + 'r: 'b, + { + UnwrapMut(self.0) + } +} + +impl RngCore for UnwrapMut<'_, R> { + #[inline] + fn next_u32(&mut self) -> u32 { + self.0.try_next_u32().unwrap() + } + + #[inline] + fn next_u64(&mut self) -> u64 { + self.0.try_next_u64().unwrap() + } + + #[inline] + fn fill_bytes(&mut self, dst: &mut [u8]) { + self.0.try_fill_bytes(dst).unwrap() + } +} + +impl CryptoRng for UnwrapMut<'_, R> {} + /// A random number generator that can be explicitly seeded. /// /// This trait encapsulates the low-level functionality common to all /// pseudo-random number generators (PRNGs, or algorithmic generators). /// +/// A generator implementing `SeedableRng` will usually be deterministic, but +/// beware that portability and reproducibility of results **is not implied**. +/// Refer to documentation of the generator, noting that generators named after +/// a specific algorithm are usually tested for reproducibility against a +/// reference vector, while `SmallRng` and `StdRng` specifically opt out of +/// reproducibility guarantees. +/// /// [`rand`]: https://docs.rs/rand pub trait SeedableRng: Sized { /// Seed type, which is restricted to types mutably-dereferenceable as `u8` @@ -593,4 +655,118 @@ mod test { // value-breakage test: assert_eq!(results[0], 5029875928683246316); } + + // A stub RNG. + struct SomeRng; + + impl RngCore for SomeRng { + fn next_u32(&mut self) -> u32 { + unimplemented!() + } + fn next_u64(&mut self) -> u64 { + unimplemented!() + } + fn fill_bytes(&mut self, _: &mut [u8]) { + unimplemented!() + } + } + + impl CryptoRng for SomeRng {} + + #[test] + fn dyn_rngcore_to_tryrngcore() { + // Illustrates the need for `+ ?Sized` bound in `impl TryRngCore for R`. + + // A method in another crate taking a fallible RNG + fn third_party_api(_rng: &mut (impl TryRngCore + ?Sized)) -> bool { + true + } + + // A method in our crate requiring an infallible RNG + fn my_api(rng: &mut dyn RngCore) -> bool { + // We want to call the method above + third_party_api(rng) + } + + assert!(my_api(&mut SomeRng)); + } + + #[test] + fn dyn_cryptorng_to_trycryptorng() { + // Illustrates the need for `+ ?Sized` bound in `impl TryCryptoRng for R`. + + // A method in another crate taking a fallible RNG + fn third_party_api(_rng: &mut (impl TryCryptoRng + ?Sized)) -> bool { + true + } + + // A method in our crate requiring an infallible RNG + fn my_api(rng: &mut dyn CryptoRng) -> bool { + // We want to call the method above + third_party_api(rng) + } + + assert!(my_api(&mut SomeRng)); + } + + #[test] + fn dyn_unwrap_mut_tryrngcore() { + // Illustrates the need for `+ ?Sized` bound in + // `impl RngCore for UnwrapMut<'_, R>`. + + fn third_party_api(_rng: &mut impl RngCore) -> bool { + true + } + + fn my_api(rng: &mut (impl TryRngCore + ?Sized)) -> bool { + let mut infallible_rng = rng.unwrap_mut(); + third_party_api(&mut infallible_rng) + } + + assert!(my_api(&mut SomeRng)); + } + + #[test] + fn dyn_unwrap_mut_trycryptorng() { + // Illustrates the need for `+ ?Sized` bound in + // `impl CryptoRng for UnwrapMut<'_, R>`. + + fn third_party_api(_rng: &mut impl CryptoRng) -> bool { + true + } + + fn my_api(rng: &mut (impl TryCryptoRng + ?Sized)) -> bool { + let mut infallible_rng = rng.unwrap_mut(); + third_party_api(&mut infallible_rng) + } + + assert!(my_api(&mut SomeRng)); + } + + #[test] + fn reborrow_unwrap_mut() { + struct FourRng; + + impl TryRngCore for FourRng { + type Error = core::convert::Infallible; + fn try_next_u32(&mut self) -> Result { + Ok(4) + } + fn try_next_u64(&mut self) -> Result { + unimplemented!() + } + fn try_fill_bytes(&mut self, _: &mut [u8]) -> Result<(), Self::Error> { + unimplemented!() + } + } + + let mut rng = FourRng; + let mut rng = rng.unwrap_mut(); + + assert_eq!(rng.next_u32(), 4); + let mut rng2 = rng.re(); + assert_eq!(rng2.next_u32(), 4); + drop(rng2); + assert_eq!(rng.next_u32(), 4); + } } diff --git a/rand_distr/CHANGELOG.md b/rand_distr/CHANGELOG.md deleted file mode 100644 index 81fa3a3c4bc..00000000000 --- a/rand_distr/CHANGELOG.md +++ /dev/null @@ -1,110 +0,0 @@ -# Changelog -All notable changes to this project will be documented in this file. - -The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) -and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). - -## [0.5.0] - 2025-01-27 - -### Dependencies and features -- Bump the MSRV to 1.61.0 (#1207, #1246, #1269, #1341, #1416); note that 1.60.0 may work for dependents when using `--ignore-rust-version` -- Update to `rand` v0.9.0 (#1558) -- Rename feature `serde1` to `serde` (#1477) - -### API changes -- Make distributions comparable with `PartialEq` (#1218) -- `Dirichlet` now uses `const` generics, which means that its size is required at compile time (#1292) -- The `Dirichlet::new_with_size` constructor was removed (#1292) -- Add `WeightedIndexTree` (#1372, #1444) -- Add `PertBuilder` to allow specification of `mean` or `mode` (#1452) -- Rename `Zeta`'s parameter `a` to `s` (#1466) -- Mark `WeightError`, `PoissonError`, `BinomialError` as `#[non_exhaustive]` (#1480) -- Remove support for usage of `isize` as a `WeightedAliasIndex` weight (#1487) -- Change parameter type of `Zipf::new`: `n` is now floating-point (#1518) - -### API changes: renames -- Move `Slice` -> `slice::Choose`, `EmptySlice` -> `slice::Empty` (#1548) -- Rename trait `DistString` -> `SampleString` (#1548) -- Rename `DistIter` -> `Iter`, `DistMap` -> `Map` (#1548) -- Move `{Weight, WeightError, WeightedIndex}` -> `weighted::{Weight, Error, WeightedIndex}` (#1548) -- Move `weighted_alias::{AliasableWeight, WeightedAliasIndex}` -> `weighted::{..}` (#1548) -- Move `weighted_tree::WeightedTreeIndex` -> `weighted::WeightedTreeIndex` (#1548) - -### Testing -- Add Kolmogorov Smirnov tests for distributions (#1494, #1504, #1525, #1530) - -### Fixes -- Fix Knuth's method so `Poisson` doesn't return -1.0 for small lambda (#1284) -- Fix `Poisson` distribution instantiation so it return an error if lambda is infinite (#1291) -- Fix Dirichlet sample for small alpha values to avoid NaN samples (#1209) -- Fix infinite loop in `Binomial` distribution (#1325) -- Fix `Pert` distribution where `mode` is close to `(min + max) / 2` (#1452) -- Fix panic in Binomial (#1484) -- Limit the maximal acceptable lambda for `Poisson` to solve (#1312) (#1498) -- Fix bug in `Hypergeometric`, this is a Value-breaking change (#1510) - -### Other changes -- Remove unused fields from `Gamma`, `NormalInverseGaussian` and `Zipf` distributions (#1184) - This breaks serialization compatibility with older versions. -- Add plots for `rand_distr` distributions to documentation (#1434) -- Move some of the computations in Binomial from `sample` to `new` (#1484) - -## [0.4.3] - 2021-12-30 -- Fix `no_std` build (#1208) - -## [0.4.2] - 2021-09-18 -- New `Zeta` and `Zipf` distributions (#1136) -- New `SkewNormal` distribution (#1149) -- New `Gumbel` and `Frechet` distributions (#1168, #1171) - -## [0.4.1] - 2021-06-15 -- Empirically test PDF of normal distribution (#1121) -- Correctly document `no_std` support (#1100) -- Add `std_math` feature to prefer `std` over `libm` for floating point math (#1100) -- Add mean and std_dev accessors to Normal (#1114) -- Make sure all distributions and their error types implement `Error`, `Display`, `Clone`, - `Copy`, `PartialEq` and `Eq` as appropriate (#1126) -- Port benchmarks to use Criterion crate (#1116) -- Support serde for distributions (#1141) - -## [0.4.0] - 2020-12-18 -- Bump `rand` to v0.8.0 -- New `Geometric`, `StandardGeometric` and `Hypergeometric` distributions (#1062) -- New `Beta` sampling algorithm for improved performance and accuracy (#1000) -- `Normal` and `LogNormal` now support `from_mean_cv` and `from_zscore` (#1044) -- Variants of `NormalError` changed (#1044) - -## [0.3.0] - 2020-08-25 -- Move alias method for `WeightedIndex` from `rand` (#945) -- Rename `WeightedIndex` to `WeightedAliasIndex` (#1008) -- Replace custom `Float` trait with `num-traits::Float` (#987) -- Enable `no_std` support via `num-traits` math functions (#987) -- Remove `Distribution` impl for `Poisson` (#987) -- Tweak `Dirichlet` and `alias_method` to use boxed slice instead of `Vec` (#987) -- Use whitelist for package contents, reducing size by 5kb (#983) -- Add case `lambda = 0` in the parametrization of `Exp` (#972) -- Implement inverse Gaussian distribution (#954) -- Reformatting and use of `rustfmt::skip` (#926) -- All error types now implement `std::error::Error` (#919) -- Re-exported `rand::distributions::BernoulliError` (#919) -- Add value stability tests for distributions (#891) - -## [0.2.2] - 2019-09-10 -- Fix version requirement on rand lib (#847) -- Clippy fixes & suppression (#840) - -## [0.2.1] - 2019-06-29 -- Update dependency to support Rand 0.7 -- Doc link fixes - -## [0.2.0] - 2019-06-06 -- Remove `new` constructors for zero-sized types -- Add Pert distribution -- Fix undefined behavior in `Poisson` -- Make all distributions return `Result`s instead of panicking -- Implement `f32` support for most distributions -- Rename `UnitSphereSurface` to `UnitSphere` -- Implement `UnitBall` and `UnitDisc` - -## [0.1.0] - 2019-06-06 -Initial release. This is equivalent to the code in `rand` 0.6.5. diff --git a/rand_distr/COPYRIGHT b/rand_distr/COPYRIGHT deleted file mode 100644 index 468d907caf9..00000000000 --- a/rand_distr/COPYRIGHT +++ /dev/null @@ -1,12 +0,0 @@ -Copyrights in the Rand project are retained by their contributors. No -copyright assignment is required to contribute to the Rand project. - -For full authorship information, see the version control history. - -Except as otherwise noted (below and/or in individual files), Rand is -licensed under the Apache License, Version 2.0 or - or the MIT license - or , at your option. - -The Rand project includes code from the Rust project -published under these same licenses. diff --git a/rand_distr/Cargo.toml b/rand_distr/Cargo.toml deleted file mode 100644 index dd55673777c..00000000000 --- a/rand_distr/Cargo.toml +++ /dev/null @@ -1,48 +0,0 @@ -[package] -name = "rand_distr" -version = "0.5.0" -authors = ["The Rand Project Developers"] -license = "MIT OR Apache-2.0" -readme = "README.md" -repository = "https://github.com/rust-random/rand" -documentation = "https://docs.rs/rand_distr" -homepage = "https://rust-random.github.io/book" -description = """ -Sampling from random number distributions -""" -keywords = ["random", "rng", "distribution", "probability"] -categories = ["algorithms", "no-std"] -edition = "2021" -rust-version = "1.63" -include = ["/src", "LICENSE-*", "README.md", "CHANGELOG.md", "COPYRIGHT"] - -[package.metadata.docs.rs] -features = ["serde"] -rustdoc-args = ["--generate-link-to-definition"] - -[features] -default = ["std"] -std = ["alloc", "rand/std"] -alloc = ["rand/alloc"] - -# Use std's floating-point arithmetic instead of libm. -# Note that any other crate depending on `num-traits`'s `std` -# feature (default-enabled) will have the same effect. -std_math = ["num-traits/std"] - -serde = ["dep:serde", "dep:serde_with", "rand/serde"] - -[dependencies] -rand = { path = "..", version = "0.9.0", default-features = false } -num-traits = { version = "0.2", default-features = false, features = ["libm"] } -serde = { version = "1.0.103", features = ["derive"], optional = true } -serde_with = { version = ">= 3.0, <= 3.11", optional = true } - -[dev-dependencies] -rand_pcg = { version = "0.9.0", path = "../rand_pcg" } -# For inline examples -rand = { path = "..", version = "0.9.0", features = ["small_rng"] } -# Histogram implementation for testing uniformity -average = { version = "0.15", features = [ "std" ] } -# Special functions for testing distributions -special = "0.11.0" diff --git a/rand_distr/LICENSE-APACHE b/rand_distr/LICENSE-APACHE deleted file mode 100644 index 455787c2334..00000000000 --- a/rand_distr/LICENSE-APACHE +++ /dev/null @@ -1,187 +0,0 @@ - Apache License - Version 2.0, January 2004 - https://www.apache.org/licenses/ - -TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - -1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - -2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - -3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - -4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - -5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - -6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - -7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - -8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - -9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - -END OF TERMS AND CONDITIONS - -APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. diff --git a/rand_distr/LICENSE-MIT b/rand_distr/LICENSE-MIT deleted file mode 100644 index cf656074cbf..00000000000 --- a/rand_distr/LICENSE-MIT +++ /dev/null @@ -1,25 +0,0 @@ -Copyright 2018 Developers of the Rand project - -Permission is hereby granted, free of charge, to any -person obtaining a copy of this software and associated -documentation files (the "Software"), to deal in the -Software without restriction, including without -limitation the rights to use, copy, modify, merge, -publish, distribute, sublicense, and/or sell copies of -the Software, and to permit persons to whom the Software -is furnished to do so, subject to the following -conditions: - -The above copyright notice and this permission notice -shall be included in all copies or substantial portions -of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF -ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED -TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A -PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT -SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY -CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR -IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -DEALINGS IN THE SOFTWARE. diff --git a/rand_distr/README.md b/rand_distr/README.md deleted file mode 100644 index 193d54123d1..00000000000 --- a/rand_distr/README.md +++ /dev/null @@ -1,57 +0,0 @@ -# rand_distr - -[![Test Status](https://github.com/rust-random/rand/actions/workflows/test.yml/badge.svg?event=push)](https://github.com/rust-random/rand/actions) -[![Latest version](https://img.shields.io/crates/v/rand_distr.svg)](https://crates.io/crates/rand_distr) -[![Book](https://img.shields.io/badge/book-master-yellow.svg)](https://rust-random.github.io/book/) -[![API](https://img.shields.io/badge/api-master-yellow.svg)](https://rust-random.github.io/rand/rand_distr) -[![API](https://docs.rs/rand_distr/badge.svg)](https://docs.rs/rand_distr) - -Implements a full suite of random number distribution sampling routines. - -This crate is a superset of the [rand::distr] module, including support -for sampling from Beta, Binomial, Cauchy, ChiSquared, Dirichlet, Exponential, -FisherF, Gamma, Geometric, Hypergeometric, InverseGaussian, LogNormal, Normal, -Pareto, PERT, Poisson, StudentT, Triangular and Weibull distributions. Sampling -from the unit ball, unit circle, unit disc and unit sphere surfaces is also -supported. - -It is worth mentioning the [statrs] crate which provides similar functionality -along with various support functions, including PDF and CDF computation. In -contrast, this `rand_distr` crate focuses on sampling from distributions. - -## Portability and libm - -The floating point functions from `num_traits` and `libm` are used to support -`no_std` environments and ensure reproducibility. If the floating point -functions from `std` are preferred, which may provide better accuracy and -performance but may produce different random values, the `std_math` feature -can be enabled. (Note that any other crate depending on `num-traits` with the -`std` feature (default-enabled) will have the same effect.) - -## Crate features - -- `std` (enabled by default): `rand_distr` implements the `Error` trait for - its error types. Implies `alloc` and `rand/std`. -- `alloc` (enabled by default): required for some distributions when not using - `std` (in particular, `Dirichlet` and `WeightedAliasIndex`). -- `std_math`: see above on portability and libm -- `serde`: implement (de)seriaialization using `serde` - -## Links - -- [API documentation (master)](https://rust-random.github.io/rand/rand_distr) -- [API documentation (docs.rs)](https://docs.rs/rand_distr) -- [Changelog](CHANGELOG.md) -- [The Rand project](https://github.com/rust-random/rand) - - -[statrs]: https://github.com/boxtown/statrs -[rand::distr]: https://rust-random.github.io/rand/rand/distr/index.html - -## License - -`rand_distr` is distributed under the terms of both the MIT license and the -Apache License (Version 2.0). - -See [LICENSE-APACHE](LICENSE-APACHE) and [LICENSE-MIT](LICENSE-MIT), and -[COPYRIGHT](COPYRIGHT) for details. diff --git a/rand_distr/src/beta.rs b/rand_distr/src/beta.rs deleted file mode 100644 index 4dc297cfd50..00000000000 --- a/rand_distr/src/beta.rs +++ /dev/null @@ -1,298 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// Copyright 2013 The Rust Project Developers. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! The Beta distribution. - -use crate::{Distribution, Open01}; -use core::fmt; -use num_traits::Float; -use rand::Rng; -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; - -/// The algorithm used for sampling the Beta distribution. -/// -/// Reference: -/// -/// R. C. H. Cheng (1978). -/// Generating beta variates with nonintegral shape parameters. -/// Communications of the ACM 21, 317-322. -/// https://doi.org/10.1145/359460.359482 -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -enum BetaAlgorithm { - BB(BB), - BC(BC), -} - -/// Algorithm BB for `min(alpha, beta) > 1`. -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -struct BB { - alpha: N, - beta: N, - gamma: N, -} - -/// Algorithm BC for `min(alpha, beta) <= 1`. -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -struct BC { - alpha: N, - beta: N, - kappa1: N, - kappa2: N, -} - -/// The [Beta distribution](https://en.wikipedia.org/wiki/Beta_distribution) `Beta(α, β)`. -/// -/// The Beta distribution is a continuous probability distribution -/// defined on the interval `[0, 1]`. It is the conjugate prior for the -/// parameter `p` of the [`Binomial`][crate::Binomial] distribution. -/// -/// It has two shape parameters `α` (alpha) and `β` (beta) which control -/// the shape of the distribution. Both `a` and `β` must be greater than zero. -/// The distribution is symmetric when `α = β`. -/// -/// # Plot -/// -/// The plot shows the Beta distribution with various combinations -/// of `α` and `β`. -/// -/// ![Beta distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/beta.svg) -/// -/// # Example -/// -/// ``` -/// use rand_distr::{Distribution, Beta}; -/// -/// let beta = Beta::new(2.0, 5.0).unwrap(); -/// let v = beta.sample(&mut rand::rng()); -/// println!("{} is from a Beta(2, 5) distribution", v); -/// ``` -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct Beta -where - F: Float, - Open01: Distribution, -{ - a: F, - b: F, - switched_params: bool, - algorithm: BetaAlgorithm, -} - -/// Error type returned from [`Beta::new`]. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub enum Error { - /// `alpha <= 0` or `nan`. - AlphaTooSmall, - /// `beta <= 0` or `nan`. - BetaTooSmall, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Error::AlphaTooSmall => "alpha is not positive in beta distribution", - Error::BetaTooSmall => "beta is not positive in beta distribution", - }) - } -} - -#[cfg(feature = "std")] -impl std::error::Error for Error {} - -impl Beta -where - F: Float, - Open01: Distribution, -{ - /// Construct an object representing the `Beta(alpha, beta)` - /// distribution. - pub fn new(alpha: F, beta: F) -> Result, Error> { - if !(alpha > F::zero()) { - return Err(Error::AlphaTooSmall); - } - if !(beta > F::zero()) { - return Err(Error::BetaTooSmall); - } - // From now on, we use the notation from the reference, - // i.e. `alpha` and `beta` are renamed to `a0` and `b0`. - let (a0, b0) = (alpha, beta); - let (a, b, switched_params) = if a0 < b0 { - (a0, b0, false) - } else { - (b0, a0, true) - }; - if a > F::one() { - // Algorithm BB - let alpha = a + b; - - let two = F::from(2.).unwrap(); - let beta_numer = alpha - two; - let beta_denom = two * a * b - alpha; - let beta = (beta_numer / beta_denom).sqrt(); - - let gamma = a + F::one() / beta; - - Ok(Beta { - a, - b, - switched_params, - algorithm: BetaAlgorithm::BB(BB { alpha, beta, gamma }), - }) - } else { - // Algorithm BC - // - // Here `a` is the maximum instead of the minimum. - let (a, b, switched_params) = (b, a, !switched_params); - let alpha = a + b; - let beta = F::one() / b; - let delta = F::one() + a - b; - let kappa1 = delta - * (F::from(1. / 18. / 4.).unwrap() + F::from(3. / 18. / 4.).unwrap() * b) - / (a * beta - F::from(14. / 18.).unwrap()); - let kappa2 = F::from(0.25).unwrap() - + (F::from(0.5).unwrap() + F::from(0.25).unwrap() / delta) * b; - - Ok(Beta { - a, - b, - switched_params, - algorithm: BetaAlgorithm::BC(BC { - alpha, - beta, - kappa1, - kappa2, - }), - }) - } - } -} - -impl Distribution for Beta -where - F: Float, - Open01: Distribution, -{ - fn sample(&self, rng: &mut R) -> F { - let mut w; - match self.algorithm { - BetaAlgorithm::BB(algo) => { - loop { - // 1. - let u1 = rng.sample(Open01); - let u2 = rng.sample(Open01); - let v = algo.beta * (u1 / (F::one() - u1)).ln(); - w = self.a * v.exp(); - let z = u1 * u1 * u2; - let r = algo.gamma * v - F::from(4.).unwrap().ln(); - let s = self.a + r - w; - // 2. - if s + F::one() + F::from(5.).unwrap().ln() >= F::from(5.).unwrap() * z { - break; - } - // 3. - let t = z.ln(); - if s >= t { - break; - } - // 4. - if !(r + algo.alpha * (algo.alpha / (self.b + w)).ln() < t) { - break; - } - } - } - BetaAlgorithm::BC(algo) => { - loop { - let z; - // 1. - let u1 = rng.sample(Open01); - let u2 = rng.sample(Open01); - if u1 < F::from(0.5).unwrap() { - // 2. - let y = u1 * u2; - z = u1 * y; - if F::from(0.25).unwrap() * u2 + z - y >= algo.kappa1 { - continue; - } - } else { - // 3. - z = u1 * u1 * u2; - if z <= F::from(0.25).unwrap() { - let v = algo.beta * (u1 / (F::one() - u1)).ln(); - w = self.a * v.exp(); - break; - } - // 4. - if z >= algo.kappa2 { - continue; - } - } - // 5. - let v = algo.beta * (u1 / (F::one() - u1)).ln(); - w = self.a * v.exp(); - if !(algo.alpha * ((algo.alpha / (self.b + w)).ln() + v) - - F::from(4.).unwrap().ln() - < z.ln()) - { - break; - }; - } - } - }; - // 5. for BB, 6. for BC - if !self.switched_params { - if w == F::infinity() { - // Assuming `b` is finite, for large `w`: - return F::one(); - } - w / (self.b + w) - } else { - self.b / (self.b + w) - } - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_beta() { - let beta = Beta::new(1.0, 2.0).unwrap(); - let mut rng = crate::test::rng(201); - for _ in 0..1000 { - beta.sample(&mut rng); - } - } - - #[test] - #[should_panic] - fn test_beta_invalid_dof() { - Beta::new(0., 0.).unwrap(); - } - - #[test] - fn test_beta_small_param() { - let beta = Beta::::new(1e-3, 1e-3).unwrap(); - let mut rng = crate::test::rng(206); - for i in 0..1000 { - assert!(!beta.sample(&mut rng).is_nan(), "failed at i={}", i); - } - } - - #[test] - fn beta_distributions_can_be_compared() { - assert_eq!(Beta::new(1.0, 2.0), Beta::new(1.0, 2.0)); - } -} diff --git a/rand_distr/src/binomial.rs b/rand_distr/src/binomial.rs deleted file mode 100644 index d6dfceae777..00000000000 --- a/rand_distr/src/binomial.rs +++ /dev/null @@ -1,457 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// Copyright 2016-2017 The Rust Project Developers. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! The binomial distribution `Binomial(n, p)`. - -use crate::{Distribution, Uniform}; -use core::cmp::Ordering; -use core::fmt; -#[allow(unused_imports)] -use num_traits::Float; -use rand::Rng; - -/// The [binomial distribution](https://en.wikipedia.org/wiki/Binomial_distribution) `Binomial(n, p)`. -/// -/// The binomial distribution is a discrete probability distribution -/// which describes the probability of seeing `k` successes in `n` -/// independent trials, each of which has success probability `p`. -/// -/// # Density function -/// -/// `f(k) = n!/(k! (n-k)!) p^k (1-p)^(n-k)` for `k >= 0`. -/// -/// # Plot -/// -/// The following plot of the binomial distribution illustrates the -/// probability of `k` successes out of `n = 10` trials with `p = 0.2` -/// and `p = 0.6` for `0 <= k <= n`. -/// -/// ![Binomial distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/binomial.svg) -/// -/// # Example -/// -/// ``` -/// use rand_distr::{Binomial, Distribution}; -/// -/// let bin = Binomial::new(20, 0.3).unwrap(); -/// let v = bin.sample(&mut rand::rng()); -/// println!("{} is from a binomial distribution", v); -/// ``` -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct Binomial { - method: Method, -} - -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -enum Method { - Binv(Binv, bool), - Btpe(Btpe, bool), - Poisson(crate::poisson::KnuthMethod), - Constant(u64), -} - -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -struct Binv { - r: f64, - s: f64, - a: f64, - n: u64, -} - -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -struct Btpe { - n: u64, - p: f64, - m: i64, - p1: f64, -} - -/// Error type returned from [`Binomial::new`]. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -// Marked non_exhaustive to allow a new error code in the solution to #1378. -#[non_exhaustive] -pub enum Error { - /// `p < 0` or `nan`. - ProbabilityTooSmall, - /// `p > 1`. - ProbabilityTooLarge, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Error::ProbabilityTooSmall => "p < 0 or is NaN in binomial distribution", - Error::ProbabilityTooLarge => "p > 1 in binomial distribution", - }) - } -} - -#[cfg(feature = "std")] -impl std::error::Error for Error {} - -impl Binomial { - /// Construct a new `Binomial` with the given shape parameters `n` (number - /// of trials) and `p` (probability of success). - pub fn new(n: u64, p: f64) -> Result { - if !(p >= 0.0) { - return Err(Error::ProbabilityTooSmall); - } - if !(p <= 1.0) { - return Err(Error::ProbabilityTooLarge); - } - - if p == 0.0 { - return Ok(Binomial { - method: Method::Constant(0), - }); - } - - if p == 1.0 { - return Ok(Binomial { - method: Method::Constant(n), - }); - } - - // The binomial distribution is symmetrical with respect to p -> 1-p - let flipped = p > 0.5; - let p = if flipped { 1.0 - p } else { p }; - - // For small n * min(p, 1 - p), the BINV algorithm based on the inverse - // transformation of the binomial distribution is efficient. Otherwise, - // the BTPE algorithm is used. - // - // Voratas Kachitvichyanukul and Bruce W. Schmeiser. 1988. Binomial - // random variate generation. Commun. ACM 31, 2 (February 1988), - // 216-222. http://dx.doi.org/10.1145/42372.42381 - - // Threshold for preferring the BINV algorithm. The paper suggests 10, - // Ranlib uses 30, and GSL uses 14. - const BINV_THRESHOLD: f64 = 10.; - - let np = n as f64 * p; - let method = if np < BINV_THRESHOLD { - let q = 1.0 - p; - if q == 1.0 { - // p is so small that this is extremely close to a Poisson distribution. - // The flipped case cannot occur here. - Method::Poisson(crate::poisson::KnuthMethod::new(np)) - } else { - let s = p / q; - Method::Binv( - Binv { - r: q.powf(n as f64), - s, - a: (n as f64 + 1.0) * s, - n, - }, - flipped, - ) - } - } else { - let q = 1.0 - p; - let npq = np * q; - let p1 = (2.195 * npq.sqrt() - 4.6 * q).floor() + 0.5; - let f_m = np + p; - let m = f64_to_i64(f_m); - Method::Btpe(Btpe { n, p, m, p1 }, flipped) - }; - Ok(Binomial { method }) - } -} - -/// Convert a `f64` to an `i64`, panicking on overflow. -fn f64_to_i64(x: f64) -> i64 { - assert!(x < (i64::MAX as f64)); - x as i64 -} - -fn binv(binv: Binv, flipped: bool, rng: &mut R) -> u64 { - // Same value as in GSL. - // It is possible for BINV to get stuck, so we break if x > BINV_MAX_X and try again. - // It would be safer to set BINV_MAX_X to self.n, but it is extremely unlikely to be relevant. - // When n*p < 10, so is n*p*q which is the variance, so a result > 110 would be 100 / sqrt(10) = 31 standard deviations away. - const BINV_MAX_X: u64 = 110; - - let sample = 'outer: loop { - let mut r = binv.r; - let mut u: f64 = rng.random(); - let mut x = 0; - - while u > r { - u -= r; - x += 1; - if x > BINV_MAX_X { - continue 'outer; - } - r *= binv.a / (x as f64) - binv.s; - } - break x; - }; - - if flipped { - binv.n - sample - } else { - sample - } -} - -#[allow(clippy::many_single_char_names)] // Same names as in the reference. -fn btpe(btpe: Btpe, flipped: bool, rng: &mut R) -> u64 { - // Threshold for using the squeeze algorithm. This can be freely - // chosen based on performance. Ranlib and GSL use 20. - const SQUEEZE_THRESHOLD: i64 = 20; - - // Step 0: Calculate constants as functions of `n` and `p`. - let n = btpe.n as f64; - let np = n * btpe.p; - let q = 1. - btpe.p; - let npq = np * q; - let f_m = np + btpe.p; - let m = btpe.m; - // radius of triangle region, since height=1 also area of region - let p1 = btpe.p1; - // tip of triangle - let x_m = (m as f64) + 0.5; - // left edge of triangle - let x_l = x_m - p1; - // right edge of triangle - let x_r = x_m + p1; - let c = 0.134 + 20.5 / (15.3 + (m as f64)); - // p1 + area of parallelogram region - let p2 = p1 * (1. + 2. * c); - - fn lambda(a: f64) -> f64 { - a * (1. + 0.5 * a) - } - - let lambda_l = lambda((f_m - x_l) / (f_m - x_l * btpe.p)); - let lambda_r = lambda((x_r - f_m) / (x_r * q)); - - let p3 = p2 + c / lambda_l; - - let p4 = p3 + c / lambda_r; - - // return value - let mut y: i64; - - let gen_u = Uniform::new(0., p4).unwrap(); - let gen_v = Uniform::new(0., 1.).unwrap(); - - loop { - // Step 1: Generate `u` for selecting the region. If region 1 is - // selected, generate a triangularly distributed variate. - let u = gen_u.sample(rng); - let mut v = gen_v.sample(rng); - if !(u > p1) { - y = f64_to_i64(x_m - p1 * v + u); - break; - } - - if !(u > p2) { - // Step 2: Region 2, parallelograms. Check if region 2 is - // used. If so, generate `y`. - let x = x_l + (u - p1) / c; - v = v * c + 1.0 - (x - x_m).abs() / p1; - if v > 1. { - continue; - } else { - y = f64_to_i64(x); - } - } else if !(u > p3) { - // Step 3: Region 3, left exponential tail. - y = f64_to_i64(x_l + v.ln() / lambda_l); - if y < 0 { - continue; - } else { - v *= (u - p2) * lambda_l; - } - } else { - // Step 4: Region 4, right exponential tail. - y = f64_to_i64(x_r - v.ln() / lambda_r); - if y > 0 && (y as u64) > btpe.n { - continue; - } else { - v *= (u - p3) * lambda_r; - } - } - - // Step 5: Acceptance/rejection comparison. - - // Step 5.0: Test for appropriate method of evaluating f(y). - let k = (y - m).abs(); - if !(k > SQUEEZE_THRESHOLD && (k as f64) < 0.5 * npq - 1.) { - // Step 5.1: Evaluate f(y) via the recursive relationship. Start the - // search from the mode. - let s = btpe.p / q; - let a = s * (n + 1.); - let mut f = 1.0; - match m.cmp(&y) { - Ordering::Less => { - let mut i = m; - loop { - i += 1; - f *= a / (i as f64) - s; - if i == y { - break; - } - } - } - Ordering::Greater => { - let mut i = y; - loop { - i += 1; - f /= a / (i as f64) - s; - if i == m { - break; - } - } - } - Ordering::Equal => {} - } - if v > f { - continue; - } else { - break; - } - } - - // Step 5.2: Squeezing. Check the value of ln(v) against upper and - // lower bound of ln(f(y)). - let k = k as f64; - let rho = (k / npq) * ((k * (k / 3. + 0.625) + 1. / 6.) / npq + 0.5); - let t = -0.5 * k * k / npq; - let alpha = v.ln(); - if alpha < t - rho { - break; - } - if alpha > t + rho { - continue; - } - - // Step 5.3: Final acceptance/rejection test. - let x1 = (y + 1) as f64; - let f1 = (m + 1) as f64; - let z = (f64_to_i64(n) + 1 - m) as f64; - let w = (f64_to_i64(n) - y + 1) as f64; - - fn stirling(a: f64) -> f64 { - let a2 = a * a; - (13860. - (462. - (132. - (99. - 140. / a2) / a2) / a2) / a2) / a / 166320. - } - - if alpha - > x_m * (f1 / x1).ln() - + (n - (m as f64) + 0.5) * (z / w).ln() - + ((y - m) as f64) * (w * btpe.p / (x1 * q)).ln() - // We use the signs from the GSL implementation, which are - // different than the ones in the reference. According to - // the GSL authors, the new signs were verified to be - // correct by one of the original designers of the - // algorithm. - + stirling(f1) - + stirling(z) - - stirling(x1) - - stirling(w) - { - continue; - } - - break; - } - assert!(y >= 0); - let y = y as u64; - - if flipped { - btpe.n - y - } else { - y - } -} - -impl Distribution for Binomial { - fn sample(&self, rng: &mut R) -> u64 { - match self.method { - Method::Binv(binv_para, flipped) => binv(binv_para, flipped, rng), - Method::Btpe(btpe_para, flipped) => btpe(btpe_para, flipped, rng), - Method::Poisson(poisson) => poisson.sample(rng) as u64, - Method::Constant(c) => c, - } - } -} - -#[cfg(test)] -mod test { - use super::Binomial; - use crate::Distribution; - use rand::Rng; - - fn test_binomial_mean_and_variance(n: u64, p: f64, rng: &mut R) { - let binomial = Binomial::new(n, p).unwrap(); - - let expected_mean = n as f64 * p; - let expected_variance = n as f64 * p * (1.0 - p); - - let mut results = [0.0; 1000]; - for i in results.iter_mut() { - *i = binomial.sample(rng) as f64; - } - - let mean = results.iter().sum::() / results.len() as f64; - assert!((mean - expected_mean).abs() < expected_mean / 50.0); - - let variance = - results.iter().map(|x| (x - mean) * (x - mean)).sum::() / results.len() as f64; - assert!((variance - expected_variance).abs() < expected_variance / 10.0); - } - - #[test] - fn test_binomial() { - let mut rng = crate::test::rng(351); - test_binomial_mean_and_variance(150, 0.1, &mut rng); - test_binomial_mean_and_variance(70, 0.6, &mut rng); - test_binomial_mean_and_variance(40, 0.5, &mut rng); - test_binomial_mean_and_variance(20, 0.7, &mut rng); - test_binomial_mean_and_variance(20, 0.5, &mut rng); - test_binomial_mean_and_variance(1 << 61, 1e-17, &mut rng); - test_binomial_mean_and_variance(u64::MAX, 1e-19, &mut rng); - } - - #[test] - fn test_binomial_end_points() { - let mut rng = crate::test::rng(352); - assert_eq!(rng.sample(Binomial::new(20, 0.0).unwrap()), 0); - assert_eq!(rng.sample(Binomial::new(20, 1.0).unwrap()), 20); - } - - #[test] - #[should_panic] - fn test_binomial_invalid_lambda_neg() { - Binomial::new(20, -10.0).unwrap(); - } - - #[test] - fn binomial_distributions_can_be_compared() { - assert_eq!(Binomial::new(1, 1.0), Binomial::new(1, 1.0)); - } - - #[test] - fn binomial_avoid_infinite_loop() { - let dist = Binomial::new(16000000, 3.1444753148558566e-10).unwrap(); - let mut sum: u64 = 0; - let mut rng = crate::test::rng(742); - for _ in 0..100_000 { - sum = sum.wrapping_add(dist.sample(&mut rng)); - } - assert_ne!(sum, 0); - } -} diff --git a/rand_distr/src/cauchy.rs b/rand_distr/src/cauchy.rs deleted file mode 100644 index 8f0faad3863..00000000000 --- a/rand_distr/src/cauchy.rs +++ /dev/null @@ -1,204 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// Copyright 2016-2017 The Rust Project Developers. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! The Cauchy distribution `Cauchy(x₀, γ)`. - -use crate::{Distribution, StandardUniform}; -use core::fmt; -use num_traits::{Float, FloatConst}; -use rand::Rng; - -/// The [Cauchy distribution](https://en.wikipedia.org/wiki/Cauchy_distribution) `Cauchy(x₀, γ)`. -/// -/// The Cauchy distribution is a continuous probability distribution with -/// parameters `x₀` (median) and `γ` (scale). -/// It describes the distribution of the ratio of two independent -/// normally distributed random variables with means `x₀` and scales `γ`. -/// In other words, if `X` and `Y` are independent normally distributed -/// random variables with means `x₀` and scales `γ`, respectively, then -/// `X / Y` is `Cauchy(x₀, γ)` distributed. -/// -/// # Density function -/// -/// `f(x) = 1 / (π * γ * (1 + ((x - x₀) / γ)²))` -/// -/// # Plot -/// -/// The plot illustrates the Cauchy distribution with various values of `x₀` and `γ`. -/// Note how the median parameter `x₀` shifts the distribution along the x-axis, -/// and how the scale `γ` changes the density around the median. -/// -/// The standard Cauchy distribution is the special case with `x₀ = 0` and `γ = 1`, -/// which corresponds to the ratio of two [`StandardNormal`](crate::StandardNormal) distributions. -/// -/// ![Cauchy distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/cauchy.svg) -/// -/// # Example -/// -/// ``` -/// use rand_distr::{Cauchy, Distribution}; -/// -/// let cau = Cauchy::new(2.0, 5.0).unwrap(); -/// let v = cau.sample(&mut rand::rng()); -/// println!("{} is from a Cauchy(2, 5) distribution", v); -/// ``` -/// -/// # Notes -/// -/// Note that at least for `f32`, results are not fully portable due to minor -/// differences in the target system's *tan* implementation, `tanf`. -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct Cauchy -where - F: Float + FloatConst, - StandardUniform: Distribution, -{ - median: F, - scale: F, -} - -/// Error type returned from [`Cauchy::new`]. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum Error { - /// `scale <= 0` or `nan`. - ScaleTooSmall, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Error::ScaleTooSmall => "scale is not positive in Cauchy distribution", - }) - } -} - -#[cfg(feature = "std")] -impl std::error::Error for Error {} - -impl Cauchy -where - F: Float + FloatConst, - StandardUniform: Distribution, -{ - /// Construct a new `Cauchy` with the given shape parameters - /// `median` the peak location and `scale` the scale factor. - pub fn new(median: F, scale: F) -> Result, Error> { - if !(scale > F::zero()) { - return Err(Error::ScaleTooSmall); - } - Ok(Cauchy { median, scale }) - } -} - -impl Distribution for Cauchy -where - F: Float + FloatConst, - StandardUniform: Distribution, -{ - fn sample(&self, rng: &mut R) -> F { - // sample from [0, 1) - let x = StandardUniform.sample(rng); - // get standard cauchy random number - // note that π/2 is not exactly representable, even if x=0.5 the result is finite - let comp_dev = (F::PI() * x).tan(); - // shift and scale according to parameters - self.median + self.scale * comp_dev - } -} - -#[cfg(test)] -mod test { - use super::*; - - fn median(numbers: &mut [f64]) -> f64 { - sort(numbers); - let mid = numbers.len() / 2; - numbers[mid] - } - - fn sort(numbers: &mut [f64]) { - numbers.sort_by(|a, b| a.partial_cmp(b).unwrap()); - } - - #[test] - fn test_cauchy_averages() { - // NOTE: given that the variance and mean are undefined, - // this test does not have any rigorous statistical meaning. - let cauchy = Cauchy::new(10.0, 5.0).unwrap(); - let mut rng = crate::test::rng(123); - let mut numbers: [f64; 1000] = [0.0; 1000]; - let mut sum = 0.0; - for number in &mut numbers[..] { - *number = cauchy.sample(&mut rng); - sum += *number; - } - let median = median(&mut numbers); - #[cfg(feature = "std")] - std::println!("Cauchy median: {}", median); - assert!((median - 10.0).abs() < 0.4); // not 100% certain, but probable enough - let mean = sum / 1000.0; - #[cfg(feature = "std")] - std::println!("Cauchy mean: {}", mean); - // for a Cauchy distribution the mean should not converge - assert!((mean - 10.0).abs() > 0.4); // not 100% certain, but probable enough - } - - #[test] - #[should_panic] - fn test_cauchy_invalid_scale_zero() { - Cauchy::new(0.0, 0.0).unwrap(); - } - - #[test] - #[should_panic] - fn test_cauchy_invalid_scale_neg() { - Cauchy::new(0.0, -10.0).unwrap(); - } - - #[test] - fn value_stability() { - fn gen_samples(m: F, s: F, buf: &mut [F]) - where - StandardUniform: Distribution, - { - let distr = Cauchy::new(m, s).unwrap(); - let mut rng = crate::test::rng(353); - for x in buf { - *x = rng.sample(distr); - } - } - - let mut buf = [0.0; 4]; - gen_samples(100f64, 10.0, &mut buf); - assert_eq!( - &buf, - &[ - 77.93369152808678, - 90.1606912098641, - 125.31516221323625, - 86.10217834773925 - ] - ); - - // Unfortunately this test is not fully portable due to reliance on the - // system's implementation of tanf (see doc on Cauchy struct). - let mut buf = [0.0; 4]; - gen_samples(10f32, 7.0, &mut buf); - let expected = [15.023088, -5.446413, 3.7092876, 3.112482]; - for (a, b) in buf.iter().zip(expected.iter()) { - assert_almost_eq!(*a, *b, 1e-5); - } - } - - #[test] - fn cauchy_distributions_can_be_compared() { - assert_eq!(Cauchy::new(1.0, 2.0), Cauchy::new(1.0, 2.0)); - } -} diff --git a/rand_distr/src/chi_squared.rs b/rand_distr/src/chi_squared.rs deleted file mode 100644 index 409a78bb311..00000000000 --- a/rand_distr/src/chi_squared.rs +++ /dev/null @@ -1,179 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// Copyright 2013 The Rust Project Developers. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! The Chi-squared distribution. - -use self::ChiSquaredRepr::*; - -use crate::{Distribution, Exp1, Gamma, Open01, StandardNormal}; -use core::fmt; -use num_traits::Float; -use rand::Rng; -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; - -/// The [chi-squared distribution](https://en.wikipedia.org/wiki/Chi-squared_distribution) `χ²(k)`. -/// -/// The chi-squared distribution is a continuous probability -/// distribution with parameter `k > 0` degrees of freedom. -/// -/// For `k > 0` integral, this distribution is the sum of the squares -/// of `k` independent standard normal random variables. For other -/// `k`, this uses the equivalent characterisation -/// `χ²(k) = Gamma(k/2, 2)`. -/// -/// # Plot -/// -/// The plot shows the chi-squared distribution with various degrees -/// of freedom. -/// -/// ![Chi-squared distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/chi_squared.svg) -/// -/// # Example -/// -/// ``` -/// use rand_distr::{ChiSquared, Distribution}; -/// -/// let chi = ChiSquared::new(11.0).unwrap(); -/// let v = chi.sample(&mut rand::rng()); -/// println!("{} is from a χ²(11) distribution", v) -/// ``` -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct ChiSquared -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - repr: ChiSquaredRepr, -} - -/// Error type returned from [`ChiSquared::new`] and [`StudentT::new`](crate::StudentT::new). -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub enum Error { - /// `0.5 * k <= 0` or `nan`. - DoFTooSmall, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Error::DoFTooSmall => { - "degrees-of-freedom k is not positive in chi-squared distribution" - } - }) - } -} - -#[cfg(feature = "std")] -impl std::error::Error for Error {} - -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -enum ChiSquaredRepr -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - // k == 1, Gamma(alpha, ..) is particularly slow for alpha < 1, - // e.g. when alpha = 1/2 as it would be for this case, so special- - // casing and using the definition of N(0,1)^2 is faster. - DoFExactlyOne, - DoFAnythingElse(Gamma), -} - -impl ChiSquared -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - /// Create a new chi-squared distribution with degrees-of-freedom - /// `k`. - pub fn new(k: F) -> Result, Error> { - let repr = if k == F::one() { - DoFExactlyOne - } else { - if !(F::from(0.5).unwrap() * k > F::zero()) { - return Err(Error::DoFTooSmall); - } - DoFAnythingElse(Gamma::new(F::from(0.5).unwrap() * k, F::from(2.0).unwrap()).unwrap()) - }; - Ok(ChiSquared { repr }) - } -} -impl Distribution for ChiSquared -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - fn sample(&self, rng: &mut R) -> F { - match self.repr { - DoFExactlyOne => { - // k == 1 => N(0,1)^2 - let norm: F = rng.sample(StandardNormal); - norm * norm - } - DoFAnythingElse(ref g) => g.sample(rng), - } - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_chi_squared_one() { - let chi = ChiSquared::new(1.0).unwrap(); - let mut rng = crate::test::rng(201); - for _ in 0..1000 { - chi.sample(&mut rng); - } - } - #[test] - fn test_chi_squared_small() { - let chi = ChiSquared::new(0.5).unwrap(); - let mut rng = crate::test::rng(202); - for _ in 0..1000 { - chi.sample(&mut rng); - } - } - #[test] - fn test_chi_squared_large() { - let chi = ChiSquared::new(30.0).unwrap(); - let mut rng = crate::test::rng(203); - for _ in 0..1000 { - chi.sample(&mut rng); - } - } - #[test] - #[should_panic] - fn test_chi_squared_invalid_dof() { - ChiSquared::new(-1.0).unwrap(); - } - - #[test] - fn gamma_distributions_can_be_compared() { - assert_eq!(Gamma::new(1.0, 2.0), Gamma::new(1.0, 2.0)); - } - - #[test] - fn chi_squared_distributions_can_be_compared() { - assert_eq!(ChiSquared::new(1.0), ChiSquared::new(1.0)); - } -} diff --git a/rand_distr/src/dirichlet.rs b/rand_distr/src/dirichlet.rs deleted file mode 100644 index ac17fa2e298..00000000000 --- a/rand_distr/src/dirichlet.rs +++ /dev/null @@ -1,446 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// Copyright 2013 The Rust Project Developers. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! The dirichlet distribution `Dirichlet(α₁, α₂, ..., αₙ)`. - -#![cfg(feature = "alloc")] -use crate::{Beta, Distribution, Exp1, Gamma, Open01, StandardNormal}; -use core::fmt; -use num_traits::{Float, NumCast}; -use rand::Rng; -#[cfg(feature = "serde")] -use serde_with::serde_as; - -use alloc::{boxed::Box, vec, vec::Vec}; - -#[derive(Clone, Debug, PartialEq)] -#[cfg_attr(feature = "serde", serde_as)] -struct DirichletFromGamma -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - samplers: [Gamma; N], -} - -/// Error type returned from [`DirchletFromGamma::new`]. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -enum DirichletFromGammaError { - /// Gamma::new(a, 1) failed. - GammmaNewFailed, - - /// gamma_dists.try_into() failed (in theory, this should not happen). - GammaArrayCreationFailed, -} - -impl DirichletFromGamma -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - /// Construct a new `DirichletFromGamma` with the given parameters `alpha`. - /// - /// This function is part of a private implementation detail. - /// It assumes that the input is correct, so no validation of alpha is done. - #[inline] - fn new(alpha: [F; N]) -> Result, DirichletFromGammaError> { - let mut gamma_dists = Vec::new(); - for a in alpha { - let dist = - Gamma::new(a, F::one()).map_err(|_| DirichletFromGammaError::GammmaNewFailed)?; - gamma_dists.push(dist); - } - Ok(DirichletFromGamma { - samplers: gamma_dists - .try_into() - .map_err(|_| DirichletFromGammaError::GammaArrayCreationFailed)?, - }) - } -} - -impl Distribution<[F; N]> for DirichletFromGamma -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - fn sample(&self, rng: &mut R) -> [F; N] { - let mut samples = [F::zero(); N]; - let mut sum = F::zero(); - - for (s, g) in samples.iter_mut().zip(self.samplers.iter()) { - *s = g.sample(rng); - sum = sum + *s; - } - let invacc = F::one() / sum; - for s in samples.iter_mut() { - *s = *s * invacc; - } - samples - } -} - -#[derive(Clone, Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -struct DirichletFromBeta -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - samplers: Box<[Beta]>, -} - -/// Error type returned from [`DirchletFromBeta::new`]. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -enum DirichletFromBetaError { - /// Beta::new(a, b) failed. - BetaNewFailed, -} - -impl DirichletFromBeta -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - /// Construct a new `DirichletFromBeta` with the given parameters `alpha`. - /// - /// This function is part of a private implementation detail. - /// It assumes that the input is correct, so no validation of alpha is done. - #[inline] - fn new(alpha: [F; N]) -> Result, DirichletFromBetaError> { - // `alpha_rev_csum` is the reverse of the cumulative sum of the - // reverse of `alpha[1..]`. E.g. if `alpha = [a0, a1, a2, a3]`, then - // `alpha_rev_csum` is `[a1 + a2 + a3, a2 + a3, a3]`. - // Note that instances of DirichletFromBeta will always have N >= 2, - // so the subtractions of 1, 2 and 3 from N in the following are safe. - let mut alpha_rev_csum = vec![alpha[N - 1]; N - 1]; - for k in 0..(N - 2) { - alpha_rev_csum[N - 3 - k] = alpha_rev_csum[N - 2 - k] + alpha[N - 2 - k]; - } - - // Zip `alpha[..(N-1)]` and `alpha_rev_csum`; for the example - // `alpha = [a0, a1, a2, a3]`, the zip result holds the tuples - // `[(a0, a1+a2+a3), (a1, a2+a3), (a2, a3)]`. - // Then pass each tuple to `Beta::new()` to create the `Beta` - // instances. - let mut beta_dists = Vec::new(); - for (&a, &b) in alpha[..(N - 1)].iter().zip(alpha_rev_csum.iter()) { - let dist = Beta::new(a, b).map_err(|_| DirichletFromBetaError::BetaNewFailed)?; - beta_dists.push(dist); - } - Ok(DirichletFromBeta { - samplers: beta_dists.into_boxed_slice(), - }) - } -} - -impl Distribution<[F; N]> for DirichletFromBeta -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - fn sample(&self, rng: &mut R) -> [F; N] { - let mut samples = [F::zero(); N]; - let mut acc = F::one(); - - for (s, beta) in samples.iter_mut().zip(self.samplers.iter()) { - let beta_sample = beta.sample(rng); - *s = acc * beta_sample; - acc = acc * (F::one() - beta_sample); - } - samples[N - 1] = acc; - samples - } -} - -#[derive(Clone, Debug, PartialEq)] -#[cfg_attr(feature = "serde", serde_as)] -enum DirichletRepr -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - /// Dirichlet distribution that generates samples using the Gamma distribution. - FromGamma(DirichletFromGamma), - - /// Dirichlet distribution that generates samples using the Beta distribution. - FromBeta(DirichletFromBeta), -} - -/// The [Dirichlet distribution](https://en.wikipedia.org/wiki/Dirichlet_distribution) `Dirichlet(α₁, α₂, ..., αₖ)`. -/// -/// The Dirichlet distribution is a family of continuous multivariate -/// probability distributions parameterized by a vector of positive -/// real numbers `α₁, α₂, ..., αₖ`, where `k` is the number of dimensions -/// of the distribution. The distribution is supported on the `k-1`-dimensional -/// simplex, which is the set of points `x = [x₁, x₂, ..., xₖ]` such that -/// `0 ≤ xᵢ ≤ 1` and `∑ xᵢ = 1`. -/// It is a multivariate generalization of the [`Beta`](crate::Beta) distribution. -/// The distribution is symmetric when all `αᵢ` are equal. -/// -/// # Plot -/// -/// The following plot illustrates the 2-dimensional simplices for various -/// 3-dimensional Dirichlet distributions. -/// -/// ![Dirichlet distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/dirichlet.png) -/// -/// # Example -/// -/// ``` -/// use rand::prelude::*; -/// use rand_distr::Dirichlet; -/// -/// let dirichlet = Dirichlet::new([1.0, 2.0, 3.0]).unwrap(); -/// let samples = dirichlet.sample(&mut rand::rng()); -/// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples); -/// ``` -#[cfg_attr(feature = "serde", serde_as)] -#[derive(Clone, Debug, PartialEq)] -pub struct Dirichlet -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - repr: DirichletRepr, -} - -/// Error type returned from [`Dirichlet::new`]. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum Error { - /// `alpha.len() < 2`. - AlphaTooShort, - /// `alpha <= 0.0` or `nan`. - AlphaTooSmall, - /// `alpha` is subnormal. - /// Variate generation methods are not reliable with subnormal inputs. - AlphaSubnormal, - /// `alpha` is infinite. - AlphaInfinite, - /// Failed to create required Gamma distribution(s). - FailedToCreateGamma, - /// Failed to create required Beta distribition(s). - FailedToCreateBeta, - /// `size < 2`. - SizeTooSmall, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Error::AlphaTooShort | Error::SizeTooSmall => { - "less than 2 dimensions in Dirichlet distribution" - } - Error::AlphaTooSmall => "alpha is not positive in Dirichlet distribution", - Error::AlphaSubnormal => "alpha contains a subnormal value in Dirichlet distribution", - Error::AlphaInfinite => "alpha contains an infinite value in Dirichlet distribution", - Error::FailedToCreateGamma => { - "failed to create required Gamma distribution for Dirichlet distribution" - } - Error::FailedToCreateBeta => { - "failed to create required Beta distribition for Dirichlet distribution" - } - }) - } -} - -#[cfg(feature = "std")] -impl std::error::Error for Error {} - -impl Dirichlet -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - /// Construct a new `Dirichlet` with the given alpha parameter `alpha`. - /// - /// Requires `alpha.len() >= 2`, and each value in `alpha` must be positive, - /// finite and not subnormal. - #[inline] - pub fn new(alpha: [F; N]) -> Result, Error> { - if N < 2 { - return Err(Error::AlphaTooShort); - } - for &ai in alpha.iter() { - if !(ai > F::zero()) { - // This also catches nan. - return Err(Error::AlphaTooSmall); - } - if ai.is_infinite() { - return Err(Error::AlphaInfinite); - } - if !ai.is_normal() { - return Err(Error::AlphaSubnormal); - } - } - - if alpha.iter().all(|&x| x <= NumCast::from(0.1).unwrap()) { - // Use the Beta method when all the alphas are less than 0.1 This - // threshold provides a reasonable compromise between using the faster - // Gamma method for as wide a range as possible while ensuring that - // the probability of generating nans is negligibly small. - let dist = DirichletFromBeta::new(alpha).map_err(|_| Error::FailedToCreateBeta)?; - Ok(Dirichlet { - repr: DirichletRepr::FromBeta(dist), - }) - } else { - let dist = DirichletFromGamma::new(alpha).map_err(|_| Error::FailedToCreateGamma)?; - Ok(Dirichlet { - repr: DirichletRepr::FromGamma(dist), - }) - } - } -} - -impl Distribution<[F; N]> for Dirichlet -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - fn sample(&self, rng: &mut R) -> [F; N] { - match &self.repr { - DirichletRepr::FromGamma(dirichlet) => dirichlet.sample(rng), - DirichletRepr::FromBeta(dirichlet) => dirichlet.sample(rng), - } - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_dirichlet() { - let d = Dirichlet::new([1.0, 2.0, 3.0]).unwrap(); - let mut rng = crate::test::rng(221); - let samples = d.sample(&mut rng); - assert!(samples.into_iter().all(|x: f64| x > 0.0)); - } - - #[test] - #[should_panic] - fn test_dirichlet_invalid_length() { - Dirichlet::new([0.5]).unwrap(); - } - - #[test] - #[should_panic] - fn test_dirichlet_alpha_zero() { - Dirichlet::new([0.1, 0.0, 0.3]).unwrap(); - } - - #[test] - #[should_panic] - fn test_dirichlet_alpha_negative() { - Dirichlet::new([0.1, -1.5, 0.3]).unwrap(); - } - - #[test] - #[should_panic] - fn test_dirichlet_alpha_nan() { - Dirichlet::new([0.5, f64::NAN, 0.25]).unwrap(); - } - - #[test] - #[should_panic] - fn test_dirichlet_alpha_subnormal() { - Dirichlet::new([0.5, 1.5e-321, 0.25]).unwrap(); - } - - #[test] - #[should_panic] - fn test_dirichlet_alpha_inf() { - Dirichlet::new([0.5, f64::INFINITY, 0.25]).unwrap(); - } - - #[test] - fn dirichlet_distributions_can_be_compared() { - assert_eq!(Dirichlet::new([1.0, 2.0]), Dirichlet::new([1.0, 2.0])); - } - - /// Check that the means of the components of n samples from - /// the Dirichlet distribution agree with the expected means - /// with a relative tolerance of rtol. - /// - /// This is a crude statistical test, but it will catch egregious - /// mistakes. It will also also fail if any samples contain nan. - fn check_dirichlet_means(alpha: [f64; N], n: i32, rtol: f64, seed: u64) { - let d = Dirichlet::new(alpha).unwrap(); - let mut rng = crate::test::rng(seed); - let mut sums = [0.0; N]; - for _ in 0..n { - let samples = d.sample(&mut rng); - for i in 0..N { - sums[i] += samples[i]; - } - } - let sample_mean = sums.map(|x| x / n as f64); - let alpha_sum: f64 = alpha.iter().sum(); - let expected_mean = alpha.map(|x| x / alpha_sum); - for i in 0..N { - assert_almost_eq!(sample_mean[i], expected_mean[i], rtol); - } - } - - #[test] - fn test_dirichlet_means() { - // Check the means of 20000 samples for several different alphas. - let n = 20000; - let rtol = 2e-2; - let seed = 1317624576693539401; - check_dirichlet_means([0.5, 0.25], n, rtol, seed); - check_dirichlet_means([123.0, 75.0], n, rtol, seed); - check_dirichlet_means([2.0, 2.5, 5.0, 7.0], n, rtol, seed); - check_dirichlet_means([0.1, 8.0, 1.0, 2.0, 2.0, 0.85, 0.05, 12.5], n, rtol, seed); - } - - #[test] - fn test_dirichlet_means_very_small_alpha() { - // With values of alpha that are all 0.001, check that the means of the - // components of 10000 samples are within 1% of the expected means. - // With the sampling method based on gamma variates, this test would - // fail, with about 10% of the samples containing nan. - let alpha = [0.001; 3]; - let n = 10000; - let rtol = 1e-2; - let seed = 1317624576693539401; - check_dirichlet_means(alpha, n, rtol, seed); - } - - #[test] - fn test_dirichlet_means_small_alpha() { - // With values of alpha that are all less than 0.1, check that the - // means of the components of 150000 samples are within 0.1% of the - // expected means. - let alpha = [0.05, 0.025, 0.075, 0.05]; - let n = 150000; - let rtol = 1e-3; - let seed = 1317624576693539401; - check_dirichlet_means(alpha, n, rtol, seed); - } -} diff --git a/rand_distr/src/exponential.rs b/rand_distr/src/exponential.rs deleted file mode 100644 index 6d61015a8c1..00000000000 --- a/rand_distr/src/exponential.rs +++ /dev/null @@ -1,219 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// Copyright 2013 The Rust Project Developers. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! The exponential distribution `Exp(λ)`. - -use crate::utils::ziggurat; -use crate::{ziggurat_tables, Distribution}; -use core::fmt; -use num_traits::Float; -use rand::Rng; - -/// The standard exponential distribution `Exp(1)`. -/// -/// This is equivalent to `Exp::new(1.0)` or sampling with -/// `-rng.gen::().ln()`, but faster. -/// -/// See [`Exp`](crate::Exp) for the general exponential distribution. -/// -/// # Plot -/// -/// The following plot illustrates the exponential distribution with `λ = 1`. -/// -/// ![Exponential distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/exponential_exp1.svg) -/// -/// # Example -/// -/// ``` -/// use rand::prelude::*; -/// use rand_distr::Exp1; -/// -/// let val: f64 = rand::rng().sample(Exp1); -/// println!("{}", val); -/// ``` -/// -/// # Notes -/// -/// Implemented via the ZIGNOR variant[^1] of the Ziggurat method. The exact -/// description in the paper was adjusted to use tables for the exponential -/// distribution rather than normal. -/// -/// [^1]: Jurgen A. Doornik (2005). [*An Improved Ziggurat Method to -/// Generate Normal Random Samples*]( -/// https://www.doornik.com/research/ziggurat.pdf). -/// Nuffield College, Oxford -#[derive(Clone, Copy, Debug)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct Exp1; - -impl Distribution for Exp1 { - #[inline] - fn sample(&self, rng: &mut R) -> f32 { - // TODO: use optimal 32-bit implementation - let x: f64 = self.sample(rng); - x as f32 - } -} - -// This could be done via `-rng.gen::().ln()` but that is slower. -impl Distribution for Exp1 { - #[inline] - fn sample(&self, rng: &mut R) -> f64 { - #[inline] - fn pdf(x: f64) -> f64 { - (-x).exp() - } - #[inline] - fn zero_case(rng: &mut R, _u: f64) -> f64 { - ziggurat_tables::ZIG_EXP_R - rng.random::().ln() - } - - ziggurat( - rng, - false, - &ziggurat_tables::ZIG_EXP_X, - &ziggurat_tables::ZIG_EXP_F, - pdf, - zero_case, - ) - } -} - -/// The [exponential distribution](https://en.wikipedia.org/wiki/Exponential_distribution) `Exp(λ)`. -/// -/// The exponential distribution is a continuous probability distribution -/// with rate parameter `λ` (`lambda`). It describes the time between events -/// in a [`Poisson`](crate::Poisson) process, i.e. a process in which -/// events occur continuously and independently at a constant average rate. -/// -/// See [`Exp1`](crate::Exp1) for an optimised implementation for `λ = 1`. -/// -/// # Density function -/// -/// `f(x) = λ * exp(-λ * x)` for `x > 0`, when `λ > 0`. -/// -/// For `λ = 0`, all samples yield infinity (because a Poisson process -/// with rate 0 has no events). -/// -/// # Plot -/// -/// The following plot illustrates the exponential distribution with -/// various values of `λ`. -/// The `λ` parameter controls the rate of decay as `x` approaches infinity, -/// and the mean of the distribution is `1/λ`. -/// -/// ![Exponential distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/exponential.svg) -/// -/// # Example -/// -/// ``` -/// use rand_distr::{Exp, Distribution}; -/// -/// let exp = Exp::new(2.0).unwrap(); -/// let v = exp.sample(&mut rand::rng()); -/// println!("{} is from a Exp(2) distribution", v); -/// ``` -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct Exp -where - F: Float, - Exp1: Distribution, -{ - /// `lambda` stored as `1/lambda`, since this is what we scale by. - lambda_inverse: F, -} - -/// Error type returned from [`Exp::new`]. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum Error { - /// `lambda < 0` or `nan`. - LambdaTooSmall, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Error::LambdaTooSmall => "lambda is negative or NaN in exponential distribution", - }) - } -} - -#[cfg(feature = "std")] -impl std::error::Error for Error {} - -impl Exp -where - F: Float, - Exp1: Distribution, -{ - /// Construct a new `Exp` with the given shape parameter - /// `lambda`. - /// - /// # Remarks - /// - /// For custom types `N` implementing the [`Float`] trait, - /// the case `lambda = 0` is handled as follows: each sample corresponds - /// to a sample from an `Exp1` multiplied by `1 / 0`. Primitive types - /// yield infinity, since `1 / 0 = infinity`. - #[inline] - pub fn new(lambda: F) -> Result, Error> { - if !(lambda >= F::zero()) { - return Err(Error::LambdaTooSmall); - } - Ok(Exp { - lambda_inverse: F::one() / lambda, - }) - } -} - -impl Distribution for Exp -where - F: Float, - Exp1: Distribution, -{ - fn sample(&self, rng: &mut R) -> F { - rng.sample(Exp1) * self.lambda_inverse - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_exp() { - let exp = Exp::new(10.0).unwrap(); - let mut rng = crate::test::rng(221); - for _ in 0..1000 { - assert!(exp.sample(&mut rng) >= 0.0); - } - } - #[test] - fn test_zero() { - let d = Exp::new(0.0).unwrap(); - assert_eq!(d.sample(&mut crate::test::rng(21)), f64::infinity()); - } - #[test] - #[should_panic] - fn test_exp_invalid_lambda_neg() { - Exp::new(-10.0).unwrap(); - } - - #[test] - #[should_panic] - fn test_exp_invalid_lambda_nan() { - Exp::new(f64::nan()).unwrap(); - } - - #[test] - fn exponential_distributions_can_be_compared() { - assert_eq!(Exp::new(1.0), Exp::new(1.0)); - } -} diff --git a/rand_distr/src/fisher_f.rs b/rand_distr/src/fisher_f.rs deleted file mode 100644 index 9c2c13cf64f..00000000000 --- a/rand_distr/src/fisher_f.rs +++ /dev/null @@ -1,131 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// Copyright 2013 The Rust Project Developers. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! The Fisher F-distribution. - -use crate::{ChiSquared, Distribution, Exp1, Open01, StandardNormal}; -use core::fmt; -use num_traits::Float; -use rand::Rng; -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; - -/// The [Fisher F-distribution](https://en.wikipedia.org/wiki/F-distribution) `F(m, n)`. -/// -/// This distribution is equivalent to the ratio of two normalised -/// chi-squared distributions, that is, `F(m,n) = (χ²(m)/m) / -/// (χ²(n)/n)`. -/// -/// # Plot -/// -/// The plot shows the F-distribution with various values of `m` and `n`. -/// -/// ![F-distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/fisher_f.svg) -/// -/// # Example -/// -/// ``` -/// use rand_distr::{FisherF, Distribution}; -/// -/// let f = FisherF::new(2.0, 32.0).unwrap(); -/// let v = f.sample(&mut rand::rng()); -/// println!("{} is from an F(2, 32) distribution", v) -/// ``` -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct FisherF -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - numer: ChiSquared, - denom: ChiSquared, - // denom_dof / numer_dof so that this can just be a straight - // multiplication, rather than a division. - dof_ratio: F, -} - -/// Error type returned from [`FisherF::new`]. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub enum Error { - /// `m <= 0` or `nan`. - MTooSmall, - /// `n <= 0` or `nan`. - NTooSmall, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Error::MTooSmall => "m is not positive in Fisher F distribution", - Error::NTooSmall => "n is not positive in Fisher F distribution", - }) - } -} - -#[cfg(feature = "std")] -impl std::error::Error for Error {} - -impl FisherF -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - /// Create a new `FisherF` distribution, with the given parameter. - pub fn new(m: F, n: F) -> Result, Error> { - let zero = F::zero(); - if !(m > zero) { - return Err(Error::MTooSmall); - } - if !(n > zero) { - return Err(Error::NTooSmall); - } - - Ok(FisherF { - numer: ChiSquared::new(m).unwrap(), - denom: ChiSquared::new(n).unwrap(), - dof_ratio: n / m, - }) - } -} -impl Distribution for FisherF -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - fn sample(&self, rng: &mut R) -> F { - self.numer.sample(rng) / self.denom.sample(rng) * self.dof_ratio - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_f() { - let f = FisherF::new(2.0, 32.0).unwrap(); - let mut rng = crate::test::rng(204); - for _ in 0..1000 { - f.sample(&mut rng); - } - } - - #[test] - fn fisher_f_distributions_can_be_compared() { - assert_eq!(FisherF::new(1.0, 2.0), FisherF::new(1.0, 2.0)); - } -} diff --git a/rand_distr/src/frechet.rs b/rand_distr/src/frechet.rs deleted file mode 100644 index feecd603fb5..00000000000 --- a/rand_distr/src/frechet.rs +++ /dev/null @@ -1,205 +0,0 @@ -// Copyright 2021 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! The Fréchet distribution `Fréchet(μ, σ, α)`. - -use crate::{Distribution, OpenClosed01}; -use core::fmt; -use num_traits::Float; -use rand::Rng; - -/// The [Fréchet distribution](https://en.wikipedia.org/wiki/Fr%C3%A9chet_distribution) `Fréchet(α, μ, σ)`. -/// -/// The Fréchet distribution is a continuous probability distribution -/// with location parameter `μ` (`mu`), scale parameter `σ` (`sigma`), -/// and shape parameter `α` (`alpha`). It describes the distribution -/// of the maximum (or minimum) of a number of random variables. -/// It is also known as the Type II extreme value distribution. -/// -/// # Density function -/// -/// `f(x) = [(x - μ) / σ]^(-1 - α) exp[-(x - μ) / σ]^(-α) α / σ` -/// -/// # Plot -/// -/// The plot shows the Fréchet distribution with various values of `μ`, `σ`, and `α`. -/// Note how the location parameter `μ` shifts the distribution along the x-axis, -/// the scale parameter `σ` stretches or compresses the distribution along the x-axis, -/// and the shape parameter `α` changes the tail behavior. -/// -/// ![Fréchet distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/frechet.svg) -/// -/// # Example -/// -/// ``` -/// use rand::prelude::*; -/// use rand_distr::Frechet; -/// -/// let val: f64 = rand::rng().sample(Frechet::new(0.0, 1.0, 1.0).unwrap()); -/// println!("{}", val); -/// ``` -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct Frechet -where - F: Float, - OpenClosed01: Distribution, -{ - location: F, - scale: F, - shape: F, -} - -/// Error type returned from [`Frechet::new`]. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum Error { - /// location is infinite or NaN - LocationNotFinite, - /// scale is not finite positive number - ScaleNotPositive, - /// shape is not finite positive number - ShapeNotPositive, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Error::LocationNotFinite => "location is not finite in Frechet distribution", - Error::ScaleNotPositive => "scale is not positive and finite in Frechet distribution", - Error::ShapeNotPositive => "shape is not positive and finite in Frechet distribution", - }) - } -} - -#[cfg(feature = "std")] -impl std::error::Error for Error {} - -impl Frechet -where - F: Float, - OpenClosed01: Distribution, -{ - /// Construct a new `Frechet` distribution with given `location`, `scale`, and `shape`. - pub fn new(location: F, scale: F, shape: F) -> Result, Error> { - if scale <= F::zero() || scale.is_infinite() || scale.is_nan() { - return Err(Error::ScaleNotPositive); - } - if shape <= F::zero() || shape.is_infinite() || shape.is_nan() { - return Err(Error::ShapeNotPositive); - } - if location.is_infinite() || location.is_nan() { - return Err(Error::LocationNotFinite); - } - Ok(Frechet { - location, - scale, - shape, - }) - } -} - -impl Distribution for Frechet -where - F: Float, - OpenClosed01: Distribution, -{ - fn sample(&self, rng: &mut R) -> F { - let x: F = rng.sample(OpenClosed01); - self.location + self.scale * (-x.ln()).powf(-self.shape.recip()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - #[should_panic] - fn test_zero_scale() { - Frechet::new(0.0, 0.0, 1.0).unwrap(); - } - - #[test] - #[should_panic] - fn test_infinite_scale() { - Frechet::new(0.0, f64::INFINITY, 1.0).unwrap(); - } - - #[test] - #[should_panic] - fn test_nan_scale() { - Frechet::new(0.0, f64::NAN, 1.0).unwrap(); - } - - #[test] - #[should_panic] - fn test_zero_shape() { - Frechet::new(0.0, 1.0, 0.0).unwrap(); - } - - #[test] - #[should_panic] - fn test_infinite_shape() { - Frechet::new(0.0, 1.0, f64::INFINITY).unwrap(); - } - - #[test] - #[should_panic] - fn test_nan_shape() { - Frechet::new(0.0, 1.0, f64::NAN).unwrap(); - } - - #[test] - #[should_panic] - fn test_infinite_location() { - Frechet::new(f64::INFINITY, 1.0, 1.0).unwrap(); - } - - #[test] - #[should_panic] - fn test_nan_location() { - Frechet::new(f64::NAN, 1.0, 1.0).unwrap(); - } - - #[test] - fn test_sample_against_cdf() { - fn quantile_function(x: f64) -> f64 { - (-x.ln()).recip() - } - let location = 0.0; - let scale = 1.0; - let shape = 1.0; - let iterations = 100_000; - let increment = 1.0 / iterations as f64; - let probabilities = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]; - let mut quantiles = [0.0; 9]; - for (i, p) in probabilities.iter().enumerate() { - quantiles[i] = quantile_function(*p); - } - let mut proportions = [0.0; 9]; - let d = Frechet::new(location, scale, shape).unwrap(); - let mut rng = crate::test::rng(1); - for _ in 0..iterations { - let replicate = d.sample(&mut rng); - for (i, q) in quantiles.iter().enumerate() { - if replicate < *q { - proportions[i] += increment; - } - } - } - assert!(proportions - .iter() - .zip(&probabilities) - .all(|(p_hat, p)| (p_hat - p).abs() < 0.003)) - } - - #[test] - fn frechet_distributions_can_be_compared() { - assert_eq!(Frechet::new(1.0, 2.0, 3.0), Frechet::new(1.0, 2.0, 3.0)); - } -} diff --git a/rand_distr/src/gamma.rs b/rand_distr/src/gamma.rs deleted file mode 100644 index 0fc6b756df3..00000000000 --- a/rand_distr/src/gamma.rs +++ /dev/null @@ -1,281 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// Copyright 2013 The Rust Project Developers. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! The Gamma distribution. - -use self::GammaRepr::*; - -use crate::{Distribution, Exp, Exp1, Open01, StandardNormal}; -use core::fmt; -use num_traits::Float; -use rand::Rng; -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; - -/// The [Gamma distribution](https://en.wikipedia.org/wiki/Gamma_distribution) `Gamma(k, θ)`. -/// -/// The Gamma distribution is a continuous probability distribution -/// with shape parameter `k > 0` (number of events) and -/// scale parameter `θ > 0` (mean waiting time between events). -/// It describes the time until `k` events occur in a Poisson -/// process with rate `1/θ`. It is the generalization of the -/// [`Exponential`](crate::Exp) distribution. -/// -/// # Density function -/// -/// `f(x) = x^(k - 1) * exp(-x / θ) / (Γ(k) * θ^k)` for `x > 0`, -/// where `Γ` is the [gamma function](https://en.wikipedia.org/wiki/Gamma_function). -/// -/// # Plot -/// -/// The following plot illustrates the Gamma distribution with -/// various values of `k` and `θ`. -/// Curves with `θ = 1` are more saturated, while corresponding -/// curves with `θ = 2` have a lighter color. -/// -/// ![Gamma distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/gamma.svg) -/// -/// # Example -/// -/// ``` -/// use rand_distr::{Distribution, Gamma}; -/// -/// let gamma = Gamma::new(2.0, 5.0).unwrap(); -/// let v = gamma.sample(&mut rand::rng()); -/// println!("{} is from a Gamma(2, 5) distribution", v); -/// ``` -/// -/// # Notes -/// -/// The algorithm used is that described by Marsaglia & Tsang 2000[^1], -/// falling back to directly sampling from an Exponential for `shape -/// == 1`, and using the boosting technique described in that paper for -/// `shape < 1`. -/// -/// [^1]: George Marsaglia and Wai Wan Tsang. 2000. "A Simple Method for -/// Generating Gamma Variables" *ACM Trans. Math. Softw.* 26, 3 -/// (September 2000), 363-372. -/// DOI:[10.1145/358407.358414](https://doi.acm.org/10.1145/358407.358414) -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct Gamma -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - repr: GammaRepr, -} - -/// Error type returned from [`Gamma::new`]. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum Error { - /// `shape <= 0` or `nan`. - ShapeTooSmall, - /// `scale <= 0` or `nan`. - ScaleTooSmall, - /// `1 / scale == 0`. - ScaleTooLarge, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Error::ShapeTooSmall => "shape is not positive in gamma distribution", - Error::ScaleTooSmall => "scale is not positive in gamma distribution", - Error::ScaleTooLarge => "scale is infinity in gamma distribution", - }) - } -} - -#[cfg(feature = "std")] -impl std::error::Error for Error {} - -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -enum GammaRepr -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - Large(GammaLargeShape), - One(Exp), - Small(GammaSmallShape), -} - -// These two helpers could be made public, but saving the -// match-on-Gamma-enum branch from using them directly (e.g. if one -// knows that the shape is always > 1) doesn't appear to be much -// faster. - -/// Gamma distribution where the shape parameter is less than 1. -/// -/// Note, samples from this require a compulsory floating-point `pow` -/// call, which makes it significantly slower than sampling from a -/// gamma distribution where the shape parameter is greater than or -/// equal to 1. -/// -/// See `Gamma` for sampling from a Gamma distribution with general -/// shape parameters. -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -struct GammaSmallShape -where - F: Float, - StandardNormal: Distribution, - Open01: Distribution, -{ - inv_shape: F, - large_shape: GammaLargeShape, -} - -/// Gamma distribution where the shape parameter is larger than 1. -/// -/// See `Gamma` for sampling from a Gamma distribution with general -/// shape parameters. -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -struct GammaLargeShape -where - F: Float, - StandardNormal: Distribution, - Open01: Distribution, -{ - scale: F, - c: F, - d: F, -} - -impl Gamma -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - /// Construct an object representing the `Gamma(shape, scale)` - /// distribution. - #[inline] - pub fn new(shape: F, scale: F) -> Result, Error> { - if !(shape > F::zero()) { - return Err(Error::ShapeTooSmall); - } - if !(scale > F::zero()) { - return Err(Error::ScaleTooSmall); - } - - let repr = if shape == F::one() { - One(Exp::new(F::one() / scale).map_err(|_| Error::ScaleTooLarge)?) - } else if shape < F::one() { - Small(GammaSmallShape::new_raw(shape, scale)) - } else { - Large(GammaLargeShape::new_raw(shape, scale)) - }; - Ok(Gamma { repr }) - } -} - -impl GammaSmallShape -where - F: Float, - StandardNormal: Distribution, - Open01: Distribution, -{ - fn new_raw(shape: F, scale: F) -> GammaSmallShape { - GammaSmallShape { - inv_shape: F::one() / shape, - large_shape: GammaLargeShape::new_raw(shape + F::one(), scale), - } - } -} - -impl GammaLargeShape -where - F: Float, - StandardNormal: Distribution, - Open01: Distribution, -{ - fn new_raw(shape: F, scale: F) -> GammaLargeShape { - let d = shape - F::from(1. / 3.).unwrap(); - GammaLargeShape { - scale, - c: F::one() / (F::from(9.).unwrap() * d).sqrt(), - d, - } - } -} - -impl Distribution for Gamma -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - fn sample(&self, rng: &mut R) -> F { - match self.repr { - Small(ref g) => g.sample(rng), - One(ref g) => g.sample(rng), - Large(ref g) => g.sample(rng), - } - } -} -impl Distribution for GammaSmallShape -where - F: Float, - StandardNormal: Distribution, - Open01: Distribution, -{ - fn sample(&self, rng: &mut R) -> F { - let u: F = rng.sample(Open01); - - self.large_shape.sample(rng) * u.powf(self.inv_shape) - } -} -impl Distribution for GammaLargeShape -where - F: Float, - StandardNormal: Distribution, - Open01: Distribution, -{ - fn sample(&self, rng: &mut R) -> F { - // Marsaglia & Tsang method, 2000 - loop { - let x: F = rng.sample(StandardNormal); - let v_cbrt = F::one() + self.c * x; - if v_cbrt <= F::zero() { - // a^3 <= 0 iff a <= 0 - continue; - } - - let v = v_cbrt * v_cbrt * v_cbrt; - let u: F = rng.sample(Open01); - - let x_sqr = x * x; - if u < F::one() - F::from(0.0331).unwrap() * x_sqr * x_sqr - || u.ln() < F::from(0.5).unwrap() * x_sqr + self.d * (F::one() - v + v.ln()) - { - return self.d * v * self.scale; - } - } - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn gamma_distributions_can_be_compared() { - assert_eq!(Gamma::new(1.0, 2.0), Gamma::new(1.0, 2.0)); - } -} diff --git a/rand_distr/src/geometric.rs b/rand_distr/src/geometric.rs deleted file mode 100644 index 74d30a4459a..00000000000 --- a/rand_distr/src/geometric.rs +++ /dev/null @@ -1,267 +0,0 @@ -//! The geometric distribution `Geometric(p)`. - -use crate::Distribution; -use core::fmt; -#[allow(unused_imports)] -use num_traits::Float; -use rand::Rng; - -/// The [geometric distribution](https://en.wikipedia.org/wiki/Geometric_distribution) `Geometric(p)`. -/// -/// This is the probability distribution of the number of failures -/// (bounded to `[0, u64::MAX]`) before the first success in a -/// series of [`Bernoulli`](crate::Bernoulli) trials, where the -/// probability of success on each trial is `p`. -/// -/// This is the discrete analogue of the [exponential distribution](crate::Exp). -/// -/// See [`StandardGeometric`](crate::StandardGeometric) for an optimised -/// implementation for `p = 0.5`. -/// -/// # Density function -/// -/// `f(k) = (1 - p)^k p` for `k >= 0`. -/// -/// # Plot -/// -/// The following plot illustrates the geometric distribution for various -/// values of `p`. Note how higher `p` values shift the distribution to -/// the left, and the mean of the distribution is `1/p`. -/// -/// ![Geometric distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/geometric.svg) -/// -/// # Example -/// ``` -/// use rand_distr::{Geometric, Distribution}; -/// -/// let geo = Geometric::new(0.25).unwrap(); -/// let v = geo.sample(&mut rand::rng()); -/// println!("{} is from a Geometric(0.25) distribution", v); -/// ``` -#[derive(Copy, Clone, Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct Geometric { - p: f64, - pi: f64, - k: u64, -} - -/// Error type returned from [`Geometric::new`]. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum Error { - /// `p < 0 || p > 1` or `nan` - InvalidProbability, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Error::InvalidProbability => { - "p is NaN or outside the interval [0, 1] in geometric distribution" - } - }) - } -} - -#[cfg(feature = "std")] -impl std::error::Error for Error {} - -impl Geometric { - /// Construct a new `Geometric` with the given shape parameter `p` - /// (probability of success on each trial). - pub fn new(p: f64) -> Result { - if !p.is_finite() || !(0.0..=1.0).contains(&p) { - Err(Error::InvalidProbability) - } else if p == 0.0 || p >= 2.0 / 3.0 { - Ok(Geometric { p, pi: p, k: 0 }) - } else { - let (pi, k) = { - // choose smallest k such that pi = (1 - p)^(2^k) <= 0.5 - let mut k = 1; - let mut pi = (1.0 - p).powi(2); - while pi > 0.5 { - k += 1; - pi = pi * pi; - } - (pi, k) - }; - - Ok(Geometric { p, pi, k }) - } - } -} - -impl Distribution for Geometric { - fn sample(&self, rng: &mut R) -> u64 { - if self.p >= 2.0 / 3.0 { - // use the trivial algorithm: - let mut failures = 0; - loop { - let u = rng.random::(); - if u <= self.p { - break; - } - failures += 1; - } - return failures; - } - - if self.p == 0.0 { - return u64::MAX; - } - - let Geometric { p, pi, k } = *self; - - // Based on the algorithm presented in section 3 of - // Karl Bringmann and Tobias Friedrich (July 2013) - Exact and Efficient - // Generation of Geometric Random Variates and Random Graphs, published - // in International Colloquium on Automata, Languages and Programming - // (pp.267-278) - // https://people.mpi-inf.mpg.de/~kbringma/paper/2013ICALP-1.pdf - - // Use the trivial algorithm to sample D from Geo(pi) = Geo(p) / 2^k: - let d = { - let mut failures = 0; - while rng.random::() < pi { - failures += 1; - } - failures - }; - - // Use rejection sampling for the remainder M from Geo(p) % 2^k: - // choose M uniformly from [0, 2^k), but reject with probability (1 - p)^M - // NOTE: The paper suggests using bitwise sampling here, which is - // currently unsupported, but should improve performance by requiring - // fewer iterations on average. ~ October 28, 2020 - let m = loop { - let m = rng.random::() & ((1 << k) - 1); - let p_reject = if m <= i32::MAX as u64 { - (1.0 - p).powi(m as i32) - } else { - (1.0 - p).powf(m as f64) - }; - - let u = rng.random::(); - if u < p_reject { - break m; - } - }; - - (d << k) + m - } -} - -/// The standard geometric distribution `Geometric(0.5)`. -/// -/// This is equivalent to `Geometric::new(0.5)`, but faster. -/// -/// See [`Geometric`](crate::Geometric) for the general geometric distribution. -/// -/// # Plot -/// -/// The following plot illustrates the standard geometric distribution. -/// -/// ![Standard Geometric distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/standard_geometric.svg) -/// -/// # Example -/// ``` -/// use rand::prelude::*; -/// use rand_distr::StandardGeometric; -/// -/// let v = StandardGeometric.sample(&mut rand::rng()); -/// println!("{} is from a Geometric(0.5) distribution", v); -/// ``` -/// -/// # Notes -/// Implemented via iterated -/// [`Rng::gen::().leading_zeros()`](Rng::gen::().leading_zeros()). -#[derive(Copy, Clone, Debug)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct StandardGeometric; - -impl Distribution for StandardGeometric { - fn sample(&self, rng: &mut R) -> u64 { - let mut result = 0; - loop { - let x = rng.random::().leading_zeros() as u64; - result += x; - if x < 64 { - break; - } - } - result - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_geo_invalid_p() { - assert!(Geometric::new(f64::NAN).is_err()); - assert!(Geometric::new(f64::INFINITY).is_err()); - assert!(Geometric::new(f64::NEG_INFINITY).is_err()); - - assert!(Geometric::new(-0.5).is_err()); - assert!(Geometric::new(0.0).is_ok()); - assert!(Geometric::new(1.0).is_ok()); - assert!(Geometric::new(2.0).is_err()); - } - - fn test_geo_mean_and_variance(p: f64, rng: &mut R) { - let distr = Geometric::new(p).unwrap(); - - let expected_mean = (1.0 - p) / p; - let expected_variance = (1.0 - p) / (p * p); - - let mut results = [0.0; 10000]; - for i in results.iter_mut() { - *i = distr.sample(rng) as f64; - } - - let mean = results.iter().sum::() / results.len() as f64; - assert!((mean - expected_mean).abs() < expected_mean / 40.0); - - let variance = - results.iter().map(|x| (x - mean) * (x - mean)).sum::() / results.len() as f64; - assert!((variance - expected_variance).abs() < expected_variance / 10.0); - } - - #[test] - fn test_geometric() { - let mut rng = crate::test::rng(12345); - - test_geo_mean_and_variance(0.10, &mut rng); - test_geo_mean_and_variance(0.25, &mut rng); - test_geo_mean_and_variance(0.50, &mut rng); - test_geo_mean_and_variance(0.75, &mut rng); - test_geo_mean_and_variance(0.90, &mut rng); - } - - #[test] - fn test_standard_geometric() { - let mut rng = crate::test::rng(654321); - - let distr = StandardGeometric; - let expected_mean = 1.0; - let expected_variance = 2.0; - - let mut results = [0.0; 1000]; - for i in results.iter_mut() { - *i = distr.sample(&mut rng) as f64; - } - - let mean = results.iter().sum::() / results.len() as f64; - assert!((mean - expected_mean).abs() < expected_mean / 50.0); - - let variance = - results.iter().map(|x| (x - mean) * (x - mean)).sum::() / results.len() as f64; - assert!((variance - expected_variance).abs() < expected_variance / 10.0); - } - - #[test] - fn geometric_distributions_can_be_compared() { - assert_eq!(Geometric::new(1.0), Geometric::new(1.0)); - } -} diff --git a/rand_distr/src/gumbel.rs b/rand_distr/src/gumbel.rs deleted file mode 100644 index f420a52df84..00000000000 --- a/rand_distr/src/gumbel.rs +++ /dev/null @@ -1,173 +0,0 @@ -// Copyright 2021 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! The Gumbel distribution `Gumbel(μ, β)`. - -use crate::{Distribution, OpenClosed01}; -use core::fmt; -use num_traits::Float; -use rand::Rng; - -/// The [Gumbel distribution](https://en.wikipedia.org/wiki/Gumbel_distribution) `Gumbel(μ, β)`. -/// -/// The Gumbel distribution is a continuous probability distribution -/// with location parameter `μ` (`mu`) and scale parameter `β` (`beta`). -/// It is used to model the distribution of the maximum (or minimum) -/// of a number of samples of various distributions. -/// -/// # Density function -/// -/// `f(x) = exp(-(z + exp(-z))) / β`, where `z = (x - μ) / β`. -/// -/// # Plot -/// -/// The following plot illustrates the Gumbel distribution with various values of `μ` and `β`. -/// Note how the location parameter `μ` shifts the distribution along the x-axis, -/// and the scale parameter `β` changes the density around `μ`. -/// Note also the asymptotic behavior of the distribution towards the right. -/// -/// ![Gumbel distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/gumbel.svg) -/// -/// # Example -/// ``` -/// use rand::prelude::*; -/// use rand_distr::Gumbel; -/// -/// let val: f64 = rand::rng().sample(Gumbel::new(0.0, 1.0).unwrap()); -/// println!("{}", val); -/// ``` -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct Gumbel -where - F: Float, - OpenClosed01: Distribution, -{ - location: F, - scale: F, -} - -/// Error type returned from [`Gumbel::new`]. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum Error { - /// location is infinite or NaN - LocationNotFinite, - /// scale is not finite positive number - ScaleNotPositive, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Error::ScaleNotPositive => "scale is not positive and finite in Gumbel distribution", - Error::LocationNotFinite => "location is not finite in Gumbel distribution", - }) - } -} - -#[cfg(feature = "std")] -impl std::error::Error for Error {} - -impl Gumbel -where - F: Float, - OpenClosed01: Distribution, -{ - /// Construct a new `Gumbel` distribution with given `location` and `scale`. - pub fn new(location: F, scale: F) -> Result, Error> { - if scale <= F::zero() || scale.is_infinite() || scale.is_nan() { - return Err(Error::ScaleNotPositive); - } - if location.is_infinite() || location.is_nan() { - return Err(Error::LocationNotFinite); - } - Ok(Gumbel { location, scale }) - } -} - -impl Distribution for Gumbel -where - F: Float, - OpenClosed01: Distribution, -{ - fn sample(&self, rng: &mut R) -> F { - let x: F = rng.sample(OpenClosed01); - self.location - self.scale * (-x.ln()).ln() - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - #[should_panic] - fn test_zero_scale() { - Gumbel::new(0.0, 0.0).unwrap(); - } - - #[test] - #[should_panic] - fn test_infinite_scale() { - Gumbel::new(0.0, f64::INFINITY).unwrap(); - } - - #[test] - #[should_panic] - fn test_nan_scale() { - Gumbel::new(0.0, f64::NAN).unwrap(); - } - - #[test] - #[should_panic] - fn test_infinite_location() { - Gumbel::new(f64::INFINITY, 1.0).unwrap(); - } - - #[test] - #[should_panic] - fn test_nan_location() { - Gumbel::new(f64::NAN, 1.0).unwrap(); - } - - #[test] - fn test_sample_against_cdf() { - fn neg_log_log(x: f64) -> f64 { - -(-x.ln()).ln() - } - let location = 0.0; - let scale = 1.0; - let iterations = 100_000; - let increment = 1.0 / iterations as f64; - let probabilities = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]; - let mut quantiles = [0.0; 9]; - for (i, p) in probabilities.iter().enumerate() { - quantiles[i] = neg_log_log(*p); - } - let mut proportions = [0.0; 9]; - let d = Gumbel::new(location, scale).unwrap(); - let mut rng = crate::test::rng(1); - for _ in 0..iterations { - let replicate = d.sample(&mut rng); - for (i, q) in quantiles.iter().enumerate() { - if replicate < *q { - proportions[i] += increment; - } - } - } - assert!(proportions - .iter() - .zip(&probabilities) - .all(|(p_hat, p)| (p_hat - p).abs() < 0.003)) - } - - #[test] - fn gumbel_distributions_can_be_compared() { - assert_eq!(Gumbel::new(1.0, 2.0), Gumbel::new(1.0, 2.0)); - } -} diff --git a/rand_distr/src/hypergeometric.rs b/rand_distr/src/hypergeometric.rs deleted file mode 100644 index f446357530b..00000000000 --- a/rand_distr/src/hypergeometric.rs +++ /dev/null @@ -1,514 +0,0 @@ -//! The hypergeometric distribution `Hypergeometric(N, K, n)`. - -use crate::Distribution; -use core::fmt; -#[allow(unused_imports)] -use num_traits::Float; -use rand::distr::uniform::Uniform; -use rand::Rng; - -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -enum SamplingMethod { - InverseTransform { - initial_p: f64, - initial_x: i64, - }, - RejectionAcceptance { - m: f64, - a: f64, - lambda_l: f64, - lambda_r: f64, - x_l: f64, - x_r: f64, - p1: f64, - p2: f64, - p3: f64, - }, -} - -/// The [hypergeometric distribution](https://en.wikipedia.org/wiki/Hypergeometric_distribution) `Hypergeometric(N, K, n)`. -/// -/// This is the distribution of successes in samples of size `n` drawn without -/// replacement from a population of size `N` containing `K` success states. -/// -/// See the [binomial distribution](crate::Binomial) for the analogous distribution -/// for sampling with replacement. It is a good approximation when the population -/// size is much larger than the sample size. -/// -/// # Density function -/// -/// `f(k) = binomial(K, k) * binomial(N-K, n-k) / binomial(N, n)`, -/// where `binomial(a, b) = a! / (b! * (a - b)!)`. -/// -/// # Plot -/// -/// The following plot of the hypergeometric distribution illustrates the probability of drawing -/// `k` successes in `n = 10` draws from a population of `N = 50` items, of which either `K = 12` -/// or `K = 35` are successes. -/// -/// ![Hypergeometric distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/hypergeometric.svg) -/// -/// # Example -/// ``` -/// use rand_distr::{Distribution, Hypergeometric}; -/// -/// let hypergeo = Hypergeometric::new(60, 24, 7).unwrap(); -/// let v = hypergeo.sample(&mut rand::rng()); -/// println!("{} is from a hypergeometric distribution", v); -/// ``` -#[derive(Copy, Clone, Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct Hypergeometric { - n1: u64, - n2: u64, - k: u64, - offset_x: i64, - sign_x: i64, - sampling_method: SamplingMethod, -} - -/// Error type returned from [`Hypergeometric::new`]. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum Error { - /// `total_population_size` is too large, causing floating point underflow. - PopulationTooLarge, - /// `population_with_feature > total_population_size`. - ProbabilityTooLarge, - /// `sample_size > total_population_size`. - SampleSizeTooLarge, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Error::PopulationTooLarge => { - "total_population_size is too large causing underflow in geometric distribution" - } - Error::ProbabilityTooLarge => { - "population_with_feature > total_population_size in geometric distribution" - } - Error::SampleSizeTooLarge => { - "sample_size > total_population_size in geometric distribution" - } - }) - } -} - -#[cfg(feature = "std")] -impl std::error::Error for Error {} - -// evaluate fact(numerator.0)*fact(numerator.1) / fact(denominator.0)*fact(denominator.1) -fn fraction_of_products_of_factorials(numerator: (u64, u64), denominator: (u64, u64)) -> f64 { - let min_top = u64::min(numerator.0, numerator.1); - let min_bottom = u64::min(denominator.0, denominator.1); - // the factorial of this will cancel out: - let min_all = u64::min(min_top, min_bottom); - - let max_top = u64::max(numerator.0, numerator.1); - let max_bottom = u64::max(denominator.0, denominator.1); - let max_all = u64::max(max_top, max_bottom); - - let mut result = 1.0; - for i in (min_all + 1)..=max_all { - if i <= min_top { - result *= i as f64; - } - - if i <= min_bottom { - result /= i as f64; - } - - if i <= max_top { - result *= i as f64; - } - - if i <= max_bottom { - result /= i as f64; - } - } - - result -} - -const LOGSQRT2PI: f64 = 0.91893853320467274178; // log(sqrt(2*pi)) - -fn ln_of_factorial(v: f64) -> f64 { - // the paper calls for ln(v!), but also wants to pass in fractions, - // so we need to use Stirling's approximation to fill in the gaps: - - // shift v by 3, because Stirling is bad for small values - let v_3 = v + 3.0; - let ln_fac = (v_3 + 0.5) * v_3.ln() - v_3 + LOGSQRT2PI + 1.0 / (12.0 * v_3); - // make the correction for the shift - ln_fac - ((v + 3.0) * (v + 2.0) * (v + 1.0)).ln() -} - -impl Hypergeometric { - /// Constructs a new `Hypergeometric` with the shape parameters - /// `N = total_population_size`, - /// `K = population_with_feature`, - /// `n = sample_size`. - #[allow(clippy::many_single_char_names)] // Same names as in the reference. - pub fn new( - total_population_size: u64, - population_with_feature: u64, - sample_size: u64, - ) -> Result { - if population_with_feature > total_population_size { - return Err(Error::ProbabilityTooLarge); - } - - if sample_size > total_population_size { - return Err(Error::SampleSizeTooLarge); - } - - // set-up constants as function of original parameters - let n = total_population_size; - let (mut sign_x, mut offset_x) = (1, 0); - let (n1, n2) = { - // switch around success and failure states if necessary to ensure n1 <= n2 - let population_without_feature = n - population_with_feature; - if population_with_feature > population_without_feature { - sign_x = -1; - offset_x = sample_size as i64; - (population_without_feature, population_with_feature) - } else { - (population_with_feature, population_without_feature) - } - }; - // when sampling more than half the total population, take the smaller - // group as sampled instead (we can then return n1-x instead). - // - // Note: the boundary condition given in the paper is `sample_size < n / 2`; - // we're deviating here, because when n is even, it doesn't matter whether - // we switch here or not, but when n is odd `n/2 < n - n/2`, so switching - // when `k == n/2`, we'd actually be taking the _larger_ group as sampled. - let k = if sample_size <= n / 2 { - sample_size - } else { - offset_x += n1 as i64 * sign_x; - sign_x *= -1; - n - sample_size - }; - - // Algorithm H2PE has bounded runtime only if `M - max(0, k-n2) >= 10`, - // where `M` is the mode of the distribution. - // Use algorithm HIN for the remaining parameter space. - // - // Voratas Kachitvichyanukul and Bruce W. Schmeiser. 1985. Computer - // generation of hypergeometric random variates. - // J. Statist. Comput. Simul. Vol.22 (August 1985), 127-145 - // https://www.researchgate.net/publication/233212638 - const HIN_THRESHOLD: f64 = 10.0; - let m = ((k + 1) as f64 * (n1 + 1) as f64 / (n + 2) as f64).floor(); - let sampling_method = if m - f64::max(0.0, k as f64 - n2 as f64) < HIN_THRESHOLD { - let (initial_p, initial_x) = if k < n2 { - ( - fraction_of_products_of_factorials((n2, n - k), (n, n2 - k)), - 0, - ) - } else { - ( - fraction_of_products_of_factorials((n1, k), (n, k - n2)), - (k - n2) as i64, - ) - }; - - if initial_p <= 0.0 || !initial_p.is_finite() { - return Err(Error::PopulationTooLarge); - } - - SamplingMethod::InverseTransform { - initial_p, - initial_x, - } - } else { - let a = ln_of_factorial(m) - + ln_of_factorial(n1 as f64 - m) - + ln_of_factorial(k as f64 - m) - + ln_of_factorial((n2 - k) as f64 + m); - - let numerator = (n - k) as f64 * k as f64 * n1 as f64 * n2 as f64; - let denominator = (n - 1) as f64 * n as f64 * n as f64; - let d = 1.5 * (numerator / denominator).sqrt() + 0.5; - - let x_l = m - d + 0.5; - let x_r = m + d + 0.5; - - let k_l = f64::exp( - a - ln_of_factorial(x_l) - - ln_of_factorial(n1 as f64 - x_l) - - ln_of_factorial(k as f64 - x_l) - - ln_of_factorial((n2 - k) as f64 + x_l), - ); - let k_r = f64::exp( - a - ln_of_factorial(x_r - 1.0) - - ln_of_factorial(n1 as f64 - x_r + 1.0) - - ln_of_factorial(k as f64 - x_r + 1.0) - - ln_of_factorial((n2 - k) as f64 + x_r - 1.0), - ); - - let numerator = x_l * ((n2 - k) as f64 + x_l); - let denominator = (n1 as f64 - x_l + 1.0) * (k as f64 - x_l + 1.0); - let lambda_l = -((numerator / denominator).ln()); - - let numerator = (n1 as f64 - x_r + 1.0) * (k as f64 - x_r + 1.0); - let denominator = x_r * ((n2 - k) as f64 + x_r); - let lambda_r = -((numerator / denominator).ln()); - - // the paper literally gives `p2 + kL/lambdaL` where it (probably) - // should have been `p2 <- p1 + kL/lambdaL`; another print error?! - let p1 = 2.0 * d; - let p2 = p1 + k_l / lambda_l; - let p3 = p2 + k_r / lambda_r; - - SamplingMethod::RejectionAcceptance { - m, - a, - lambda_l, - lambda_r, - x_l, - x_r, - p1, - p2, - p3, - } - }; - - Ok(Hypergeometric { - n1, - n2, - k, - offset_x, - sign_x, - sampling_method, - }) - } -} - -impl Distribution for Hypergeometric { - #[allow(clippy::many_single_char_names)] // Same names as in the reference. - fn sample(&self, rng: &mut R) -> u64 { - use SamplingMethod::*; - - let Hypergeometric { - n1, - n2, - k, - sign_x, - offset_x, - sampling_method, - } = *self; - let x = match sampling_method { - InverseTransform { - initial_p: mut p, - initial_x: mut x, - } => { - let mut u = rng.random::(); - - // the paper erroneously uses `until n < p`, which doesn't make any sense - while u > p && x < k as i64 { - u -= p; - p *= ((n1 as i64 - x) * (k as i64 - x)) as f64; - p /= ((x + 1) * (n2 as i64 - k as i64 + 1 + x)) as f64; - x += 1; - } - x - } - RejectionAcceptance { - m, - a, - lambda_l, - lambda_r, - x_l, - x_r, - p1, - p2, - p3, - } => { - let distr_region_select = Uniform::new(0.0, p3).unwrap(); - loop { - let (y, v) = loop { - let u = distr_region_select.sample(rng); - let v = rng.random::(); // for the accept/reject decision - - if u <= p1 { - // Region 1, central bell - let y = (x_l + u).floor(); - break (y, v); - } else if u <= p2 { - // Region 2, left exponential tail - let y = (x_l + v.ln() / lambda_l).floor(); - if y as i64 >= i64::max(0, k as i64 - n2 as i64) { - let v = v * (u - p1) * lambda_l; - break (y, v); - } - } else { - // Region 3, right exponential tail - let y = (x_r - v.ln() / lambda_r).floor(); - if y as u64 <= u64::min(n1, k) { - let v = v * (u - p2) * lambda_r; - break (y, v); - } - } - }; - - // Step 4: Acceptance/Rejection Comparison - if m < 100.0 || y <= 50.0 { - // Step 4.1: evaluate f(y) via recursive relationship - let mut f = 1.0; - if m < y { - for i in (m as u64 + 1)..=(y as u64) { - f *= (n1 - i + 1) as f64 * (k - i + 1) as f64; - f /= i as f64 * (n2 - k + i) as f64; - } - } else { - for i in (y as u64 + 1)..=(m as u64) { - f *= i as f64 * (n2 - k + i) as f64; - f /= (n1 - i + 1) as f64 * (k - i + 1) as f64; - } - } - - if v <= f { - break y as i64; - } - } else { - // Step 4.2: Squeezing - let y1 = y + 1.0; - let ym = y - m; - let yn = n1 as f64 - y + 1.0; - let yk = k as f64 - y + 1.0; - let nk = n2 as f64 - k as f64 + y1; - let r = -ym / y1; - let s = ym / yn; - let t = ym / yk; - let e = -ym / nk; - let g = yn * yk / (y1 * nk) - 1.0; - let dg = if g < 0.0 { 1.0 + g } else { 1.0 }; - let gu = g * (1.0 + g * (-0.5 + g / 3.0)); - let gl = gu - g.powi(4) / (4.0 * dg); - let xm = m + 0.5; - let xn = n1 as f64 - m + 0.5; - let xk = k as f64 - m + 0.5; - let nm = n2 as f64 - k as f64 + xm; - let ub = xm * r * (1.0 + r * (-0.5 + r / 3.0)) - + xn * s * (1.0 + s * (-0.5 + s / 3.0)) - + xk * t * (1.0 + t * (-0.5 + t / 3.0)) - + nm * e * (1.0 + e * (-0.5 + e / 3.0)) - + y * gu - - m * gl - + 0.0034; - let av = v.ln(); - if av > ub { - continue; - } - let dr = if r < 0.0 { - xm * r.powi(4) / (1.0 + r) - } else { - xm * r.powi(4) - }; - let ds = if s < 0.0 { - xn * s.powi(4) / (1.0 + s) - } else { - xn * s.powi(4) - }; - let dt = if t < 0.0 { - xk * t.powi(4) / (1.0 + t) - } else { - xk * t.powi(4) - }; - let de = if e < 0.0 { - nm * e.powi(4) / (1.0 + e) - } else { - nm * e.powi(4) - }; - - if av < ub - 0.25 * (dr + ds + dt + de) + (y + m) * (gl - gu) - 0.0078 { - break y as i64; - } - - // Step 4.3: Final Acceptance/Rejection Test - let av_critical = a - - ln_of_factorial(y) - - ln_of_factorial(n1 as f64 - y) - - ln_of_factorial(k as f64 - y) - - ln_of_factorial((n2 - k) as f64 + y); - if v.ln() <= av_critical { - break y as i64; - } - } - } - } - }; - - (offset_x + sign_x * x) as u64 - } -} - -#[cfg(test)] -mod test { - - use super::*; - - #[test] - fn test_hypergeometric_invalid_params() { - assert!(Hypergeometric::new(100, 101, 5).is_err()); - assert!(Hypergeometric::new(100, 10, 101).is_err()); - assert!(Hypergeometric::new(100, 101, 101).is_err()); - assert!(Hypergeometric::new(100, 10, 5).is_ok()); - } - - fn test_hypergeometric_mean_and_variance(n: u64, k: u64, s: u64, rng: &mut R) { - let distr = Hypergeometric::new(n, k, s).unwrap(); - - let expected_mean = s as f64 * k as f64 / n as f64; - let expected_variance = { - let numerator = (s * k * (n - k) * (n - s)) as f64; - let denominator = (n * n * (n - 1)) as f64; - numerator / denominator - }; - - let mut results = [0.0; 1000]; - for i in results.iter_mut() { - *i = distr.sample(rng) as f64; - } - - let mean = results.iter().sum::() / results.len() as f64; - assert!((mean - expected_mean).abs() < expected_mean / 50.0); - - let variance = - results.iter().map(|x| (x - mean) * (x - mean)).sum::() / results.len() as f64; - assert!((variance - expected_variance).abs() < expected_variance / 10.0); - } - - #[test] - fn test_hypergeometric() { - let mut rng = crate::test::rng(737); - - // exercise algorithm HIN: - test_hypergeometric_mean_and_variance(500, 400, 30, &mut rng); - test_hypergeometric_mean_and_variance(250, 200, 230, &mut rng); - test_hypergeometric_mean_and_variance(100, 20, 6, &mut rng); - test_hypergeometric_mean_and_variance(50, 10, 47, &mut rng); - - // exercise algorithm H2PE - test_hypergeometric_mean_and_variance(5000, 2500, 500, &mut rng); - test_hypergeometric_mean_and_variance(10100, 10000, 1000, &mut rng); - test_hypergeometric_mean_and_variance(100100, 100, 10000, &mut rng); - } - - #[test] - fn hypergeometric_distributions_can_be_compared() { - assert_eq!(Hypergeometric::new(1, 2, 3), Hypergeometric::new(1, 2, 3)); - } - - #[test] - fn stirling() { - let test = [0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; - for &v in test.iter() { - let ln_fac = ln_of_factorial(v); - assert!((special::Gamma::ln_gamma(v + 1.0).0 - ln_fac).abs() < 1e-4); - } - } -} diff --git a/rand_distr/src/inverse_gaussian.rs b/rand_distr/src/inverse_gaussian.rs deleted file mode 100644 index 354c2e05986..00000000000 --- a/rand_distr/src/inverse_gaussian.rs +++ /dev/null @@ -1,143 +0,0 @@ -//! The inverse Gaussian distribution `IG(μ, λ)`. - -use crate::{Distribution, StandardNormal, StandardUniform}; -use core::fmt; -use num_traits::Float; -use rand::Rng; - -/// Error type returned from [`InverseGaussian::new`] -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum Error { - /// `mean <= 0` or `nan`. - MeanNegativeOrNull, - /// `shape <= 0` or `nan`. - ShapeNegativeOrNull, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Error::MeanNegativeOrNull => "mean <= 0 or is NaN in inverse Gaussian distribution", - Error::ShapeNegativeOrNull => "shape <= 0 or is NaN in inverse Gaussian distribution", - }) - } -} - -#[cfg(feature = "std")] -impl std::error::Error for Error {} - -/// The [inverse Gaussian distribution](https://en.wikipedia.org/wiki/Inverse_Gaussian_distribution) `IG(μ, λ)`. -/// -/// This is a continuous probability distribution with mean parameter `μ` (`mu`) -/// and shape parameter `λ` (`lambda`), defined for `x > 0`. -/// It is also known as the Wald distribution. -/// -/// # Plot -/// -/// The following plot shows the inverse Gaussian distribution -/// with various values of `μ` and `λ`. -/// -/// ![Inverse Gaussian distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/inverse_gaussian.svg) -/// -/// # Example -/// ``` -/// use rand_distr::{InverseGaussian, Distribution}; -/// -/// let inv_gauss = InverseGaussian::new(1.0, 2.0).unwrap(); -/// let v = inv_gauss.sample(&mut rand::rng()); -/// println!("{} is from a inverse Gaussian(1, 2) distribution", v); -/// ``` -#[derive(Debug, Clone, Copy, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct InverseGaussian -where - F: Float, - StandardNormal: Distribution, - StandardUniform: Distribution, -{ - mean: F, - shape: F, -} - -impl InverseGaussian -where - F: Float, - StandardNormal: Distribution, - StandardUniform: Distribution, -{ - /// Construct a new `InverseGaussian` distribution with the given mean and - /// shape. - pub fn new(mean: F, shape: F) -> Result, Error> { - let zero = F::zero(); - if !(mean > zero) { - return Err(Error::MeanNegativeOrNull); - } - - if !(shape > zero) { - return Err(Error::ShapeNegativeOrNull); - } - - Ok(Self { mean, shape }) - } -} - -impl Distribution for InverseGaussian -where - F: Float, - StandardNormal: Distribution, - StandardUniform: Distribution, -{ - #[allow(clippy::many_single_char_names)] - fn sample(&self, rng: &mut R) -> F - where - R: Rng + ?Sized, - { - let mu = self.mean; - let l = self.shape; - - let v: F = rng.sample(StandardNormal); - let y = mu * v * v; - - let mu_2l = mu / (F::from(2.).unwrap() * l); - - let x = mu + mu_2l * (y - (F::from(4.).unwrap() * l * y + y * y).sqrt()); - - let u: F = rng.random(); - - if u <= mu / (mu + x) { - return x; - } - - mu * mu / x - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_inverse_gaussian() { - let inv_gauss = InverseGaussian::new(1.0, 1.0).unwrap(); - let mut rng = crate::test::rng(210); - for _ in 0..1000 { - inv_gauss.sample(&mut rng); - } - } - - #[test] - fn test_inverse_gaussian_invalid_param() { - assert!(InverseGaussian::new(-1.0, 1.0).is_err()); - assert!(InverseGaussian::new(-1.0, -1.0).is_err()); - assert!(InverseGaussian::new(1.0, -1.0).is_err()); - assert!(InverseGaussian::new(1.0, 1.0).is_ok()); - } - - #[test] - fn inverse_gaussian_distributions_can_be_compared() { - assert_eq!( - InverseGaussian::new(1.0, 2.0), - InverseGaussian::new(1.0, 2.0) - ); - } -} diff --git a/rand_distr/src/lib.rs b/rand_distr/src/lib.rs deleted file mode 100644 index ef1109b7d6f..00000000000 --- a/rand_distr/src/lib.rs +++ /dev/null @@ -1,216 +0,0 @@ -// Copyright 2019 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -#![doc( - html_logo_url = "https://www.rust-lang.org/logos/rust-logo-128x128-blk.png", - html_favicon_url = "https://www.rust-lang.org/favicon.ico", - html_root_url = "https://rust-random.github.io/rand/" -)] -#![forbid(unsafe_code)] -#![deny(missing_docs)] -#![deny(missing_debug_implementations)] -#![allow( - clippy::excessive_precision, - clippy::float_cmp, - clippy::unreadable_literal -)] -#![allow(clippy::neg_cmp_op_on_partial_ord)] // suggested fix too verbose -#![no_std] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] - -//! Generating random samples from probability distributions. -//! -//! ## Re-exports -//! -//! This crate is a super-set of the [`rand::distr`] module. See the -//! [`rand::distr`] module documentation for an overview of the core -//! [`Distribution`] trait and implementations. -//! -//! The following are re-exported: -//! -//! - The [`Distribution`] trait and [`Iter`] helper type -//! - The [`StandardUniform`], [`Alphanumeric`], [`Uniform`], [`OpenClosed01`], -//! [`Open01`], [`Bernoulli`] distributions -//! - The [`weighted`] module -//! -//! ## Distributions -//! -//! This crate provides the following probability distributions: -//! -//! - Related to real-valued quantities that grow linearly -//! (e.g. errors, offsets): -//! - [`Normal`] distribution, and [`StandardNormal`] as a primitive -//! - [`SkewNormal`] distribution -//! - [`Cauchy`] distribution -//! - Related to Bernoulli trials (yes/no events, with a given probability): -//! - [`Binomial`] distribution -//! - [`Geometric`] distribution -//! - [`Hypergeometric`] distribution -//! - Related to positive real-valued quantities that grow exponentially -//! (e.g. prices, incomes, populations): -//! - [`LogNormal`] distribution -//! - Related to the occurrence of independent events at a given rate: -//! - [`Pareto`] distribution -//! - [`Poisson`] distribution -//! - [`Exp`]onential distribution, and [`Exp1`] as a primitive -//! - [`Weibull`] distribution -//! - [`Gumbel`] distribution -//! - [`Frechet`] distribution -//! - [`Zeta`] distribution -//! - [`Zipf`] distribution -//! - Gamma and derived distributions: -//! - [`Gamma`] distribution -//! - [`ChiSquared`] distribution -//! - [`StudentT`] distribution -//! - [`FisherF`] distribution -//! - Triangular distribution: -//! - [`Beta`] distribution -//! - [`Triangular`] distribution -//! - Multivariate probability distributions -//! - [`Dirichlet`] distribution -//! - [`UnitSphere`] distribution -//! - [`UnitBall`] distribution -//! - [`UnitCircle`] distribution -//! - [`UnitDisc`] distribution -//! - Misc. distributions -//! - [`InverseGaussian`] distribution -//! - [`NormalInverseGaussian`] distribution - -#[cfg(feature = "alloc")] -extern crate alloc; - -#[cfg(feature = "std")] -extern crate std; - -// This is used for doc links: -#[allow(unused)] -use rand::Rng; - -pub use rand::distr::{ - uniform, Alphanumeric, Bernoulli, BernoulliError, Distribution, Iter, Open01, OpenClosed01, - StandardUniform, Uniform, -}; - -pub use self::beta::{Beta, Error as BetaError}; -pub use self::binomial::{Binomial, Error as BinomialError}; -pub use self::cauchy::{Cauchy, Error as CauchyError}; -pub use self::chi_squared::{ChiSquared, Error as ChiSquaredError}; -#[cfg(feature = "alloc")] -pub use self::dirichlet::{Dirichlet, Error as DirichletError}; -pub use self::exponential::{Error as ExpError, Exp, Exp1}; -pub use self::fisher_f::{Error as FisherFError, FisherF}; -pub use self::frechet::{Error as FrechetError, Frechet}; -pub use self::gamma::{Error as GammaError, Gamma}; -pub use self::geometric::{Error as GeoError, Geometric, StandardGeometric}; -pub use self::gumbel::{Error as GumbelError, Gumbel}; -pub use self::hypergeometric::{Error as HyperGeoError, Hypergeometric}; -pub use self::inverse_gaussian::{Error as InverseGaussianError, InverseGaussian}; -pub use self::normal::{Error as NormalError, LogNormal, Normal, StandardNormal}; -pub use self::normal_inverse_gaussian::{ - Error as NormalInverseGaussianError, NormalInverseGaussian, -}; -pub use self::pareto::{Error as ParetoError, Pareto}; -pub use self::pert::{Pert, PertBuilder, PertError}; -pub use self::poisson::{Error as PoissonError, Poisson}; -pub use self::skew_normal::{Error as SkewNormalError, SkewNormal}; -pub use self::triangular::{Triangular, TriangularError}; -pub use self::unit_ball::UnitBall; -pub use self::unit_circle::UnitCircle; -pub use self::unit_disc::UnitDisc; -pub use self::unit_sphere::UnitSphere; -pub use self::weibull::{Error as WeibullError, Weibull}; -pub use self::zeta::{Error as ZetaError, Zeta}; -pub use self::zipf::{Error as ZipfError, Zipf}; -pub use student_t::StudentT; - -pub use num_traits; - -#[cfg(feature = "alloc")] -pub mod weighted; - -#[cfg(test)] -#[macro_use] -mod test { - // Notes on testing - // - // Testing random number distributions correctly is hard. The following - // testing is desired: - // - // - Construction: test initialisation with a few valid parameter sets. - // - Erroneous usage: test that incorrect usage generates an error. - // - Vector: test that usage with fixed inputs (including RNG) generates a - // fixed output sequence on all platforms. - // - Correctness at fixed points (optional): using a specific mock RNG, - // check that specific values are sampled (e.g. end-points and median of - // distribution). - // - Correctness of PDF (extra): generate a histogram of samples within a - // certain range, and check this approximates the PDF. These tests are - // expected to be expensive, and should be behind a feature-gate. - // - // TODO: Vector and correctness tests are largely absent so far. - // NOTE: Some distributions have tests checking only that samples can be - // generated. This is redundant with vector and correctness tests. - - /// Construct a deterministic RNG with the given seed - pub fn rng(seed: u64) -> impl rand::RngCore { - // For tests, we want a statistically good, fast, reproducible RNG. - // PCG32 will do fine, and will be easy to embed if we ever need to. - const INC: u64 = 11634580027462260723; - rand_pcg::Pcg32::new(seed, INC) - } - - /// Assert that two numbers are almost equal to each other. - /// - /// On panic, this macro will print the values of the expressions with their - /// debug representations. - macro_rules! assert_almost_eq { - ($a:expr, $b:expr, $prec:expr) => { - let diff = ($a - $b).abs(); - assert!( - diff <= $prec, - "assertion failed: `abs(left - right) = {:.1e} < {:e}`, \ - (left: `{}`, right: `{}`)", - diff, - $prec, - $a, - $b - ); - }; - } -} - -mod beta; -mod binomial; -mod cauchy; -mod chi_squared; -mod dirichlet; -mod exponential; -mod fisher_f; -mod frechet; -mod gamma; -mod geometric; -mod gumbel; -mod hypergeometric; -mod inverse_gaussian; -mod normal; -mod normal_inverse_gaussian; -mod pareto; -mod pert; -pub(crate) mod poisson; -mod skew_normal; -mod student_t; -mod triangular; -mod unit_ball; -mod unit_circle; -mod unit_disc; -mod unit_sphere; -mod utils; -mod weibull; -mod zeta; -mod ziggurat_tables; -mod zipf; diff --git a/rand_distr/src/normal.rs b/rand_distr/src/normal.rs deleted file mode 100644 index 330c1ec2d6f..00000000000 --- a/rand_distr/src/normal.rs +++ /dev/null @@ -1,432 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// Copyright 2013 The Rust Project Developers. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! The Normal and derived distributions. - -use crate::utils::ziggurat; -use crate::{ziggurat_tables, Distribution, Open01}; -use core::fmt; -use num_traits::Float; -use rand::Rng; - -/// The standard Normal distribution `N(0, 1)`. -/// -/// This is equivalent to `Normal::new(0.0, 1.0)`, but faster. -/// -/// See [`Normal`](crate::Normal) for the general Normal distribution. -/// -/// # Plot -/// -/// The following diagram shows the standard Normal distribution. -/// -/// ![Standard Normal distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/standard_normal.svg) -/// -/// # Example -/// ``` -/// use rand::prelude::*; -/// use rand_distr::StandardNormal; -/// -/// let val: f64 = rand::rng().sample(StandardNormal); -/// println!("{}", val); -/// ``` -/// -/// # Notes -/// -/// Implemented via the ZIGNOR variant[^1] of the Ziggurat method. -/// -/// [^1]: Jurgen A. Doornik (2005). [*An Improved Ziggurat Method to -/// Generate Normal Random Samples*]( -/// https://www.doornik.com/research/ziggurat.pdf). -/// Nuffield College, Oxford -#[derive(Clone, Copy, Debug)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct StandardNormal; - -impl Distribution for StandardNormal { - #[inline] - fn sample(&self, rng: &mut R) -> f32 { - // TODO: use optimal 32-bit implementation - let x: f64 = self.sample(rng); - x as f32 - } -} - -impl Distribution for StandardNormal { - fn sample(&self, rng: &mut R) -> f64 { - #[inline] - fn pdf(x: f64) -> f64 { - (-x * x / 2.0).exp() - } - #[inline] - fn zero_case(rng: &mut R, u: f64) -> f64 { - // compute a random number in the tail by hand - - // strange initial conditions, because the loop is not - // do-while, so the condition should be true on the first - // run, they get overwritten anyway (0 < 1, so these are - // good). - let mut x = 1.0f64; - let mut y = 0.0f64; - - while -2.0 * y < x * x { - let x_: f64 = rng.sample(Open01); - let y_: f64 = rng.sample(Open01); - - x = x_.ln() / ziggurat_tables::ZIG_NORM_R; - y = y_.ln(); - } - - if u < 0.0 { - x - ziggurat_tables::ZIG_NORM_R - } else { - ziggurat_tables::ZIG_NORM_R - x - } - } - - ziggurat( - rng, - true, // this is symmetric - &ziggurat_tables::ZIG_NORM_X, - &ziggurat_tables::ZIG_NORM_F, - pdf, - zero_case, - ) - } -} - -/// The [Normal distribution](https://en.wikipedia.org/wiki/Normal_distribution) `N(μ, σ²)`. -/// -/// The Normal distribution, also known as the Gaussian distribution or -/// bell curve, is a continuous probability distribution with mean -/// `μ` (`mu`) and standard deviation `σ` (`sigma`). -/// It is used to model continuous data that tend to cluster around a mean. -/// The Normal distribution is symmetric and characterized by its bell-shaped curve. -/// -/// See [`StandardNormal`](crate::StandardNormal) for an -/// optimised implementation for `μ = 0` and `σ = 1`. -/// -/// # Density function -/// -/// `f(x) = (1 / sqrt(2π σ²)) * exp(-((x - μ)² / (2σ²)))` -/// -/// # Plot -/// -/// The following diagram shows the Normal distribution with various values of `μ` -/// and `σ`. -/// The blue curve is the [`StandardNormal`](crate::StandardNormal) distribution, `N(0, 1)`. -/// -/// ![Normal distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/normal.svg) -/// -/// # Example -/// -/// ``` -/// use rand_distr::{Normal, Distribution}; -/// -/// // mean 2, standard deviation 3 -/// let normal = Normal::new(2.0, 3.0).unwrap(); -/// let v = normal.sample(&mut rand::rng()); -/// println!("{} is from a N(2, 9) distribution", v) -/// ``` -/// -/// # Notes -/// -/// Implemented via the ZIGNOR variant[^1] of the Ziggurat method. -/// -/// [^1]: Jurgen A. Doornik (2005). [*An Improved Ziggurat Method to -/// Generate Normal Random Samples*]( -/// https://www.doornik.com/research/ziggurat.pdf). -/// Nuffield College, Oxford -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct Normal -where - F: Float, - StandardNormal: Distribution, -{ - mean: F, - std_dev: F, -} - -/// Error type returned from [`Normal::new`] and [`LogNormal::new`](crate::LogNormal::new). -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum Error { - /// The mean value is too small (log-normal samples must be positive) - MeanTooSmall, - /// The standard deviation or other dispersion parameter is not finite. - BadVariance, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Error::MeanTooSmall => "mean < 0 or NaN in log-normal distribution", - Error::BadVariance => "variation parameter is non-finite in (log)normal distribution", - }) - } -} - -#[cfg(feature = "std")] -impl std::error::Error for Error {} - -impl Normal -where - F: Float, - StandardNormal: Distribution, -{ - /// Construct, from mean and standard deviation - /// - /// Parameters: - /// - /// - mean (`μ`, unrestricted) - /// - standard deviation (`σ`, must be finite) - #[inline] - pub fn new(mean: F, std_dev: F) -> Result, Error> { - if !std_dev.is_finite() { - return Err(Error::BadVariance); - } - Ok(Normal { mean, std_dev }) - } - - /// Construct, from mean and coefficient of variation - /// - /// Parameters: - /// - /// - mean (`μ`, unrestricted) - /// - coefficient of variation (`cv = abs(σ / μ)`) - #[inline] - pub fn from_mean_cv(mean: F, cv: F) -> Result, Error> { - if !cv.is_finite() || cv < F::zero() { - return Err(Error::BadVariance); - } - let std_dev = cv * mean; - Ok(Normal { mean, std_dev }) - } - - /// Sample from a z-score - /// - /// This may be useful for generating correlated samples `x1` and `x2` - /// from two different distributions, as follows. - /// ``` - /// # use rand::prelude::*; - /// # use rand_distr::{Normal, StandardNormal}; - /// let mut rng = rand::rng(); - /// let z = StandardNormal.sample(&mut rng); - /// let x1 = Normal::new(0.0, 1.0).unwrap().from_zscore(z); - /// let x2 = Normal::new(2.0, -3.0).unwrap().from_zscore(z); - /// ``` - #[inline] - pub fn from_zscore(&self, zscore: F) -> F { - self.mean + self.std_dev * zscore - } - - /// Returns the mean (`μ`) of the distribution. - pub fn mean(&self) -> F { - self.mean - } - - /// Returns the standard deviation (`σ`) of the distribution. - pub fn std_dev(&self) -> F { - self.std_dev - } -} - -impl Distribution for Normal -where - F: Float, - StandardNormal: Distribution, -{ - fn sample(&self, rng: &mut R) -> F { - self.from_zscore(rng.sample(StandardNormal)) - } -} - -/// The [log-normal distribution](https://en.wikipedia.org/wiki/Log-normal_distribution) `ln N(μ, σ²)`. -/// -/// This is the distribution of the random variable `X = exp(Y)` where `Y` is -/// normally distributed with mean `μ` and variance `σ²`. In other words, if -/// `X` is log-normal distributed, then `ln(X)` is `N(μ, σ²)` distributed. -/// -/// # Plot -/// -/// The following diagram shows the log-normal distribution with various values -/// of `μ` and `σ`. -/// -/// ![Log-normal distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/log_normal.svg) -/// -/// # Example -/// -/// ``` -/// use rand_distr::{LogNormal, Distribution}; -/// -/// // mean 2, standard deviation 3 -/// let log_normal = LogNormal::new(2.0, 3.0).unwrap(); -/// let v = log_normal.sample(&mut rand::rng()); -/// println!("{} is from an ln N(2, 9) distribution", v) -/// ``` -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct LogNormal -where - F: Float, - StandardNormal: Distribution, -{ - norm: Normal, -} - -impl LogNormal -where - F: Float, - StandardNormal: Distribution, -{ - /// Construct, from (log-space) mean and standard deviation - /// - /// Parameters are the "standard" log-space measures (these are the mean - /// and standard deviation of the logarithm of samples): - /// - /// - `mu` (`μ`, unrestricted) is the mean of the underlying distribution - /// - `sigma` (`σ`, must be finite) is the standard deviation of the - /// underlying Normal distribution - #[inline] - pub fn new(mu: F, sigma: F) -> Result, Error> { - let norm = Normal::new(mu, sigma)?; - Ok(LogNormal { norm }) - } - - /// Construct, from (linear-space) mean and coefficient of variation - /// - /// Parameters are linear-space measures: - /// - /// - mean (`μ > 0`) is the (real) mean of the distribution - /// - coefficient of variation (`cv = σ / μ`, requiring `cv ≥ 0`) is a - /// standardized measure of dispersion - /// - /// As a special exception, `μ = 0, cv = 0` is allowed (samples are `-inf`). - #[inline] - pub fn from_mean_cv(mean: F, cv: F) -> Result, Error> { - if cv == F::zero() { - let mu = mean.ln(); - let norm = Normal::new(mu, F::zero()).unwrap(); - return Ok(LogNormal { norm }); - } - if !(mean > F::zero()) { - return Err(Error::MeanTooSmall); - } - if !(cv >= F::zero()) { - return Err(Error::BadVariance); - } - - // Using X ~ lognormal(μ, σ), CV² = Var(X) / E(X)² - // E(X) = exp(μ + σ² / 2) = exp(μ) × exp(σ² / 2) - // Var(X) = exp(2μ + σ²)(exp(σ²) - 1) = E(X)² × (exp(σ²) - 1) - // but Var(X) = (CV × E(X))² so CV² = exp(σ²) - 1 - // thus σ² = log(CV² + 1) - // and exp(μ) = E(X) / exp(σ² / 2) = E(X) / sqrt(CV² + 1) - let a = F::one() + cv * cv; // e - let mu = F::from(0.5).unwrap() * (mean * mean / a).ln(); - let sigma = a.ln().sqrt(); - let norm = Normal::new(mu, sigma)?; - Ok(LogNormal { norm }) - } - - /// Sample from a z-score - /// - /// This may be useful for generating correlated samples `x1` and `x2` - /// from two different distributions, as follows. - /// ``` - /// # use rand::prelude::*; - /// # use rand_distr::{LogNormal, StandardNormal}; - /// let mut rng = rand::rng(); - /// let z = StandardNormal.sample(&mut rng); - /// let x1 = LogNormal::from_mean_cv(3.0, 1.0).unwrap().from_zscore(z); - /// let x2 = LogNormal::from_mean_cv(2.0, 4.0).unwrap().from_zscore(z); - /// ``` - #[inline] - pub fn from_zscore(&self, zscore: F) -> F { - self.norm.from_zscore(zscore).exp() - } -} - -impl Distribution for LogNormal -where - F: Float, - StandardNormal: Distribution, -{ - #[inline] - fn sample(&self, rng: &mut R) -> F { - self.norm.sample(rng).exp() - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_normal() { - let norm = Normal::new(10.0, 10.0).unwrap(); - let mut rng = crate::test::rng(210); - for _ in 0..1000 { - norm.sample(&mut rng); - } - } - #[test] - fn test_normal_cv() { - let norm = Normal::from_mean_cv(1024.0, 1.0 / 256.0).unwrap(); - assert_eq!((norm.mean, norm.std_dev), (1024.0, 4.0)); - } - #[test] - fn test_normal_invalid_sd() { - assert!(Normal::from_mean_cv(10.0, -1.0).is_err()); - } - - #[test] - fn test_log_normal() { - let lnorm = LogNormal::new(10.0, 10.0).unwrap(); - let mut rng = crate::test::rng(211); - for _ in 0..1000 { - lnorm.sample(&mut rng); - } - } - #[test] - fn test_log_normal_cv() { - let lnorm = LogNormal::from_mean_cv(0.0, 0.0).unwrap(); - assert_eq!( - (lnorm.norm.mean, lnorm.norm.std_dev), - (f64::NEG_INFINITY, 0.0) - ); - - let lnorm = LogNormal::from_mean_cv(1.0, 0.0).unwrap(); - assert_eq!((lnorm.norm.mean, lnorm.norm.std_dev), (0.0, 0.0)); - - let e = core::f64::consts::E; - let lnorm = LogNormal::from_mean_cv(e.sqrt(), (e - 1.0).sqrt()).unwrap(); - assert_almost_eq!(lnorm.norm.mean, 0.0, 2e-16); - assert_almost_eq!(lnorm.norm.std_dev, 1.0, 2e-16); - - let lnorm = LogNormal::from_mean_cv(e.powf(1.5), (e - 1.0).sqrt()).unwrap(); - assert_almost_eq!(lnorm.norm.mean, 1.0, 1e-15); - assert_eq!(lnorm.norm.std_dev, 1.0); - } - #[test] - fn test_log_normal_invalid_sd() { - assert!(LogNormal::from_mean_cv(-1.0, 1.0).is_err()); - assert!(LogNormal::from_mean_cv(0.0, 1.0).is_err()); - assert!(LogNormal::from_mean_cv(1.0, -1.0).is_err()); - } - - #[test] - fn normal_distributions_can_be_compared() { - assert_eq!(Normal::new(1.0, 2.0), Normal::new(1.0, 2.0)); - } - - #[test] - fn log_normal_distributions_can_be_compared() { - assert_eq!(LogNormal::new(1.0, 2.0), LogNormal::new(1.0, 2.0)); - } -} diff --git a/rand_distr/src/normal_inverse_gaussian.rs b/rand_distr/src/normal_inverse_gaussian.rs deleted file mode 100644 index 6ad2e58fe65..00000000000 --- a/rand_distr/src/normal_inverse_gaussian.rs +++ /dev/null @@ -1,137 +0,0 @@ -use crate::{Distribution, InverseGaussian, StandardNormal, StandardUniform}; -use core::fmt; -use num_traits::Float; -use rand::Rng; - -/// Error type returned from [`NormalInverseGaussian::new`] -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum Error { - /// `alpha <= 0` or `nan`. - AlphaNegativeOrNull, - /// `|beta| >= alpha` or `nan`. - AbsoluteBetaNotLessThanAlpha, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Error::AlphaNegativeOrNull => { - "alpha <= 0 or is NaN in normal inverse Gaussian distribution" - } - Error::AbsoluteBetaNotLessThanAlpha => { - "|beta| >= alpha or is NaN in normal inverse Gaussian distribution" - } - }) - } -} - -#[cfg(feature = "std")] -impl std::error::Error for Error {} - -/// The [normal-inverse Gaussian distribution](https://en.wikipedia.org/wiki/Normal-inverse_Gaussian_distribution) `NIG(α, β)`. -/// -/// This is a continuous probability distribution with two parameters, -/// `α` (`alpha`) and `β` (`beta`), defined in `(-∞, ∞)`. -/// It is also known as the normal-Wald distribution. -/// -/// # Plot -/// -/// The following plot shows the normal-inverse Gaussian distribution with various values of `α` and `β`. -/// -/// ![Normal-inverse Gaussian distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/normal_inverse_gaussian.svg) -/// -/// # Example -/// ``` -/// use rand_distr::{NormalInverseGaussian, Distribution}; -/// -/// let norm_inv_gauss = NormalInverseGaussian::new(2.0, 1.0).unwrap(); -/// let v = norm_inv_gauss.sample(&mut rand::rng()); -/// println!("{} is from a normal-inverse Gaussian(2, 1) distribution", v); -/// ``` -#[derive(Debug, Clone, Copy, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct NormalInverseGaussian -where - F: Float, - StandardNormal: Distribution, - StandardUniform: Distribution, -{ - beta: F, - inverse_gaussian: InverseGaussian, -} - -impl NormalInverseGaussian -where - F: Float, - StandardNormal: Distribution, - StandardUniform: Distribution, -{ - /// Construct a new `NormalInverseGaussian` distribution with the given alpha (tail heaviness) and - /// beta (asymmetry) parameters. - pub fn new(alpha: F, beta: F) -> Result, Error> { - if !(alpha > F::zero()) { - return Err(Error::AlphaNegativeOrNull); - } - - if !(beta.abs() < alpha) { - return Err(Error::AbsoluteBetaNotLessThanAlpha); - } - - let gamma = (alpha * alpha - beta * beta).sqrt(); - - let mu = F::one() / gamma; - - let inverse_gaussian = InverseGaussian::new(mu, F::one()).unwrap(); - - Ok(Self { - beta, - inverse_gaussian, - }) - } -} - -impl Distribution for NormalInverseGaussian -where - F: Float, - StandardNormal: Distribution, - StandardUniform: Distribution, -{ - fn sample(&self, rng: &mut R) -> F - where - R: Rng + ?Sized, - { - let inv_gauss = rng.sample(self.inverse_gaussian); - - self.beta * inv_gauss + inv_gauss.sqrt() * rng.sample(StandardNormal) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_normal_inverse_gaussian() { - let norm_inv_gauss = NormalInverseGaussian::new(2.0, 1.0).unwrap(); - let mut rng = crate::test::rng(210); - for _ in 0..1000 { - norm_inv_gauss.sample(&mut rng); - } - } - - #[test] - fn test_normal_inverse_gaussian_invalid_param() { - assert!(NormalInverseGaussian::new(-1.0, 1.0).is_err()); - assert!(NormalInverseGaussian::new(-1.0, -1.0).is_err()); - assert!(NormalInverseGaussian::new(1.0, 2.0).is_err()); - assert!(NormalInverseGaussian::new(2.0, 1.0).is_ok()); - } - - #[test] - fn normal_inverse_gaussian_distributions_can_be_compared() { - assert_eq!( - NormalInverseGaussian::new(1.0, 2.0), - NormalInverseGaussian::new(1.0, 2.0) - ); - } -} diff --git a/rand_distr/src/pareto.rs b/rand_distr/src/pareto.rs deleted file mode 100644 index 7334ccd5f15..00000000000 --- a/rand_distr/src/pareto.rs +++ /dev/null @@ -1,164 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! The Pareto distribution `Pareto(xₘ, α)`. - -use crate::{Distribution, OpenClosed01}; -use core::fmt; -use num_traits::Float; -use rand::Rng; - -/// The [Pareto distribution](https://en.wikipedia.org/wiki/Pareto_distribution) `Pareto(xₘ, α)`. -/// -/// The Pareto distribution is a continuous probability distribution with -/// scale parameter `xₘ` ( or `k`) and shape parameter `α`. -/// -/// # Plot -/// -/// The following plot shows the Pareto distribution with various values of -/// `xₘ` and `α`. -/// Note how the shape parameter `α` corresponds to the height of the jump -/// in density at `x = xₘ`, and to the rate of decay in the tail. -/// -/// ![Pareto distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/pareto.svg) -/// -/// # Example -/// ``` -/// use rand::prelude::*; -/// use rand_distr::Pareto; -/// -/// let val: f64 = rand::rng().sample(Pareto::new(1., 2.).unwrap()); -/// println!("{}", val); -/// ``` -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct Pareto -where - F: Float, - OpenClosed01: Distribution, -{ - scale: F, - inv_neg_shape: F, -} - -/// Error type returned from [`Pareto::new`]. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum Error { - /// `scale <= 0` or `nan`. - ScaleTooSmall, - /// `shape <= 0` or `nan`. - ShapeTooSmall, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Error::ScaleTooSmall => "scale is not positive in Pareto distribution", - Error::ShapeTooSmall => "shape is not positive in Pareto distribution", - }) - } -} - -#[cfg(feature = "std")] -impl std::error::Error for Error {} - -impl Pareto -where - F: Float, - OpenClosed01: Distribution, -{ - /// Construct a new Pareto distribution with given `scale` and `shape`. - /// - /// In the literature, `scale` is commonly written as xm or k and - /// `shape` is often written as α. - pub fn new(scale: F, shape: F) -> Result, Error> { - let zero = F::zero(); - - if !(scale > zero) { - return Err(Error::ScaleTooSmall); - } - if !(shape > zero) { - return Err(Error::ShapeTooSmall); - } - Ok(Pareto { - scale, - inv_neg_shape: F::from(-1.0).unwrap() / shape, - }) - } -} - -impl Distribution for Pareto -where - F: Float, - OpenClosed01: Distribution, -{ - fn sample(&self, rng: &mut R) -> F { - let u: F = OpenClosed01.sample(rng); - self.scale * u.powf(self.inv_neg_shape) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use core::fmt::{Debug, Display, LowerExp}; - - #[test] - #[should_panic] - fn invalid() { - Pareto::new(0., 0.).unwrap(); - } - - #[test] - fn sample() { - let scale = 1.0; - let shape = 2.0; - let d = Pareto::new(scale, shape).unwrap(); - let mut rng = crate::test::rng(1); - for _ in 0..1000 { - let r = d.sample(&mut rng); - assert!(r >= scale); - } - } - - #[test] - fn value_stability() { - fn test_samples>( - distr: D, - thresh: F, - expected: &[F], - ) { - let mut rng = crate::test::rng(213); - for v in expected { - let x = rng.sample(&distr); - assert_almost_eq!(x, *v, thresh); - } - } - - test_samples( - Pareto::new(1f32, 1.0).unwrap(), - 1e-6, - &[1.0423688, 2.1235929, 4.132709, 1.4679428], - ); - test_samples( - Pareto::new(2.0, 0.5).unwrap(), - 1e-14, - &[ - 9.019295276219136, - 4.3097126018270595, - 6.837815045397157, - 105.8826669383772, - ], - ); - } - - #[test] - fn pareto_distributions_can_be_compared() { - assert_eq!(Pareto::new(1.0, 2.0), Pareto::new(1.0, 2.0)); - } -} diff --git a/rand_distr/src/pert.rs b/rand_distr/src/pert.rs deleted file mode 100644 index 5c247a3d1e8..00000000000 --- a/rand_distr/src/pert.rs +++ /dev/null @@ -1,213 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. -//! The PERT distribution. - -use crate::{Beta, Distribution, Exp1, Open01, StandardNormal}; -use core::fmt; -use num_traits::Float; -use rand::Rng; - -/// The [PERT distribution](https://en.wikipedia.org/wiki/PERT_distribution) `PERT(min, max, mode, shape)`. -/// -/// Similar to the [`Triangular`] distribution, the PERT distribution is -/// parameterised by a range and a mode within that range. Unlike the -/// [`Triangular`] distribution, the probability density function of the PERT -/// distribution is smooth, with a configurable weighting around the mode. -/// -/// # Plot -/// -/// The following plot shows the PERT distribution with `min = -1`, `max = 1`, -/// and various values of `mode` and `shape`. -/// -/// ![PERT distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/pert.svg) -/// -/// # Example -/// -/// ```rust -/// use rand_distr::{Pert, Distribution}; -/// -/// let d = Pert::new(0., 5.).with_mode(2.5).unwrap(); -/// let v = d.sample(&mut rand::rng()); -/// println!("{} is from a PERT distribution", v); -/// ``` -/// -/// [`Triangular`]: crate::Triangular -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct Pert -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - min: F, - range: F, - beta: Beta, -} - -/// Error type returned from [`Pert`] constructors. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum PertError { - /// `max < min` or `min` or `max` is NaN. - RangeTooSmall, - /// `mode < min` or `mode > max` or `mode` is NaN. - ModeRange, - /// `shape < 0` or `shape` is NaN - ShapeTooSmall, -} - -impl fmt::Display for PertError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - PertError::RangeTooSmall => "requirement min < max is not met in PERT distribution", - PertError::ModeRange => "mode is outside [min, max] in PERT distribution", - PertError::ShapeTooSmall => "shape < 0 or is NaN in PERT distribution", - }) - } -} - -#[cfg(feature = "std")] -impl std::error::Error for PertError {} - -impl Pert -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - /// Construct a PERT distribution with defined `min`, `max` - /// - /// # Example - /// - /// ``` - /// use rand_distr::Pert; - /// let pert_dist = Pert::new(0.0, 10.0) - /// .with_shape(3.5) - /// .with_mean(3.0) - /// .unwrap(); - /// # let _unused: Pert = pert_dist; - /// ``` - #[allow(clippy::new_ret_no_self)] - #[inline] - pub fn new(min: F, max: F) -> PertBuilder { - let shape = F::from(4.0).unwrap(); - PertBuilder { min, max, shape } - } -} - -/// Struct used to build a [`Pert`] -#[derive(Debug)] -pub struct PertBuilder { - min: F, - max: F, - shape: F, -} - -impl PertBuilder -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - /// Set the shape parameter - /// - /// If not specified, this defaults to 4. - #[inline] - pub fn with_shape(mut self, shape: F) -> PertBuilder { - self.shape = shape; - self - } - - /// Specify the mean - #[inline] - pub fn with_mean(self, mean: F) -> Result, PertError> { - let two = F::from(2.0).unwrap(); - let mode = ((self.shape + two) * mean - self.min - self.max) / self.shape; - self.with_mode(mode) - } - - /// Specify the mode - #[inline] - pub fn with_mode(self, mode: F) -> Result, PertError> { - if !(self.max > self.min) { - return Err(PertError::RangeTooSmall); - } - if !(mode >= self.min && self.max >= mode) { - return Err(PertError::ModeRange); - } - if !(self.shape >= F::from(0.).unwrap()) { - return Err(PertError::ShapeTooSmall); - } - - let (min, max, shape) = (self.min, self.max, self.shape); - let range = max - min; - let v = F::from(1.0).unwrap() + shape * (mode - min) / range; - let w = F::from(1.0).unwrap() + shape * (max - mode) / range; - let beta = Beta::new(v, w).map_err(|_| PertError::RangeTooSmall)?; - Ok(Pert { min, range, beta }) - } -} - -impl Distribution for Pert -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - #[inline] - fn sample(&self, rng: &mut R) -> F { - self.beta.sample(rng) * self.range + self.min - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_pert() { - for &(min, max, mode) in &[(-1., 1., 0.), (1., 2., 1.), (5., 25., 25.)] { - let _distr = Pert::new(min, max).with_mode(mode).unwrap(); - // TODO: test correctness - } - - for &(min, max, mode) in &[(-1., 1., 2.), (-1., 1., -2.), (2., 1., 1.)] { - assert!(Pert::new(min, max).with_mode(mode).is_err()); - } - } - - #[test] - fn distributions_can_be_compared() { - let (min, mode, max, shape) = (1.0, 2.0, 3.0, 4.0); - let p1 = Pert::new(min, max).with_mode(mode).unwrap(); - let mean = (min + shape * mode + max) / (shape + 2.0); - let p2 = Pert::new(min, max).with_mean(mean).unwrap(); - assert_eq!(p1, p2); - } - - #[test] - fn mode_almost_half_range() { - assert!(Pert::new(0.0f32, 0.48258883).with_mode(0.24129441).is_ok()); - } - - #[test] - fn almost_symmetric_about_zero() { - let distr = Pert::new(-10f32, 10f32).with_mode(f32::EPSILON); - assert!(distr.is_ok()); - } - - #[test] - fn almost_symmetric() { - let distr = Pert::new(0f32, 2f32).with_mode(1f32 + f32::EPSILON); - assert!(distr.is_ok()); - } -} diff --git a/rand_distr/src/poisson.rs b/rand_distr/src/poisson.rs deleted file mode 100644 index 3e4421259bd..00000000000 --- a/rand_distr/src/poisson.rs +++ /dev/null @@ -1,305 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// Copyright 2016-2017 The Rust Project Developers. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! The Poisson distribution `Poisson(λ)`. - -use crate::{Cauchy, Distribution, StandardUniform}; -use core::fmt; -use num_traits::{Float, FloatConst}; -use rand::Rng; - -/// The [Poisson distribution](https://en.wikipedia.org/wiki/Poisson_distribution) `Poisson(λ)`. -/// -/// The Poisson distribution is a discrete probability distribution with -/// rate parameter `λ` (`lambda`). It models the number of events occurring in a fixed -/// interval of time or space. -/// -/// This distribution has density function: -/// `f(k) = λ^k * exp(-λ) / k!` for `k >= 0`. -/// -/// # Plot -/// -/// The following plot shows the Poisson distribution with various values of `λ`. -/// Note how the expected number of events increases with `λ`. -/// -/// ![Poisson distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/poisson.svg) -/// -/// # Example -/// -/// ``` -/// use rand_distr::{Poisson, Distribution}; -/// -/// let poi = Poisson::new(2.0).unwrap(); -/// let v: f64 = poi.sample(&mut rand::rng()); -/// println!("{} is from a Poisson(2) distribution", v); -/// ``` -/// -/// # Integer vs FP return type -/// -/// This implementation uses floating-point (FP) logic internally. -/// -/// Due to the parameter limit λ < [Self::MAX_LAMBDA], it -/// statistically impossible to sample a value larger [`u64::MAX`]. As such, it -/// is reasonable to cast generated samples to `u64` using `as`: -/// `distr.sample(&mut rng) as u64` (and memory safe since Rust 1.45). -/// Similarly, when `λ < 4.2e9` it can be safely assumed that samples are less -/// than `u32::MAX`. -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct Poisson(Method) -where - F: Float + FloatConst, - StandardUniform: Distribution; - -/// Error type returned from [`Poisson::new`]. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum Error { - /// `lambda <= 0` - ShapeTooSmall, - /// `lambda = ∞` or `lambda = nan` - NonFinite, - /// `lambda` is too large, see [Poisson::MAX_LAMBDA] - ShapeTooLarge, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Error::ShapeTooSmall => "lambda is not positive in Poisson distribution", - Error::NonFinite => "lambda is infinite or nan in Poisson distribution", - Error::ShapeTooLarge => { - "lambda is too large in Poisson distribution, see Poisson::MAX_LAMBDA" - } - }) - } -} - -#[cfg(feature = "std")] -impl std::error::Error for Error {} - -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub(crate) struct KnuthMethod { - exp_lambda: F, -} - -impl KnuthMethod { - pub(crate) fn new(lambda: F) -> Self { - KnuthMethod { - exp_lambda: (-lambda).exp(), - } - } -} - -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -struct RejectionMethod { - lambda: F, - log_lambda: F, - sqrt_2lambda: F, - magic_val: F, -} - -impl RejectionMethod { - pub(crate) fn new(lambda: F) -> Self { - let log_lambda = lambda.ln(); - let sqrt_2lambda = (F::from(2.0).unwrap() * lambda).sqrt(); - let magic_val = lambda * log_lambda - crate::utils::log_gamma(F::one() + lambda); - RejectionMethod { - lambda, - log_lambda, - sqrt_2lambda, - magic_val, - } - } -} - -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -enum Method { - Knuth(KnuthMethod), - Rejection(RejectionMethod), -} - -impl Poisson -where - F: Float + FloatConst, - StandardUniform: Distribution, -{ - /// Construct a new `Poisson` with the given shape parameter - /// `lambda`. - /// - /// The maximum allowed lambda is [MAX_LAMBDA](Self::MAX_LAMBDA). - pub fn new(lambda: F) -> Result, Error> { - if !lambda.is_finite() { - return Err(Error::NonFinite); - } - if !(lambda > F::zero()) { - return Err(Error::ShapeTooSmall); - } - - // Use the Knuth method only for low expected values - let method = if lambda < F::from(12.0).unwrap() { - Method::Knuth(KnuthMethod::new(lambda)) - } else { - if lambda > F::from(Self::MAX_LAMBDA).unwrap() { - return Err(Error::ShapeTooLarge); - } - Method::Rejection(RejectionMethod::new(lambda)) - }; - - Ok(Poisson(method)) - } - - /// The maximum supported value of `lambda` - /// - /// This value was selected such that - /// `MAX_LAMBDA + 1e6 * sqrt(MAX_LAMBDA) < 2^64 - 1`, - /// thus ensuring that the probability of sampling a value larger than - /// `u64::MAX` is less than 1e-1000. - /// - /// Applying this limit also solves - /// [#1312](https://github.com/rust-random/rand/issues/1312). - pub const MAX_LAMBDA: f64 = 1.844e19; -} - -impl Distribution for KnuthMethod -where - F: Float + FloatConst, - StandardUniform: Distribution, -{ - fn sample(&self, rng: &mut R) -> F { - let mut result = F::one(); - let mut p = rng.random::(); - while p > self.exp_lambda { - p = p * rng.random::(); - result = result + F::one(); - } - result - F::one() - } -} - -impl Distribution for RejectionMethod -where - F: Float + FloatConst, - StandardUniform: Distribution, -{ - fn sample(&self, rng: &mut R) -> F { - // The algorithm from Numerical Recipes in C - - // we use the Cauchy distribution as the comparison distribution - // f(x) ~ 1/(1+x^2) - let cauchy = Cauchy::new(F::zero(), F::one()).unwrap(); - let mut result; - - loop { - let mut comp_dev; - - loop { - // draw from the Cauchy distribution - comp_dev = rng.sample(cauchy); - // shift the peak of the comparison distribution - result = self.sqrt_2lambda * comp_dev + self.lambda; - // repeat the drawing until we are in the range of possible values - if result >= F::zero() { - break; - } - } - // now the result is a random variable greater than 0 with Cauchy distribution - // the result should be an integer value - result = result.floor(); - - // this is the ratio of the Poisson distribution to the comparison distribution - // the magic value scales the distribution function to a range of approximately 0-1 - // since it is not exact, we multiply the ratio by 0.9 to avoid ratios greater than 1 - // this doesn't change the resulting distribution, only increases the rate of failed drawings - let check = F::from(0.9).unwrap() - * (F::one() + comp_dev * comp_dev) - * (result * self.log_lambda - - crate::utils::log_gamma(F::one() + result) - - self.magic_val) - .exp(); - - // check with uniform random value - if below the threshold, we are within the target distribution - if rng.random::() <= check { - break; - } - } - result - } -} - -impl Distribution for Poisson -where - F: Float + FloatConst, - StandardUniform: Distribution, -{ - #[inline] - fn sample(&self, rng: &mut R) -> F { - match &self.0 { - Method::Knuth(method) => method.sample(rng), - Method::Rejection(method) => method.sample(rng), - } - } -} - -#[cfg(test)] -mod test { - use super::*; - - fn test_poisson_avg_gen(lambda: F, tol: F) - where - StandardUniform: Distribution, - { - let poisson = Poisson::new(lambda).unwrap(); - let mut rng = crate::test::rng(123); - let mut sum = F::zero(); - for _ in 0..1000 { - sum = sum + poisson.sample(&mut rng); - } - let avg = sum / F::from(1000.0).unwrap(); - assert!((avg - lambda).abs() < tol); - } - - #[test] - fn test_poisson_avg() { - test_poisson_avg_gen::(10.0, 0.1); - test_poisson_avg_gen::(15.0, 0.1); - - test_poisson_avg_gen::(10.0, 0.1); - test_poisson_avg_gen::(15.0, 0.1); - - // Small lambda will use Knuth's method with exp_lambda == 1.0 - test_poisson_avg_gen::(0.00000000000000005, 0.1); - test_poisson_avg_gen::(0.00000000000000005, 0.1); - } - - #[test] - #[should_panic] - fn test_poisson_invalid_lambda_zero() { - Poisson::new(0.0).unwrap(); - } - - #[test] - #[should_panic] - fn test_poisson_invalid_lambda_infinity() { - Poisson::new(f64::INFINITY).unwrap(); - } - - #[test] - #[should_panic] - fn test_poisson_invalid_lambda_neg() { - Poisson::new(-10.0).unwrap(); - } - - #[test] - fn poisson_distributions_can_be_compared() { - assert_eq!(Poisson::new(1.0), Poisson::new(1.0)); - } -} diff --git a/rand_distr/src/skew_normal.rs b/rand_distr/src/skew_normal.rs deleted file mode 100644 index 1be2311a6b5..00000000000 --- a/rand_distr/src/skew_normal.rs +++ /dev/null @@ -1,272 +0,0 @@ -// Copyright 2021 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! The Skew Normal distribution `SN(ξ, ω, α)`. - -use crate::{Distribution, StandardNormal}; -use core::fmt; -use num_traits::Float; -use rand::Rng; - -/// The [skew normal distribution](https://en.wikipedia.org/wiki/Skew_normal_distribution) `SN(ξ, ω, α)`. -/// -/// The skew normal distribution is a generalization of the -/// [`Normal`](crate::Normal) distribution to allow for non-zero skewness. -/// It has location parameter `ξ` (`xi`), scale parameter `ω` (`omega`), -/// and shape parameter `α` (`alpha`). -/// -/// The `ξ` and `ω` parameters correspond to the mean `μ` and standard -/// deviation `σ` of the normal distribution, respectively. -/// The `α` parameter controls the skewness. -/// -/// # Density function -/// -/// It has the density function, for `scale > 0`, -/// `f(x) = 2 / scale * phi((x - location) / scale) * Phi(alpha * (x - location) / scale)` -/// where `phi` and `Phi` are the density and distribution of a standard normal variable. -/// -/// # Plot -/// -/// The following plot shows the skew normal distribution with `location = 0`, `scale = 1` -/// (corresponding to the [`standard normal distribution`](crate::StandardNormal)), and -/// various values of `shape`. -/// -/// ![Skew normal distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/skew_normal.svg) -/// -/// # Example -/// -/// ``` -/// use rand_distr::{SkewNormal, Distribution}; -/// -/// // location 2, scale 3, shape 1 -/// let skew_normal = SkewNormal::new(2.0, 3.0, 1.0).unwrap(); -/// let v = skew_normal.sample(&mut rand::rng()); -/// println!("{} is from a SN(2, 3, 1) distribution", v) -/// ``` -/// -/// # Implementation details -/// -/// We are using the algorithm from [A Method to Simulate the Skew Normal Distribution]. -/// -/// [skew normal distribution]: https://en.wikipedia.org/wiki/Skew_normal_distribution -/// [`Normal`]: struct.Normal.html -/// [A Method to Simulate the Skew Normal Distribution]: https://dx.doi.org/10.4236/am.2014.513201 -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct SkewNormal -where - F: Float, - StandardNormal: Distribution, -{ - location: F, - scale: F, - shape: F, -} - -/// Error type returned from [`SkewNormal::new`]. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum Error { - /// The scale parameter is not finite or it is less or equal to zero. - ScaleTooSmall, - /// The shape parameter is not finite. - BadShape, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Error::ScaleTooSmall => { - "scale parameter is either non-finite or it is less or equal to zero in skew normal distribution" - } - Error::BadShape => "shape parameter is non-finite in skew normal distribution", - }) - } -} - -#[cfg(feature = "std")] -impl std::error::Error for Error {} - -impl SkewNormal -where - F: Float, - StandardNormal: Distribution, -{ - /// Construct, from location, scale and shape. - /// - /// Parameters: - /// - /// - location (unrestricted) - /// - scale (must be finite and larger than zero) - /// - shape (must be finite) - #[inline] - pub fn new(location: F, scale: F, shape: F) -> Result, Error> { - if !scale.is_finite() || !(scale > F::zero()) { - return Err(Error::ScaleTooSmall); - } - if !shape.is_finite() { - return Err(Error::BadShape); - } - Ok(SkewNormal { - location, - scale, - shape, - }) - } - - /// Returns the location of the distribution. - pub fn location(&self) -> F { - self.location - } - - /// Returns the scale of the distribution. - pub fn scale(&self) -> F { - self.scale - } - - /// Returns the shape of the distribution. - pub fn shape(&self) -> F { - self.shape - } -} - -impl Distribution for SkewNormal -where - F: Float, - StandardNormal: Distribution, -{ - fn sample(&self, rng: &mut R) -> F { - let linear_map = |x: F| -> F { x * self.scale + self.location }; - let u_1: F = rng.sample(StandardNormal); - if self.shape == F::zero() { - linear_map(u_1) - } else { - let u_2 = rng.sample(StandardNormal); - let (u, v) = (u_1.max(u_2), u_1.min(u_2)); - if self.shape == -F::one() { - linear_map(v) - } else if self.shape == F::one() { - linear_map(u) - } else { - let normalized = ((F::one() + self.shape) * u + (F::one() - self.shape) * v) - / ((F::one() + self.shape * self.shape).sqrt() - * F::from(core::f64::consts::SQRT_2).unwrap()); - linear_map(normalized) - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - fn test_samples>(distr: D, zero: F, expected: &[F]) { - let mut rng = crate::test::rng(213); - let mut buf = [zero; 4]; - for x in &mut buf { - *x = rng.sample(&distr); - } - assert_eq!(buf, expected); - } - - #[test] - #[should_panic] - fn invalid_scale_nan() { - SkewNormal::new(0.0, f64::NAN, 0.0).unwrap(); - } - - #[test] - #[should_panic] - fn invalid_scale_zero() { - SkewNormal::new(0.0, 0.0, 0.0).unwrap(); - } - - #[test] - #[should_panic] - fn invalid_scale_negative() { - SkewNormal::new(0.0, -1.0, 0.0).unwrap(); - } - - #[test] - #[should_panic] - fn invalid_scale_infinite() { - SkewNormal::new(0.0, f64::INFINITY, 0.0).unwrap(); - } - - #[test] - #[should_panic] - fn invalid_shape_nan() { - SkewNormal::new(0.0, 1.0, f64::NAN).unwrap(); - } - - #[test] - #[should_panic] - fn invalid_shape_infinite() { - SkewNormal::new(0.0, 1.0, f64::INFINITY).unwrap(); - } - - #[test] - fn valid_location_nan() { - SkewNormal::new(f64::NAN, 1.0, 0.0).unwrap(); - } - - #[test] - fn skew_normal_value_stability() { - test_samples( - SkewNormal::new(0.0, 1.0, 0.0).unwrap(), - 0f32, - &[-0.11844189, 0.781378, 0.06563994, -1.1932899], - ); - test_samples( - SkewNormal::new(0.0, 1.0, 0.0).unwrap(), - 0f64, - &[ - -0.11844188827977231, - 0.7813779637772346, - 0.06563993969580051, - -1.1932899004186373, - ], - ); - test_samples( - SkewNormal::new(f64::INFINITY, 1.0, 0.0).unwrap(), - 0f64, - &[f64::INFINITY, f64::INFINITY, f64::INFINITY, f64::INFINITY], - ); - test_samples( - SkewNormal::new(f64::NEG_INFINITY, 1.0, 0.0).unwrap(), - 0f64, - &[ - f64::NEG_INFINITY, - f64::NEG_INFINITY, - f64::NEG_INFINITY, - f64::NEG_INFINITY, - ], - ); - } - - #[test] - fn skew_normal_value_location_nan() { - let skew_normal = SkewNormal::new(f64::NAN, 1.0, 0.0).unwrap(); - let mut rng = crate::test::rng(213); - let mut buf = [0.0; 4]; - for x in &mut buf { - *x = rng.sample(skew_normal); - } - for value in buf.iter() { - assert!(value.is_nan()); - } - } - - #[test] - fn skew_normal_distributions_can_be_compared() { - assert_eq!( - SkewNormal::new(1.0, 2.0, 3.0), - SkewNormal::new(1.0, 2.0, 3.0) - ); - } -} diff --git a/rand_distr/src/student_t.rs b/rand_distr/src/student_t.rs deleted file mode 100644 index b0d7d078ae2..00000000000 --- a/rand_distr/src/student_t.rs +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// Copyright 2013 The Rust Project Developers. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! The Student's t-distribution. - -use crate::{ChiSquared, ChiSquaredError}; -use crate::{Distribution, Exp1, Open01, StandardNormal}; -use num_traits::Float; -use rand::Rng; -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; - -/// The [Student t-distribution](https://en.wikipedia.org/wiki/Student%27s_t-distribution) `t(ν)`. -/// -/// The t-distribution is a continuous probability distribution -/// parameterized by degrees of freedom `ν` (`nu`), which -/// arises when estimating the mean of a normally-distributed -/// population in situations where the sample size is small and -/// the population's standard deviation is unknown. -/// It is widely used in hypothesis testing. -/// -/// For `ν = 1`, this is equivalent to the standard -/// [`Cauchy`](crate::Cauchy) distribution, -/// and as `ν` diverges to infinity, `t(ν)` converges to -/// [`StandardNormal`](crate::StandardNormal). -/// -/// # Plot -/// -/// The plot shows the t-distribution with various degrees of freedom. -/// -/// ![T-distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/student_t.svg) -/// -/// # Example -/// -/// ``` -/// use rand_distr::{StudentT, Distribution}; -/// -/// let t = StudentT::new(11.0).unwrap(); -/// let v = t.sample(&mut rand::rng()); -/// println!("{} is from a t(11) distribution", v) -/// ``` -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct StudentT -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - chi: ChiSquared, - dof: F, -} - -impl StudentT -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - /// Create a new Student t-distribution with `ν` (nu) - /// degrees of freedom. - pub fn new(nu: F) -> Result, ChiSquaredError> { - Ok(StudentT { - chi: ChiSquared::new(nu)?, - dof: nu, - }) - } -} -impl Distribution for StudentT -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - fn sample(&self, rng: &mut R) -> F { - let norm: F = rng.sample(StandardNormal); - norm * (self.dof / self.chi.sample(rng)).sqrt() - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_t() { - let t = StudentT::new(11.0).unwrap(); - let mut rng = crate::test::rng(205); - for _ in 0..1000 { - t.sample(&mut rng); - } - } - - #[test] - fn student_t_distributions_can_be_compared() { - assert_eq!(StudentT::new(1.0), StudentT::new(1.0)); - } -} diff --git a/rand_distr/src/triangular.rs b/rand_distr/src/triangular.rs deleted file mode 100644 index 05a46e57ecf..00000000000 --- a/rand_distr/src/triangular.rs +++ /dev/null @@ -1,149 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. -//! The triangular distribution. - -use crate::{Distribution, StandardUniform}; -use core::fmt; -use num_traits::Float; -use rand::Rng; - -/// The [triangular distribution](https://en.wikipedia.org/wiki/Triangular_distribution) `Triangular(min, max, mode)`. -/// -/// A continuous probability distribution parameterised by a range, and a mode -/// (most likely value) within that range. -/// -/// The probability density function is triangular. For a similar distribution -/// with a smooth PDF, see the [`Pert`] distribution. -/// -/// # Plot -/// -/// The following plot shows the triangular distribution with various values of -/// `min`, `max`, and `mode`. -/// -/// ![Triangular distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/triangular.svg) -/// -/// # Example -/// -/// ```rust -/// use rand_distr::{Triangular, Distribution}; -/// -/// let d = Triangular::new(0., 5., 2.5).unwrap(); -/// let v = d.sample(&mut rand::rng()); -/// println!("{} is from a triangular distribution", v); -/// ``` -/// -/// [`Pert`]: crate::Pert -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct Triangular -where - F: Float, - StandardUniform: Distribution, -{ - min: F, - max: F, - mode: F, -} - -/// Error type returned from [`Triangular::new`]. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum TriangularError { - /// `max < min` or `min` or `max` is NaN. - RangeTooSmall, - /// `mode < min` or `mode > max` or `mode` is NaN. - ModeRange, -} - -impl fmt::Display for TriangularError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - TriangularError::RangeTooSmall => { - "requirement min <= max is not met in triangular distribution" - } - TriangularError::ModeRange => "mode is outside [min, max] in triangular distribution", - }) - } -} - -#[cfg(feature = "std")] -impl std::error::Error for TriangularError {} - -impl Triangular -where - F: Float, - StandardUniform: Distribution, -{ - /// Set up the Triangular distribution with defined `min`, `max` and `mode`. - #[inline] - pub fn new(min: F, max: F, mode: F) -> Result, TriangularError> { - if !(max >= min) { - return Err(TriangularError::RangeTooSmall); - } - if !(mode >= min && max >= mode) { - return Err(TriangularError::ModeRange); - } - Ok(Triangular { min, max, mode }) - } -} - -impl Distribution for Triangular -where - F: Float, - StandardUniform: Distribution, -{ - #[inline] - fn sample(&self, rng: &mut R) -> F { - let f: F = rng.sample(StandardUniform); - let diff_mode_min = self.mode - self.min; - let range = self.max - self.min; - let f_range = f * range; - if f_range < diff_mode_min { - self.min + (f_range * diff_mode_min).sqrt() - } else { - self.max - ((range - f_range) * (self.max - self.mode)).sqrt() - } - } -} - -#[cfg(test)] -mod test { - use super::*; - use rand::{rngs::mock, Rng}; - - #[test] - fn test_triangular() { - let mut half_rng = mock::StepRng::new(0x8000_0000_0000_0000, 0); - assert_eq!(half_rng.random::(), 0.5); - for &(min, max, mode, median) in &[ - (-1., 1., 0., 0.), - (1., 2., 1., 2. - 0.5f64.sqrt()), - (5., 25., 25., 5. + 200f64.sqrt()), - (1e-5, 1e5, 1e-3, 1e5 - 4999999949.5f64.sqrt()), - (0., 1., 0.9, 0.45f64.sqrt()), - (-4., -0.5, -2., -4.0 + 3.5f64.sqrt()), - ] { - #[cfg(feature = "std")] - std::println!("{} {} {} {}", min, max, mode, median); - let distr = Triangular::new(min, max, mode).unwrap(); - // Test correct value at median: - assert_eq!(distr.sample(&mut half_rng), median); - } - - for &(min, max, mode) in &[(-1., 1., 2.), (-1., 1., -2.), (2., 1., 1.)] { - assert!(Triangular::new(min, max, mode).is_err()); - } - } - - #[test] - fn triangular_distributions_can_be_compared() { - assert_eq!( - Triangular::new(1.0, 3.0, 2.0), - Triangular::new(1.0, 3.0, 2.0) - ); - } -} diff --git a/rand_distr/src/unit_ball.rs b/rand_distr/src/unit_ball.rs deleted file mode 100644 index 514fc30812a..00000000000 --- a/rand_distr/src/unit_ball.rs +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright 2019 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -use crate::{uniform::SampleUniform, Distribution, Uniform}; -use num_traits::Float; -use rand::Rng; - -/// Samples uniformly from the volume of the unit ball in three dimensions. -/// -/// Implemented via rejection sampling. -/// -/// For a distribution that samples only from the surface of the unit ball, -/// see [`UnitSphere`](crate::UnitSphere). -/// -/// For a similar distribution in two dimensions, see [`UnitDisc`](crate::UnitDisc). -/// -/// # Plot -/// -/// The following plot shows the unit ball in three dimensions. -/// This distribution samples individual points from the entire volume -/// of the ball. -/// -/// ![Unit ball](https://raw.githubusercontent.com/rust-random/charts/main/charts/unit_ball.svg) -/// -/// # Example -/// -/// ``` -/// use rand_distr::{UnitBall, Distribution}; -/// -/// let v: [f64; 3] = UnitBall.sample(&mut rand::rng()); -/// println!("{:?} is from the unit ball.", v) -/// ``` -#[derive(Clone, Copy, Debug)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct UnitBall; - -impl Distribution<[F; 3]> for UnitBall { - #[inline] - fn sample(&self, rng: &mut R) -> [F; 3] { - let uniform = Uniform::new(F::from(-1.).unwrap(), F::from(1.).unwrap()).unwrap(); - let mut x1; - let mut x2; - let mut x3; - loop { - x1 = uniform.sample(rng); - x2 = uniform.sample(rng); - x3 = uniform.sample(rng); - if x1 * x1 + x2 * x2 + x3 * x3 <= F::from(1.).unwrap() { - break; - } - } - [x1, x2, x3] - } -} diff --git a/rand_distr/src/unit_circle.rs b/rand_distr/src/unit_circle.rs deleted file mode 100644 index d25d829f5a5..00000000000 --- a/rand_distr/src/unit_circle.rs +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -use crate::{uniform::SampleUniform, Distribution, Uniform}; -use num_traits::Float; -use rand::Rng; - -/// Samples uniformly from the circumference of the unit circle in two dimensions. -/// -/// Implemented via a method by von Neumann[^1]. -/// -/// For a distribution that also samples from the interior of the unit circle, -/// see [`UnitDisc`](crate::UnitDisc). -/// -/// For a similar distribution in three dimensions, see [`UnitSphere`](crate::UnitSphere). -/// -/// # Plot -/// -/// The following plot shows the unit circle. -/// -/// ![Unit circle](https://raw.githubusercontent.com/rust-random/charts/main/charts/unit_circle.svg) -/// -/// # Example -/// -/// ``` -/// use rand_distr::{UnitCircle, Distribution}; -/// -/// let v: [f64; 2] = UnitCircle.sample(&mut rand::rng()); -/// println!("{:?} is from the unit circle.", v) -/// ``` -/// -/// [^1]: von Neumann, J. (1951) [*Various Techniques Used in Connection with -/// Random Digits.*](https://mcnp.lanl.gov/pdf_files/nbs_vonneumann.pdf) -/// NBS Appl. Math. Ser., No. 12. Washington, DC: U.S. Government Printing -/// Office, pp. 36-38. -#[derive(Clone, Copy, Debug)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct UnitCircle; - -impl Distribution<[F; 2]> for UnitCircle { - #[inline] - fn sample(&self, rng: &mut R) -> [F; 2] { - let uniform = Uniform::new(F::from(-1.).unwrap(), F::from(1.).unwrap()).unwrap(); - let mut x1; - let mut x2; - let mut sum; - loop { - x1 = uniform.sample(rng); - x2 = uniform.sample(rng); - sum = x1 * x1 + x2 * x2; - if sum < F::from(1.).unwrap() { - break; - } - } - let diff = x1 * x1 - x2 * x2; - [diff / sum, F::from(2.).unwrap() * x1 * x2 / sum] - } -} - -#[cfg(test)] -mod tests { - use super::UnitCircle; - use crate::Distribution; - - #[test] - fn norm() { - let mut rng = crate::test::rng(1); - for _ in 0..1000 { - let x: [f64; 2] = UnitCircle.sample(&mut rng); - assert_almost_eq!(x[0] * x[0] + x[1] * x[1], 1., 1e-15); - } - } -} diff --git a/rand_distr/src/unit_disc.rs b/rand_distr/src/unit_disc.rs deleted file mode 100644 index c95fd1d6c83..00000000000 --- a/rand_distr/src/unit_disc.rs +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright 2019 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -use crate::{uniform::SampleUniform, Distribution, Uniform}; -use num_traits::Float; -use rand::Rng; - -/// Samples uniformly from the unit disc in two dimensions. -/// -/// Implemented via rejection sampling. -/// -/// For a distribution that samples only from the circumference of the unit disc, -/// see [`UnitCircle`](crate::UnitCircle). -/// -/// For a similar distribution in three dimensions, see [`UnitBall`](crate::UnitBall). -/// -/// # Plot -/// -/// The following plot shows the unit disc. -/// This distribution samples individual points from the entire area of the disc. -/// -/// ![Unit disc](https://raw.githubusercontent.com/rust-random/charts/main/charts/unit_disc.svg) -/// -/// # Example -/// -/// ``` -/// use rand_distr::{UnitDisc, Distribution}; -/// -/// let v: [f64; 2] = UnitDisc.sample(&mut rand::rng()); -/// println!("{:?} is from the unit Disc.", v) -/// ``` -#[derive(Clone, Copy, Debug)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct UnitDisc; - -impl Distribution<[F; 2]> for UnitDisc { - #[inline] - fn sample(&self, rng: &mut R) -> [F; 2] { - let uniform = Uniform::new(F::from(-1.).unwrap(), F::from(1.).unwrap()).unwrap(); - let mut x1; - let mut x2; - loop { - x1 = uniform.sample(rng); - x2 = uniform.sample(rng); - if x1 * x1 + x2 * x2 <= F::from(1.).unwrap() { - break; - } - } - [x1, x2] - } -} diff --git a/rand_distr/src/unit_sphere.rs b/rand_distr/src/unit_sphere.rs deleted file mode 100644 index 1d531924efb..00000000000 --- a/rand_distr/src/unit_sphere.rs +++ /dev/null @@ -1,79 +0,0 @@ -// Copyright 2018-2019 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -use crate::{uniform::SampleUniform, Distribution, Uniform}; -use num_traits::Float; -use rand::Rng; - -/// Samples uniformly from the surface of the unit sphere in three dimensions. -/// -/// Implemented via a method by Marsaglia[^1]. -/// -/// For a distribution that also samples from the interior of the sphere, -/// see [`UnitBall`](crate::UnitBall). -/// -/// For a similar distribution in two dimensions, see [`UnitCircle`](crate::UnitCircle). -/// -/// # Plot -/// -/// The following plot shows the unit sphere as a wireframe. -/// The wireframe is meant to illustrate that this distribution samples -/// from the surface of the sphere only, not from the interior. -/// -/// ![Unit sphere](https://raw.githubusercontent.com/rust-random/charts/main/charts/unit_sphere.svg) -/// -/// # Example -/// -/// ``` -/// use rand_distr::{UnitSphere, Distribution}; -/// -/// let v: [f64; 3] = UnitSphere.sample(&mut rand::rng()); -/// println!("{:?} is from the unit sphere surface.", v) -/// ``` -/// -/// [^1]: Marsaglia, George (1972). [*Choosing a Point from the Surface of a -/// Sphere.*](https://doi.org/10.1214/aoms/1177692644) -/// Ann. Math. Statist. 43, no. 2, 645--646. -#[derive(Clone, Copy, Debug)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct UnitSphere; - -impl Distribution<[F; 3]> for UnitSphere { - #[inline] - fn sample(&self, rng: &mut R) -> [F; 3] { - let uniform = Uniform::new(F::from(-1.).unwrap(), F::from(1.).unwrap()).unwrap(); - loop { - let (x1, x2) = (uniform.sample(rng), uniform.sample(rng)); - let sum = x1 * x1 + x2 * x2; - if sum >= F::from(1.).unwrap() { - continue; - } - let factor = F::from(2.).unwrap() * (F::one() - sum).sqrt(); - return [ - x1 * factor, - x2 * factor, - F::from(1.).unwrap() - F::from(2.).unwrap() * sum, - ]; - } - } -} - -#[cfg(test)] -mod tests { - use super::UnitSphere; - use crate::Distribution; - - #[test] - fn norm() { - let mut rng = crate::test::rng(1); - for _ in 0..1000 { - let x: [f64; 3] = UnitSphere.sample(&mut rng); - assert_almost_eq!(x[0] * x[0] + x[1] * x[1] + x[2] * x[2], 1., 1e-15); - } - } -} diff --git a/rand_distr/src/utils.rs b/rand_distr/src/utils.rs deleted file mode 100644 index f0cf2a1005a..00000000000 --- a/rand_distr/src/utils.rs +++ /dev/null @@ -1,118 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! Math helper functions - -use crate::ziggurat_tables; -use num_traits::Float; -use rand::distr::hidden_export::IntoFloat; -use rand::Rng; - -/// Calculates ln(gamma(x)) (natural logarithm of the gamma -/// function) using the Lanczos approximation. -/// -/// The approximation expresses the gamma function as: -/// `gamma(z+1) = sqrt(2*pi)*(z+g+0.5)^(z+0.5)*exp(-z-g-0.5)*Ag(z)` -/// `g` is an arbitrary constant; we use the approximation with `g=5`. -/// -/// Noting that `gamma(z+1) = z*gamma(z)` and applying `ln` to both sides: -/// `ln(gamma(z)) = (z+0.5)*ln(z+g+0.5)-(z+g+0.5) + ln(sqrt(2*pi)*Ag(z)/z)` -/// -/// `Ag(z)` is an infinite series with coefficients that can be calculated -/// ahead of time - we use just the first 6 terms, which is good enough -/// for most purposes. -pub(crate) fn log_gamma(x: F) -> F { - // precalculated 6 coefficients for the first 6 terms of the series - let coefficients: [F; 6] = [ - F::from(76.18009172947146).unwrap(), - F::from(-86.50532032941677).unwrap(), - F::from(24.01409824083091).unwrap(), - F::from(-1.231739572450155).unwrap(), - F::from(0.1208650973866179e-2).unwrap(), - F::from(-0.5395239384953e-5).unwrap(), - ]; - - // (x+0.5)*ln(x+g+0.5)-(x+g+0.5) - let tmp = x + F::from(5.5).unwrap(); - let log = (x + F::from(0.5).unwrap()) * tmp.ln() - tmp; - - // the first few terms of the series for Ag(x) - let mut a = F::from(1.000000000190015).unwrap(); - let mut denom = x; - for &coeff in &coefficients { - denom = denom + F::one(); - a = a + (coeff / denom); - } - - // get everything together - // a is Ag(x) - // 2.5066... is sqrt(2pi) - log + (F::from(2.5066282746310005).unwrap() * a / x).ln() -} - -/// Sample a random number using the Ziggurat method (specifically the -/// ZIGNOR variant from Doornik 2005). Most of the arguments are -/// directly from the paper: -/// -/// * `rng`: source of randomness -/// * `symmetric`: whether this is a symmetric distribution, or one-sided with P(x < 0) = 0. -/// * `X`: the $x_i$ abscissae. -/// * `F`: precomputed values of the PDF at the $x_i$, (i.e. $f(x_i)$) -/// * `F_DIFF`: precomputed values of $f(x_i) - f(x_{i+1})$ -/// * `pdf`: the probability density function -/// * `zero_case`: manual sampling from the tail when we chose the -/// bottom box (i.e. i == 0) -#[inline(always)] // Forced inlining improves the perf by 25-50% -pub(crate) fn ziggurat( - rng: &mut R, - symmetric: bool, - x_tab: ziggurat_tables::ZigTable, - f_tab: ziggurat_tables::ZigTable, - mut pdf: P, - mut zero_case: Z, -) -> f64 -where - P: FnMut(f64) -> f64, - Z: FnMut(&mut R, f64) -> f64, -{ - loop { - // As an optimisation we re-implement the conversion to a f64. - // From the remaining 12 most significant bits we use 8 to construct `i`. - // This saves us generating a whole extra random number, while the added - // precision of using 64 bits for f64 does not buy us much. - let bits = rng.next_u64(); - let i = bits as usize & 0xff; - - let u = if symmetric { - // Convert to a value in the range [2,4) and subtract to get [-1,1) - // We can't convert to an open range directly, that would require - // subtracting `3.0 - EPSILON`, which is not representable. - // It is possible with an extra step, but an open range does not - // seem necessary for the ziggurat algorithm anyway. - (bits >> 12).into_float_with_exponent(1) - 3.0 - } else { - // Convert to a value in the range [1,2) and subtract to get (0,1) - (bits >> 12).into_float_with_exponent(0) - (1.0 - f64::EPSILON / 2.0) - }; - let x = u * x_tab[i]; - - let test_x = if symmetric { x.abs() } else { x }; - - // algebraically equivalent to |u| < x_tab[i+1]/x_tab[i] (or u < x_tab[i+1]/x_tab[i]) - if test_x < x_tab[i + 1] { - return x; - } - if i == 0 { - return zero_case(rng, u); - } - // algebraically equivalent to f1 + DRanU()*(f0 - f1) < 1 - if f_tab[i + 1] + (f_tab[i] - f_tab[i + 1]) * rng.random::() < pdf(x) { - return x; - } - } -} diff --git a/rand_distr/src/weibull.rs b/rand_distr/src/weibull.rs deleted file mode 100644 index 1a9faf46c22..00000000000 --- a/rand_distr/src/weibull.rs +++ /dev/null @@ -1,166 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! The Weibull distribution `Weibull(λ, k)` - -use crate::{Distribution, OpenClosed01}; -use core::fmt; -use num_traits::Float; -use rand::Rng; - -/// The [Weibull distribution](https://en.wikipedia.org/wiki/Weibull_distribution) `Weibull(λ, k)`. -/// -/// This is a family of continuous probability distributions with -/// scale parameter `λ` (`lambda`) and shape parameter `k`. It is used -/// to model reliability data, life data, and accelerated life testing data. -/// -/// # Density function -/// -/// `f(x; λ, k) = (k / λ) * (x / λ)^(k - 1) * exp(-(x / λ)^k)` for `x >= 0`. -/// -/// # Plot -/// -/// The following plot shows the Weibull distribution with various values of `λ` and `k`. -/// -/// ![Weibull distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/weibull.svg) -/// -/// # Example -/// ``` -/// use rand::prelude::*; -/// use rand_distr::Weibull; -/// -/// let val: f64 = rand::rng().sample(Weibull::new(1., 10.).unwrap()); -/// println!("{}", val); -/// ``` -/// -/// # Numerics -/// -/// For small `k` like `< 0.005`, even with `f64` a significant number of samples will be so small that they underflow to `0.0` -/// or so big they overflow to `inf`. This is a limitation of the floating point representation and not specific to this implementation. -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct Weibull -where - F: Float, - OpenClosed01: Distribution, -{ - inv_shape: F, - scale: F, -} - -/// Error type returned from [`Weibull::new`]. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum Error { - /// `scale <= 0` or `nan`. - ScaleTooSmall, - /// `shape <= 0` or `nan`. - ShapeTooSmall, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Error::ScaleTooSmall => "scale is not positive in Weibull distribution", - Error::ShapeTooSmall => "shape is not positive in Weibull distribution", - }) - } -} - -#[cfg(feature = "std")] -impl std::error::Error for Error {} - -impl Weibull -where - F: Float, - OpenClosed01: Distribution, -{ - /// Construct a new `Weibull` distribution with given `scale` and `shape`. - pub fn new(scale: F, shape: F) -> Result, Error> { - if !(scale > F::zero()) { - return Err(Error::ScaleTooSmall); - } - if !(shape > F::zero()) { - return Err(Error::ShapeTooSmall); - } - Ok(Weibull { - inv_shape: F::from(1.).unwrap() / shape, - scale, - }) - } -} - -impl Distribution for Weibull -where - F: Float, - OpenClosed01: Distribution, -{ - fn sample(&self, rng: &mut R) -> F { - let x: F = rng.sample(OpenClosed01); - self.scale * (-x.ln()).powf(self.inv_shape) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - #[should_panic] - fn invalid() { - Weibull::new(0., 0.).unwrap(); - } - - #[test] - fn sample() { - let scale = 1.0; - let shape = 2.0; - let d = Weibull::new(scale, shape).unwrap(); - let mut rng = crate::test::rng(1); - for _ in 0..1000 { - let r = d.sample(&mut rng); - assert!(r >= 0.); - } - } - - #[test] - fn value_stability() { - fn test_samples>( - distr: D, - zero: F, - expected: &[F], - ) { - let mut rng = crate::test::rng(213); - let mut buf = [zero; 4]; - for x in &mut buf { - *x = rng.sample(&distr); - } - assert_eq!(buf, expected); - } - - test_samples( - Weibull::new(1.0, 1.0).unwrap(), - 0f32, - &[0.041495778, 0.7531094, 1.4189332, 0.38386202], - ); - test_samples( - Weibull::new(2.0, 0.5).unwrap(), - 0f64, - &[ - 1.1343478702739669, - 0.29470010050655226, - 0.7556151370284702, - 7.877212340241561, - ], - ); - } - - #[test] - fn weibull_distributions_can_be_compared() { - assert_eq!(Weibull::new(1.0, 2.0), Weibull::new(1.0, 2.0)); - } -} diff --git a/rand_distr/src/weighted/mod.rs b/rand_distr/src/weighted/mod.rs deleted file mode 100644 index 1c54e48e69c..00000000000 --- a/rand_distr/src/weighted/mod.rs +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! Weighted (index) sampling -//! -//! This module is a superset of [`rand::distr::weighted`]. -//! -//! Multiple implementations of weighted index sampling are provided: -//! -//! - [`WeightedIndex`] (a re-export from [`rand`]) supports fast construction -//! and `O(log N)` sampling over `N` weights. -//! It also supports updating weights with `O(N)` time. -//! - [`WeightedAliasIndex`] supports `O(1)` sampling, but due to high -//! construction time many samples are required to outperform [`WeightedIndex`]. -//! - [`WeightedTreeIndex`] supports `O(log N)` sampling and -//! update/insertion/removal of weights with `O(log N)` time. - -mod weighted_alias; -mod weighted_tree; - -pub use rand::distr::weighted::*; -pub use weighted_alias::*; -pub use weighted_tree::*; diff --git a/rand_distr/src/weighted/weighted_alias.rs b/rand_distr/src/weighted/weighted_alias.rs deleted file mode 100644 index 862f2b70b33..00000000000 --- a/rand_distr/src/weighted/weighted_alias.rs +++ /dev/null @@ -1,539 +0,0 @@ -// Copyright 2019 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! This module contains an implementation of alias method for sampling random -//! indices with probabilities proportional to a collection of weights. - -use super::Error; -use crate::{uniform::SampleUniform, Distribution, Uniform}; -use alloc::{boxed::Box, vec, vec::Vec}; -use core::fmt; -use core::iter::Sum; -use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign}; -use rand::Rng; -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; - -/// A distribution using weighted sampling to pick a discretely selected item. -/// -/// Sampling a [`WeightedAliasIndex`] distribution returns the index of a randomly -/// selected element from the vector used to create the [`WeightedAliasIndex`]. -/// The chance of a given element being picked is proportional to the value of -/// the element. The weights can have any type `W` for which a implementation of -/// [`AliasableWeight`] exists. -/// -/// # Performance -/// -/// Given that `n` is the number of items in the vector used to create an -/// [`WeightedAliasIndex`], it will require `O(n)` amount of memory. -/// More specifically it takes up some constant amount of memory plus -/// the vector used to create it and a [`Vec`] with capacity `n`. -/// -/// Time complexity for the creation of a [`WeightedAliasIndex`] is `O(n)`. -/// Sampling is `O(1)`, it makes a call to [`Uniform::sample`] and a call -/// to [`Uniform::sample`]. -/// -/// # Example -/// -/// ``` -/// use rand_distr::weighted::WeightedAliasIndex; -/// use rand::prelude::*; -/// -/// let choices = vec!['a', 'b', 'c']; -/// let weights = vec![2, 1, 1]; -/// let dist = WeightedAliasIndex::new(weights).unwrap(); -/// let mut rng = rand::rng(); -/// for _ in 0..100 { -/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' -/// println!("{}", choices[dist.sample(&mut rng)]); -/// } -/// -/// let items = [('a', 0), ('b', 3), ('c', 7)]; -/// let dist2 = WeightedAliasIndex::new(items.iter().map(|item| item.1).collect()).unwrap(); -/// for _ in 0..100 { -/// // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c' -/// println!("{}", items[dist2.sample(&mut rng)].0); -/// } -/// ``` -/// -/// [`WeightedAliasIndex`]: WeightedAliasIndex -/// [`Vec`]: Vec -/// [`Uniform::sample`]: Distribution::sample -/// [`Uniform::sample`]: Distribution::sample -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[cfg_attr( - feature = "serde", - serde(bound(serialize = "W: Serialize, W::Sampler: Serialize")) -)] -#[cfg_attr( - feature = "serde", - serde(bound(deserialize = "W: Deserialize<'de>, W::Sampler: Deserialize<'de>")) -)] -pub struct WeightedAliasIndex { - aliases: Box<[u32]>, - no_alias_odds: Box<[W]>, - uniform_index: Uniform, - uniform_within_weight_sum: Uniform, -} - -impl WeightedAliasIndex { - /// Creates a new [`WeightedAliasIndex`]. - /// - /// Error cases: - /// - [`Error::InvalidInput`] when `weights.len()` is zero or greater than `u32::MAX`. - /// - [`Error::InvalidWeight`] when a weight is not-a-number, - /// negative or greater than `max = W::MAX / weights.len()`. - /// - [`Error::InsufficientNonZero`] when the sum of all weights is zero. - pub fn new(weights: Vec) -> Result { - let n = weights.len(); - if n == 0 || n > u32::MAX as usize { - return Err(Error::InvalidInput); - } - let n = n as u32; - - let max_weight_size = W::try_from_u32_lossy(n) - .map(|n| W::MAX / n) - .unwrap_or(W::ZERO); - if !weights - .iter() - .all(|&w| W::ZERO <= w && w <= max_weight_size) - { - return Err(Error::InvalidWeight); - } - - // The sum of weights will represent 100% of no alias odds. - let weight_sum = AliasableWeight::sum(weights.as_slice()); - // Prevent floating point overflow due to rounding errors. - let weight_sum = if weight_sum > W::MAX { - W::MAX - } else { - weight_sum - }; - if weight_sum == W::ZERO { - return Err(Error::InsufficientNonZero); - } - - // `weight_sum` would have been zero if `try_from_lossy` causes an error here. - let n_converted = W::try_from_u32_lossy(n).unwrap(); - - let mut no_alias_odds = weights.into_boxed_slice(); - for odds in no_alias_odds.iter_mut() { - *odds *= n_converted; - // Prevent floating point overflow due to rounding errors. - *odds = if *odds > W::MAX { W::MAX } else { *odds }; - } - - /// This struct is designed to contain three data structures at once, - /// sharing the same memory. More precisely it contains two linked lists - /// and an alias map, which will be the output of this method. To keep - /// the three data structures from getting in each other's way, it must - /// be ensured that a single index is only ever in one of them at the - /// same time. - struct Aliases { - aliases: Box<[u32]>, - smalls_head: u32, - bigs_head: u32, - } - - impl Aliases { - fn new(size: u32) -> Self { - Aliases { - aliases: vec![0; size as usize].into_boxed_slice(), - smalls_head: u32::MAX, - bigs_head: u32::MAX, - } - } - - fn push_small(&mut self, idx: u32) { - self.aliases[idx as usize] = self.smalls_head; - self.smalls_head = idx; - } - - fn push_big(&mut self, idx: u32) { - self.aliases[idx as usize] = self.bigs_head; - self.bigs_head = idx; - } - - fn pop_small(&mut self) -> u32 { - let popped = self.smalls_head; - self.smalls_head = self.aliases[popped as usize]; - popped - } - - fn pop_big(&mut self) -> u32 { - let popped = self.bigs_head; - self.bigs_head = self.aliases[popped as usize]; - popped - } - - fn smalls_is_empty(&self) -> bool { - self.smalls_head == u32::MAX - } - - fn bigs_is_empty(&self) -> bool { - self.bigs_head == u32::MAX - } - - fn set_alias(&mut self, idx: u32, alias: u32) { - self.aliases[idx as usize] = alias; - } - } - - let mut aliases = Aliases::new(n); - - // Split indices into those with small weights and those with big weights. - for (index, &odds) in no_alias_odds.iter().enumerate() { - if odds < weight_sum { - aliases.push_small(index as u32); - } else { - aliases.push_big(index as u32); - } - } - - // Build the alias map by finding an alias with big weight for each index with - // small weight. - while !aliases.smalls_is_empty() && !aliases.bigs_is_empty() { - let s = aliases.pop_small(); - let b = aliases.pop_big(); - - aliases.set_alias(s, b); - no_alias_odds[b as usize] = - no_alias_odds[b as usize] - weight_sum + no_alias_odds[s as usize]; - - if no_alias_odds[b as usize] < weight_sum { - aliases.push_small(b); - } else { - aliases.push_big(b); - } - } - - // The remaining indices should have no alias odds of about 100%. This is due to - // numeric accuracy. Otherwise they would be exactly 100%. - while !aliases.smalls_is_empty() { - no_alias_odds[aliases.pop_small() as usize] = weight_sum; - } - while !aliases.bigs_is_empty() { - no_alias_odds[aliases.pop_big() as usize] = weight_sum; - } - - // Prepare distributions for sampling. Creating them beforehand improves - // sampling performance. - let uniform_index = Uniform::new(0, n).unwrap(); - let uniform_within_weight_sum = Uniform::new(W::ZERO, weight_sum).unwrap(); - - Ok(Self { - aliases: aliases.aliases, - no_alias_odds, - uniform_index, - uniform_within_weight_sum, - }) - } -} - -impl Distribution for WeightedAliasIndex { - fn sample(&self, rng: &mut R) -> usize { - let candidate = rng.sample(self.uniform_index); - if rng.sample(&self.uniform_within_weight_sum) < self.no_alias_odds[candidate as usize] { - candidate as usize - } else { - self.aliases[candidate as usize] as usize - } - } -} - -impl fmt::Debug for WeightedAliasIndex -where - W: fmt::Debug, - Uniform: fmt::Debug, -{ - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("WeightedAliasIndex") - .field("aliases", &self.aliases) - .field("no_alias_odds", &self.no_alias_odds) - .field("uniform_index", &self.uniform_index) - .field("uniform_within_weight_sum", &self.uniform_within_weight_sum) - .finish() - } -} - -impl Clone for WeightedAliasIndex -where - Uniform: Clone, -{ - fn clone(&self) -> Self { - Self { - aliases: self.aliases.clone(), - no_alias_odds: self.no_alias_odds.clone(), - uniform_index: self.uniform_index, - uniform_within_weight_sum: self.uniform_within_weight_sum.clone(), - } - } -} - -/// Weight bound for [`WeightedAliasIndex`] -/// -/// Currently no guarantees on the correctness of [`WeightedAliasIndex`] are -/// given for custom implementations of this trait. -pub trait AliasableWeight: - Sized - + Copy - + SampleUniform - + PartialOrd - + Add - + AddAssign - + Sub - + SubAssign - + Mul - + MulAssign - + Div - + DivAssign - + Sum -{ - /// Maximum number representable by `Self`. - const MAX: Self; - - /// Element of `Self` equivalent to 0. - const ZERO: Self; - - /// Produce an instance of `Self` from a `u32` value, or return `None` if - /// out of range. Loss of precision (where `Self` is a floating point type) - /// is acceptable. - fn try_from_u32_lossy(n: u32) -> Option; - - /// Sums all values in slice `values`. - fn sum(values: &[Self]) -> Self { - values.iter().copied().sum() - } -} - -macro_rules! impl_weight_for_float { - ($T: ident) => { - impl AliasableWeight for $T { - const MAX: Self = $T::MAX; - const ZERO: Self = 0.0; - - fn try_from_u32_lossy(n: u32) -> Option { - Some(n as $T) - } - - fn sum(values: &[Self]) -> Self { - pairwise_sum(values) - } - } - }; -} - -/// In comparison to naive accumulation, the pairwise sum algorithm reduces -/// rounding errors when there are many floating point values. -fn pairwise_sum(values: &[T]) -> T { - if values.len() <= 32 { - values.iter().copied().sum() - } else { - let mid = values.len() / 2; - let (a, b) = values.split_at(mid); - pairwise_sum(a) + pairwise_sum(b) - } -} - -macro_rules! impl_weight_for_int { - ($T: ident) => { - impl AliasableWeight for $T { - const MAX: Self = $T::MAX; - const ZERO: Self = 0; - - fn try_from_u32_lossy(n: u32) -> Option { - let n_converted = n as Self; - if n_converted >= Self::ZERO && n_converted as u32 == n { - Some(n_converted) - } else { - None - } - } - } - }; -} - -impl_weight_for_float!(f64); -impl_weight_for_float!(f32); -impl_weight_for_int!(usize); -impl_weight_for_int!(u128); -impl_weight_for_int!(u64); -impl_weight_for_int!(u32); -impl_weight_for_int!(u16); -impl_weight_for_int!(u8); -impl_weight_for_int!(i128); -impl_weight_for_int!(i64); -impl_weight_for_int!(i32); -impl_weight_for_int!(i16); -impl_weight_for_int!(i8); - -#[cfg(test)] -mod test { - use super::*; - - #[test] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_weighted_index_f32() { - test_weighted_index(f32::into); - - // Floating point special cases - assert_eq!( - WeightedAliasIndex::new(vec![f32::INFINITY]).unwrap_err(), - Error::InvalidWeight - ); - assert_eq!( - WeightedAliasIndex::new(vec![-0_f32]).unwrap_err(), - Error::InsufficientNonZero - ); - assert_eq!( - WeightedAliasIndex::new(vec![-1_f32]).unwrap_err(), - Error::InvalidWeight - ); - assert_eq!( - WeightedAliasIndex::new(vec![f32::NEG_INFINITY]).unwrap_err(), - Error::InvalidWeight - ); - assert_eq!( - WeightedAliasIndex::new(vec![f32::NAN]).unwrap_err(), - Error::InvalidWeight - ); - } - - #[test] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_weighted_index_u128() { - test_weighted_index(|x: u128| x as f64); - } - - #[test] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_weighted_index_i128() { - test_weighted_index(|x: i128| x as f64); - - // Signed integer special cases - assert_eq!( - WeightedAliasIndex::new(vec![-1_i128]).unwrap_err(), - Error::InvalidWeight - ); - assert_eq!( - WeightedAliasIndex::new(vec![i128::MIN]).unwrap_err(), - Error::InvalidWeight - ); - } - - #[test] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_weighted_index_u8() { - test_weighted_index(u8::into); - } - - #[test] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_weighted_index_i8() { - test_weighted_index(i8::into); - - // Signed integer special cases - assert_eq!( - WeightedAliasIndex::new(vec![-1_i8]).unwrap_err(), - Error::InvalidWeight - ); - assert_eq!( - WeightedAliasIndex::new(vec![i8::MIN]).unwrap_err(), - Error::InvalidWeight - ); - } - - fn test_weighted_index f64>(w_to_f64: F) - where - WeightedAliasIndex: fmt::Debug, - { - const NUM_WEIGHTS: u32 = 10; - const ZERO_WEIGHT_INDEX: u32 = 3; - const NUM_SAMPLES: u32 = 15000; - let mut rng = crate::test::rng(0x9c9fa0b0580a7031); - - let weights = { - let mut weights = Vec::with_capacity(NUM_WEIGHTS as usize); - let random_weight_distribution = Uniform::new_inclusive( - W::ZERO, - W::MAX / W::try_from_u32_lossy(NUM_WEIGHTS).unwrap(), - ) - .unwrap(); - for _ in 0..NUM_WEIGHTS { - weights.push(rng.sample(&random_weight_distribution)); - } - weights[ZERO_WEIGHT_INDEX as usize] = W::ZERO; - weights - }; - let weight_sum = weights.iter().copied().sum::(); - let expected_counts = weights - .iter() - .map(|&w| w_to_f64(w) / w_to_f64(weight_sum) * NUM_SAMPLES as f64) - .collect::>(); - let weight_distribution = WeightedAliasIndex::new(weights).unwrap(); - - let mut counts = vec![0; NUM_WEIGHTS as usize]; - for _ in 0..NUM_SAMPLES { - counts[rng.sample(&weight_distribution)] += 1; - } - - assert_eq!(counts[ZERO_WEIGHT_INDEX as usize], 0); - for (count, expected_count) in counts.into_iter().zip(expected_counts) { - let difference = (count as f64 - expected_count).abs(); - let max_allowed_difference = NUM_SAMPLES as f64 / NUM_WEIGHTS as f64 * 0.1; - assert!(difference <= max_allowed_difference); - } - - assert_eq!( - WeightedAliasIndex::::new(vec![]).unwrap_err(), - Error::InvalidInput - ); - assert_eq!( - WeightedAliasIndex::new(vec![W::ZERO]).unwrap_err(), - Error::InsufficientNonZero - ); - assert_eq!( - WeightedAliasIndex::new(vec![W::MAX, W::MAX]).unwrap_err(), - Error::InvalidWeight - ); - } - - #[test] - fn value_stability() { - fn test_samples( - weights: Vec, - buf: &mut [usize], - expected: &[usize], - ) { - assert_eq!(buf.len(), expected.len()); - let distr = WeightedAliasIndex::new(weights).unwrap(); - let mut rng = crate::test::rng(0x9c9fa0b0580a7031); - for r in buf.iter_mut() { - *r = rng.sample(&distr); - } - assert_eq!(buf, expected); - } - - let mut buf = [0; 10]; - test_samples( - vec![1i32, 1, 1, 1, 1, 1, 1, 1, 1], - &mut buf, - &[6, 5, 7, 5, 8, 7, 6, 2, 3, 7], - ); - test_samples( - vec![0.7f32, 0.1, 0.1, 0.1], - &mut buf, - &[2, 0, 0, 0, 0, 0, 0, 0, 1, 3], - ); - test_samples( - vec![1.0f64, 0.999, 0.998, 0.997], - &mut buf, - &[2, 1, 2, 3, 2, 1, 3, 2, 1, 1], - ); - } -} diff --git a/rand_distr/src/weighted/weighted_tree.rs b/rand_distr/src/weighted/weighted_tree.rs deleted file mode 100644 index dd315aa5f8f..00000000000 --- a/rand_distr/src/weighted/weighted_tree.rs +++ /dev/null @@ -1,390 +0,0 @@ -// Copyright 2024 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! This module contains an implementation of a tree structure for sampling random -//! indices with probabilities proportional to a collection of weights. - -use core::ops::SubAssign; - -use super::{Error, Weight}; -use crate::Distribution; -use alloc::vec::Vec; -use rand::distr::uniform::{SampleBorrow, SampleUniform}; -use rand::Rng; -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; - -/// A distribution using weighted sampling to pick a discretely selected item. -/// -/// Sampling a [`WeightedTreeIndex`] distribution returns the index of a randomly -/// selected element from the vector used to create the [`WeightedTreeIndex`]. -/// The chance of a given element being picked is proportional to the value of -/// the element. The weights can have any type `W` for which an implementation of -/// [`Weight`] exists. -/// -/// # Key differences -/// -/// The main distinction between [`WeightedTreeIndex`] and [`WeightedIndex`] -/// lies in the internal representation of weights. In [`WeightedTreeIndex`], -/// weights are structured as a tree, which is optimized for frequent updates of the weights. -/// -/// # Caution: Floating point types -/// -/// When utilizing [`WeightedTreeIndex`] with floating point types (such as f32 or f64), -/// exercise caution due to the inherent nature of floating point arithmetic. Floating point types -/// are susceptible to numerical rounding errors. Since operations on floating point weights are -/// repeated numerous times, rounding errors can accumulate, potentially leading to noticeable -/// deviations from the expected behavior. -/// -/// Ideally, use fixed point or integer types whenever possible. -/// -/// # Performance -/// -/// A [`WeightedTreeIndex`] with `n` elements requires `O(n)` memory. -/// -/// Time complexity for the operations of a [`WeightedTreeIndex`] are: -/// * Constructing: Building the initial tree from an iterator of weights takes `O(n)` time. -/// * Sampling: Choosing an index (traversing down the tree) requires `O(log n)` time. -/// * Weight Update: Modifying a weight (traversing up the tree), requires `O(log n)` time. -/// * Weight Addition (Pushing): Adding a new weight (traversing up the tree), requires `O(log n)` time. -/// * Weight Removal (Popping): Removing a weight (traversing up the tree), requires `O(log n)` time. -/// -/// # Example -/// -/// ``` -/// use rand_distr::weighted::WeightedTreeIndex; -/// use rand::prelude::*; -/// -/// let choices = vec!['a', 'b', 'c']; -/// let weights = vec![2, 0]; -/// let mut dist = WeightedTreeIndex::new(&weights).unwrap(); -/// dist.push(1).unwrap(); -/// dist.update(1, 1).unwrap(); -/// let mut rng = rand::rng(); -/// let mut samples = [0; 3]; -/// for _ in 0..100 { -/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' -/// let i = dist.sample(&mut rng); -/// samples[i] += 1; -/// } -/// println!("Results: {:?}", choices.iter().zip(samples.iter()).collect::>()); -/// ``` -/// -/// [`WeightedTreeIndex`]: WeightedTreeIndex -/// [`WeightedIndex`]: super::WeightedIndex -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[cfg_attr( - feature = "serde", - serde(bound(serialize = "W: Serialize, W::Sampler: Serialize")) -)] -#[cfg_attr( - feature = "serde", - serde(bound(deserialize = "W: Deserialize<'de>, W::Sampler: Deserialize<'de>")) -)] -#[derive(Clone, Default, Debug, PartialEq)] -pub struct WeightedTreeIndex< - W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign + Weight, -> { - subtotals: Vec, -} - -impl + Weight> - WeightedTreeIndex -{ - /// Creates a new [`WeightedTreeIndex`] from a slice of weights. - /// - /// Error cases: - /// - [`Error::InvalidWeight`] when a weight is not-a-number or negative. - /// - [`Error::Overflow`] when the sum of all weights overflows. - pub fn new(weights: I) -> Result - where - I: IntoIterator, - I::Item: SampleBorrow, - { - let mut subtotals: Vec = weights.into_iter().map(|x| x.borrow().clone()).collect(); - for weight in subtotals.iter() { - if !(*weight >= W::ZERO) { - return Err(Error::InvalidWeight); - } - } - let n = subtotals.len(); - for i in (1..n).rev() { - let w = subtotals[i].clone(); - let parent = (i - 1) / 2; - subtotals[parent] - .checked_add_assign(&w) - .map_err(|()| Error::Overflow)?; - } - Ok(Self { subtotals }) - } - - /// Returns `true` if the tree contains no weights. - pub fn is_empty(&self) -> bool { - self.subtotals.is_empty() - } - - /// Returns the number of weights. - pub fn len(&self) -> usize { - self.subtotals.len() - } - - /// Returns `true` if we can sample. - /// - /// This is the case if the total weight of the tree is greater than zero. - pub fn is_valid(&self) -> bool { - if let Some(weight) = self.subtotals.first() { - *weight > W::ZERO - } else { - false - } - } - - /// Gets the weight at an index. - pub fn get(&self, index: usize) -> W { - let left_index = 2 * index + 1; - let right_index = 2 * index + 2; - let mut w = self.subtotals[index].clone(); - w -= self.subtotal(left_index); - w -= self.subtotal(right_index); - w - } - - /// Removes the last weight and returns it, or [`None`] if it is empty. - pub fn pop(&mut self) -> Option { - self.subtotals.pop().map(|weight| { - let mut index = self.len(); - while index != 0 { - index = (index - 1) / 2; - self.subtotals[index] -= weight.clone(); - } - weight - }) - } - - /// Appends a new weight at the end. - /// - /// Error cases: - /// - [`Error::InvalidWeight`] when a weight is not-a-number or negative. - /// - [`Error::Overflow`] when the sum of all weights overflows. - pub fn push(&mut self, weight: W) -> Result<(), Error> { - if !(weight >= W::ZERO) { - return Err(Error::InvalidWeight); - } - if let Some(total) = self.subtotals.first() { - let mut total = total.clone(); - if total.checked_add_assign(&weight).is_err() { - return Err(Error::Overflow); - } - } - let mut index = self.len(); - self.subtotals.push(weight.clone()); - while index != 0 { - index = (index - 1) / 2; - self.subtotals[index].checked_add_assign(&weight).unwrap(); - } - Ok(()) - } - - /// Updates the weight at an index. - /// - /// Error cases: - /// - [`Error::InvalidWeight`] when a weight is not-a-number or negative. - /// - [`Error::Overflow`] when the sum of all weights overflows. - pub fn update(&mut self, mut index: usize, weight: W) -> Result<(), Error> { - if !(weight >= W::ZERO) { - return Err(Error::InvalidWeight); - } - let old_weight = self.get(index); - if weight > old_weight { - let mut difference = weight; - difference -= old_weight; - if let Some(total) = self.subtotals.first() { - let mut total = total.clone(); - if total.checked_add_assign(&difference).is_err() { - return Err(Error::Overflow); - } - } - self.subtotals[index] - .checked_add_assign(&difference) - .unwrap(); - while index != 0 { - index = (index - 1) / 2; - self.subtotals[index] - .checked_add_assign(&difference) - .unwrap(); - } - } else if weight < old_weight { - let mut difference = old_weight; - difference -= weight; - self.subtotals[index] -= difference.clone(); - while index != 0 { - index = (index - 1) / 2; - self.subtotals[index] -= difference.clone(); - } - } - Ok(()) - } - - fn subtotal(&self, index: usize) -> W { - if index < self.subtotals.len() { - self.subtotals[index].clone() - } else { - W::ZERO - } - } -} - -impl + Weight> - WeightedTreeIndex -{ - /// Samples a randomly selected index from the weighted distribution. - /// - /// Returns an error if there are no elements or all weights are zero. This - /// is unlike [`Distribution::sample`], which panics in those cases. - pub fn try_sample(&self, rng: &mut R) -> Result { - let total_weight = self.subtotals.first().cloned().unwrap_or(W::ZERO); - if total_weight == W::ZERO { - return Err(Error::InsufficientNonZero); - } - let mut target_weight = rng.random_range(W::ZERO..total_weight); - let mut index = 0; - loop { - // Maybe descend into the left sub tree. - let left_index = 2 * index + 1; - let left_subtotal = self.subtotal(left_index); - if target_weight < left_subtotal { - index = left_index; - continue; - } - target_weight -= left_subtotal; - - // Maybe descend into the right sub tree. - let right_index = 2 * index + 2; - let right_subtotal = self.subtotal(right_index); - if target_weight < right_subtotal { - index = right_index; - continue; - } - target_weight -= right_subtotal; - - // Otherwise we found the index with the target weight. - break; - } - assert!(target_weight >= W::ZERO); - assert!(target_weight < self.get(index)); - Ok(index) - } -} - -/// Samples a randomly selected index from the weighted distribution. -/// -/// Caution: This method panics if there are no elements or all weights are zero. However, -/// it is guaranteed that this method will not panic if a call to [`WeightedTreeIndex::is_valid`] -/// returns `true`. -impl + Weight> Distribution - for WeightedTreeIndex -{ - #[track_caller] - fn sample(&self, rng: &mut R) -> usize { - self.try_sample(rng).unwrap() - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_no_item_error() { - let mut rng = crate::test::rng(0x9c9fa0b0580a7031); - #[allow(clippy::needless_borrows_for_generic_args)] - let tree = WeightedTreeIndex::::new(&[]).unwrap(); - assert_eq!( - tree.try_sample(&mut rng).unwrap_err(), - Error::InsufficientNonZero - ); - } - - #[test] - fn test_overflow_error() { - assert_eq!(WeightedTreeIndex::new([i32::MAX, 2]), Err(Error::Overflow)); - let mut tree = WeightedTreeIndex::new([i32::MAX - 2, 1]).unwrap(); - assert_eq!(tree.push(3), Err(Error::Overflow)); - assert_eq!(tree.update(1, 4), Err(Error::Overflow)); - tree.update(1, 2).unwrap(); - } - - #[test] - fn test_all_weights_zero_error() { - let tree = WeightedTreeIndex::::new([0.0, 0.0]).unwrap(); - let mut rng = crate::test::rng(0x9c9fa0b0580a7031); - assert_eq!( - tree.try_sample(&mut rng).unwrap_err(), - Error::InsufficientNonZero - ); - } - - #[test] - fn test_invalid_weight_error() { - assert_eq!( - WeightedTreeIndex::::new([1, -1]).unwrap_err(), - Error::InvalidWeight - ); - #[allow(clippy::needless_borrows_for_generic_args)] - let mut tree = WeightedTreeIndex::::new(&[]).unwrap(); - assert_eq!(tree.push(-1).unwrap_err(), Error::InvalidWeight); - tree.push(1).unwrap(); - assert_eq!(tree.update(0, -1).unwrap_err(), Error::InvalidWeight); - } - - #[test] - fn test_tree_modifications() { - let mut tree = WeightedTreeIndex::new([9, 1, 2]).unwrap(); - tree.push(3).unwrap(); - tree.push(5).unwrap(); - tree.update(0, 0).unwrap(); - assert_eq!(tree.pop(), Some(5)); - let expected = WeightedTreeIndex::new([0, 1, 2, 3]).unwrap(); - assert_eq!(tree, expected); - } - - #[test] - #[allow(clippy::needless_range_loop)] - fn test_sample_counts_match_probabilities() { - let start = 1; - let end = 3; - let samples = 20; - let mut rng = crate::test::rng(0x9c9fa0b0580a7031); - let weights: Vec = (0..end).map(|_| rng.random()).collect(); - let mut tree = WeightedTreeIndex::new(weights).unwrap(); - let mut total_weight = 0.0; - let mut weights = alloc::vec![0.0; end]; - for i in 0..end { - tree.update(i, i as f64).unwrap(); - weights[i] = i as f64; - total_weight += i as f64; - } - for i in 0..start { - tree.update(i, 0.0).unwrap(); - weights[i] = 0.0; - total_weight -= i as f64; - } - let mut counts = alloc::vec![0_usize; end]; - for _ in 0..samples { - let i = tree.sample(&mut rng); - counts[i] += 1; - } - for i in 0..start { - assert_eq!(counts[i], 0); - } - for i in start..end { - let diff = counts[i] as f64 / samples as f64 - weights[i] / total_weight; - assert!(diff.abs() < 0.05); - } - } -} diff --git a/rand_distr/src/zeta.rs b/rand_distr/src/zeta.rs deleted file mode 100644 index f93f167d7c3..00000000000 --- a/rand_distr/src/zeta.rs +++ /dev/null @@ -1,203 +0,0 @@ -// Copyright 2021 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! The Zeta distribution. - -use crate::{Distribution, StandardUniform}; -use core::fmt; -use num_traits::Float; -use rand::{distr::OpenClosed01, Rng}; - -/// The [Zeta distribution](https://en.wikipedia.org/wiki/Zeta_distribution) `Zeta(s)`. -/// -/// The [Zeta distribution](https://en.wikipedia.org/wiki/Zeta_distribution) -/// is a discrete probability distribution with parameter `s`. -/// It is a special case of the [`Zipf`](crate::Zipf) distribution with `n = ∞`. -/// It is also known as the discrete Pareto, Riemann-Zeta, Zipf, or Zipf–Estoup distribution. -/// -/// # Density function -/// -/// `f(k) = k^(-s) / ζ(s)` for `k >= 1`, where `ζ` is the -/// [Riemann zeta function](https://en.wikipedia.org/wiki/Riemann_zeta_function). -/// -/// # Plot -/// -/// The following plot illustrates the zeta distribution for various values of `s`. -/// -/// ![Zeta distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/zeta.svg) -/// -/// # Example -/// ``` -/// use rand::prelude::*; -/// use rand_distr::Zeta; -/// -/// let val: f64 = rand::rng().sample(Zeta::new(1.5).unwrap()); -/// println!("{}", val); -/// ``` -/// -/// # Integer vs FP return type -/// -/// This implementation uses floating-point (FP) logic internally, which can -/// potentially generate very large samples (exceeding e.g. `u64::MAX`). -/// -/// It is *safe* to cast such results to an integer type using `as` -/// (e.g. `distr.sample(&mut rng) as u64`), since such casts are saturating -/// (e.g. `2f64.powi(64) as u64 == u64::MAX`). It is up to the user to -/// determine whether this potential loss of accuracy is acceptable -/// (this determination may depend on the distribution's parameters). -/// -/// # Notes -/// -/// The zeta distribution has no upper limit. Sampled values may be infinite. -/// In particular, a value of infinity might be returned for the following -/// reasons: -/// 1. it is the best representation in the type `F` of the actual sample. -/// 2. to prevent infinite loops for very small `s`. -/// -/// # Implementation details -/// -/// We are using the algorithm from -/// [Non-Uniform Random Variate Generation](https://doi.org/10.1007/978-1-4613-8643-8), -/// Section 6.1, page 551. -#[derive(Clone, Copy, Debug, PartialEq)] -pub struct Zeta -where - F: Float, - StandardUniform: Distribution, - OpenClosed01: Distribution, -{ - s_minus_1: F, - b: F, -} - -/// Error type returned from [`Zeta::new`]. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum Error { - /// `s <= 1` or `nan`. - STooSmall, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Error::STooSmall => "s <= 1 or is NaN in Zeta distribution", - }) - } -} - -#[cfg(feature = "std")] -impl std::error::Error for Error {} - -impl Zeta -where - F: Float, - StandardUniform: Distribution, - OpenClosed01: Distribution, -{ - /// Construct a new `Zeta` distribution with given `s` parameter. - #[inline] - pub fn new(s: F) -> Result, Error> { - if !(s > F::one()) { - return Err(Error::STooSmall); - } - let s_minus_1 = s - F::one(); - let two = F::one() + F::one(); - Ok(Zeta { - s_minus_1, - b: two.powf(s_minus_1), - }) - } -} - -impl Distribution for Zeta -where - F: Float, - StandardUniform: Distribution, - OpenClosed01: Distribution, -{ - #[inline] - fn sample(&self, rng: &mut R) -> F { - loop { - let u = rng.sample(OpenClosed01); - let x = u.powf(-F::one() / self.s_minus_1).floor(); - debug_assert!(x >= F::one()); - if x.is_infinite() { - // For sufficiently small `s`, `x` will always be infinite, - // which is rejected, resulting in an infinite loop. We avoid - // this by always returning infinity instead. - return x; - } - - let t = (F::one() + F::one() / x).powf(self.s_minus_1); - - let v = rng.sample(StandardUniform); - if v * x * (t - F::one()) * self.b <= t * (self.b - F::one()) { - return x; - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - fn test_samples>(distr: D, zero: F, expected: &[F]) { - let mut rng = crate::test::rng(213); - let mut buf = [zero; 4]; - for x in &mut buf { - *x = rng.sample(&distr); - } - assert_eq!(buf, expected); - } - - #[test] - #[should_panic] - fn zeta_invalid() { - Zeta::new(1.).unwrap(); - } - - #[test] - #[should_panic] - fn zeta_nan() { - Zeta::new(f64::NAN).unwrap(); - } - - #[test] - fn zeta_sample() { - let a = 2.0; - let d = Zeta::new(a).unwrap(); - let mut rng = crate::test::rng(1); - for _ in 0..1000 { - let r = d.sample(&mut rng); - assert!(r >= 1.); - } - } - - #[test] - fn zeta_small_a() { - let a = 1. + 1e-15; - let d = Zeta::new(a).unwrap(); - let mut rng = crate::test::rng(2); - for _ in 0..1000 { - let r = d.sample(&mut rng); - assert!(r >= 1.); - } - } - - #[test] - fn zeta_value_stability() { - test_samples(Zeta::new(1.5).unwrap(), 0f32, &[1.0, 2.0, 1.0, 1.0]); - test_samples(Zeta::new(2.0).unwrap(), 0f64, &[2.0, 1.0, 1.0, 1.0]); - } - - #[test] - fn zeta_distributions_can_be_compared() { - assert_eq!(Zeta::new(1.0), Zeta::new(1.0)); - } -} diff --git a/rand_distr/src/ziggurat_tables.rs b/rand_distr/src/ziggurat_tables.rs deleted file mode 100644 index f830a601bdd..00000000000 --- a/rand_distr/src/ziggurat_tables.rs +++ /dev/null @@ -1,283 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// Copyright 2013 The Rust Project Developers. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -// Tables for distributions which are sampled using the ziggurat -// algorithm. Autogenerated by `ziggurat_tables.py`. - -pub type ZigTable = &'static [f64; 257]; -pub const ZIG_NORM_R: f64 = 3.654152885361008796; -#[rustfmt::skip] -pub static ZIG_NORM_X: [f64; 257] = - [3.910757959537090045, 3.654152885361008796, 3.449278298560964462, 3.320244733839166074, - 3.224575052047029100, 3.147889289517149969, 3.083526132001233044, 3.027837791768635434, - 2.978603279880844834, 2.934366867207854224, 2.894121053612348060, 2.857138730872132548, - 2.822877396825325125, 2.790921174000785765, 2.760944005278822555, 2.732685359042827056, - 2.705933656121858100, 2.680514643284522158, 2.656283037575502437, 2.633116393630324570, - 2.610910518487548515, 2.589575986706995181, 2.569035452680536569, 2.549221550323460761, - 2.530075232158516929, 2.511544441625342294, 2.493583041269680667, 2.476149939669143318, - 2.459208374333311298, 2.442725318198956774, 2.426670984935725972, 2.411018413899685520, - 2.395743119780480601, 2.380822795170626005, 2.366237056715818632, 2.351967227377659952, - 2.337996148795031370, 2.324308018869623016, 2.310888250599850036, 2.297723348901329565, - 2.284800802722946056, 2.272108990226823888, 2.259637095172217780, 2.247375032945807760, - 2.235313384928327984, 2.223443340090905718, 2.211756642882544366, 2.200245546609647995, - 2.188902771624720689, 2.177721467738641614, 2.166695180352645966, 2.155817819875063268, - 2.145083634046203613, 2.134487182844320152, 2.124023315687815661, 2.113687150684933957, - 2.103474055713146829, 2.093379631137050279, 2.083399693996551783, 2.073530263516978778, - 2.063767547809956415, 2.054107931648864849, 2.044547965215732788, 2.035084353727808715, - 2.025713947862032960, 2.016433734904371722, 2.007240830558684852, 1.998132471356564244, - 1.989106007615571325, 1.980158896898598364, 1.971288697931769640, 1.962493064942461896, - 1.953769742382734043, 1.945116560006753925, 1.936531428273758904, 1.928012334050718257, - 1.919557336591228847, 1.911164563769282232, 1.902832208548446369, 1.894558525668710081, - 1.886341828534776388, 1.878180486290977669, 1.870072921069236838, 1.862017605397632281, - 1.854013059758148119, 1.846057850283119750, 1.838150586580728607, 1.830289919680666566, - 1.822474540091783224, 1.814703175964167636, 1.806974591348693426, 1.799287584547580199, - 1.791640986550010028, 1.784033659547276329, 1.776464495522344977, 1.768932414909077933, - 1.761436365316706665, 1.753975320315455111, 1.746548278279492994, 1.739154261283669012, - 1.731792314050707216, 1.724461502945775715, 1.717160915015540690, 1.709889657069006086, - 1.702646854797613907, 1.695431651932238548, 1.688243209434858727, 1.681080704722823338, - 1.673943330923760353, 1.666830296159286684, 1.659740822855789499, 1.652674147080648526, - 1.645629517902360339, 1.638606196773111146, 1.631603456932422036, 1.624620582830568427, - 1.617656869570534228, 1.610711622367333673, 1.603784156023583041, 1.596873794420261339, - 1.589979870021648534, 1.583101723393471438, 1.576238702733332886, 1.569390163412534456, - 1.562555467528439657, 1.555733983466554893, 1.548925085471535512, 1.542128153226347553, - 1.535342571438843118, 1.528567729435024614, 1.521803020758293101, 1.515047842773992404, - 1.508301596278571965, 1.501563685112706548, 1.494833515777718391, 1.488110497054654369, - 1.481394039625375747, 1.474683555695025516, 1.467978458615230908, 1.461278162507407830, - 1.454582081885523293, 1.447889631277669675, 1.441200224845798017, 1.434513276002946425, - 1.427828197027290358, 1.421144398672323117, 1.414461289772464658, 1.407778276843371534, - 1.401094763676202559, 1.394410150925071257, 1.387723835686884621, 1.381035211072741964, - 1.374343665770030531, 1.367648583594317957, 1.360949343030101844, 1.354245316759430606, - 1.347535871177359290, 1.340820365893152122, 1.334098153216083604, 1.327368577624624679, - 1.320630975217730096, 1.313884673146868964, 1.307128989027353860, 1.300363230327433728, - 1.293586693733517645, 1.286798664489786415, 1.279998415710333237, 1.273185207661843732, - 1.266358287014688333, 1.259516886060144225, 1.252660221891297887, 1.245787495544997903, - 1.238897891102027415, 1.231990574742445110, 1.225064693752808020, 1.218119375481726552, - 1.211153726239911244, 1.204166830140560140, 1.197157747875585931, 1.190125515422801650, - 1.183069142678760732, 1.175987612011489825, 1.168879876726833800, 1.161744859441574240, - 1.154581450355851802, 1.147388505416733873, 1.140164844363995789, 1.132909248648336975, - 1.125620459211294389, 1.118297174115062909, 1.110938046009249502, 1.103541679420268151, - 1.096106627847603487, 1.088631390649514197, 1.081114409698889389, 1.073554065787871714, - 1.065948674757506653, 1.058296483326006454, 1.050595664586207123, 1.042844313139370538, - 1.035040439828605274, 1.027181966030751292, 1.019266717460529215, 1.011292417434978441, - 1.003256679539591412, 0.995156999629943084, 0.986990747093846266, 0.978755155288937750, - 0.970447311058864615, 0.962064143217605250, 0.953602409875572654, 0.945058684462571130, - 0.936429340280896860, 0.927710533396234771, 0.918898183643734989, 0.909987953490768997, - 0.900975224455174528, 0.891855070726792376, 0.882622229578910122, 0.873271068082494550, - 0.863795545546826915, 0.854189171001560554, 0.844444954902423661, 0.834555354079518752, - 0.824512208745288633, 0.814306670128064347, 0.803929116982664893, 0.793369058833152785, - 0.782615023299588763, 0.771654424216739354, 0.760473406422083165, 0.749056662009581653, - 0.737387211425838629, 0.725446140901303549, 0.713212285182022732, 0.700661841097584448, - 0.687767892786257717, 0.674499822827436479, 0.660822574234205984, 0.646695714884388928, - 0.632072236375024632, 0.616896989996235545, 0.601104617743940417, 0.584616766093722262, - 0.567338257040473026, 0.549151702313026790, 0.529909720646495108, 0.509423329585933393, - 0.487443966121754335, 0.463634336771763245, 0.437518402186662658, 0.408389134588000746, - 0.375121332850465727, 0.335737519180459465, 0.286174591747260509, 0.215241895913273806, - 0.000000000000000000]; -#[rustfmt::skip] -pub static ZIG_NORM_F: [f64; 257] = - [0.000477467764586655, 0.001260285930498598, 0.002609072746106363, 0.004037972593371872, - 0.005522403299264754, 0.007050875471392110, 0.008616582769422917, 0.010214971439731100, - 0.011842757857943104, 0.013497450601780807, 0.015177088307982072, 0.016880083152595839, - 0.018605121275783350, 0.020351096230109354, 0.022117062707379922, 0.023902203305873237, - 0.025705804008632656, 0.027527235669693315, 0.029365939758230111, 0.031221417192023690, - 0.033093219458688698, 0.034980941461833073, 0.036884215688691151, 0.038802707404656918, - 0.040736110656078753, 0.042684144916619378, 0.044646552251446536, 0.046623094902089664, - 0.048613553216035145, 0.050617723861121788, 0.052635418276973649, 0.054666461325077916, - 0.056710690106399467, 0.058767952921137984, 0.060838108349751806, 0.062921024437977854, - 0.065016577971470438, 0.067124653828023989, 0.069245144397250269, 0.071377949059141965, - 0.073522973714240991, 0.075680130359194964, 0.077849336702372207, 0.080030515814947509, - 0.082223595813495684, 0.084428509570654661, 0.086645194450867782, 0.088873592068594229, - 0.091113648066700734, 0.093365311913026619, 0.095628536713353335, 0.097903279039215627, - 0.100189498769172020, 0.102487158942306270, 0.104796225622867056, 0.107116667775072880, - 0.109448457147210021, 0.111791568164245583, 0.114145977828255210, 0.116511665626037014, - 0.118888613443345698, 0.121276805485235437, 0.123676228202051403, 0.126086870220650349, - 0.128508722280473636, 0.130941777174128166, 0.133386029692162844, 0.135841476571757352, - 0.138308116449064322, 0.140785949814968309, 0.143274978974047118, 0.145775208006537926, - 0.148286642733128721, 0.150809290682410169, 0.153343161060837674, 0.155888264725064563, - 0.158444614156520225, 0.161012223438117663, 0.163591108232982951, 0.166181285765110071, - 0.168782774801850333, 0.171395595638155623, 0.174019770082499359, 0.176655321444406654, - 0.179302274523530397, 0.181960655600216487, 0.184630492427504539, 0.187311814224516926, - 0.190004651671193070, 0.192709036904328807, 0.195425003514885592, 0.198152586546538112, - 0.200891822495431333, 0.203642749311121501, 0.206405406398679298, 0.209179834621935651, - 0.211966076307852941, 0.214764175252008499, 0.217574176725178370, 0.220396127481011589, - 0.223230075764789593, 0.226076071323264877, 0.228934165415577484, 0.231804410825248525, - 0.234686861873252689, 0.237581574432173676, 0.240488605941449107, 0.243408015423711988, - 0.246339863502238771, 0.249284212419516704, 0.252241126056943765, 0.255210669955677150, - 0.258192911338648023, 0.261187919133763713, 0.264195763998317568, 0.267216518344631837, - 0.270250256366959984, 0.273297054069675804, 0.276356989296781264, 0.279430141762765316, - 0.282516593084849388, 0.285616426816658109, 0.288729728483353931, 0.291856585618280984, - 0.294997087801162572, 0.298151326697901342, 0.301319396102034120, 0.304501391977896274, - 0.307697412505553769, 0.310907558127563710, 0.314131931597630143, 0.317370638031222396, - 0.320623784958230129, 0.323891482377732021, 0.327173842814958593, 0.330470981380537099, - 0.333783015832108509, 0.337110066638412809, 0.340452257045945450, 0.343809713148291340, - 0.347182563958251478, 0.350570941482881204, 0.353974980801569250, 0.357394820147290515, - 0.360830600991175754, 0.364282468130549597, 0.367750569780596226, 0.371235057669821344, - 0.374736087139491414, 0.378253817247238111, 0.381788410875031348, 0.385340034841733958, - 0.388908860020464597, 0.392495061461010764, 0.396098818517547080, 0.399720314981931668, - 0.403359739222868885, 0.407017284331247953, 0.410693148271983222, 0.414387534042706784, - 0.418100649839684591, 0.421832709231353298, 0.425583931339900579, 0.429354541031341519, - 0.433144769114574058, 0.436954852549929273, 0.440785034667769915, 0.444635565397727750, - 0.448506701509214067, 0.452398706863882505, 0.456311852680773566, 0.460246417814923481, - 0.464202689050278838, 0.468180961407822172, 0.472181538469883255, 0.476204732721683788, - 0.480250865911249714, 0.484320269428911598, 0.488413284707712059, 0.492530263646148658, - 0.496671569054796314, 0.500837575128482149, 0.505028667945828791, 0.509245245998136142, - 0.513487720749743026, 0.517756517232200619, 0.522052074674794864, 0.526374847174186700, - 0.530725304406193921, 0.535103932383019565, 0.539511234259544614, 0.543947731192649941, - 0.548413963257921133, 0.552910490428519918, 0.557437893621486324, 0.561996775817277916, - 0.566587763258951771, 0.571211506738074970, 0.575868682975210544, 0.580559996103683473, - 0.585286179266300333, 0.590047996335791969, 0.594846243770991268, 0.599681752622167719, - 0.604555390700549533, 0.609468064928895381, 0.614420723892076803, 0.619414360609039205, - 0.624450015550274240, 0.629528779928128279, 0.634651799290960050, 0.639820277456438991, - 0.645035480824251883, 0.650298743114294586, 0.655611470583224665, 0.660975147780241357, - 0.666391343912380640, 0.671861719900766374, 0.677388036222513090, 0.682972161648791376, - 0.688616083008527058, 0.694321916130032579, 0.700091918140490099, 0.705928501336797409, - 0.711834248882358467, 0.717811932634901395, 0.723864533472881599, 0.729995264565802437, - 0.736207598131266683, 0.742505296344636245, 0.748892447223726720, 0.755373506511754500, - 0.761953346841546475, 0.768637315803334831, 0.775431304986138326, 0.782341832659861902, - 0.789376143571198563, 0.796542330428254619, 0.803849483176389490, 0.811307874318219935, - 0.818929191609414797, 0.826726833952094231, 0.834716292992930375, 0.842915653118441077, - 0.851346258465123684, 0.860033621203008636, 0.869008688043793165, 0.878309655816146839, - 0.887984660763399880, 0.898095921906304051, 0.908726440060562912, 0.919991505048360247, - 0.932060075968990209, 0.945198953453078028, 0.959879091812415930, 0.977101701282731328, - 1.000000000000000000]; -pub const ZIG_EXP_R: f64 = 7.697117470131050077; -#[rustfmt::skip] -pub static ZIG_EXP_X: [f64; 257] = - [8.697117470131052741, 7.697117470131050077, 6.941033629377212577, 6.478378493832569696, - 6.144164665772472667, 5.882144315795399869, 5.666410167454033697, 5.482890627526062488, - 5.323090505754398016, 5.181487281301500047, 5.054288489981304089, 4.938777085901250530, - 4.832939741025112035, 4.735242996601741083, 4.644491885420085175, 4.559737061707351380, - 4.480211746528421912, 4.405287693473573185, 4.334443680317273007, 4.267242480277365857, - 4.203313713735184365, 4.142340865664051464, 4.084051310408297830, 4.028208544647936762, - 3.974606066673788796, 3.923062500135489739, 3.873417670399509127, 3.825529418522336744, - 3.779270992411667862, 3.734528894039797375, 3.691201090237418825, 3.649195515760853770, - 3.608428813128909507, 3.568825265648337020, 3.530315889129343354, 3.492837654774059608, - 3.456332821132760191, 3.420748357251119920, 3.386035442460300970, 3.352149030900109405, - 3.319047470970748037, 3.286692171599068679, 3.255047308570449882, 3.224079565286264160, - 3.193757903212240290, 3.164053358025972873, 3.134938858084440394, 3.106389062339824481, - 3.078380215254090224, 3.050890016615455114, 3.023897504455676621, 2.997382949516130601, - 2.971327759921089662, 2.945714394895045718, 2.920526286512740821, 2.895747768600141825, - 2.871364012015536371, 2.847360965635188812, 2.823725302450035279, 2.800444370250737780, - 2.777506146439756574, 2.754899196562344610, 2.732612636194700073, 2.710636095867928752, - 2.688959688741803689, 2.667573980773266573, 2.646469963151809157, 2.625639026797788489, - 2.605072938740835564, 2.584763820214140750, 2.564704126316905253, 2.544886627111869970, - 2.525304390037828028, 2.505950763528594027, 2.486819361740209455, 2.467904050297364815, - 2.449198932978249754, 2.430698339264419694, 2.412396812688870629, 2.394289099921457886, - 2.376370140536140596, 2.358635057409337321, 2.341079147703034380, 2.323697874390196372, - 2.306486858283579799, 2.289441870532269441, 2.272558825553154804, 2.255833774367219213, - 2.239262898312909034, 2.222842503111036816, 2.206569013257663858, 2.190438966723220027, - 2.174449009937774679, 2.158595893043885994, 2.142876465399842001, 2.127287671317368289, - 2.111826546019042183, 2.096490211801715020, 2.081275874393225145, 2.066180819490575526, - 2.051202409468584786, 2.036338080248769611, 2.021585338318926173, 2.006941757894518563, - 1.992404978213576650, 1.977972700957360441, 1.963642687789548313, 1.949412758007184943, - 1.935280786297051359, 1.921244700591528076, 1.907302480018387536, 1.893452152939308242, - 1.879691795072211180, 1.866019527692827973, 1.852433515911175554, 1.838931967018879954, - 1.825513128903519799, 1.812175288526390649, 1.798916770460290859, 1.785735935484126014, - 1.772631179231305643, 1.759600930889074766, 1.746643651946074405, 1.733757834985571566, - 1.720942002521935299, 1.708194705878057773, 1.695514524101537912, 1.682900062917553896, - 1.670349953716452118, 1.657862852574172763, 1.645437439303723659, 1.633072416535991334, - 1.620766508828257901, 1.608518461798858379, 1.596327041286483395, 1.584191032532688892, - 1.572109239386229707, 1.560080483527888084, 1.548103603714513499, 1.536177455041032092, - 1.524300908219226258, 1.512472848872117082, 1.500692176842816750, 1.488957805516746058, - 1.477268661156133867, 1.465623682245745352, 1.454021818848793446, 1.442462031972012504, - 1.430943292938879674, 1.419464582769983219, 1.408024891569535697, 1.396623217917042137, - 1.385258568263121992, 1.373929956328490576, 1.362636402505086775, 1.351376933258335189, - 1.340150580529504643, 1.328956381137116560, 1.317793376176324749, 1.306660610415174117, - 1.295557131686601027, 1.284481990275012642, 1.273434238296241139, 1.262412929069615330, - 1.251417116480852521, 1.240445854334406572, 1.229498195693849105, 1.218573192208790124, - 1.207669893426761121, 1.196787346088403092, 1.185924593404202199, 1.175080674310911677, - 1.164254622705678921, 1.153445466655774743, 1.142652227581672841, 1.131873919411078511, - 1.121109547701330200, 1.110358108727411031, 1.099618588532597308, 1.088889961938546813, - 1.078171191511372307, 1.067461226479967662, 1.056759001602551429, 1.046063435977044209, - 1.035373431790528542, 1.024687873002617211, 1.014005623957096480, 1.003325527915696735, - 0.992646405507275897, 0.981967053085062602, 0.971286240983903260, 0.960602711668666509, - 0.949915177764075969, 0.939222319955262286, 0.928522784747210395, 0.917815182070044311, - 0.907098082715690257, 0.896370015589889935, 0.885629464761751528, 0.874874866291025066, - 0.864104604811004484, 0.853317009842373353, 0.842510351810368485, 0.831682837734273206, - 0.820832606554411814, 0.809957724057418282, 0.799056177355487174, 0.788125868869492430, - 0.777164609759129710, 0.766170112735434672, 0.755139984181982249, 0.744071715500508102, - 0.732962673584365398, 0.721810090308756203, 0.710611050909655040, 0.699362481103231959, - 0.688061132773747808, 0.676703568029522584, 0.665286141392677943, 0.653804979847664947, - 0.642255960424536365, 0.630634684933490286, 0.618936451394876075, 0.607156221620300030, - 0.595288584291502887, 0.583327712748769489, 0.571267316532588332, 0.559100585511540626, - 0.546820125163310577, 0.534417881237165604, 0.521885051592135052, 0.509211982443654398, - 0.496388045518671162, 0.483401491653461857, 0.470239275082169006, 0.456886840931420235, - 0.443327866073552401, 0.429543940225410703, 0.415514169600356364, 0.401214678896277765, - 0.386617977941119573, 0.371692145329917234, 0.356399760258393816, 0.340696481064849122, - 0.324529117016909452, 0.307832954674932158, 0.290527955491230394, 0.272513185478464703, - 0.253658363385912022, 0.233790483059674731, 0.212671510630966620, 0.189958689622431842, - 0.165127622564187282, 0.137304980940012589, 0.104838507565818778, 0.063852163815001570, - 0.000000000000000000]; -#[rustfmt::skip] -pub static ZIG_EXP_F: [f64; 257] = - [0.000167066692307963, 0.000454134353841497, 0.000967269282327174, 0.001536299780301573, - 0.002145967743718907, 0.002788798793574076, 0.003460264777836904, 0.004157295120833797, - 0.004877655983542396, 0.005619642207205489, 0.006381905937319183, 0.007163353183634991, - 0.007963077438017043, 0.008780314985808977, 0.009614413642502212, 0.010464810181029981, - 0.011331013597834600, 0.012212592426255378, 0.013109164931254991, 0.014020391403181943, - 0.014945968011691148, 0.015885621839973156, 0.016839106826039941, 0.017806200410911355, - 0.018786700744696024, 0.019780424338009740, 0.020787204072578114, 0.021806887504283581, - 0.022839335406385240, 0.023884420511558174, 0.024942026419731787, 0.026012046645134221, - 0.027094383780955803, 0.028188948763978646, 0.029295660224637411, 0.030414443910466622, - 0.031545232172893622, 0.032687963508959555, 0.033842582150874358, 0.035009037697397431, - 0.036187284781931443, 0.037377282772959382, 0.038578995503074871, 0.039792391023374139, - 0.041017441380414840, 0.042254122413316254, 0.043502413568888197, 0.044762297732943289, - 0.046033761076175184, 0.047316792913181561, 0.048611385573379504, 0.049917534282706379, - 0.051235237055126281, 0.052564494593071685, 0.053905310196046080, 0.055257689676697030, - 0.056621641283742870, 0.057997175631200659, 0.059384305633420280, 0.060783046445479660, - 0.062193415408541036, 0.063615431999807376, 0.065049117786753805, 0.066494496385339816, - 0.067951593421936643, 0.069420436498728783, 0.070901055162371843, 0.072393480875708752, - 0.073897746992364746, 0.075413888734058410, 0.076941943170480517, 0.078481949201606435, - 0.080033947542319905, 0.081597980709237419, 0.083174093009632397, 0.084762330532368146, - 0.086362741140756927, 0.087975374467270231, 0.089600281910032886, 0.091237516631040197, - 0.092887133556043569, 0.094549189376055873, 0.096223742550432825, 0.097910853311492213, - 0.099610583670637132, 0.101322997425953631, 0.103048160171257702, 0.104786139306570145, - 0.106537004050001632, 0.108300825451033755, 0.110077676405185357, 0.111867631670056283, - 0.113670767882744286, 0.115487163578633506, 0.117316899211555525, 0.119160057175327641, - 0.121016721826674792, 0.122886979509545108, 0.124770918580830933, 0.126668629437510671, - 0.128580204545228199, 0.130505738468330773, 0.132445327901387494, 0.134399071702213602, - 0.136367070926428829, 0.138349428863580176, 0.140346251074862399, 0.142357645432472146, - 0.144383722160634720, 0.146424593878344889, 0.148480375643866735, 0.150551185001039839, - 0.152637142027442801, 0.154738369384468027, 0.156854992369365148, 0.158987138969314129, - 0.161134939917591952, 0.163298528751901734, 0.165478041874935922, 0.167673618617250081, - 0.169885401302527550, 0.172113535315319977, 0.174358169171353411, 0.176619454590494829, - 0.178897546572478278, 0.181192603475496261, 0.183504787097767436, 0.185834262762197083, - 0.188181199404254262, 0.190545769663195363, 0.192928149976771296, 0.195328520679563189, - 0.197747066105098818, 0.200183974691911210, 0.202639439093708962, 0.205113656293837654, - 0.207606827724221982, 0.210119159388988230, 0.212650861992978224, 0.215202151075378628, - 0.217773247148700472, 0.220364375843359439, 0.222975768058120111, 0.225607660116683956, - 0.228260293930716618, 0.230933917169627356, 0.233628783437433291, 0.236345152457059560, - 0.239083290262449094, 0.241843469398877131, 0.244625969131892024, 0.247431075665327543, - 0.250259082368862240, 0.253110290015629402, 0.255985007030415324, 0.258883549749016173, - 0.261806242689362922, 0.264753418835062149, 0.267725419932044739, 0.270722596799059967, - 0.273745309652802915, 0.276793928448517301, 0.279868833236972869, 0.282970414538780746, - 0.286099073737076826, 0.289255223489677693, 0.292439288161892630, 0.295651704281261252, - 0.298892921015581847, 0.302163400675693528, 0.305463619244590256, 0.308794066934560185, - 0.312155248774179606, 0.315547685227128949, 0.318971912844957239, 0.322428484956089223, - 0.325917972393556354, 0.329440964264136438, 0.332998068761809096, 0.336589914028677717, - 0.340217149066780189, 0.343880444704502575, 0.347580494621637148, 0.351318016437483449, - 0.355093752866787626, 0.358908472948750001, 0.362762973354817997, 0.366658079781514379, - 0.370594648435146223, 0.374573567615902381, 0.378595759409581067, 0.382662181496010056, - 0.386773829084137932, 0.390931736984797384, 0.395136981833290435, 0.399390684475231350, - 0.403694012530530555, 0.408048183152032673, 0.412454465997161457, 0.416914186433003209, - 0.421428728997616908, 0.425999541143034677, 0.430628137288459167, 0.435316103215636907, - 0.440065100842354173, 0.444876873414548846, 0.449753251162755330, 0.454696157474615836, - 0.459707615642138023, 0.464789756250426511, 0.469944825283960310, 0.475175193037377708, - 0.480483363930454543, 0.485871987341885248, 0.491343869594032867, 0.496901987241549881, - 0.502549501841348056, 0.508289776410643213, 0.514126393814748894, 0.520063177368233931, - 0.526104213983620062, 0.532253880263043655, 0.538516872002862246, 0.544898237672440056, - 0.551403416540641733, 0.558038282262587892, 0.564809192912400615, 0.571723048664826150, - 0.578787358602845359, 0.586010318477268366, 0.593400901691733762, 0.600968966365232560, - 0.608725382079622346, 0.616682180915207878, 0.624852738703666200, 0.633251994214366398, - 0.641896716427266423, 0.650805833414571433, 0.660000841079000145, 0.669506316731925177, - 0.679350572264765806, 0.689566496117078431, 0.700192655082788606, 0.711274760805076456, - 0.722867659593572465, 0.735038092431424039, 0.747868621985195658, 0.761463388849896838, - 0.775956852040116218, 0.791527636972496285, 0.808421651523009044, 0.826993296643051101, - 0.847785500623990496, 0.871704332381204705, 0.900469929925747703, 0.938143680862176477, - 1.000000000000000000]; diff --git a/rand_distr/src/zipf.rs b/rand_distr/src/zipf.rs deleted file mode 100644 index f2e80d37908..00000000000 --- a/rand_distr/src/zipf.rs +++ /dev/null @@ -1,244 +0,0 @@ -// Copyright 2021 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -//! The Zipf distribution. - -use crate::{Distribution, StandardUniform}; -use core::fmt; -use num_traits::Float; -use rand::Rng; - -/// The Zipf (Zipfian) distribution `Zipf(n, s)`. -/// -/// The samples follow [Zipf's law](https://en.wikipedia.org/wiki/Zipf%27s_law): -/// The frequency of each sample from a finite set of size `n` is inversely -/// proportional to a power of its frequency rank (with exponent `s`). -/// -/// For large `n`, this converges to the [`Zeta`](crate::Zeta) distribution. -/// -/// For `s = 0`, this becomes a [`uniform`](crate::Uniform) distribution. -/// -/// # Plot -/// -/// The following plot illustrates the Zipf distribution for `n = 10` and -/// various values of `s`. -/// -/// ![Zipf distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/zipf.svg) -/// -/// # Example -/// ``` -/// use rand::prelude::*; -/// use rand_distr::Zipf; -/// -/// let val: f64 = rand::rng().sample(Zipf::new(10.0, 1.5).unwrap()); -/// println!("{}", val); -/// ``` -/// -/// # Integer vs FP return type -/// -/// This implementation uses floating-point (FP) logic internally. It may be -/// expected that the samples are no greater than `n`, thus it is reasonable to -/// cast generated samples to any integer type which can also represent `n` -/// (e.g. `distr.sample(&mut rng) as u64`). -/// -/// # Implementation details -/// -/// Implemented via [rejection sampling](https://en.wikipedia.org/wiki/Rejection_sampling), -/// due to Jason Crease[1]. -/// -/// [1]: https://jasoncrease.medium.com/rejection-sampling-the-zipf-distribution-6b359792cffa -#[derive(Clone, Copy, Debug, PartialEq)] -pub struct Zipf -where - F: Float, - StandardUniform: Distribution, -{ - s: F, - t: F, - q: F, -} - -/// Error type returned from [`Zipf::new`]. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum Error { - /// `s < 0` or `nan`. - STooSmall, - /// `n < 1`. - NTooSmall, -} - -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - Error::STooSmall => "s < 0 or is NaN in Zipf distribution", - Error::NTooSmall => "n < 1 in Zipf distribution", - }) - } -} - -#[cfg(feature = "std")] -impl std::error::Error for Error {} - -impl Zipf -where - F: Float, - StandardUniform: Distribution, -{ - /// Construct a new `Zipf` distribution for a set with `n` elements and a - /// frequency rank exponent `s`. - /// - /// The parameter `n` is typically integral, however we use type - ///
F: [Float]
in order to permit very large values - /// and since our implementation requires a floating-point type. - #[inline] - pub fn new(n: F, s: F) -> Result, Error> { - if !(s >= F::zero()) { - return Err(Error::STooSmall); - } - if n < F::one() { - return Err(Error::NTooSmall); - } - let q = if s != F::one() { - // Make sure to calculate the division only once. - F::one() / (F::one() - s) - } else { - // This value is never used. - F::zero() - }; - let t = if s != F::one() { - (n.powf(F::one() - s) - s) * q - } else { - F::one() + n.ln() - }; - debug_assert!(t > F::zero()); - Ok(Zipf { s, t, q }) - } - - /// Inverse cumulative density function - #[inline] - fn inv_cdf(&self, p: F) -> F { - let one = F::one(); - let pt = p * self.t; - if pt <= one { - pt - } else if self.s != one { - (pt * (one - self.s) + self.s).powf(self.q) - } else { - (pt - one).exp() - } - } -} - -impl Distribution for Zipf -where - F: Float, - StandardUniform: Distribution, -{ - #[inline] - fn sample(&self, rng: &mut R) -> F { - let one = F::one(); - loop { - let inv_b = self.inv_cdf(rng.sample(StandardUniform)); - let x = (inv_b + one).floor(); - let mut ratio = x.powf(-self.s); - if x > one { - ratio = ratio * inv_b.powf(self.s) - }; - - let y = rng.sample(StandardUniform); - if y < ratio { - return x; - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - fn test_samples>(distr: D, zero: F, expected: &[F]) { - let mut rng = crate::test::rng(213); - let mut buf = [zero; 4]; - for x in &mut buf { - *x = rng.sample(&distr); - } - assert_eq!(buf, expected); - } - - #[test] - #[should_panic] - fn zipf_s_too_small() { - Zipf::new(10., -1.).unwrap(); - } - - #[test] - #[should_panic] - fn zipf_n_too_small() { - Zipf::new(0., 1.).unwrap(); - } - - #[test] - #[should_panic] - fn zipf_nan() { - Zipf::new(10., f64::NAN).unwrap(); - } - - #[test] - fn zipf_sample() { - let d = Zipf::new(10., 0.5).unwrap(); - let mut rng = crate::test::rng(2); - for _ in 0..1000 { - let r = d.sample(&mut rng); - assert!(r >= 1.); - } - } - - #[test] - fn zipf_sample_s_1() { - let d = Zipf::new(10., 1.).unwrap(); - let mut rng = crate::test::rng(2); - for _ in 0..1000 { - let r = d.sample(&mut rng); - assert!(r >= 1.); - } - } - - #[test] - fn zipf_sample_s_0() { - let d = Zipf::new(10., 0.).unwrap(); - let mut rng = crate::test::rng(2); - for _ in 0..1000 { - let r = d.sample(&mut rng); - assert!(r >= 1.); - } - // TODO: verify that this is a uniform distribution - } - - #[test] - fn zipf_sample_large_n() { - let d = Zipf::new(f64::MAX, 1.5).unwrap(); - let mut rng = crate::test::rng(2); - for _ in 0..1000 { - let r = d.sample(&mut rng); - assert!(r >= 1.); - } - // TODO: verify that this is a zeta distribution - } - - #[test] - fn zipf_value_stability() { - test_samples(Zipf::new(10., 0.5).unwrap(), 0f32, &[10.0, 2.0, 6.0, 7.0]); - test_samples(Zipf::new(10., 2.0).unwrap(), 0f64, &[1.0, 2.0, 3.0, 2.0]); - } - - #[test] - fn zipf_distributions_can_be_compared() { - assert_eq!(Zipf::new(1.0, 2.0), Zipf::new(1.0, 2.0)); - } -} diff --git a/rand_distr/tests/value_stability.rs b/rand_distr/tests/value_stability.rs deleted file mode 100644 index 330119b68f6..00000000000 --- a/rand_distr/tests/value_stability.rs +++ /dev/null @@ -1,553 +0,0 @@ -// Copyright 2018 Developers of the Rand project. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -use average::assert_almost_eq; -use core::fmt::Debug; -use rand::Rng; -use rand_distr::*; - -fn get_rng(seed: u64) -> impl Rng { - // For tests, we want a statistically good, fast, reproducible RNG. - // PCG32 will do fine, and will be easy to embed if we ever need to. - const INC: u64 = 11634580027462260723; - rand_pcg::Pcg32::new(seed, INC) -} - -/// We only assert approximate equality since some platforms do not perform -/// identically (i686-unknown-linux-gnu and most notably x86_64-pc-windows-gnu). -trait ApproxEq { - fn assert_almost_eq(&self, rhs: &Self); -} - -impl ApproxEq for f32 { - fn assert_almost_eq(&self, rhs: &Self) { - assert_almost_eq!(self, rhs, 1e-6); - } -} -impl ApproxEq for f64 { - fn assert_almost_eq(&self, rhs: &Self) { - assert_almost_eq!(self, rhs, 1e-14); - } -} -impl ApproxEq for u64 { - fn assert_almost_eq(&self, rhs: &Self) { - assert_eq!(self, rhs); - } -} -impl ApproxEq for [T; 2] { - fn assert_almost_eq(&self, rhs: &Self) { - self[0].assert_almost_eq(&rhs[0]); - self[1].assert_almost_eq(&rhs[1]); - } -} -impl ApproxEq for [T; 3] { - fn assert_almost_eq(&self, rhs: &Self) { - self[0].assert_almost_eq(&rhs[0]); - self[1].assert_almost_eq(&rhs[1]); - self[2].assert_almost_eq(&rhs[2]); - } -} - -fn test_samples>(seed: u64, distr: D, expected: &[F]) { - let mut rng = get_rng(seed); - for val in expected { - let x = rng.sample(&distr); - x.assert_almost_eq(val); - } -} - -#[test] -fn binomial_stability() { - // We have multiple code paths: np < 10, p > 0.5 - test_samples(353, Binomial::new(2, 0.7).unwrap(), &[1, 1, 2, 1]); - test_samples(353, Binomial::new(20, 0.3).unwrap(), &[7, 7, 5, 7]); - test_samples( - 353, - Binomial::new(2000, 0.6).unwrap(), - &[1194, 1208, 1192, 1210], - ); -} - -#[test] -fn geometric_stability() { - test_samples(464, StandardGeometric, &[3, 0, 1, 0, 0, 3, 2, 1, 2, 0]); - - test_samples(464, Geometric::new(0.5).unwrap(), &[2, 1, 1, 0, 0, 1, 0, 1]); - test_samples( - 464, - Geometric::new(0.05).unwrap(), - &[24, 51, 81, 67, 27, 11, 7, 6], - ); - test_samples( - 464, - Geometric::new(0.95).unwrap(), - &[0, 0, 0, 0, 1, 0, 0, 0], - ); - - // expect non-random behaviour for series of pre-determined trials - test_samples(464, Geometric::new(0.0).unwrap(), &[u64::MAX; 100][..]); - test_samples(464, Geometric::new(1.0).unwrap(), &[0; 100][..]); -} - -#[test] -fn hypergeometric_stability() { - // We have multiple code paths based on the distribution's mode and sample_size - test_samples( - 7221, - Hypergeometric::new(99, 33, 8).unwrap(), - &[4, 3, 2, 2, 3, 2, 3, 1], - ); // Algorithm HIN - test_samples( - 7221, - Hypergeometric::new(100, 50, 50).unwrap(), - &[23, 27, 26, 27, 22, 25, 31, 25], - ); // Algorithm H2PE -} - -#[test] -fn unit_ball_stability() { - test_samples( - 2, - UnitBall, - &[ - [ - 0.018035709265959987f64, - -0.4348771383120438, - -0.07982762085055706, - ], - [ - 0.10588569388223945, - -0.4734350111375454, - -0.7392104908825501, - ], - [ - 0.11060237642041049, - -0.16065642822852677, - -0.8444043930440075, - ], - ], - ); -} - -#[test] -fn unit_circle_stability() { - test_samples( - 2, - UnitCircle, - &[ - [-0.9965658683520504f64, -0.08280380447614634], - [-0.9790853270389644, -0.20345004884984505], - [-0.8449189758898707, 0.5348943112253227], - ], - ); -} - -#[test] -fn unit_sphere_stability() { - test_samples( - 2, - UnitSphere, - &[ - [ - 0.03247542860231647f64, - -0.7830477442152738, - 0.6211131755296027, - ], - [ - -0.09978440840914075, - 0.9706650829833128, - -0.21875184231323952, - ], - [0.2735582468624679, 0.9435374242279655, -0.1868234852870203], - ], - ); -} - -#[test] -fn unit_disc_stability() { - test_samples( - 2, - UnitDisc, - &[ - [0.018035709265959987f64, -0.4348771383120438], - [-0.07982762085055706, 0.7765329819820659], - [0.21450745997299503, 0.7398636984333291], - ], - ); -} - -#[test] -fn pareto_stability() { - test_samples( - 213, - Pareto::new(1.0, 1.0).unwrap(), - &[1.0423688f32, 2.1235929, 4.132709, 1.4679428], - ); - test_samples( - 213, - Pareto::new(2.0, 0.5).unwrap(), - &[ - 9.019295276219136f64, - 4.3097126018270595, - 6.837815045397157, - 105.8826669383772, - ], - ); -} - -#[test] -fn poisson_stability() { - test_samples(223, Poisson::new(7.0).unwrap(), &[5.0f32, 11.0, 6.0, 5.0]); - test_samples(223, Poisson::new(7.0).unwrap(), &[9.0f64, 5.0, 7.0, 6.0]); - test_samples( - 223, - Poisson::new(27.0).unwrap(), - &[28.0f32, 32.0, 36.0, 36.0], - ); -} - -#[test] -fn triangular_stability() { - test_samples( - 860, - Triangular::new(2., 10., 3.).unwrap(), - &[ - 5.74373257511361f64, - 7.890059162791258f64, - 4.7256280652553455f64, - 2.9474808121184077f64, - 3.058301946314053f64, - ], - ); -} - -#[test] -fn normal_inverse_gaussian_stability() { - test_samples( - 213, - NormalInverseGaussian::new(2.0, 1.0).unwrap(), - &[0.6568966f32, 1.3744819, 2.216063, 0.11488572], - ); - test_samples( - 213, - NormalInverseGaussian::new(2.0, 1.0).unwrap(), - &[ - 0.6838707059642927f64, - 2.4447306460569784, - 0.2361045023235968, - 1.7774534624785319, - ], - ); -} - -#[test] -fn pert_stability() { - // mean = 4, var = 12/7 - test_samples( - 860, - Pert::new(2., 10.).with_mode(3.).unwrap(), - &[ - 4.908681667460367, - 4.014196196158352, - 2.6489397149197234, - 3.4569780580044727, - 4.242864311947118, - ], - ); -} - -#[test] -fn inverse_gaussian_stability() { - test_samples( - 213, - InverseGaussian::new(1.0, 3.0).unwrap(), - &[0.9339157f32, 1.108113, 0.50864697, 0.39849377], - ); - test_samples( - 213, - InverseGaussian::new(1.0, 3.0).unwrap(), - &[ - 1.0707604954722476f64, - 0.9628140605340697, - 0.4069687656468226, - 0.660283852985818, - ], - ); -} - -#[test] -fn gamma_stability() { - // Gamma has 3 cases: shape == 1, shape < 1, shape > 1 - test_samples( - 223, - Gamma::new(1.0, 5.0).unwrap(), - &[5.398085f32, 9.162783, 0.2300583, 1.7235851], - ); - test_samples( - 223, - Gamma::new(0.8, 5.0).unwrap(), - &[0.5051203f32, 0.9048302, 3.095812, 1.8566116], - ); - test_samples( - 223, - Gamma::new(1.1, 5.0).unwrap(), - &[ - 7.783878094584059f64, - 1.4939528171618057, - 8.638017638857592, - 3.0949337228829004, - ], - ); - - // ChiSquared has 2 cases: k == 1, k != 1 - test_samples( - 223, - ChiSquared::new(1.0).unwrap(), - &[ - 0.4893526200348249f64, - 1.635249736808788, - 0.5013580219361969, - 0.1457735613733489, - ], - ); - test_samples( - 223, - ChiSquared::new(0.1).unwrap(), - &[ - 0.014824404726978617f64, - 0.021602123937134326, - 0.0000003431429746851693, - 0.00000002291755769542258, - ], - ); - test_samples( - 223, - ChiSquared::new(10.0).unwrap(), - &[12.693656f32, 6.812016, 11.082001, 12.436167], - ); - - // FisherF has same special cases as ChiSquared on each param - test_samples( - 223, - FisherF::new(1.0, 13.5).unwrap(), - &[0.32283646f32, 0.048049655, 0.0788893, 1.817178], - ); - test_samples( - 223, - FisherF::new(1.0, 1.0).unwrap(), - &[0.29925257f32, 3.4392934, 9.567652, 0.020074], - ); - test_samples( - 223, - FisherF::new(0.7, 13.5).unwrap(), - &[ - 3.3196593155045124f64, - 0.3409169916262829, - 0.03377989856426519, - 0.00004041672861036937, - ], - ); - - // StudentT has same special cases as ChiSquared - test_samples( - 223, - StudentT::new(1.0).unwrap(), - &[0.54703987f32, -1.8545331, 3.093162, -0.14168274], - ); - test_samples( - 223, - StudentT::new(1.1).unwrap(), - &[ - 0.7729195887949754f64, - 1.2606210611616204, - -1.7553606501113175, - -2.377641221169782, - ], - ); - - // Beta has two special cases: - // - // 1. min(alpha, beta) <= 1 - // 2. min(alpha, beta) > 1 - test_samples( - 223, - Beta::new(1.0, 0.8).unwrap(), - &[ - 0.8300703726659456, - 0.8134131062097899, - 0.47912589330631555, - 0.25323238071138526, - ], - ); - test_samples( - 223, - Beta::new(3.0, 1.2).unwrap(), - &[ - 0.49563509121756827, - 0.9551305482256759, - 0.5151181353461637, - 0.7551732971235077, - ], - ); -} - -#[test] -fn exponential_stability() { - test_samples(223, Exp1, &[1.079617f32, 1.8325565, 0.04601166, 0.34471703]); - test_samples( - 223, - Exp1, - &[ - 1.0796170642388276f64, - 1.8325565304274, - 0.04601166186842716, - 0.3447170217100157, - ], - ); - - test_samples( - 223, - Exp::new(2.0).unwrap(), - &[0.5398085f32, 0.91627824, 0.02300583, 0.17235851], - ); - test_samples( - 223, - Exp::new(1.0).unwrap(), - &[ - 1.0796170642388276f64, - 1.8325565304274, - 0.04601166186842716, - 0.3447170217100157, - ], - ); -} - -#[test] -fn normal_stability() { - test_samples( - 213, - StandardNormal, - &[-0.11844189f32, 0.781378, 0.06563994, -1.1932899], - ); - test_samples( - 213, - StandardNormal, - &[ - -0.11844188827977231f64, - 0.7813779637772346, - 0.06563993969580051, - -1.1932899004186373, - ], - ); - - test_samples( - 213, - Normal::new(0.0, 1.0).unwrap(), - &[-0.11844189f32, 0.781378, 0.06563994, -1.1932899], - ); - test_samples( - 213, - Normal::new(2.0, 0.5).unwrap(), - &[ - 1.940779055860114f64, - 2.3906889818886174, - 2.0328199698479, - 1.4033550497906813, - ], - ); - - test_samples( - 213, - LogNormal::new(0.0, 1.0).unwrap(), - &[0.88830346f32, 2.1844804, 1.0678421, 0.30322206], - ); - test_samples( - 213, - LogNormal::new(2.0, 0.5).unwrap(), - &[ - 6.964174338639032f64, - 10.921015733601452, - 7.6355881556915906, - 4.068828213584092, - ], - ); -} - -#[test] -fn weibull_stability() { - test_samples( - 213, - Weibull::new(1.0, 1.0).unwrap(), - &[0.041495778f32, 0.7531094, 1.4189332, 0.38386202], - ); - test_samples( - 213, - Weibull::new(2.0, 0.5).unwrap(), - &[ - 1.1343478702739669f64, - 0.29470010050655226, - 0.7556151370284702, - 7.877212340241561, - ], - ); -} - -#[cfg(feature = "alloc")] -#[test] -fn dirichlet_stability() { - let mut rng = get_rng(223); - assert_eq!( - rng.sample(Dirichlet::new([1.0, 2.0, 3.0]).unwrap()), - [0.12941567177708177, 0.4702121891675036, 0.4003721390554146] - ); - assert_eq!( - rng.sample(Dirichlet::new([8.0; 5]).unwrap()), - [ - 0.17684200044809556, - 0.29915953935953055, - 0.1832858056608014, - 0.1425623503573967, - 0.19815030417417595 - ] - ); - // Test stability for the case where all alphas are less than 0.1. - assert_eq!( - rng.sample(Dirichlet::new([0.05, 0.025, 0.075, 0.05]).unwrap()), - [ - 0.00027580456855692104, - 2.296135759821706e-20, - 3.004118281150937e-9, - 0.9997241924273248 - ] - ); -} - -#[test] -fn cauchy_stability() { - test_samples( - 353, - Cauchy::new(100f64, 10.0).unwrap(), - &[ - 77.93369152808678f64, - 90.1606912098641, - 125.31516221323625, - 86.10217834773925, - ], - ); - - // Unfortunately this test is not fully portable due to reliance on the - // system's implementation of tanf (see doc on Cauchy struct). - // We use a lower threshold of 1e-5 here. - let distr = Cauchy::new(10f32, 7.0).unwrap(); - let mut rng = get_rng(353); - let expected = [15.023088, -5.446413, 3.7092876, 3.112482]; - for &a in expected.iter() { - let b = rng.sample(distr); - assert_almost_eq!(a, b, 1e-5); - } -} diff --git a/rand_pcg/src/pcg128.rs b/rand_pcg/src/pcg128.rs index 990303c41fb..d2341425673 100644 --- a/rand_pcg/src/pcg128.rs +++ b/rand_pcg/src/pcg128.rs @@ -234,7 +234,7 @@ impl SeedableRng for Mcg128Xsl64 { // Read as if a little-endian u128 value: let mut seed_u64 = [0u64; 2]; le::read_u64_into(&seed, &mut seed_u64); - let state = u128::from(seed_u64[0]) | u128::from(seed_u64[1]) << 64; + let state = u128::from(seed_u64[0]) | (u128::from(seed_u64[1]) << 64); Mcg128Xsl64::new(state) } } diff --git a/src/distr/distribution.rs b/src/distr/distribution.rs index 6f4e202647e..48598ec0fde 100644 --- a/src/distr/distribution.rs +++ b/src/distr/distribution.rs @@ -250,7 +250,7 @@ mod tests { #[test] #[cfg(feature = "alloc")] fn test_dist_string() { - use crate::distr::{Alphanumeric, SampleString, StandardUniform}; + use crate::distr::{Alphabetic, Alphanumeric, SampleString, StandardUniform}; use core::str; let mut rng = crate::test::rng(213); @@ -261,5 +261,9 @@ mod tests { let s2 = StandardUniform.sample_string(&mut rng, 20); assert_eq!(s2.chars().count(), 20); assert_eq!(str::from_utf8(s2.as_bytes()), Ok(s2.as_str())); + + let s3 = Alphabetic.sample_string(&mut rng, 20); + assert_eq!(s3.len(), 20); + assert_eq!(str::from_utf8(s3.as_bytes()), Ok(s3.as_str())); } } diff --git a/src/distr/float.rs b/src/distr/float.rs index ec380b4bd4d..44c33e99268 100644 --- a/src/distr/float.rs +++ b/src/distr/float.rs @@ -175,7 +175,7 @@ float_impls! { feature = "simd_support", f64x8, u64x8, f64, u64, 52, 1023 } #[cfg(test)] mod tests { use super::*; - use crate::rngs::mock::StepRng; + use crate::test::const_rng; const EPSILON32: f32 = f32::EPSILON; const EPSILON64: f64 = f64::EPSILON; @@ -187,30 +187,30 @@ mod tests { let two = $ty::splat(2.0); // StandardUniform - let mut zeros = StepRng::new(0, 0); + let mut zeros = const_rng(0); assert_eq!(zeros.random::<$ty>(), $ZERO); - let mut one = StepRng::new(1 << 8 | 1 << (8 + 32), 0); + let mut one = const_rng(1 << 8 | 1 << (8 + 32)); assert_eq!(one.random::<$ty>(), $EPSILON / two); - let mut max = StepRng::new(!0, 0); + let mut max = const_rng(!0); assert_eq!(max.random::<$ty>(), $ty::splat(1.0) - $EPSILON / two); // OpenClosed01 - let mut zeros = StepRng::new(0, 0); + let mut zeros = const_rng(0); assert_eq!(zeros.sample::<$ty, _>(OpenClosed01), $ZERO + $EPSILON / two); - let mut one = StepRng::new(1 << 8 | 1 << (8 + 32), 0); + let mut one = const_rng(1 << 8 | 1 << (8 + 32)); assert_eq!(one.sample::<$ty, _>(OpenClosed01), $EPSILON); - let mut max = StepRng::new(!0, 0); + let mut max = const_rng(!0); assert_eq!(max.sample::<$ty, _>(OpenClosed01), $ZERO + $ty::splat(1.0)); // Open01 - let mut zeros = StepRng::new(0, 0); + let mut zeros = const_rng(0); assert_eq!(zeros.sample::<$ty, _>(Open01), $ZERO + $EPSILON / two); - let mut one = StepRng::new(1 << 9 | 1 << (9 + 32), 0); + let mut one = const_rng(1 << 9 | 1 << (9 + 32)); assert_eq!( one.sample::<$ty, _>(Open01), $EPSILON / two * $ty::splat(3.0) ); - let mut max = StepRng::new(!0, 0); + let mut max = const_rng(!0); assert_eq!( max.sample::<$ty, _>(Open01), $ty::splat(1.0) - $EPSILON / two @@ -235,30 +235,30 @@ mod tests { let two = $ty::splat(2.0); // StandardUniform - let mut zeros = StepRng::new(0, 0); + let mut zeros = const_rng(0); assert_eq!(zeros.random::<$ty>(), $ZERO); - let mut one = StepRng::new(1 << 11, 0); + let mut one = const_rng(1 << 11); assert_eq!(one.random::<$ty>(), $EPSILON / two); - let mut max = StepRng::new(!0, 0); + let mut max = const_rng(!0); assert_eq!(max.random::<$ty>(), $ty::splat(1.0) - $EPSILON / two); // OpenClosed01 - let mut zeros = StepRng::new(0, 0); + let mut zeros = const_rng(0); assert_eq!(zeros.sample::<$ty, _>(OpenClosed01), $ZERO + $EPSILON / two); - let mut one = StepRng::new(1 << 11, 0); + let mut one = const_rng(1 << 11); assert_eq!(one.sample::<$ty, _>(OpenClosed01), $EPSILON); - let mut max = StepRng::new(!0, 0); + let mut max = const_rng(!0); assert_eq!(max.sample::<$ty, _>(OpenClosed01), $ZERO + $ty::splat(1.0)); // Open01 - let mut zeros = StepRng::new(0, 0); + let mut zeros = const_rng(0); assert_eq!(zeros.sample::<$ty, _>(Open01), $ZERO + $EPSILON / two); - let mut one = StepRng::new(1 << 12, 0); + let mut one = const_rng(1 << 12); assert_eq!( one.sample::<$ty, _>(Open01), $EPSILON / two * $ty::splat(3.0) ); - let mut max = StepRng::new(!0, 0); + let mut max = const_rng(!0); assert_eq!( max.sample::<$ty, _>(Open01), $ty::splat(1.0) - $EPSILON / two diff --git a/src/distr/integer.rs b/src/distr/integer.rs index d0040e69e7e..37b2081c471 100644 --- a/src/distr/integer.rs +++ b/src/distr/integer.rs @@ -107,21 +107,50 @@ impl_nzint!(NonZeroI64, NonZeroI64::new); impl_nzint!(NonZeroI128, NonZeroI128::new); #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -macro_rules! x86_intrinsic_impl { - ($meta:meta, $($intrinsic:ident),+) => {$( - #[cfg($meta)] - impl Distribution<$intrinsic> for StandardUniform { - #[inline] - fn sample(&self, rng: &mut R) -> $intrinsic { - // On proper hardware, this should compile to SIMD instructions - // Verified on x86 Haswell with __m128i, __m256i - let mut buf = [0_u8; core::mem::size_of::<$intrinsic>()]; - rng.fill_bytes(&mut buf); - // x86 is little endian so no need for conversion - zerocopy::transmute!(buf) - } - } - )+}; +impl Distribution<__m128i> for StandardUniform { + #[inline] + fn sample(&self, rng: &mut R) -> __m128i { + // NOTE: It's tempting to use the u128 impl here, but confusingly this + // results in different code (return via rdx, r10 instead of rax, rdx + // with u128 impl) and is much slower (+130 time). This version calls + // impls::fill_bytes_via_next but performs well. + + let mut buf = [0_u8; core::mem::size_of::<__m128i>()]; + rng.fill_bytes(&mut buf); + // x86 is little endian so no need for conversion + + // SAFETY: All byte sequences of `buf` represent values of the output type. + unsafe { core::mem::transmute(buf) } + } +} + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +impl Distribution<__m256i> for StandardUniform { + #[inline] + fn sample(&self, rng: &mut R) -> __m256i { + let mut buf = [0_u8; core::mem::size_of::<__m256i>()]; + rng.fill_bytes(&mut buf); + // x86 is little endian so no need for conversion + + // SAFETY: All byte sequences of `buf` represent values of the output type. + unsafe { core::mem::transmute(buf) } + } +} + +#[cfg(all( + any(target_arch = "x86", target_arch = "x86_64"), + feature = "simd_support" +))] +impl Distribution<__m512i> for StandardUniform { + #[inline] + fn sample(&self, rng: &mut R) -> __m512i { + let mut buf = [0_u8; core::mem::size_of::<__m512i>()]; + rng.fill_bytes(&mut buf); + // x86 is little endian so no need for conversion + + // SAFETY: All byte sequences of `buf` represent values of the output type. + unsafe { core::mem::transmute(buf) } + } } #[cfg(feature = "simd_support")] @@ -148,24 +177,6 @@ macro_rules! simd_impl { #[cfg(feature = "simd_support")] simd_impl!(u8, i8, u16, i16, u32, i32, u64, i64); -#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -x86_intrinsic_impl!( - any(target_arch = "x86", target_arch = "x86_64"), - __m128i, - __m256i -); -#[cfg(all( - any(target_arch = "x86", target_arch = "x86_64"), - feature = "simd_support" -))] -x86_intrinsic_impl!( - all( - any(target_arch = "x86", target_arch = "x86_64"), - feature = "simd_support" - ), - __m512i -); - #[cfg(test)] mod tests { use super::*; diff --git a/src/distr/mod.rs b/src/distr/mod.rs index 10016119ba2..a66504624bb 100644 --- a/src/distr/mod.rs +++ b/src/distr/mod.rs @@ -46,6 +46,9 @@ //! numbers of the `char` type; in contrast [`StandardUniform`] may sample any valid //! `char`. //! +//! There's also an [`Alphabetic`] distribution which acts similarly to [`Alphanumeric`] but +//! doesn't include digits. +//! //! For floats (`f32`, `f64`), [`StandardUniform`] samples from `[0, 1)`. Also //! provided are [`Open01`] (samples from `(0, 1)`) and [`OpenClosed01`] //! (samples from `(0, 1]`). No option is provided to sample from `[0, 1]`; it @@ -104,7 +107,7 @@ pub use self::bernoulli::{Bernoulli, BernoulliError}; pub use self::distribution::SampleString; pub use self::distribution::{Distribution, Iter, Map}; pub use self::float::{Open01, OpenClosed01}; -pub use self::other::Alphanumeric; +pub use self::other::{Alphabetic, Alphanumeric}; #[doc(inline)] pub use self::uniform::Uniform; @@ -126,7 +129,8 @@ use crate::Rng; /// code points in the range `0...0x10_FFFF`, except for the range /// `0xD800...0xDFFF` (the surrogate code points). This includes /// unassigned/reserved code points. -/// For some uses, the [`Alphanumeric`] distribution will be more appropriate. +/// For some uses, the [`Alphanumeric`] or [`Alphabetic`] distribution will be more +/// appropriate. /// * `bool` samples `false` or `true`, each with probability 0.5. /// * Floating point types (`f32` and `f64`) are uniformly distributed in the /// half-open range `[0, 1)`. See also the [notes below](#floating-point-implementation). diff --git a/src/distr/other.rs b/src/distr/other.rs index 9890bdafe6d..47b99323d6b 100644 --- a/src/distr/other.rs +++ b/src/distr/other.rs @@ -10,6 +10,7 @@ #[cfg(feature = "alloc")] use alloc::string::String; +use core::array; use core::char; use core::num::Wrapping; @@ -18,7 +19,6 @@ use crate::distr::SampleString; use crate::distr::{Distribution, StandardUniform, Uniform}; use crate::Rng; -use core::mem::{self, MaybeUninit}; #[cfg(feature = "simd_support")] use core::simd::prelude::*; #[cfg(feature = "simd_support")] @@ -70,6 +70,35 @@ use serde::{Deserialize, Serialize}; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Alphanumeric; +/// Sample a [`u8`], uniformly distributed over letters: +/// a-z and A-Z. +/// +/// # Example +/// +/// You're able to generate random Alphabetic characters via mapping or via the +/// [`SampleString::sample_string`] method like so: +/// +/// ``` +/// use rand::Rng; +/// use rand::distr::{Alphabetic, SampleString}; +/// +/// // Manual mapping +/// let mut rng = rand::rng(); +/// let chars: String = (0..7).map(|_| rng.sample(Alphabetic) as char).collect(); +/// println!("Random chars: {}", chars); +/// +/// // Using [`SampleString::sample_string`] +/// let string = Alphabetic.sample_string(&mut rand::rng(), 16); +/// println!("Random string: {}", string); +/// ``` +/// +/// # Passwords +/// +/// Refer to [`Alphanumeric#Passwords`]. +#[derive(Debug, Clone, Copy, Default)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct Alphabetic; + // ----- Implementations of distributions ----- impl Distribution for StandardUniform { @@ -89,6 +118,7 @@ impl Distribution for StandardUniform { if n <= 0xDFFF { n -= GAP_SIZE; } + // SAFETY: We ensure above that `n` represents a `char`. unsafe { char::from_u32_unchecked(n) } } } @@ -123,11 +153,41 @@ impl Distribution for Alphanumeric { } } +impl Distribution for Alphabetic { + fn sample(&self, rng: &mut R) -> u8 { + const RANGE: u8 = 26 + 26; + + let offset = rng.random_range(0..RANGE) + b'A'; + + // Account for upper-cases + offset + (offset > b'Z') as u8 * (b'a' - b'Z' - 1) + } +} + #[cfg(feature = "alloc")] impl SampleString for Alphanumeric { fn append_string(&self, rng: &mut R, string: &mut String, len: usize) { + // SAFETY: `self` only samples alphanumeric characters, which are valid UTF-8. + unsafe { + let v = string.as_mut_vec(); + v.extend( + self.sample_iter(rng) + .take(len) + .inspect(|b| debug_assert!(b.is_ascii_alphanumeric())), + ); + } + } +} + +#[cfg(feature = "alloc")] +impl SampleString for Alphabetic { + fn append_string(&self, rng: &mut R, string: &mut String, len: usize) { + // SAFETY: With this distribution we guarantee that we're working with valid ASCII + // characters. + // See [#1590](https://github.com/rust-random/rand/issues/1590). unsafe { let v = string.as_mut_vec(); + v.reserve_exact(len); v.extend(self.sample_iter(rng).take(len)); } } @@ -237,14 +297,8 @@ where StandardUniform: Distribution, { #[inline] - fn sample(&self, _rng: &mut R) -> [T; N] { - let mut buff: [MaybeUninit; N] = unsafe { MaybeUninit::uninit().assume_init() }; - - for elem in &mut buff { - *elem = MaybeUninit::new(_rng.random()); - } - - unsafe { mem::transmute_copy::<_, _>(&buff) } + fn sample(&self, rng: &mut R) -> [T; N] { + array::from_fn(|_| rng.random()) } } @@ -300,6 +354,20 @@ mod tests { assert!(!incorrect); } + #[test] + fn test_alphabetic() { + let mut rng = crate::test::rng(806); + + // Test by generating a relatively large number of chars, so we also + // take the rejection sampling path. + let mut incorrect = false; + for _ in 0..100 { + let c: char = rng.sample(Alphabetic).into(); + incorrect |= !c.is_ascii_alphabetic(); + } + assert!(!incorrect); + } + #[test] fn value_stability() { fn test_samples>( @@ -327,6 +395,7 @@ mod tests { ], ); test_samples(&Alphanumeric, 0, &[104, 109, 101, 51, 77]); + test_samples(&Alphabetic, 0, &[97, 102, 89, 116, 75]); test_samples(&StandardUniform, false, &[true, true, false, true, false]); test_samples( &StandardUniform, diff --git a/src/distr/uniform_float.rs b/src/distr/uniform_float.rs index adcc7b710d6..e9b0421aaf0 100644 --- a/src/distr/uniform_float.rs +++ b/src/distr/uniform_float.rs @@ -219,14 +219,14 @@ uniform_float_impl! { feature = "simd_support", f64x8, u64x8, f64, u64, 64 - 52 mod tests { use super::*; use crate::distr::{utils::FloatSIMDScalarUtils, Uniform}; - use crate::rngs::mock::StepRng; + use crate::test::{const_rng, step_rng}; #[test] #[cfg_attr(miri, ignore)] // Miri is too slow fn test_floats() { let mut rng = crate::test::rng(252); - let mut zero_rng = StepRng::new(0, 0); - let mut max_rng = StepRng::new(0xffff_ffff_ffff_ffff, 0); + let mut zero_rng = const_rng(0); + let mut max_rng = const_rng(0xffff_ffff_ffff_ffff); macro_rules! t { ($ty:ty, $f_scalar:ident, $bits_shifted:expr) => {{ let v: &[($f_scalar, $f_scalar)] = &[ @@ -248,31 +248,34 @@ mod tests { let my_uniform = Uniform::new(low, high).unwrap(); let my_incl_uniform = Uniform::new_inclusive(low, high).unwrap(); for _ in 0..100 { - let v = rng.sample(my_uniform).extract(lane); + let v = rng.sample(my_uniform).extract_lane(lane); assert!(low_scalar <= v && v <= high_scalar); - let v = rng.sample(my_incl_uniform).extract(lane); + let v = rng.sample(my_incl_uniform).extract_lane(lane); assert!(low_scalar <= v && v <= high_scalar); let v = <$ty as SampleUniform>::Sampler::sample_single(low, high, &mut rng) .unwrap() - .extract(lane); + .extract_lane(lane); assert!(low_scalar <= v && v <= high_scalar); let v = <$ty as SampleUniform>::Sampler::sample_single_inclusive( low, high, &mut rng, ) .unwrap() - .extract(lane); + .extract_lane(lane); assert!(low_scalar <= v && v <= high_scalar); } assert_eq!( rng.sample(Uniform::new_inclusive(low, low).unwrap()) - .extract(lane), + .extract_lane(lane), low_scalar ); - assert_eq!(zero_rng.sample(my_uniform).extract(lane), low_scalar); - assert_eq!(zero_rng.sample(my_incl_uniform).extract(lane), low_scalar); + assert_eq!(zero_rng.sample(my_uniform).extract_lane(lane), low_scalar); + assert_eq!( + zero_rng.sample(my_incl_uniform).extract_lane(lane), + low_scalar + ); assert_eq!( <$ty as SampleUniform>::Sampler::sample_single( low, @@ -280,7 +283,7 @@ mod tests { &mut zero_rng ) .unwrap() - .extract(lane), + .extract_lane(lane), low_scalar ); assert_eq!( @@ -290,12 +293,12 @@ mod tests { &mut zero_rng ) .unwrap() - .extract(lane), + .extract_lane(lane), low_scalar ); - assert!(max_rng.sample(my_uniform).extract(lane) <= high_scalar); - assert!(max_rng.sample(my_incl_uniform).extract(lane) <= high_scalar); + assert!(max_rng.sample(my_uniform).extract_lane(lane) <= high_scalar); + assert!(max_rng.sample(my_incl_uniform).extract_lane(lane) <= high_scalar); // sample_single cannot cope with max_rng: // assert!(<$ty as SampleUniform>::Sampler // ::sample_single(low, high, &mut max_rng).unwrap() @@ -307,7 +310,7 @@ mod tests { &mut max_rng ) .unwrap() - .extract(lane) + .extract_lane(lane) <= high_scalar ); @@ -315,10 +318,8 @@ mod tests { // since for those rounding might result in selecting high for a very // long time. if (high_scalar - low_scalar) > 0.0001 { - let mut lowering_max_rng = StepRng::new( - 0xffff_ffff_ffff_ffff, - (-1i64 << $bits_shifted) as u64, - ); + let mut lowering_max_rng = + step_rng(0xffff_ffff_ffff_ffff, (-1i64 << $bits_shifted) as u64); assert!( <$ty as SampleUniform>::Sampler::sample_single( low, @@ -326,7 +327,7 @@ mod tests { &mut lowering_max_rng ) .unwrap() - .extract(lane) + .extract_lane(lane) <= high_scalar ); } diff --git a/src/distr/utils.rs b/src/distr/utils.rs index b54dc6d6c4e..784534f48b0 100644 --- a/src/distr/utils.rs +++ b/src/distr/utils.rs @@ -236,7 +236,7 @@ pub(crate) trait FloatSIMDScalarUtils: FloatSIMDUtils { type Scalar; fn replace(self, index: usize, new_value: Self::Scalar) -> Self; - fn extract(self, index: usize) -> Self::Scalar; + fn extract_lane(self, index: usize) -> Self::Scalar; } /// Implement functions on f32/f64 to give them APIs similar to SIMD types @@ -320,7 +320,7 @@ macro_rules! scalar_float_impl { } #[inline] - fn extract(self, index: usize) -> Self::Scalar { + fn extract_lane(self, index: usize) -> Self::Scalar { debug_assert_eq!(index, 0); self } @@ -395,7 +395,7 @@ macro_rules! simd_impl { } #[inline] - fn extract(self, index: usize) -> Self::Scalar { + fn extract_lane(self, index: usize) -> Self::Scalar { self.as_array()[index] } } diff --git a/src/lib.rs b/src/lib.rs index e1a9ef4ddc1..9187c9cc16a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -59,6 +59,7 @@ clippy::neg_cmp_op_on_partial_ord, clippy::nonminimal_bool )] +#![deny(clippy::undocumented_unsafe_blocks)] #[cfg(feature = "alloc")] extern crate alloc; @@ -96,6 +97,9 @@ macro_rules! error { ($($x:tt)*) => ( } ) } +// Re-export rand_core itself +pub use rand_core; + // Re-exports from rand_core pub use rand_core::{CryptoRng, RngCore, SeedableRng, TryCryptoRng, TryRngCore}; @@ -114,7 +118,7 @@ pub use crate::rngs::thread::rng; /// /// Use [`rand::rng()`](rng()) instead. #[cfg(feature = "thread_rng")] -#[deprecated(since = "0.9.0", note = "renamed to `rng`")] +#[deprecated(since = "0.9.0", note = "Renamed to `rng`")] #[inline] pub fn thread_rng() -> crate::rngs::ThreadRng { rng() @@ -305,6 +309,34 @@ mod test { rand_pcg::Pcg32::new(seed, INC) } + /// Construct a generator yielding a constant value + pub fn const_rng(x: u64) -> StepRng { + StepRng(x, 0) + } + + /// Construct a generator yielding an arithmetic sequence + pub fn step_rng(x: u64, increment: u64) -> StepRng { + StepRng(x, increment) + } + + #[derive(Clone)] + pub struct StepRng(u64, u64); + impl RngCore for StepRng { + fn next_u32(&mut self) -> u32 { + self.next_u64() as u32 + } + + fn next_u64(&mut self) -> u64 { + let res = self.0; + self.0 = self.0.wrapping_add(self.1); + res + } + + fn fill_bytes(&mut self, dst: &mut [u8]) { + rand_core::impls::fill_bytes_via_next(self, dst) + } + } + #[test] #[cfg(feature = "thread_rng")] fn test_random() { diff --git a/src/rng.rs b/src/rng.rs index 258c87de273..c502e1ba476 100644 --- a/src/rng.rs +++ b/src/rng.rs @@ -12,8 +12,8 @@ use crate::distr::uniform::{SampleRange, SampleUniform}; use crate::distr::{self, Distribution, StandardUniform}; use core::num::Wrapping; +use core::{mem, slice}; use rand_core::RngCore; -use zerocopy::IntoBytes; /// User-level interface for RNGs /// @@ -110,11 +110,11 @@ pub trait Rng: RngCore { /// # Example /// /// ``` - /// use rand::{rngs::mock::StepRng, Rng}; + /// use rand::{rngs::SmallRng, Rng, SeedableRng}; /// - /// let rng = StepRng::new(1, 1); + /// let rng = SmallRng::seed_from_u64(0); /// let v: Vec = rng.random_iter().take(5).collect(); - /// assert_eq!(&v, &[1, 2, 3, 4, 5]); + /// assert_eq!(v.len(), 5); /// ``` #[inline] fn random_iter(self) -> distr::Iter @@ -393,14 +393,36 @@ impl Fill for [u8] { } } +/// Call target for unsafe macros +const unsafe fn __unsafe() {} + +/// Implement `Fill` for given type `$t`. +/// +/// # Safety +/// All bit patterns of `[u8; size_of::<$t>()]` must represent values of `$t`. macro_rules! impl_fill { () => {}; - ($t:ty) => { + ($t:ty) => {{ + // Force caller to wrap with an `unsafe` block + __unsafe(); + impl Fill for [$t] { - #[inline(never)] // in micro benchmarks, this improves performance fn fill(&mut self, rng: &mut R) { if self.len() > 0 { - rng.fill_bytes(self.as_mut_bytes()); + let size = mem::size_of_val(self); + rng.fill_bytes( + // SAFETY: `self` non-null and valid for reads and writes within its `size` + // bytes. `self` meets the alignment requirements of `&mut [u8]`. + // The contents of `self` are initialized. Both `[u8]` and `[$t]` are valid + // for all bit-patterns of their contents (note that the SAFETY requirement + // on callers of this macro). `self` is not borrowed. + unsafe { + slice::from_raw_parts_mut(self.as_mut_ptr() + as *mut u8, + size + ) + } + ); for x in self { *x = x.to_le(); } @@ -409,27 +431,41 @@ macro_rules! impl_fill { } impl Fill for [Wrapping<$t>] { - #[inline(never)] fn fill(&mut self, rng: &mut R) { if self.len() > 0 { - rng.fill_bytes(self.as_mut_bytes()); + let size = self.len() * mem::size_of::<$t>(); + rng.fill_bytes( + // SAFETY: `self` non-null and valid for reads and writes within its `size` + // bytes. `self` meets the alignment requirements of `&mut [u8]`. + // The contents of `self` are initialized. Both `[u8]` and `[$t]` are valid + // for all bit-patterns of their contents (note that the SAFETY requirement + // on callers of this macro). `self` is not borrowed. + unsafe { + slice::from_raw_parts_mut(self.as_mut_ptr() + as *mut u8, + size + ) + } + ); for x in self { - *x = Wrapping(x.0.to_le()); + *x = Wrapping(x.0.to_le()); } } } - } + }} }; - ($t:ty, $($tt:ty,)*) => { + ($t:ty, $($tt:ty,)*) => {{ impl_fill!($t); // TODO: this could replace above impl once Rust #32463 is fixed // impl_fill!(Wrapping<$t>); impl_fill!($($tt,)*); - } + }} } -impl_fill!(u16, u32, u64, u128,); -impl_fill!(i8, i16, i32, i64, i128,); +// SAFETY: All bit patterns of `[u8; size_of::<$t>()]` represent values of `u*`. +const _: () = unsafe { impl_fill!(u16, u32, u64, u128,) }; +// SAFETY: All bit patterns of `[u8; size_of::<$t>()]` represent values of `i*`. +const _: () = unsafe { impl_fill!(i8, i16, i32, i64, i128,) }; impl Fill for [T; N] where @@ -443,14 +479,13 @@ where #[cfg(test)] mod test { use super::*; - use crate::rngs::mock::StepRng; - use crate::test::rng; + use crate::test::{const_rng, rng}; #[cfg(feature = "alloc")] use alloc::boxed::Box; #[test] fn test_fill_bytes_default() { - let mut r = StepRng::new(0x11_22_33_44_55_66_77_88, 0); + let mut r = const_rng(0x11_22_33_44_55_66_77_88); // check every remainder mod 8, both in small and big vectors. let lengths = [0, 1, 2, 3, 4, 5, 6, 7, 80, 81, 82, 83, 84, 85, 86, 87]; @@ -471,7 +506,7 @@ mod test { #[test] fn test_fill() { let x = 9041086907909331047; // a random u64 - let mut rng = StepRng::new(x, 0); + let mut rng = const_rng(x); // Convert to byte sequence and back to u64; byte-swap twice if BE. let mut array = [0u64; 2]; @@ -501,7 +536,7 @@ mod test { #[test] fn test_fill_empty() { let mut array = [0u32; 0]; - let mut rng = StepRng::new(0, 1); + let mut rng = rng(1); rng.fill(&mut array); rng.fill(&mut array[..]); } diff --git a/src/rngs/mock.rs b/src/rngs/mock.rs index b6da66a8565..5b6a2253b18 100644 --- a/src/rngs/mock.rs +++ b/src/rngs/mock.rs @@ -8,6 +8,8 @@ //! Mock random number generator +#![allow(deprecated)] + use rand_core::{impls, RngCore}; #[cfg(feature = "serde")] @@ -31,6 +33,7 @@ use serde::{Deserialize, Serialize}; /// # Example /// /// ``` +/// # #![allow(deprecated)] /// use rand::Rng; /// use rand::rngs::mock::StepRng; /// @@ -40,6 +43,7 @@ use serde::{Deserialize, Serialize}; /// ``` #[derive(Debug, Clone, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[deprecated(since = "0.9.2", note = "Deprecated without replacement")] pub struct StepRng { v: u64, a: u64, @@ -74,30 +78,3 @@ impl RngCore for StepRng { impls::fill_bytes_via_next(self, dst) } } - -#[cfg(test)] -mod tests { - #[cfg(any(feature = "alloc", feature = "serde"))] - use super::StepRng; - - #[test] - #[cfg(feature = "serde")] - fn test_serialization_step_rng() { - let some_rng = StepRng::new(42, 7); - let de_some_rng: StepRng = - bincode::deserialize(&bincode::serialize(&some_rng).unwrap()).unwrap(); - assert_eq!(some_rng.v, de_some_rng.v); - assert_eq!(some_rng.a, de_some_rng.a); - } - - #[test] - #[cfg(feature = "alloc")] - fn test_bool() { - use crate::{distr::StandardUniform, Rng}; - - // If this result ever changes, update doc on StepRng! - let rng = StepRng::new(0, 1 << 31); - let result: alloc::vec::Vec = rng.sample_iter(StandardUniform).take(6).collect(); - assert_eq!(&result, &[false, true, false, true, false, true]); - } -} diff --git a/src/rngs/mod.rs b/src/rngs/mod.rs index cb7ed57f33e..8ce25759a26 100644 --- a/src/rngs/mod.rs +++ b/src/rngs/mod.rs @@ -80,6 +80,7 @@ mod reseeding; pub use reseeding::ReseedingRng; +#[deprecated(since = "0.9.2")] pub mod mock; // Public so we don't export `StepRng` directly, making it a bit // more clear it is intended for testing. diff --git a/src/rngs/reseeding.rs b/src/rngs/reseeding.rs index 570d04eeba4..69b9045e0de 100644 --- a/src/rngs/reseeding.rs +++ b/src/rngs/reseeding.rs @@ -253,15 +253,15 @@ where #[cfg(feature = "std_rng")] #[cfg(test)] mod test { - use crate::rngs::mock::StepRng; use crate::rngs::std::Core; + use crate::test::const_rng; use crate::Rng; use super::ReseedingRng; #[test] fn test_reseeding() { - let zero = StepRng::new(0, 0); + let zero = const_rng(0); let thresh = 1; // reseed every time the buffer is exhausted let mut reseeding = ReseedingRng::::new(thresh, zero).unwrap(); @@ -281,7 +281,7 @@ mod test { #[test] #[allow(clippy::redundant_clone)] fn test_clone_reseeding() { - let zero = StepRng::new(0, 0); + let zero = const_rng(0); let mut rng1 = ReseedingRng::::new(32 * 4, zero).unwrap(); let first: u32 = rng1.random(); diff --git a/src/seq/index.rs b/src/seq/index.rs index 852bdac76c4..7dd0513850c 100644 --- a/src/seq/index.rs +++ b/src/seq/index.rs @@ -282,10 +282,11 @@ where } } -/// Randomly sample exactly `amount` distinct indices from `0..length` +/// Randomly sample `amount` distinct indices from `0..length` /// -/// Results are in arbitrary order (there is no guarantee of shuffling or -/// ordering). +/// The result may contain less than `amount` indices if insufficient non-zero +/// weights are available. Results are returned in an arbitrary order (there is +/// no guarantee of shuffling or ordering). /// /// Function `weight` is called once for each index to provide weights. /// @@ -295,7 +296,6 @@ where /// /// Error cases: /// - [`WeightError::InvalidWeight`] when a weight is not-a-number or negative. -/// - [`WeightError::InsufficientNonZero`] when fewer than `amount` weights are positive. /// /// This implementation uses `O(length + amount)` space and `O(length)` time. #[cfg(feature = "std")] @@ -328,10 +328,13 @@ where } } -/// Randomly sample exactly `amount` distinct indices from `0..length`, and -/// return them in an arbitrary order (there is no guarantee of shuffling or -/// ordering). The weights are to be provided by the input function `weights`, -/// which will be called once for each index. +/// Randomly sample `amount` distinct indices from `0..length` +/// +/// The result may contain less than `amount` indices if insufficient non-zero +/// weights are available. Results are returned in an arbitrary order (there is +/// no guarantee of shuffling or ordering). +/// +/// Function `weight` is called once for each index to provide weights. /// /// This implementation is based on the algorithm A-ExpJ as found in /// [Efraimidis and Spirakis, 2005](https://doi.org/10.1016/j.ipl.2005.11.003). @@ -339,7 +342,6 @@ where /// /// Error cases: /// - [`WeightError::InvalidWeight`] when a weight is not-a-number or negative. -/// - [`WeightError::InsufficientNonZero`] when fewer than `amount` weights are positive. #[cfg(feature = "std")] fn sample_efraimidis_spirakis( rng: &mut R, @@ -403,28 +405,26 @@ where index += N::one(); } - if candidates.len() < amount.as_usize() { - return Err(WeightError::InsufficientNonZero); - } + if index < length { + let mut x = rng.random::().ln() / candidates.peek().unwrap().key; + while index < length { + let weight = weight(index.as_usize()).into(); + if weight > 0.0 { + x -= weight; + if x <= 0.0 { + let min_candidate = candidates.pop().unwrap(); + let t = (min_candidate.key * weight).exp(); + let key = rng.random_range(t..1.0).ln() / weight; + candidates.push(Element { index, key }); - let mut x = rng.random::().ln() / candidates.peek().unwrap().key; - while index < length { - let weight = weight(index.as_usize()).into(); - if weight > 0.0 { - x -= weight; - if x <= 0.0 { - let min_candidate = candidates.pop().unwrap(); - let t = (min_candidate.key * weight).exp(); - let key = rng.random_range(t..1.0).ln() / weight; - candidates.push(Element { index, key }); - - x = rng.random::().ln() / candidates.peek().unwrap().key; + x = rng.random::().ln() / candidates.peek().unwrap().key; + } + } else if !(weight >= 0.0) { + return Err(WeightError::InvalidWeight); } - } else if !(weight >= 0.0) { - return Err(WeightError::InvalidWeight); - } - index += N::one(); + index += N::one(); + } } Ok(IndexVec::from( @@ -653,7 +653,7 @@ mod test { } let r = sample_weighted(&mut seed_rng(423), 10, |i| i as f64, 10); - assert_eq!(r.unwrap_err(), WeightError::InsufficientNonZero); + assert_eq!(r.unwrap().len(), 9); } #[test] diff --git a/src/seq/iterator.rs b/src/seq/iterator.rs index b10d205676a..a9a9e56155c 100644 --- a/src/seq/iterator.rs +++ b/src/seq/iterator.rs @@ -134,6 +134,10 @@ pub trait IteratorRandom: Iterator + Sized { /// force every element to be created regardless call `.inspect(|e| ())`. /// /// [`choose`]: IteratorRandom::choose + // + // Clippy is wrong here: we need to iterate over all entries with the RNG to + // ensure that choosing is *stable*. + #[allow(clippy::double_ended_iterator_last)] fn choose_stable(mut self, rng: &mut R) -> Option where R: Rng + ?Sized, diff --git a/src/seq/slice.rs b/src/seq/slice.rs index d48d9d2e9f3..f909418bc48 100644 --- a/src/seq/slice.rs +++ b/src/seq/slice.rs @@ -173,26 +173,18 @@ pub trait IndexedRandom: Index { /// Biased sampling of `amount` distinct elements /// - /// Similar to [`choose_multiple`], but where the likelihood of each element's - /// inclusion in the output may be specified. The elements are returned in an - /// arbitrary, unspecified order. + /// Similar to [`choose_multiple`], but where the likelihood of each + /// element's inclusion in the output may be specified. Zero-weighted + /// elements are never returned; the result may therefore contain fewer + /// elements than `amount` even when `self.len() >= amount`. The elements + /// are returned in an arbitrary, unspecified order. /// /// The specified function `weight` maps each item `x` to a relative /// likelihood `weight(x)`. The probability of each item being selected is /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`. /// - /// If all of the weights are equal, even if they are all zero, each element has - /// an equal likelihood of being selected. - /// - /// This implementation uses `O(length + amount)` space and `O(length)` time - /// if the "nightly" feature is enabled, or `O(length)` space and - /// `O(length + amount * log length)` time otherwise. - /// - /// # Known issues - /// - /// The algorithm currently used to implement this method loses accuracy - /// when small values are used for weights. - /// See [#1476](https://github.com/rust-random/rand/issues/1476). + /// This implementation uses `O(length + amount)` space and `O(length)` time. + /// See [`index::sample_weighted`] for details. /// /// # Example /// @@ -687,7 +679,7 @@ mod test { // Case 2: All of the weights are 0 let choices = [('a', 0), ('b', 0), ('c', 0)]; let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1); - assert_eq!(r.unwrap_err(), WeightError::InsufficientNonZero); + assert_eq!(r.unwrap().len(), 0); // Case 3: Negative weights let choices = [('a', -1), ('b', 1), ('c', 1)]; diff --git a/utils/ziggurat_tables.py b/utils/ziggurat_tables.py deleted file mode 100755 index 87a766ccc36..00000000000 --- a/utils/ziggurat_tables.py +++ /dev/null @@ -1,125 +0,0 @@ -#!/usr/bin/env python -# -# Copyright 2018 Developers of the Rand project. -# Copyright 2013 The Rust Project Developers. -# -# Licensed under the Apache License, Version 2.0 or the MIT license -# , at your -# option. This file may not be copied, modified, or distributed -# except according to those terms. - -# This creates the tables used for distributions implemented using the -# ziggurat algorithm in `rand::distr;`. They are -# (basically) the tables as used in the ZIGNOR variant (Doornik 2005). -# They are changed rarely, so the generated file should be checked in -# to git. -# -# It creates 3 tables: X as in the paper, F which is f(x_i), and -# F_DIFF which is f(x_i) - f(x_{i-1}). The latter two are just cached -# values which is not done in that paper (but is done in other -# variants). Note that the adZigR table is unnecessary because of -# algebra. -# -# It is designed to be compatible with Python 2 and 3. - -from math import exp, sqrt, log, floor -import random - -# The order should match the return value of `tables` -TABLE_NAMES = ['X', 'F'] - -# The actual length of the table is 1 more, to stop -# index-out-of-bounds errors. This should match the bitwise operation -# to find `i` in `zigurrat` in `libstd/rand/mod.rs`. Also the *_R and -# *_V constants below depend on this value. -TABLE_LEN = 256 - -# equivalent to `zigNorInit` in Doornik2005, but generalised to any -# distribution. r = dR, v = dV, f = probability density function, -# f_inv = inverse of f -def tables(r, v, f, f_inv): - # compute the x_i - xvec = [0]*(TABLE_LEN+1) - - xvec[0] = v / f(r) - xvec[1] = r - - for i in range(2, TABLE_LEN): - last = xvec[i-1] - xvec[i] = f_inv(v / last + f(last)) - - # cache the f's - fvec = [0]*(TABLE_LEN+1) - for i in range(TABLE_LEN+1): - fvec[i] = f(xvec[i]) - - return xvec, fvec - -# Distributions -# N(0, 1) -def norm_f(x): - return exp(-x*x/2.0) -def norm_f_inv(y): - return sqrt(-2.0*log(y)) - -NORM_R = 3.6541528853610088 -NORM_V = 0.00492867323399 - -NORM = tables(NORM_R, NORM_V, - norm_f, norm_f_inv) - -# Exp(1) -def exp_f(x): - return exp(-x) -def exp_f_inv(y): - return -log(y) - -EXP_R = 7.69711747013104972 -EXP_V = 0.0039496598225815571993 - -EXP = tables(EXP_R, EXP_V, - exp_f, exp_f_inv) - - -# Output the tables/constants/types - -def render_static(name, type, value): - # no space or - return 'pub static %s: %s =%s;\n' % (name, type, value) - -# static `name`: [`type`, .. `len(values)`] = -# [values[0], ..., values[3], -# values[4], ..., values[7], -# ... ]; -def render_table(name, values): - rows = [] - # 4 values on each row - for i in range(0, len(values), 4): - row = values[i:i+4] - rows.append(', '.join('%.18f' % f for f in row)) - - rendered = '\n [%s]' % ',\n '.join(rows) - return render_static(name, '[f64, .. %d]' % len(values), rendered) - - -with open('ziggurat_tables.rs', 'w') as f: - f.write('''// Copyright 2018 Developers of the Rand project. -// Copyright 2013 The Rust Project Developers. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -// Tables for distributions which are sampled using the ziggurat -// algorithm. Autogenerated by `ziggurat_tables.py`. - -pub type ZigTable = &\'static [f64, .. %d]; -''' % (TABLE_LEN + 1)) - for name, tables, r in [('NORM', NORM, NORM_R), - ('EXP', EXP, EXP_R)]: - f.write(render_static('ZIG_%s_R' % name, 'f64', ' %.18f' % r)) - for (tabname, table) in zip(TABLE_NAMES, tables): - f.write(render_table('ZIG_%s_%s' % (name, tabname), table))