rand/seq/slice.rs
1// Copyright 2018-2023 Developers of the Rand project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! `IndexedRandom`, `IndexedMutRandom`, `SliceRandom`
10
11use super::increasing_uniform::IncreasingUniform;
12use super::index;
13#[cfg(feature = "alloc")]
14use crate::distr::uniform::{SampleBorrow, SampleUniform};
15#[cfg(feature = "alloc")]
16use crate::distr::weighted::{Error as WeightError, Weight};
17use crate::Rng;
18use core::ops::{Index, IndexMut};
19
20/// Extension trait on indexable lists, providing random sampling methods.
21///
22/// This trait is implemented on `[T]` slice types. Other types supporting
23/// [`std::ops::Index<usize>`] may implement this (only [`Self::len`] must be
24/// specified).
25pub trait IndexedRandom: Index<usize> {
26 /// The length
27 fn len(&self) -> usize;
28
29 /// True when the length is zero
30 #[inline]
31 fn is_empty(&self) -> bool {
32 self.len() == 0
33 }
34
35 /// Uniformly sample one element
36 ///
37 /// Returns a reference to one uniformly-sampled random element of
38 /// the slice, or `None` if the slice is empty.
39 ///
40 /// For slices, complexity is `O(1)`.
41 ///
42 /// # Example
43 ///
44 /// ```
45 /// use rand::seq::IndexedRandom;
46 ///
47 /// let choices = [1, 2, 4, 8, 16, 32];
48 /// let mut rng = rand::rng();
49 /// println!("{:?}", choices.choose(&mut rng));
50 /// assert_eq!(choices[..0].choose(&mut rng), None);
51 /// ```
52 fn choose<R>(&self, rng: &mut R) -> Option<&Self::Output>
53 where
54 R: Rng + ?Sized,
55 {
56 if self.is_empty() {
57 None
58 } else {
59 Some(&self[rng.random_range(..self.len())])
60 }
61 }
62
63 /// Uniformly sample `amount` distinct elements from self
64 ///
65 /// Chooses `amount` elements from the slice at random, without repetition,
66 /// and in random order. The returned iterator is appropriate both for
67 /// collection into a `Vec` and filling an existing buffer (see example).
68 ///
69 /// In case this API is not sufficiently flexible, use [`index::sample`].
70 ///
71 /// For slices, complexity is the same as [`index::sample`].
72 ///
73 /// # Example
74 /// ```
75 /// use rand::seq::IndexedRandom;
76 ///
77 /// let mut rng = &mut rand::rng();
78 /// let sample = "Hello, audience!".as_bytes();
79 ///
80 /// // collect the results into a vector:
81 /// let v: Vec<u8> = sample.choose_multiple(&mut rng, 3).cloned().collect();
82 ///
83 /// // store in a buffer:
84 /// let mut buf = [0u8; 5];
85 /// for (b, slot) in sample.choose_multiple(&mut rng, buf.len()).zip(buf.iter_mut()) {
86 /// *slot = *b;
87 /// }
88 /// ```
89 #[cfg(feature = "alloc")]
90 fn choose_multiple<R>(
91 &self,
92 rng: &mut R,
93 amount: usize,
94 ) -> SliceChooseIter<'_, Self, Self::Output>
95 where
96 Self::Output: Sized,
97 R: Rng + ?Sized,
98 {
99 let amount = core::cmp::min(amount, self.len());
100 SliceChooseIter {
101 slice: self,
102 _phantom: Default::default(),
103 indices: index::sample(rng, self.len(), amount).into_iter(),
104 }
105 }
106
107 /// Uniformly sample a fixed-size array of distinct elements from self
108 ///
109 /// Chooses `N` elements from the slice at random, without repetition,
110 /// and in random order.
111 ///
112 /// For slices, complexity is the same as [`index::sample_array`].
113 ///
114 /// # Example
115 /// ```
116 /// use rand::seq::IndexedRandom;
117 ///
118 /// let mut rng = &mut rand::rng();
119 /// let sample = "Hello, audience!".as_bytes();
120 ///
121 /// let a: [u8; 3] = sample.choose_multiple_array(&mut rng).unwrap();
122 /// ```
123 fn choose_multiple_array<R, const N: usize>(&self, rng: &mut R) -> Option<[Self::Output; N]>
124 where
125 Self::Output: Clone + Sized,
126 R: Rng + ?Sized,
127 {
128 let indices = index::sample_array(rng, self.len())?;
129 Some(indices.map(|index| self[index].clone()))
130 }
131
132 /// Biased sampling for one element
133 ///
134 /// Returns a reference to one element of the slice, sampled according
135 /// to the provided weights. Returns `None` only if the slice is empty.
136 ///
137 /// The specified function `weight` maps each item `x` to a relative
138 /// likelihood `weight(x)`. The probability of each item being selected is
139 /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`.
140 ///
141 /// For slices of length `n`, complexity is `O(n)`.
142 /// For more information about the underlying algorithm,
143 /// see the [`WeightedIndex`] distribution.
144 ///
145 /// See also [`choose_weighted_mut`].
146 ///
147 /// # Example
148 ///
149 /// ```
150 /// use rand::prelude::*;
151 ///
152 /// let choices = [('a', 2), ('b', 1), ('c', 1), ('d', 0)];
153 /// let mut rng = rand::rng();
154 /// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c',
155 /// // and 'd' will never be printed
156 /// println!("{:?}", choices.choose_weighted(&mut rng, |item| item.1).unwrap().0);
157 /// ```
158 /// [`choose`]: IndexedRandom::choose
159 /// [`choose_weighted_mut`]: IndexedMutRandom::choose_weighted_mut
160 /// [`WeightedIndex`]: crate::distr::weighted::WeightedIndex
161 #[cfg(feature = "alloc")]
162 fn choose_weighted<R, F, B, X>(
163 &self,
164 rng: &mut R,
165 weight: F,
166 ) -> Result<&Self::Output, WeightError>
167 where
168 R: Rng + ?Sized,
169 F: Fn(&Self::Output) -> B,
170 B: SampleBorrow<X>,
171 X: SampleUniform + Weight + PartialOrd<X>,
172 {
173 use crate::distr::{weighted::WeightedIndex, Distribution};
174 let distr = WeightedIndex::new((0..self.len()).map(|idx| weight(&self[idx])))?;
175 Ok(&self[distr.sample(rng)])
176 }
177
178 /// Biased sampling of `amount` distinct elements
179 ///
180 /// Similar to [`choose_multiple`], but where the likelihood of each
181 /// element's inclusion in the output may be specified. Zero-weighted
182 /// elements are never returned; the result may therefore contain fewer
183 /// elements than `amount` even when `self.len() >= amount`. The elements
184 /// are returned in an arbitrary, unspecified order.
185 ///
186 /// The specified function `weight` maps each item `x` to a relative
187 /// likelihood `weight(x)`. The probability of each item being selected is
188 /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`.
189 ///
190 /// This implementation uses `O(length + amount)` space and `O(length)` time.
191 /// See [`index::sample_weighted`] for details.
192 ///
193 /// # Example
194 ///
195 /// ```
196 /// use rand::prelude::*;
197 ///
198 /// let choices = [('a', 2), ('b', 1), ('c', 1)];
199 /// let mut rng = rand::rng();
200 /// // First Draw * Second Draw = total odds
201 /// // -----------------------
202 /// // (50% * 50%) + (25% * 67%) = 41.7% chance that the output is `['a', 'b']` in some order.
203 /// // (50% * 50%) + (25% * 67%) = 41.7% chance that the output is `['a', 'c']` in some order.
204 /// // (25% * 33%) + (25% * 33%) = 16.6% chance that the output is `['b', 'c']` in some order.
205 /// println!("{:?}", choices.choose_multiple_weighted(&mut rng, 2, |item| item.1).unwrap().collect::<Vec<_>>());
206 /// ```
207 /// [`choose_multiple`]: IndexedRandom::choose_multiple
208 // Note: this is feature-gated on std due to usage of f64::powf.
209 // If necessary, we may use alloc+libm as an alternative (see PR #1089).
210 #[cfg(feature = "std")]
211 fn choose_multiple_weighted<R, F, X>(
212 &self,
213 rng: &mut R,
214 amount: usize,
215 weight: F,
216 ) -> Result<SliceChooseIter<'_, Self, Self::Output>, WeightError>
217 where
218 Self::Output: Sized,
219 R: Rng + ?Sized,
220 F: Fn(&Self::Output) -> X,
221 X: Into<f64>,
222 {
223 let amount = core::cmp::min(amount, self.len());
224 Ok(SliceChooseIter {
225 slice: self,
226 _phantom: Default::default(),
227 indices: index::sample_weighted(
228 rng,
229 self.len(),
230 |idx| weight(&self[idx]).into(),
231 amount,
232 )?
233 .into_iter(),
234 })
235 }
236}
237
238/// Extension trait on indexable lists, providing random sampling methods.
239///
240/// This trait is implemented automatically for every type implementing
241/// [`IndexedRandom`] and [`std::ops::IndexMut<usize>`].
242pub trait IndexedMutRandom: IndexedRandom + IndexMut<usize> {
243 /// Uniformly sample one element (mut)
244 ///
245 /// Returns a mutable reference to one uniformly-sampled random element of
246 /// the slice, or `None` if the slice is empty.
247 ///
248 /// For slices, complexity is `O(1)`.
249 fn choose_mut<R>(&mut self, rng: &mut R) -> Option<&mut Self::Output>
250 where
251 R: Rng + ?Sized,
252 {
253 if self.is_empty() {
254 None
255 } else {
256 let len = self.len();
257 Some(&mut self[rng.random_range(..len)])
258 }
259 }
260
261 /// Biased sampling for one element (mut)
262 ///
263 /// Returns a mutable reference to one element of the slice, sampled according
264 /// to the provided weights. Returns `None` only if the slice is empty.
265 ///
266 /// The specified function `weight` maps each item `x` to a relative
267 /// likelihood `weight(x)`. The probability of each item being selected is
268 /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`.
269 ///
270 /// For slices of length `n`, complexity is `O(n)`.
271 /// For more information about the underlying algorithm,
272 /// see the [`WeightedIndex`] distribution.
273 ///
274 /// See also [`choose_weighted`].
275 ///
276 /// [`choose_mut`]: IndexedMutRandom::choose_mut
277 /// [`choose_weighted`]: IndexedRandom::choose_weighted
278 /// [`WeightedIndex`]: crate::distr::weighted::WeightedIndex
279 #[cfg(feature = "alloc")]
280 fn choose_weighted_mut<R, F, B, X>(
281 &mut self,
282 rng: &mut R,
283 weight: F,
284 ) -> Result<&mut Self::Output, WeightError>
285 where
286 R: Rng + ?Sized,
287 F: Fn(&Self::Output) -> B,
288 B: SampleBorrow<X>,
289 X: SampleUniform + Weight + PartialOrd<X>,
290 {
291 use crate::distr::{weighted::WeightedIndex, Distribution};
292 let distr = WeightedIndex::new((0..self.len()).map(|idx| weight(&self[idx])))?;
293 let index = distr.sample(rng);
294 Ok(&mut self[index])
295 }
296}
297
298/// Extension trait on slices, providing shuffling methods.
299///
300/// This trait is implemented on all `[T]` slice types, providing several
301/// methods for choosing and shuffling elements. You must `use` this trait:
302///
303/// ```
304/// use rand::seq::SliceRandom;
305///
306/// let mut rng = rand::rng();
307/// let mut bytes = "Hello, random!".to_string().into_bytes();
308/// bytes.shuffle(&mut rng);
309/// let str = String::from_utf8(bytes).unwrap();
310/// println!("{}", str);
311/// ```
312/// Example output (non-deterministic):
313/// ```none
314/// l,nmroHado !le
315/// ```
316pub trait SliceRandom: IndexedMutRandom {
317 /// Shuffle a mutable slice in place.
318 ///
319 /// For slices of length `n`, complexity is `O(n)`.
320 /// The resulting permutation is picked uniformly from the set of all possible permutations.
321 ///
322 /// # Example
323 ///
324 /// ```
325 /// use rand::seq::SliceRandom;
326 ///
327 /// let mut rng = rand::rng();
328 /// let mut y = [1, 2, 3, 4, 5];
329 /// println!("Unshuffled: {:?}", y);
330 /// y.shuffle(&mut rng);
331 /// println!("Shuffled: {:?}", y);
332 /// ```
333 fn shuffle<R>(&mut self, rng: &mut R)
334 where
335 R: Rng + ?Sized;
336
337 /// Shuffle a slice in place, but exit early.
338 ///
339 /// Returns two mutable slices from the source slice. The first contains
340 /// `amount` elements randomly permuted. The second has the remaining
341 /// elements that are not fully shuffled.
342 ///
343 /// This is an efficient method to select `amount` elements at random from
344 /// the slice, provided the slice may be mutated.
345 ///
346 /// If you only need to choose elements randomly and `amount > self.len()/2`
347 /// then you may improve performance by taking
348 /// `amount = self.len() - amount` and using only the second slice.
349 ///
350 /// If `amount` is greater than the number of elements in the slice, this
351 /// will perform a full shuffle.
352 ///
353 /// For slices, complexity is `O(m)` where `m = amount`.
354 fn partial_shuffle<R>(
355 &mut self,
356 rng: &mut R,
357 amount: usize,
358 ) -> (&mut [Self::Output], &mut [Self::Output])
359 where
360 Self::Output: Sized,
361 R: Rng + ?Sized;
362}
363
364impl<T> IndexedRandom for [T] {
365 fn len(&self) -> usize {
366 self.len()
367 }
368}
369
370impl<IR: IndexedRandom + IndexMut<usize> + ?Sized> IndexedMutRandom for IR {}
371
372impl<T> SliceRandom for [T] {
373 fn shuffle<R>(&mut self, rng: &mut R)
374 where
375 R: Rng + ?Sized,
376 {
377 if self.len() <= 1 {
378 // There is no need to shuffle an empty or single element slice
379 return;
380 }
381 self.partial_shuffle(rng, self.len());
382 }
383
384 fn partial_shuffle<R>(&mut self, rng: &mut R, amount: usize) -> (&mut [T], &mut [T])
385 where
386 R: Rng + ?Sized,
387 {
388 let m = self.len().saturating_sub(amount);
389
390 // The algorithm below is based on Durstenfeld's algorithm for the
391 // [Fisher–Yates shuffle](https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm)
392 // for an unbiased permutation.
393 // It ensures that the last `amount` elements of the slice
394 // are randomly selected from the whole slice.
395
396 // `IncreasingUniform::next_index()` is faster than `Rng::random_range`
397 // but only works for 32 bit integers
398 // So we must use the slow method if the slice is longer than that.
399 if self.len() < (u32::MAX as usize) {
400 let mut chooser = IncreasingUniform::new(rng, m as u32);
401 for i in m..self.len() {
402 let index = chooser.next_index();
403 self.swap(i, index);
404 }
405 } else {
406 for i in m..self.len() {
407 let index = rng.random_range(..i + 1);
408 self.swap(i, index);
409 }
410 }
411 let r = self.split_at_mut(m);
412 (r.1, r.0)
413 }
414}
415
416/// An iterator over multiple slice elements.
417///
418/// This struct is created by
419/// [`IndexedRandom::choose_multiple`](trait.IndexedRandom.html#tymethod.choose_multiple).
420#[cfg(feature = "alloc")]
421#[derive(Debug)]
422pub struct SliceChooseIter<'a, S: ?Sized + 'a, T: 'a> {
423 slice: &'a S,
424 _phantom: core::marker::PhantomData<T>,
425 indices: index::IndexVecIntoIter,
426}
427
428#[cfg(feature = "alloc")]
429impl<'a, S: Index<usize, Output = T> + ?Sized + 'a, T: 'a> Iterator for SliceChooseIter<'a, S, T> {
430 type Item = &'a T;
431
432 fn next(&mut self) -> Option<Self::Item> {
433 // TODO: investigate using SliceIndex::get_unchecked when stable
434 self.indices.next().map(|i| &self.slice[i])
435 }
436
437 fn size_hint(&self) -> (usize, Option<usize>) {
438 (self.indices.len(), Some(self.indices.len()))
439 }
440}
441
442#[cfg(feature = "alloc")]
443impl<'a, S: Index<usize, Output = T> + ?Sized + 'a, T: 'a> ExactSizeIterator
444 for SliceChooseIter<'a, S, T>
445{
446 fn len(&self) -> usize {
447 self.indices.len()
448 }
449}
450
451#[cfg(test)]
452mod test {
453 use super::*;
454 #[cfg(feature = "alloc")]
455 use alloc::vec::Vec;
456
457 #[test]
458 fn test_slice_choose() {
459 let mut r = crate::test::rng(107);
460 let chars = [
461 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n',
462 ];
463 let mut chosen = [0i32; 14];
464 // The below all use a binomial distribution with n=1000, p=1/14.
465 // binocdf(40, 1000, 1/14) ~= 2e-5; 1-binocdf(106, ..) ~= 2e-5
466 for _ in 0..1000 {
467 let picked = *chars.choose(&mut r).unwrap();
468 chosen[(picked as usize) - ('a' as usize)] += 1;
469 }
470 for count in chosen.iter() {
471 assert!(40 < *count && *count < 106);
472 }
473
474 chosen.iter_mut().for_each(|x| *x = 0);
475 for _ in 0..1000 {
476 *chosen.choose_mut(&mut r).unwrap() += 1;
477 }
478 for count in chosen.iter() {
479 assert!(40 < *count && *count < 106);
480 }
481
482 let mut v: [isize; 0] = [];
483 assert_eq!(v.choose(&mut r), None);
484 assert_eq!(v.choose_mut(&mut r), None);
485 }
486
487 #[test]
488 fn value_stability_slice() {
489 let mut r = crate::test::rng(413);
490 let chars = [
491 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n',
492 ];
493 let mut nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
494
495 assert_eq!(chars.choose(&mut r), Some(&'l'));
496 assert_eq!(nums.choose_mut(&mut r), Some(&mut 3));
497
498 assert_eq!(
499 &chars.choose_multiple_array(&mut r),
500 &Some(['f', 'i', 'd', 'b', 'c', 'm', 'j', 'k'])
501 );
502
503 #[cfg(feature = "alloc")]
504 assert_eq!(
505 &chars
506 .choose_multiple(&mut r, 8)
507 .cloned()
508 .collect::<Vec<char>>(),
509 &['h', 'm', 'd', 'b', 'c', 'e', 'n', 'f']
510 );
511
512 #[cfg(feature = "alloc")]
513 assert_eq!(chars.choose_weighted(&mut r, |_| 1), Ok(&'i'));
514 #[cfg(feature = "alloc")]
515 assert_eq!(nums.choose_weighted_mut(&mut r, |_| 1), Ok(&mut 2));
516
517 let mut r = crate::test::rng(414);
518 nums.shuffle(&mut r);
519 assert_eq!(nums, [5, 11, 0, 8, 7, 12, 6, 4, 9, 3, 1, 2, 10]);
520 nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
521 let res = nums.partial_shuffle(&mut r, 6);
522 assert_eq!(res.0, &mut [7, 12, 6, 8, 1, 9]);
523 assert_eq!(res.1, &mut [0, 11, 2, 3, 4, 5, 10]);
524 }
525
526 #[test]
527 #[cfg_attr(miri, ignore)] // Miri is too slow
528 fn test_shuffle() {
529 let mut r = crate::test::rng(108);
530 let empty: &mut [isize] = &mut [];
531 empty.shuffle(&mut r);
532 let mut one = [1];
533 one.shuffle(&mut r);
534 let b: &[_] = &[1];
535 assert_eq!(one, b);
536
537 let mut two = [1, 2];
538 two.shuffle(&mut r);
539 assert!(two == [1, 2] || two == [2, 1]);
540
541 fn move_last(slice: &mut [usize], pos: usize) {
542 // use slice[pos..].rotate_left(1); once we can use that
543 let last_val = slice[pos];
544 for i in pos..slice.len() - 1 {
545 slice[i] = slice[i + 1];
546 }
547 *slice.last_mut().unwrap() = last_val;
548 }
549 let mut counts = [0i32; 24];
550 for _ in 0..10000 {
551 let mut arr: [usize; 4] = [0, 1, 2, 3];
552 arr.shuffle(&mut r);
553 let mut permutation = 0usize;
554 let mut pos_value = counts.len();
555 for i in 0..4 {
556 pos_value /= 4 - i;
557 let pos = arr.iter().position(|&x| x == i).unwrap();
558 assert!(pos < (4 - i));
559 permutation += pos * pos_value;
560 move_last(&mut arr, pos);
561 assert_eq!(arr[3], i);
562 }
563 for (i, &a) in arr.iter().enumerate() {
564 assert_eq!(a, i);
565 }
566 counts[permutation] += 1;
567 }
568 for count in counts.iter() {
569 // Binomial(10000, 1/24) with average 416.667
570 // Octave: binocdf(n, 10000, 1/24)
571 // 99.9% chance samples lie within this range:
572 assert!(352 <= *count && *count <= 483, "count: {}", count);
573 }
574 }
575
576 #[test]
577 fn test_partial_shuffle() {
578 let mut r = crate::test::rng(118);
579
580 let mut empty: [u32; 0] = [];
581 let res = empty.partial_shuffle(&mut r, 10);
582 assert_eq!((res.0.len(), res.1.len()), (0, 0));
583
584 let mut v = [1, 2, 3, 4, 5];
585 let res = v.partial_shuffle(&mut r, 2);
586 assert_eq!((res.0.len(), res.1.len()), (2, 3));
587 assert!(res.0[0] != res.0[1]);
588 // First elements are only modified if selected, so at least one isn't modified:
589 assert!(res.1[0] == 1 || res.1[1] == 2 || res.1[2] == 3);
590 }
591
592 #[test]
593 #[cfg(feature = "alloc")]
594 #[cfg_attr(miri, ignore)] // Miri is too slow
595 fn test_weighted() {
596 let mut r = crate::test::rng(406);
597 const N_REPS: u32 = 3000;
598 let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7];
599 let total_weight = weights.iter().sum::<u32>() as f32;
600
601 let verify = |result: [i32; 14]| {
602 for (i, count) in result.iter().enumerate() {
603 let exp = (weights[i] * N_REPS) as f32 / total_weight;
604 let mut err = (*count as f32 - exp).abs();
605 if err != 0.0 {
606 err /= exp;
607 }
608 assert!(err <= 0.25);
609 }
610 };
611
612 // choose_weighted
613 fn get_weight<T>(item: &(u32, T)) -> u32 {
614 item.0
615 }
616 let mut chosen = [0i32; 14];
617 let mut items = [(0u32, 0usize); 14]; // (weight, index)
618 for (i, item) in items.iter_mut().enumerate() {
619 *item = (weights[i], i);
620 }
621 for _ in 0..N_REPS {
622 let item = items.choose_weighted(&mut r, get_weight).unwrap();
623 chosen[item.1] += 1;
624 }
625 verify(chosen);
626
627 // choose_weighted_mut
628 let mut items = [(0u32, 0i32); 14]; // (weight, count)
629 for (i, item) in items.iter_mut().enumerate() {
630 *item = (weights[i], 0);
631 }
632 for _ in 0..N_REPS {
633 items.choose_weighted_mut(&mut r, get_weight).unwrap().1 += 1;
634 }
635 for (ch, item) in chosen.iter_mut().zip(items.iter()) {
636 *ch = item.1;
637 }
638 verify(chosen);
639
640 // Check error cases
641 let empty_slice = &mut [10][0..0];
642 assert_eq!(
643 empty_slice.choose_weighted(&mut r, |_| 1),
644 Err(WeightError::InvalidInput)
645 );
646 assert_eq!(
647 empty_slice.choose_weighted_mut(&mut r, |_| 1),
648 Err(WeightError::InvalidInput)
649 );
650 assert_eq!(
651 ['x'].choose_weighted_mut(&mut r, |_| 0),
652 Err(WeightError::InsufficientNonZero)
653 );
654 assert_eq!(
655 [0, -1].choose_weighted_mut(&mut r, |x| *x),
656 Err(WeightError::InvalidWeight)
657 );
658 assert_eq!(
659 [-1, 0].choose_weighted_mut(&mut r, |x| *x),
660 Err(WeightError::InvalidWeight)
661 );
662 }
663
664 #[test]
665 #[cfg(feature = "std")]
666 fn test_multiple_weighted_edge_cases() {
667 use super::*;
668
669 let mut rng = crate::test::rng(413);
670
671 // Case 1: One of the weights is 0
672 let choices = [('a', 2), ('b', 1), ('c', 0)];
673 for _ in 0..100 {
674 let result = choices
675 .choose_multiple_weighted(&mut rng, 2, |item| item.1)
676 .unwrap()
677 .collect::<Vec<_>>();
678
679 assert_eq!(result.len(), 2);
680 assert!(!result.iter().any(|val| val.0 == 'c'));
681 }
682
683 // Case 2: All of the weights are 0
684 let choices = [('a', 0), ('b', 0), ('c', 0)];
685 let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1);
686 assert_eq!(r.unwrap().len(), 0);
687
688 // Case 3: Negative weights
689 let choices = [('a', -1), ('b', 1), ('c', 1)];
690 let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1);
691 assert_eq!(r.unwrap_err(), WeightError::InvalidWeight);
692
693 // Case 4: Empty list
694 let choices = [];
695 let r = choices.choose_multiple_weighted(&mut rng, 0, |_: &()| 0);
696 assert_eq!(r.unwrap().count(), 0);
697
698 // Case 5: NaN weights
699 let choices = [('a', f64::NAN), ('b', 1.0), ('c', 1.0)];
700 let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1);
701 assert_eq!(r.unwrap_err(), WeightError::InvalidWeight);
702
703 // Case 6: +infinity weights
704 let choices = [('a', f64::INFINITY), ('b', 1.0), ('c', 1.0)];
705 for _ in 0..100 {
706 let result = choices
707 .choose_multiple_weighted(&mut rng, 2, |item| item.1)
708 .unwrap()
709 .collect::<Vec<_>>();
710 assert_eq!(result.len(), 2);
711 assert!(result.iter().any(|val| val.0 == 'a'));
712 }
713
714 // Case 7: -infinity weights
715 let choices = [('a', f64::NEG_INFINITY), ('b', 1.0), ('c', 1.0)];
716 let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1);
717 assert_eq!(r.unwrap_err(), WeightError::InvalidWeight);
718
719 // Case 8: -0 weights
720 let choices = [('a', -0.0), ('b', 1.0), ('c', 1.0)];
721 let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1);
722 assert!(r.is_ok());
723 }
724
725 #[test]
726 #[cfg(feature = "std")]
727 fn test_multiple_weighted_distributions() {
728 use super::*;
729
730 // The theoretical probabilities of the different outcomes are:
731 // AB: 0.5 * 0.667 = 0.3333
732 // AC: 0.5 * 0.333 = 0.1667
733 // BA: 0.333 * 0.75 = 0.25
734 // BC: 0.333 * 0.25 = 0.0833
735 // CA: 0.167 * 0.6 = 0.1
736 // CB: 0.167 * 0.4 = 0.0667
737 let choices = [('a', 3), ('b', 2), ('c', 1)];
738 let mut rng = crate::test::rng(414);
739
740 let mut results = [0i32; 3];
741 let expected_results = [5833, 2667, 1500];
742 for _ in 0..10000 {
743 let result = choices
744 .choose_multiple_weighted(&mut rng, 2, |item| item.1)
745 .unwrap()
746 .collect::<Vec<_>>();
747
748 assert_eq!(result.len(), 2);
749
750 match (result[0].0, result[1].0) {
751 ('a', 'b') | ('b', 'a') => {
752 results[0] += 1;
753 }
754 ('a', 'c') | ('c', 'a') => {
755 results[1] += 1;
756 }
757 ('b', 'c') | ('c', 'b') => {
758 results[2] += 1;
759 }
760 (_, _) => panic!("unexpected result"),
761 }
762 }
763
764 let mut diffs = results
765 .iter()
766 .zip(&expected_results)
767 .map(|(a, b)| (a - b).abs());
768 assert!(!diffs.any(|deviation| deviation > 100));
769 }
770}