rand/seq/
iterator.rs

1// Copyright 2018-2024 Developers of the Rand project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! `IteratorRandom`
10
11#[allow(unused)]
12use super::IndexedRandom;
13use super::coin_flipper::CoinFlipper;
14use crate::Rng;
15#[cfg(feature = "alloc")]
16use alloc::vec::Vec;
17
18/// Extension trait on iterators, providing random sampling methods.
19///
20/// This trait is implemented on all iterators `I` where `I: Iterator + Sized`
21/// and provides methods for
22/// choosing one or more elements. You must `use` this trait:
23///
24/// ```
25/// use rand::seq::IteratorRandom;
26///
27/// let faces = "😀😎😐😕😠😢";
28/// println!("I am {}!", faces.chars().choose(&mut rand::rng()).unwrap());
29/// ```
30/// Example output (non-deterministic):
31/// ```none
32/// I am 😀!
33/// ```
34pub trait IteratorRandom: Iterator + Sized {
35    /// Uniformly sample one element
36    ///
37    /// Assuming that the [`Iterator::size_hint`] is correct, this method
38    /// returns one uniformly-sampled random element of the slice, or `None`
39    /// only if the slice is empty. Incorrect bounds on the `size_hint` may
40    /// cause this method to incorrectly return `None` if fewer elements than
41    /// the advertised `lower` bound are present and may prevent sampling of
42    /// elements beyond an advertised `upper` bound (i.e. incorrect `size_hint`
43    /// is memory-safe, but may result in unexpected `None` result and
44    /// non-uniform distribution).
45    ///
46    /// With an accurate [`Iterator::size_hint`] and where [`Iterator::nth`] is
47    /// a constant-time operation, this method can offer `O(1)` performance.
48    /// Where no size hint is
49    /// available, complexity is `O(n)` where `n` is the iterator length.
50    /// Partial hints (where `lower > 0`) also improve performance.
51    ///
52    /// Note further that [`Iterator::size_hint`] may affect the number of RNG
53    /// samples used as well as the result (while remaining uniform sampling).
54    /// Consider instead using [`IteratorRandom::choose_stable`] to avoid
55    /// [`Iterator`] combinators which only change size hints from affecting the
56    /// results.
57    ///
58    /// # Example
59    ///
60    /// ```
61    /// use rand::seq::IteratorRandom;
62    ///
63    /// let words = "Mary had a little lamb".split(' ');
64    /// println!("{}", words.choose(&mut rand::rng()).unwrap());
65    /// ```
66    fn choose<R>(mut self, rng: &mut R) -> Option<Self::Item>
67    where
68        R: Rng + ?Sized,
69    {
70        let (mut lower, mut upper) = self.size_hint();
71        let mut result = None;
72
73        // Handling for this condition outside the loop allows the optimizer to eliminate the loop
74        // when the Iterator is an ExactSizeIterator. This has a large performance impact on e.g.
75        // seq_iter_choose_from_1000.
76        if upper == Some(lower) {
77            return match lower {
78                0 => None,
79                1 => self.next(),
80                _ => self.nth(rng.random_range(..lower)),
81            };
82        }
83
84        let mut coin_flipper = CoinFlipper::new(rng);
85        let mut consumed = 0;
86
87        // Continue until the iterator is exhausted
88        loop {
89            if lower > 1 {
90                let ix = coin_flipper.rng.random_range(..lower + consumed);
91                let skip = if ix < lower {
92                    result = self.nth(ix);
93                    lower - (ix + 1)
94                } else {
95                    lower
96                };
97                if upper == Some(lower) {
98                    return result;
99                }
100                consumed += lower;
101                if skip > 0 {
102                    self.nth(skip - 1);
103                }
104            } else {
105                let elem = self.next();
106                if elem.is_none() {
107                    return result;
108                }
109                consumed += 1;
110                if coin_flipper.random_ratio_one_over(consumed) {
111                    result = elem;
112                }
113            }
114
115            let hint = self.size_hint();
116            lower = hint.0;
117            upper = hint.1;
118        }
119    }
120
121    /// Uniformly sample one element (stable)
122    ///
123    /// This method is very similar to [`choose`] except that the result
124    /// only depends on the length of the iterator and the values produced by
125    /// `rng`. Notably for any iterator of a given length this will make the
126    /// same requests to `rng` and if the same sequence of values are produced
127    /// the same index will be selected from `self`. This may be useful if you
128    /// need consistent results no matter what type of iterator you are working
129    /// with. If you do not need this stability prefer [`choose`].
130    ///
131    /// Note that this method still uses [`Iterator::size_hint`] to skip
132    /// constructing elements where possible, however the selection and `rng`
133    /// calls are the same in the face of this optimization. If you want to
134    /// force every element to be created regardless call `.inspect(|e| ())`.
135    ///
136    /// [`choose`]: IteratorRandom::choose
137    //
138    // Clippy is wrong here: we need to iterate over all entries with the RNG to
139    // ensure that choosing is *stable*.
140    // "allow(unknown_lints)" can be removed when switching to at least
141    // rust-version 1.86.0, see:
142    // https://rust-lang.github.io/rust-clippy/master/index.html#double_ended_iterator_last
143    #[allow(unknown_lints)]
144    #[allow(clippy::double_ended_iterator_last)]
145    fn choose_stable<R>(mut self, rng: &mut R) -> Option<Self::Item>
146    where
147        R: Rng + ?Sized,
148    {
149        let mut consumed = 0;
150        let mut result = None;
151        let mut coin_flipper = CoinFlipper::new(rng);
152
153        loop {
154            // Currently the only way to skip elements is `nth()`. So we need to
155            // store what index to access next here.
156            // This should be replaced by `advance_by()` once it is stable:
157            // https://github.com/rust-lang/rust/issues/77404
158            let mut next = 0;
159
160            let (lower, _) = self.size_hint();
161            if lower >= 2 {
162                let highest_selected = (0..lower)
163                    .filter(|ix| coin_flipper.random_ratio_one_over(consumed + ix + 1))
164                    .last();
165
166                consumed += lower;
167                next = lower;
168
169                if let Some(ix) = highest_selected {
170                    result = self.nth(ix);
171                    next -= ix + 1;
172                    debug_assert!(result.is_some(), "iterator shorter than size_hint().0");
173                }
174            }
175
176            let elem = self.nth(next);
177            if elem.is_none() {
178                return result;
179            }
180
181            if coin_flipper.random_ratio_one_over(consumed + 1) {
182                result = elem;
183            }
184            consumed += 1;
185        }
186    }
187
188    /// Uniformly sample `amount` distinct elements into a buffer
189    ///
190    /// Collects values at random from the iterator into a supplied buffer
191    /// until that buffer is filled.
192    ///
193    /// Although the elements are selected randomly, the order of elements in
194    /// the buffer is neither stable nor fully random. If random ordering is
195    /// desired, shuffle the result.
196    ///
197    /// Returns the number of elements added to the buffer. This equals the length
198    /// of the buffer unless the iterator contains insufficient elements, in which
199    /// case this equals the number of elements available.
200    ///
201    /// Complexity is `O(n)` where `n` is the length of the iterator.
202    /// For slices, prefer [`IndexedRandom::sample`].
203    fn sample_fill<R>(mut self, rng: &mut R, buf: &mut [Self::Item]) -> usize
204    where
205        R: Rng + ?Sized,
206    {
207        let amount = buf.len();
208        let mut len = 0;
209        while len < amount {
210            if let Some(elem) = self.next() {
211                buf[len] = elem;
212                len += 1;
213            } else {
214                // Iterator exhausted; stop early
215                return len;
216            }
217        }
218
219        // Continue, since the iterator was not exhausted
220        for (i, elem) in self.enumerate() {
221            let k = rng.random_range(..i + 1 + amount);
222            if let Some(slot) = buf.get_mut(k) {
223                *slot = elem;
224            }
225        }
226        len
227    }
228
229    /// Uniformly sample `amount` distinct elements into a [`Vec`]
230    ///
231    /// This is equivalent to `sample_fill` except for the result type.
232    ///
233    /// Although the elements are selected randomly, the order of elements in
234    /// the buffer is neither stable nor fully random. If random ordering is
235    /// desired, shuffle the result.
236    ///
237    /// The length of the returned vector equals `amount` unless the iterator
238    /// contains insufficient elements, in which case it equals the number of
239    /// elements available.
240    ///
241    /// Complexity is `O(n)` where `n` is the length of the iterator.
242    /// For slices, prefer [`IndexedRandom::sample`].
243    #[cfg(feature = "alloc")]
244    fn sample<R>(mut self, rng: &mut R, amount: usize) -> Vec<Self::Item>
245    where
246        R: Rng + ?Sized,
247    {
248        let mut reservoir = Vec::with_capacity(amount);
249        reservoir.extend(self.by_ref().take(amount));
250
251        // Continue unless the iterator was exhausted
252        //
253        // note: this prevents iterators that "restart" from causing problems.
254        // If the iterator stops once, then so do we.
255        if reservoir.len() == amount {
256            for (i, elem) in self.enumerate() {
257                let k = rng.random_range(..i + 1 + amount);
258                if let Some(slot) = reservoir.get_mut(k) {
259                    *slot = elem;
260                }
261            }
262        } else {
263            // Don't hang onto extra memory. There is a corner case where
264            // `amount` was much less than `self.len()`.
265            reservoir.shrink_to_fit();
266        }
267        reservoir
268    }
269
270    /// Deprecated: use [`Self::sample_fill`] instead
271    #[deprecated(since = "0.9.2", note = "Renamed to `sample_fill`")]
272    fn choose_multiple_fill<R>(self, rng: &mut R, buf: &mut [Self::Item]) -> usize
273    where
274        R: Rng + ?Sized,
275    {
276        self.sample_fill(rng, buf)
277    }
278
279    /// Deprecated: use [`Self::sample`] instead
280    #[cfg(feature = "alloc")]
281    #[deprecated(since = "0.9.2", note = "Renamed to `sample`")]
282    fn choose_multiple<R>(self, rng: &mut R, amount: usize) -> Vec<Self::Item>
283    where
284        R: Rng + ?Sized,
285    {
286        self.sample(rng, amount)
287    }
288}
289
290impl<I> IteratorRandom for I where I: Iterator + Sized {}
291
292#[cfg(test)]
293mod test {
294    use super::*;
295    #[cfg(all(feature = "alloc", not(feature = "std")))]
296    use alloc::vec::Vec;
297
298    #[derive(Clone)]
299    struct UnhintedIterator<I: Iterator + Clone> {
300        iter: I,
301    }
302    impl<I: Iterator + Clone> Iterator for UnhintedIterator<I> {
303        type Item = I::Item;
304
305        fn next(&mut self) -> Option<Self::Item> {
306            self.iter.next()
307        }
308    }
309
310    #[derive(Clone)]
311    struct ChunkHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
312        iter: I,
313        chunk_remaining: usize,
314        chunk_size: usize,
315        hint_total_size: bool,
316    }
317    impl<I: ExactSizeIterator + Iterator + Clone> Iterator for ChunkHintedIterator<I> {
318        type Item = I::Item;
319
320        fn next(&mut self) -> Option<Self::Item> {
321            if self.chunk_remaining == 0 {
322                self.chunk_remaining = core::cmp::min(self.chunk_size, self.iter.len());
323            }
324            self.chunk_remaining = self.chunk_remaining.saturating_sub(1);
325
326            self.iter.next()
327        }
328
329        fn size_hint(&self) -> (usize, Option<usize>) {
330            (
331                self.chunk_remaining,
332                if self.hint_total_size {
333                    Some(self.iter.len())
334                } else {
335                    None
336                },
337            )
338        }
339    }
340
341    #[derive(Clone)]
342    struct WindowHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
343        iter: I,
344        window_size: usize,
345        hint_total_size: bool,
346    }
347    impl<I: ExactSizeIterator + Iterator + Clone> Iterator for WindowHintedIterator<I> {
348        type Item = I::Item;
349
350        fn next(&mut self) -> Option<Self::Item> {
351            self.iter.next()
352        }
353
354        fn size_hint(&self) -> (usize, Option<usize>) {
355            (
356                core::cmp::min(self.iter.len(), self.window_size),
357                if self.hint_total_size {
358                    Some(self.iter.len())
359                } else {
360                    None
361                },
362            )
363        }
364    }
365
366    #[test]
367    #[cfg_attr(miri, ignore)] // Miri is too slow
368    fn test_iterator_choose() {
369        let r = &mut crate::test::rng(109);
370        fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item = usize> + Clone>(r: &mut R, iter: Iter) {
371            let mut chosen = [0i32; 9];
372            for _ in 0..1000 {
373                let picked = iter.clone().choose(r).unwrap();
374                chosen[picked] += 1;
375            }
376            for count in chosen.iter() {
377                // Samples should follow Binomial(1000, 1/9)
378                // Octave: binopdf(x, 1000, 1/9) gives the prob of *count == x
379                // Note: have seen 153, which is unlikely but not impossible.
380                assert!(
381                    72 < *count && *count < 154,
382                    "count not close to 1000/9: {}",
383                    count
384                );
385            }
386        }
387
388        test_iter(r, 0..9);
389        test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned());
390        #[cfg(feature = "alloc")]
391        test_iter(r, (0..9).collect::<Vec<_>>().into_iter());
392        test_iter(r, UnhintedIterator { iter: 0..9 });
393        test_iter(
394            r,
395            ChunkHintedIterator {
396                iter: 0..9,
397                chunk_size: 4,
398                chunk_remaining: 4,
399                hint_total_size: false,
400            },
401        );
402        test_iter(
403            r,
404            ChunkHintedIterator {
405                iter: 0..9,
406                chunk_size: 4,
407                chunk_remaining: 4,
408                hint_total_size: true,
409            },
410        );
411        test_iter(
412            r,
413            WindowHintedIterator {
414                iter: 0..9,
415                window_size: 2,
416                hint_total_size: false,
417            },
418        );
419        test_iter(
420            r,
421            WindowHintedIterator {
422                iter: 0..9,
423                window_size: 2,
424                hint_total_size: true,
425            },
426        );
427
428        assert_eq!((0..0).choose(r), None);
429        assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None);
430    }
431
432    #[test]
433    #[cfg_attr(miri, ignore)] // Miri is too slow
434    fn test_iterator_choose_stable() {
435        let r = &mut crate::test::rng(109);
436        fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item = usize> + Clone>(r: &mut R, iter: Iter) {
437            let mut chosen = [0i32; 9];
438            for _ in 0..1000 {
439                let picked = iter.clone().choose_stable(r).unwrap();
440                chosen[picked] += 1;
441            }
442            for count in chosen.iter() {
443                // Samples should follow Binomial(1000, 1/9)
444                // Octave: binopdf(x, 1000, 1/9) gives the prob of *count == x
445                // Note: have seen 153, which is unlikely but not impossible.
446                assert!(
447                    72 < *count && *count < 154,
448                    "count not close to 1000/9: {}",
449                    count
450                );
451            }
452        }
453
454        test_iter(r, 0..9);
455        test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned());
456        #[cfg(feature = "alloc")]
457        test_iter(r, (0..9).collect::<Vec<_>>().into_iter());
458        test_iter(r, UnhintedIterator { iter: 0..9 });
459        test_iter(
460            r,
461            ChunkHintedIterator {
462                iter: 0..9,
463                chunk_size: 4,
464                chunk_remaining: 4,
465                hint_total_size: false,
466            },
467        );
468        test_iter(
469            r,
470            ChunkHintedIterator {
471                iter: 0..9,
472                chunk_size: 4,
473                chunk_remaining: 4,
474                hint_total_size: true,
475            },
476        );
477        test_iter(
478            r,
479            WindowHintedIterator {
480                iter: 0..9,
481                window_size: 2,
482                hint_total_size: false,
483            },
484        );
485        test_iter(
486            r,
487            WindowHintedIterator {
488                iter: 0..9,
489                window_size: 2,
490                hint_total_size: true,
491            },
492        );
493
494        assert_eq!((0..0).choose(r), None);
495        assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None);
496    }
497
498    #[test]
499    #[cfg_attr(miri, ignore)] // Miri is too slow
500    fn test_iterator_choose_stable_stability() {
501        fn test_iter(iter: impl Iterator<Item = usize> + Clone) -> [i32; 9] {
502            let r = &mut crate::test::rng(109);
503            let mut chosen = [0i32; 9];
504            for _ in 0..1000 {
505                let picked = iter.clone().choose_stable(r).unwrap();
506                chosen[picked] += 1;
507            }
508            chosen
509        }
510
511        let reference = test_iter(0..9);
512        assert_eq!(
513            test_iter([0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()),
514            reference
515        );
516
517        #[cfg(feature = "alloc")]
518        assert_eq!(test_iter((0..9).collect::<Vec<_>>().into_iter()), reference);
519        assert_eq!(test_iter(UnhintedIterator { iter: 0..9 }), reference);
520        assert_eq!(
521            test_iter(ChunkHintedIterator {
522                iter: 0..9,
523                chunk_size: 4,
524                chunk_remaining: 4,
525                hint_total_size: false,
526            }),
527            reference
528        );
529        assert_eq!(
530            test_iter(ChunkHintedIterator {
531                iter: 0..9,
532                chunk_size: 4,
533                chunk_remaining: 4,
534                hint_total_size: true,
535            }),
536            reference
537        );
538        assert_eq!(
539            test_iter(WindowHintedIterator {
540                iter: 0..9,
541                window_size: 2,
542                hint_total_size: false,
543            }),
544            reference
545        );
546        assert_eq!(
547            test_iter(WindowHintedIterator {
548                iter: 0..9,
549                window_size: 2,
550                hint_total_size: true,
551            }),
552            reference
553        );
554    }
555
556    #[test]
557    #[cfg(feature = "alloc")]
558    fn test_sample_iter() {
559        let min_val = 1;
560        let max_val = 100;
561
562        let mut r = crate::test::rng(401);
563        let vals = (min_val..max_val).collect::<Vec<i32>>();
564        let small_sample = vals.iter().sample(&mut r, 5);
565        let large_sample = vals.iter().sample(&mut r, vals.len() + 5);
566
567        assert_eq!(small_sample.len(), 5);
568        assert_eq!(large_sample.len(), vals.len());
569        // no randomization happens when amount >= len
570        assert_eq!(large_sample, vals.iter().collect::<Vec<_>>());
571
572        assert!(
573            small_sample
574                .iter()
575                .all(|e| { **e >= min_val && **e <= max_val })
576        );
577    }
578
579    #[test]
580    fn value_stability_choose() {
581        fn choose<I: Iterator<Item = u32>>(iter: I) -> Option<u32> {
582            let mut rng = crate::test::rng(411);
583            iter.choose(&mut rng)
584        }
585
586        assert_eq!(choose([].iter().cloned()), None);
587        assert_eq!(choose(0..100), Some(33));
588        assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(27));
589        assert_eq!(
590            choose(ChunkHintedIterator {
591                iter: 0..100,
592                chunk_size: 32,
593                chunk_remaining: 32,
594                hint_total_size: false,
595            }),
596            Some(91)
597        );
598        assert_eq!(
599            choose(ChunkHintedIterator {
600                iter: 0..100,
601                chunk_size: 32,
602                chunk_remaining: 32,
603                hint_total_size: true,
604            }),
605            Some(91)
606        );
607        assert_eq!(
608            choose(WindowHintedIterator {
609                iter: 0..100,
610                window_size: 32,
611                hint_total_size: false,
612            }),
613            Some(34)
614        );
615        assert_eq!(
616            choose(WindowHintedIterator {
617                iter: 0..100,
618                window_size: 32,
619                hint_total_size: true,
620            }),
621            Some(34)
622        );
623    }
624
625    #[test]
626    fn value_stability_choose_stable() {
627        fn choose<I: Iterator<Item = u32>>(iter: I) -> Option<u32> {
628            let mut rng = crate::test::rng(411);
629            iter.choose_stable(&mut rng)
630        }
631
632        assert_eq!(choose([].iter().cloned()), None);
633        assert_eq!(choose(0..100), Some(27));
634        assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(27));
635        assert_eq!(
636            choose(ChunkHintedIterator {
637                iter: 0..100,
638                chunk_size: 32,
639                chunk_remaining: 32,
640                hint_total_size: false,
641            }),
642            Some(27)
643        );
644        assert_eq!(
645            choose(ChunkHintedIterator {
646                iter: 0..100,
647                chunk_size: 32,
648                chunk_remaining: 32,
649                hint_total_size: true,
650            }),
651            Some(27)
652        );
653        assert_eq!(
654            choose(WindowHintedIterator {
655                iter: 0..100,
656                window_size: 32,
657                hint_total_size: false,
658            }),
659            Some(27)
660        );
661        assert_eq!(
662            choose(WindowHintedIterator {
663                iter: 0..100,
664                window_size: 32,
665                hint_total_size: true,
666            }),
667            Some(27)
668        );
669    }
670
671    #[test]
672    fn value_stability_sample() {
673        fn do_test<I: Clone + Iterator<Item = u32>>(iter: I, v: &[u32]) {
674            let mut rng = crate::test::rng(412);
675            let mut buf = [0u32; 8];
676            assert_eq!(iter.clone().sample_fill(&mut rng, &mut buf), v.len());
677            assert_eq!(&buf[0..v.len()], v);
678
679            #[cfg(feature = "alloc")]
680            {
681                let mut rng = crate::test::rng(412);
682                assert_eq!(iter.sample(&mut rng, v.len()), v);
683            }
684        }
685
686        do_test(0..4, &[0, 1, 2, 3]);
687        do_test(0..8, &[0, 1, 2, 3, 4, 5, 6, 7]);
688        do_test(0..100, &[77, 95, 38, 23, 25, 8, 58, 40]);
689    }
690}