rand/seq/
slice.rs

1// Copyright 2018-2023 Developers of the Rand project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! `IndexedRandom`, `IndexedMutRandom`, `SliceRandom`
10
11use super::increasing_uniform::IncreasingUniform;
12use super::index;
13#[cfg(feature = "alloc")]
14use crate::distr::uniform::{SampleBorrow, SampleUniform};
15#[cfg(feature = "alloc")]
16use crate::distr::weighted::{Error as WeightError, Weight};
17use crate::Rng;
18use core::ops::{Index, IndexMut};
19
20/// Extension trait on indexable lists, providing random sampling methods.
21///
22/// This trait is implemented on `[T]` slice types. Other types supporting
23/// [`std::ops::Index<usize>`] may implement this (only [`Self::len`] must be
24/// specified).
25pub trait IndexedRandom: Index<usize> {
26    /// The length
27    fn len(&self) -> usize;
28
29    /// True when the length is zero
30    #[inline]
31    fn is_empty(&self) -> bool {
32        self.len() == 0
33    }
34
35    /// Uniformly sample one element
36    ///
37    /// Returns a reference to one uniformly-sampled random element of
38    /// the slice, or `None` if the slice is empty.
39    ///
40    /// For slices, complexity is `O(1)`.
41    ///
42    /// # Example
43    ///
44    /// ```
45    /// use rand::seq::IndexedRandom;
46    ///
47    /// let choices = [1, 2, 4, 8, 16, 32];
48    /// let mut rng = rand::rng();
49    /// println!("{:?}", choices.choose(&mut rng));
50    /// assert_eq!(choices[..0].choose(&mut rng), None);
51    /// ```
52    fn choose<R>(&self, rng: &mut R) -> Option<&Self::Output>
53    where
54        R: Rng + ?Sized,
55    {
56        if self.is_empty() {
57            None
58        } else {
59            Some(&self[rng.random_range(..self.len())])
60        }
61    }
62
63    /// Uniformly sample `amount` distinct elements from self
64    ///
65    /// Chooses `amount` elements from the slice at random, without repetition,
66    /// and in random order. The returned iterator is appropriate both for
67    /// collection into a `Vec` and filling an existing buffer (see example).
68    ///
69    /// In case this API is not sufficiently flexible, use [`index::sample`].
70    ///
71    /// For slices, complexity is the same as [`index::sample`].
72    ///
73    /// # Example
74    /// ```
75    /// use rand::seq::IndexedRandom;
76    ///
77    /// let mut rng = &mut rand::rng();
78    /// let sample = "Hello, audience!".as_bytes();
79    ///
80    /// // collect the results into a vector:
81    /// let v: Vec<u8> = sample.choose_multiple(&mut rng, 3).cloned().collect();
82    ///
83    /// // store in a buffer:
84    /// let mut buf = [0u8; 5];
85    /// for (b, slot) in sample.choose_multiple(&mut rng, buf.len()).zip(buf.iter_mut()) {
86    ///     *slot = *b;
87    /// }
88    /// ```
89    #[cfg(feature = "alloc")]
90    fn choose_multiple<R>(
91        &self,
92        rng: &mut R,
93        amount: usize,
94    ) -> SliceChooseIter<'_, Self, Self::Output>
95    where
96        Self::Output: Sized,
97        R: Rng + ?Sized,
98    {
99        let amount = core::cmp::min(amount, self.len());
100        SliceChooseIter {
101            slice: self,
102            _phantom: Default::default(),
103            indices: index::sample(rng, self.len(), amount).into_iter(),
104        }
105    }
106
107    /// Uniformly sample a fixed-size array of distinct elements from self
108    ///
109    /// Chooses `N` elements from the slice at random, without repetition,
110    /// and in random order.
111    ///
112    /// For slices, complexity is the same as [`index::sample_array`].
113    ///
114    /// # Example
115    /// ```
116    /// use rand::seq::IndexedRandom;
117    ///
118    /// let mut rng = &mut rand::rng();
119    /// let sample = "Hello, audience!".as_bytes();
120    ///
121    /// let a: [u8; 3] = sample.choose_multiple_array(&mut rng).unwrap();
122    /// ```
123    fn choose_multiple_array<R, const N: usize>(&self, rng: &mut R) -> Option<[Self::Output; N]>
124    where
125        Self::Output: Clone + Sized,
126        R: Rng + ?Sized,
127    {
128        let indices = index::sample_array(rng, self.len())?;
129        Some(indices.map(|index| self[index].clone()))
130    }
131
132    /// Biased sampling for one element
133    ///
134    /// Returns a reference to one element of the slice, sampled according
135    /// to the provided weights. Returns `None` only if the slice is empty.
136    ///
137    /// The specified function `weight` maps each item `x` to a relative
138    /// likelihood `weight(x)`. The probability of each item being selected is
139    /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`.
140    ///
141    /// For slices of length `n`, complexity is `O(n)`.
142    /// For more information about the underlying algorithm,
143    /// see the [`WeightedIndex`] distribution.
144    ///
145    /// See also [`choose_weighted_mut`].
146    ///
147    /// # Example
148    ///
149    /// ```
150    /// use rand::prelude::*;
151    ///
152    /// let choices = [('a', 2), ('b', 1), ('c', 1), ('d', 0)];
153    /// let mut rng = rand::rng();
154    /// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c',
155    /// // and 'd' will never be printed
156    /// println!("{:?}", choices.choose_weighted(&mut rng, |item| item.1).unwrap().0);
157    /// ```
158    /// [`choose`]: IndexedRandom::choose
159    /// [`choose_weighted_mut`]: IndexedMutRandom::choose_weighted_mut
160    /// [`WeightedIndex`]: crate::distr::weighted::WeightedIndex
161    #[cfg(feature = "alloc")]
162    fn choose_weighted<R, F, B, X>(
163        &self,
164        rng: &mut R,
165        weight: F,
166    ) -> Result<&Self::Output, WeightError>
167    where
168        R: Rng + ?Sized,
169        F: Fn(&Self::Output) -> B,
170        B: SampleBorrow<X>,
171        X: SampleUniform + Weight + PartialOrd<X>,
172    {
173        use crate::distr::{weighted::WeightedIndex, Distribution};
174        let distr = WeightedIndex::new((0..self.len()).map(|idx| weight(&self[idx])))?;
175        Ok(&self[distr.sample(rng)])
176    }
177
178    /// Biased sampling of `amount` distinct elements
179    ///
180    /// Similar to [`choose_multiple`], but where the likelihood of each
181    /// element's inclusion in the output may be specified. Zero-weighted
182    /// elements are never returned; the result may therefore contain fewer
183    /// elements than `amount` even when `self.len() >= amount`. The elements
184    /// are returned in an arbitrary, unspecified order.
185    ///
186    /// The specified function `weight` maps each item `x` to a relative
187    /// likelihood `weight(x)`. The probability of each item being selected is
188    /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`.
189    ///
190    /// This implementation uses `O(length + amount)` space and `O(length)` time.
191    /// See [`index::sample_weighted`] for details.
192    ///
193    /// # Example
194    ///
195    /// ```
196    /// use rand::prelude::*;
197    ///
198    /// let choices = [('a', 2), ('b', 1), ('c', 1)];
199    /// let mut rng = rand::rng();
200    /// // First Draw * Second Draw = total odds
201    /// // -----------------------
202    /// // (50% * 50%) + (25% * 67%) = 41.7% chance that the output is `['a', 'b']` in some order.
203    /// // (50% * 50%) + (25% * 67%) = 41.7% chance that the output is `['a', 'c']` in some order.
204    /// // (25% * 33%) + (25% * 33%) = 16.6% chance that the output is `['b', 'c']` in some order.
205    /// println!("{:?}", choices.choose_multiple_weighted(&mut rng, 2, |item| item.1).unwrap().collect::<Vec<_>>());
206    /// ```
207    /// [`choose_multiple`]: IndexedRandom::choose_multiple
208    // Note: this is feature-gated on std due to usage of f64::powf.
209    // If necessary, we may use alloc+libm as an alternative (see PR #1089).
210    #[cfg(feature = "std")]
211    fn choose_multiple_weighted<R, F, X>(
212        &self,
213        rng: &mut R,
214        amount: usize,
215        weight: F,
216    ) -> Result<SliceChooseIter<'_, Self, Self::Output>, WeightError>
217    where
218        Self::Output: Sized,
219        R: Rng + ?Sized,
220        F: Fn(&Self::Output) -> X,
221        X: Into<f64>,
222    {
223        let amount = core::cmp::min(amount, self.len());
224        Ok(SliceChooseIter {
225            slice: self,
226            _phantom: Default::default(),
227            indices: index::sample_weighted(
228                rng,
229                self.len(),
230                |idx| weight(&self[idx]).into(),
231                amount,
232            )?
233            .into_iter(),
234        })
235    }
236}
237
238/// Extension trait on indexable lists, providing random sampling methods.
239///
240/// This trait is implemented automatically for every type implementing
241/// [`IndexedRandom`] and [`std::ops::IndexMut<usize>`].
242pub trait IndexedMutRandom: IndexedRandom + IndexMut<usize> {
243    /// Uniformly sample one element (mut)
244    ///
245    /// Returns a mutable reference to one uniformly-sampled random element of
246    /// the slice, or `None` if the slice is empty.
247    ///
248    /// For slices, complexity is `O(1)`.
249    fn choose_mut<R>(&mut self, rng: &mut R) -> Option<&mut Self::Output>
250    where
251        R: Rng + ?Sized,
252    {
253        if self.is_empty() {
254            None
255        } else {
256            let len = self.len();
257            Some(&mut self[rng.random_range(..len)])
258        }
259    }
260
261    /// Biased sampling for one element (mut)
262    ///
263    /// Returns a mutable reference to one element of the slice, sampled according
264    /// to the provided weights. Returns `None` only if the slice is empty.
265    ///
266    /// The specified function `weight` maps each item `x` to a relative
267    /// likelihood `weight(x)`. The probability of each item being selected is
268    /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`.
269    ///
270    /// For slices of length `n`, complexity is `O(n)`.
271    /// For more information about the underlying algorithm,
272    /// see the [`WeightedIndex`] distribution.
273    ///
274    /// See also [`choose_weighted`].
275    ///
276    /// [`choose_mut`]: IndexedMutRandom::choose_mut
277    /// [`choose_weighted`]: IndexedRandom::choose_weighted
278    /// [`WeightedIndex`]: crate::distr::weighted::WeightedIndex
279    #[cfg(feature = "alloc")]
280    fn choose_weighted_mut<R, F, B, X>(
281        &mut self,
282        rng: &mut R,
283        weight: F,
284    ) -> Result<&mut Self::Output, WeightError>
285    where
286        R: Rng + ?Sized,
287        F: Fn(&Self::Output) -> B,
288        B: SampleBorrow<X>,
289        X: SampleUniform + Weight + PartialOrd<X>,
290    {
291        use crate::distr::{weighted::WeightedIndex, Distribution};
292        let distr = WeightedIndex::new((0..self.len()).map(|idx| weight(&self[idx])))?;
293        let index = distr.sample(rng);
294        Ok(&mut self[index])
295    }
296}
297
298/// Extension trait on slices, providing shuffling methods.
299///
300/// This trait is implemented on all `[T]` slice types, providing several
301/// methods for choosing and shuffling elements. You must `use` this trait:
302///
303/// ```
304/// use rand::seq::SliceRandom;
305///
306/// let mut rng = rand::rng();
307/// let mut bytes = "Hello, random!".to_string().into_bytes();
308/// bytes.shuffle(&mut rng);
309/// let str = String::from_utf8(bytes).unwrap();
310/// println!("{}", str);
311/// ```
312/// Example output (non-deterministic):
313/// ```none
314/// l,nmroHado !le
315/// ```
316pub trait SliceRandom: IndexedMutRandom {
317    /// Shuffle a mutable slice in place.
318    ///
319    /// For slices of length `n`, complexity is `O(n)`.
320    /// The resulting permutation is picked uniformly from the set of all possible permutations.
321    ///
322    /// # Example
323    ///
324    /// ```
325    /// use rand::seq::SliceRandom;
326    ///
327    /// let mut rng = rand::rng();
328    /// let mut y = [1, 2, 3, 4, 5];
329    /// println!("Unshuffled: {:?}", y);
330    /// y.shuffle(&mut rng);
331    /// println!("Shuffled:   {:?}", y);
332    /// ```
333    fn shuffle<R>(&mut self, rng: &mut R)
334    where
335        R: Rng + ?Sized;
336
337    /// Shuffle a slice in place, but exit early.
338    ///
339    /// Returns two mutable slices from the source slice. The first contains
340    /// `amount` elements randomly permuted. The second has the remaining
341    /// elements that are not fully shuffled.
342    ///
343    /// This is an efficient method to select `amount` elements at random from
344    /// the slice, provided the slice may be mutated.
345    ///
346    /// If you only need to choose elements randomly and `amount > self.len()/2`
347    /// then you may improve performance by taking
348    /// `amount = self.len() - amount` and using only the second slice.
349    ///
350    /// If `amount` is greater than the number of elements in the slice, this
351    /// will perform a full shuffle.
352    ///
353    /// For slices, complexity is `O(m)` where `m = amount`.
354    fn partial_shuffle<R>(
355        &mut self,
356        rng: &mut R,
357        amount: usize,
358    ) -> (&mut [Self::Output], &mut [Self::Output])
359    where
360        Self::Output: Sized,
361        R: Rng + ?Sized;
362}
363
364impl<T> IndexedRandom for [T] {
365    fn len(&self) -> usize {
366        self.len()
367    }
368}
369
370impl<IR: IndexedRandom + IndexMut<usize> + ?Sized> IndexedMutRandom for IR {}
371
372impl<T> SliceRandom for [T] {
373    fn shuffle<R>(&mut self, rng: &mut R)
374    where
375        R: Rng + ?Sized,
376    {
377        if self.len() <= 1 {
378            // There is no need to shuffle an empty or single element slice
379            return;
380        }
381        self.partial_shuffle(rng, self.len());
382    }
383
384    fn partial_shuffle<R>(&mut self, rng: &mut R, amount: usize) -> (&mut [T], &mut [T])
385    where
386        R: Rng + ?Sized,
387    {
388        let m = self.len().saturating_sub(amount);
389
390        // The algorithm below is based on Durstenfeld's algorithm for the
391        // [Fisher–Yates shuffle](https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm)
392        // for an unbiased permutation.
393        // It ensures that the last `amount` elements of the slice
394        // are randomly selected from the whole slice.
395
396        // `IncreasingUniform::next_index()` is faster than `Rng::random_range`
397        // but only works for 32 bit integers
398        // So we must use the slow method if the slice is longer than that.
399        if self.len() < (u32::MAX as usize) {
400            let mut chooser = IncreasingUniform::new(rng, m as u32);
401            for i in m..self.len() {
402                let index = chooser.next_index();
403                self.swap(i, index);
404            }
405        } else {
406            for i in m..self.len() {
407                let index = rng.random_range(..i + 1);
408                self.swap(i, index);
409            }
410        }
411        let r = self.split_at_mut(m);
412        (r.1, r.0)
413    }
414}
415
416/// An iterator over multiple slice elements.
417///
418/// This struct is created by
419/// [`IndexedRandom::choose_multiple`](trait.IndexedRandom.html#tymethod.choose_multiple).
420#[cfg(feature = "alloc")]
421#[derive(Debug)]
422pub struct SliceChooseIter<'a, S: ?Sized + 'a, T: 'a> {
423    slice: &'a S,
424    _phantom: core::marker::PhantomData<T>,
425    indices: index::IndexVecIntoIter,
426}
427
428#[cfg(feature = "alloc")]
429impl<'a, S: Index<usize, Output = T> + ?Sized + 'a, T: 'a> Iterator for SliceChooseIter<'a, S, T> {
430    type Item = &'a T;
431
432    fn next(&mut self) -> Option<Self::Item> {
433        // TODO: investigate using SliceIndex::get_unchecked when stable
434        self.indices.next().map(|i| &self.slice[i])
435    }
436
437    fn size_hint(&self) -> (usize, Option<usize>) {
438        (self.indices.len(), Some(self.indices.len()))
439    }
440}
441
442#[cfg(feature = "alloc")]
443impl<'a, S: Index<usize, Output = T> + ?Sized + 'a, T: 'a> ExactSizeIterator
444    for SliceChooseIter<'a, S, T>
445{
446    fn len(&self) -> usize {
447        self.indices.len()
448    }
449}
450
451#[cfg(test)]
452mod test {
453    use super::*;
454    #[cfg(feature = "alloc")]
455    use alloc::vec::Vec;
456
457    #[test]
458    fn test_slice_choose() {
459        let mut r = crate::test::rng(107);
460        let chars = [
461            'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n',
462        ];
463        let mut chosen = [0i32; 14];
464        // The below all use a binomial distribution with n=1000, p=1/14.
465        // binocdf(40, 1000, 1/14) ~= 2e-5; 1-binocdf(106, ..) ~= 2e-5
466        for _ in 0..1000 {
467            let picked = *chars.choose(&mut r).unwrap();
468            chosen[(picked as usize) - ('a' as usize)] += 1;
469        }
470        for count in chosen.iter() {
471            assert!(40 < *count && *count < 106);
472        }
473
474        chosen.iter_mut().for_each(|x| *x = 0);
475        for _ in 0..1000 {
476            *chosen.choose_mut(&mut r).unwrap() += 1;
477        }
478        for count in chosen.iter() {
479            assert!(40 < *count && *count < 106);
480        }
481
482        let mut v: [isize; 0] = [];
483        assert_eq!(v.choose(&mut r), None);
484        assert_eq!(v.choose_mut(&mut r), None);
485    }
486
487    #[test]
488    fn value_stability_slice() {
489        let mut r = crate::test::rng(413);
490        let chars = [
491            'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n',
492        ];
493        let mut nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
494
495        assert_eq!(chars.choose(&mut r), Some(&'l'));
496        assert_eq!(nums.choose_mut(&mut r), Some(&mut 3));
497
498        assert_eq!(
499            &chars.choose_multiple_array(&mut r),
500            &Some(['f', 'i', 'd', 'b', 'c', 'm', 'j', 'k'])
501        );
502
503        #[cfg(feature = "alloc")]
504        assert_eq!(
505            &chars
506                .choose_multiple(&mut r, 8)
507                .cloned()
508                .collect::<Vec<char>>(),
509            &['h', 'm', 'd', 'b', 'c', 'e', 'n', 'f']
510        );
511
512        #[cfg(feature = "alloc")]
513        assert_eq!(chars.choose_weighted(&mut r, |_| 1), Ok(&'i'));
514        #[cfg(feature = "alloc")]
515        assert_eq!(nums.choose_weighted_mut(&mut r, |_| 1), Ok(&mut 2));
516
517        let mut r = crate::test::rng(414);
518        nums.shuffle(&mut r);
519        assert_eq!(nums, [5, 11, 0, 8, 7, 12, 6, 4, 9, 3, 1, 2, 10]);
520        nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
521        let res = nums.partial_shuffle(&mut r, 6);
522        assert_eq!(res.0, &mut [7, 12, 6, 8, 1, 9]);
523        assert_eq!(res.1, &mut [0, 11, 2, 3, 4, 5, 10]);
524    }
525
526    #[test]
527    #[cfg_attr(miri, ignore)] // Miri is too slow
528    fn test_shuffle() {
529        let mut r = crate::test::rng(108);
530        let empty: &mut [isize] = &mut [];
531        empty.shuffle(&mut r);
532        let mut one = [1];
533        one.shuffle(&mut r);
534        let b: &[_] = &[1];
535        assert_eq!(one, b);
536
537        let mut two = [1, 2];
538        two.shuffle(&mut r);
539        assert!(two == [1, 2] || two == [2, 1]);
540
541        fn move_last(slice: &mut [usize], pos: usize) {
542            // use slice[pos..].rotate_left(1); once we can use that
543            let last_val = slice[pos];
544            for i in pos..slice.len() - 1 {
545                slice[i] = slice[i + 1];
546            }
547            *slice.last_mut().unwrap() = last_val;
548        }
549        let mut counts = [0i32; 24];
550        for _ in 0..10000 {
551            let mut arr: [usize; 4] = [0, 1, 2, 3];
552            arr.shuffle(&mut r);
553            let mut permutation = 0usize;
554            let mut pos_value = counts.len();
555            for i in 0..4 {
556                pos_value /= 4 - i;
557                let pos = arr.iter().position(|&x| x == i).unwrap();
558                assert!(pos < (4 - i));
559                permutation += pos * pos_value;
560                move_last(&mut arr, pos);
561                assert_eq!(arr[3], i);
562            }
563            for (i, &a) in arr.iter().enumerate() {
564                assert_eq!(a, i);
565            }
566            counts[permutation] += 1;
567        }
568        for count in counts.iter() {
569            // Binomial(10000, 1/24) with average 416.667
570            // Octave: binocdf(n, 10000, 1/24)
571            // 99.9% chance samples lie within this range:
572            assert!(352 <= *count && *count <= 483, "count: {}", count);
573        }
574    }
575
576    #[test]
577    fn test_partial_shuffle() {
578        let mut r = crate::test::rng(118);
579
580        let mut empty: [u32; 0] = [];
581        let res = empty.partial_shuffle(&mut r, 10);
582        assert_eq!((res.0.len(), res.1.len()), (0, 0));
583
584        let mut v = [1, 2, 3, 4, 5];
585        let res = v.partial_shuffle(&mut r, 2);
586        assert_eq!((res.0.len(), res.1.len()), (2, 3));
587        assert!(res.0[0] != res.0[1]);
588        // First elements are only modified if selected, so at least one isn't modified:
589        assert!(res.1[0] == 1 || res.1[1] == 2 || res.1[2] == 3);
590    }
591
592    #[test]
593    #[cfg(feature = "alloc")]
594    #[cfg_attr(miri, ignore)] // Miri is too slow
595    fn test_weighted() {
596        let mut r = crate::test::rng(406);
597        const N_REPS: u32 = 3000;
598        let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7];
599        let total_weight = weights.iter().sum::<u32>() as f32;
600
601        let verify = |result: [i32; 14]| {
602            for (i, count) in result.iter().enumerate() {
603                let exp = (weights[i] * N_REPS) as f32 / total_weight;
604                let mut err = (*count as f32 - exp).abs();
605                if err != 0.0 {
606                    err /= exp;
607                }
608                assert!(err <= 0.25);
609            }
610        };
611
612        // choose_weighted
613        fn get_weight<T>(item: &(u32, T)) -> u32 {
614            item.0
615        }
616        let mut chosen = [0i32; 14];
617        let mut items = [(0u32, 0usize); 14]; // (weight, index)
618        for (i, item) in items.iter_mut().enumerate() {
619            *item = (weights[i], i);
620        }
621        for _ in 0..N_REPS {
622            let item = items.choose_weighted(&mut r, get_weight).unwrap();
623            chosen[item.1] += 1;
624        }
625        verify(chosen);
626
627        // choose_weighted_mut
628        let mut items = [(0u32, 0i32); 14]; // (weight, count)
629        for (i, item) in items.iter_mut().enumerate() {
630            *item = (weights[i], 0);
631        }
632        for _ in 0..N_REPS {
633            items.choose_weighted_mut(&mut r, get_weight).unwrap().1 += 1;
634        }
635        for (ch, item) in chosen.iter_mut().zip(items.iter()) {
636            *ch = item.1;
637        }
638        verify(chosen);
639
640        // Check error cases
641        let empty_slice = &mut [10][0..0];
642        assert_eq!(
643            empty_slice.choose_weighted(&mut r, |_| 1),
644            Err(WeightError::InvalidInput)
645        );
646        assert_eq!(
647            empty_slice.choose_weighted_mut(&mut r, |_| 1),
648            Err(WeightError::InvalidInput)
649        );
650        assert_eq!(
651            ['x'].choose_weighted_mut(&mut r, |_| 0),
652            Err(WeightError::InsufficientNonZero)
653        );
654        assert_eq!(
655            [0, -1].choose_weighted_mut(&mut r, |x| *x),
656            Err(WeightError::InvalidWeight)
657        );
658        assert_eq!(
659            [-1, 0].choose_weighted_mut(&mut r, |x| *x),
660            Err(WeightError::InvalidWeight)
661        );
662    }
663
664    #[test]
665    #[cfg(feature = "std")]
666    fn test_multiple_weighted_edge_cases() {
667        use super::*;
668
669        let mut rng = crate::test::rng(413);
670
671        // Case 1: One of the weights is 0
672        let choices = [('a', 2), ('b', 1), ('c', 0)];
673        for _ in 0..100 {
674            let result = choices
675                .choose_multiple_weighted(&mut rng, 2, |item| item.1)
676                .unwrap()
677                .collect::<Vec<_>>();
678
679            assert_eq!(result.len(), 2);
680            assert!(!result.iter().any(|val| val.0 == 'c'));
681        }
682
683        // Case 2: All of the weights are 0
684        let choices = [('a', 0), ('b', 0), ('c', 0)];
685        let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1);
686        assert_eq!(r.unwrap().len(), 0);
687
688        // Case 3: Negative weights
689        let choices = [('a', -1), ('b', 1), ('c', 1)];
690        let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1);
691        assert_eq!(r.unwrap_err(), WeightError::InvalidWeight);
692
693        // Case 4: Empty list
694        let choices = [];
695        let r = choices.choose_multiple_weighted(&mut rng, 0, |_: &()| 0);
696        assert_eq!(r.unwrap().count(), 0);
697
698        // Case 5: NaN weights
699        let choices = [('a', f64::NAN), ('b', 1.0), ('c', 1.0)];
700        let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1);
701        assert_eq!(r.unwrap_err(), WeightError::InvalidWeight);
702
703        // Case 6: +infinity weights
704        let choices = [('a', f64::INFINITY), ('b', 1.0), ('c', 1.0)];
705        for _ in 0..100 {
706            let result = choices
707                .choose_multiple_weighted(&mut rng, 2, |item| item.1)
708                .unwrap()
709                .collect::<Vec<_>>();
710            assert_eq!(result.len(), 2);
711            assert!(result.iter().any(|val| val.0 == 'a'));
712        }
713
714        // Case 7: -infinity weights
715        let choices = [('a', f64::NEG_INFINITY), ('b', 1.0), ('c', 1.0)];
716        let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1);
717        assert_eq!(r.unwrap_err(), WeightError::InvalidWeight);
718
719        // Case 8: -0 weights
720        let choices = [('a', -0.0), ('b', 1.0), ('c', 1.0)];
721        let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1);
722        assert!(r.is_ok());
723    }
724
725    #[test]
726    #[cfg(feature = "std")]
727    fn test_multiple_weighted_distributions() {
728        use super::*;
729
730        // The theoretical probabilities of the different outcomes are:
731        // AB: 0.5   * 0.667 = 0.3333
732        // AC: 0.5   * 0.333 = 0.1667
733        // BA: 0.333 * 0.75  = 0.25
734        // BC: 0.333 * 0.25  = 0.0833
735        // CA: 0.167 * 0.6   = 0.1
736        // CB: 0.167 * 0.4   = 0.0667
737        let choices = [('a', 3), ('b', 2), ('c', 1)];
738        let mut rng = crate::test::rng(414);
739
740        let mut results = [0i32; 3];
741        let expected_results = [5833, 2667, 1500];
742        for _ in 0..10000 {
743            let result = choices
744                .choose_multiple_weighted(&mut rng, 2, |item| item.1)
745                .unwrap()
746                .collect::<Vec<_>>();
747
748            assert_eq!(result.len(), 2);
749
750            match (result[0].0, result[1].0) {
751                ('a', 'b') | ('b', 'a') => {
752                    results[0] += 1;
753                }
754                ('a', 'c') | ('c', 'a') => {
755                    results[1] += 1;
756                }
757                ('b', 'c') | ('c', 'b') => {
758                    results[2] += 1;
759                }
760                (_, _) => panic!("unexpected result"),
761            }
762        }
763
764        let mut diffs = results
765            .iter()
766            .zip(&expected_results)
767            .map(|(a, b)| (a - b).abs());
768        assert!(!diffs.any(|deviation| deviation > 100));
769    }
770}