1#[allow(unused)]
12use super::IndexedRandom;
13use super::coin_flipper::CoinFlipper;
14use crate::Rng;
15#[cfg(feature = "alloc")]
16use alloc::vec::Vec;
17
18pub trait IteratorRandom: Iterator + Sized {
35 fn choose<R>(mut self, rng: &mut R) -> Option<Self::Item>
67 where
68 R: Rng + ?Sized,
69 {
70 let (mut lower, mut upper) = self.size_hint();
71 let mut result = None;
72
73 if upper == Some(lower) {
77 return match lower {
78 0 => None,
79 1 => self.next(),
80 _ => self.nth(rng.random_range(..lower)),
81 };
82 }
83
84 let mut coin_flipper = CoinFlipper::new(rng);
85 let mut consumed = 0;
86
87 loop {
89 if lower > 1 {
90 let ix = coin_flipper.rng.random_range(..lower + consumed);
91 let skip = if ix < lower {
92 result = self.nth(ix);
93 lower - (ix + 1)
94 } else {
95 lower
96 };
97 if upper == Some(lower) {
98 return result;
99 }
100 consumed += lower;
101 if skip > 0 {
102 self.nth(skip - 1);
103 }
104 } else {
105 let elem = self.next();
106 if elem.is_none() {
107 return result;
108 }
109 consumed += 1;
110 if coin_flipper.random_ratio_one_over(consumed) {
111 result = elem;
112 }
113 }
114
115 let hint = self.size_hint();
116 lower = hint.0;
117 upper = hint.1;
118 }
119 }
120
121 #[allow(unknown_lints)]
144 #[allow(clippy::double_ended_iterator_last)]
145 fn choose_stable<R>(mut self, rng: &mut R) -> Option<Self::Item>
146 where
147 R: Rng + ?Sized,
148 {
149 let mut consumed = 0;
150 let mut result = None;
151 let mut coin_flipper = CoinFlipper::new(rng);
152
153 loop {
154 let mut next = 0;
159
160 let (lower, _) = self.size_hint();
161 if lower >= 2 {
162 let highest_selected = (0..lower)
163 .filter(|ix| coin_flipper.random_ratio_one_over(consumed + ix + 1))
164 .last();
165
166 consumed += lower;
167 next = lower;
168
169 if let Some(ix) = highest_selected {
170 result = self.nth(ix);
171 next -= ix + 1;
172 debug_assert!(result.is_some(), "iterator shorter than size_hint().0");
173 }
174 }
175
176 let elem = self.nth(next);
177 if elem.is_none() {
178 return result;
179 }
180
181 if coin_flipper.random_ratio_one_over(consumed + 1) {
182 result = elem;
183 }
184 consumed += 1;
185 }
186 }
187
188 fn sample_fill<R>(mut self, rng: &mut R, buf: &mut [Self::Item]) -> usize
204 where
205 R: Rng + ?Sized,
206 {
207 let amount = buf.len();
208 let mut len = 0;
209 while len < amount {
210 if let Some(elem) = self.next() {
211 buf[len] = elem;
212 len += 1;
213 } else {
214 return len;
216 }
217 }
218
219 for (i, elem) in self.enumerate() {
221 let k = rng.random_range(..i + 1 + amount);
222 if let Some(slot) = buf.get_mut(k) {
223 *slot = elem;
224 }
225 }
226 len
227 }
228
229 #[cfg(feature = "alloc")]
244 fn sample<R>(mut self, rng: &mut R, amount: usize) -> Vec<Self::Item>
245 where
246 R: Rng + ?Sized,
247 {
248 let mut reservoir = Vec::with_capacity(amount);
249 reservoir.extend(self.by_ref().take(amount));
250
251 if reservoir.len() == amount {
256 for (i, elem) in self.enumerate() {
257 let k = rng.random_range(..i + 1 + amount);
258 if let Some(slot) = reservoir.get_mut(k) {
259 *slot = elem;
260 }
261 }
262 } else {
263 reservoir.shrink_to_fit();
266 }
267 reservoir
268 }
269
270 #[deprecated(since = "0.9.2", note = "Renamed to `sample_fill`")]
272 fn choose_multiple_fill<R>(self, rng: &mut R, buf: &mut [Self::Item]) -> usize
273 where
274 R: Rng + ?Sized,
275 {
276 self.sample_fill(rng, buf)
277 }
278
279 #[cfg(feature = "alloc")]
281 #[deprecated(since = "0.9.2", note = "Renamed to `sample`")]
282 fn choose_multiple<R>(self, rng: &mut R, amount: usize) -> Vec<Self::Item>
283 where
284 R: Rng + ?Sized,
285 {
286 self.sample(rng, amount)
287 }
288}
289
290impl<I> IteratorRandom for I where I: Iterator + Sized {}
291
292#[cfg(test)]
293mod test {
294 use super::*;
295 #[cfg(all(feature = "alloc", not(feature = "std")))]
296 use alloc::vec::Vec;
297
298 #[derive(Clone)]
299 struct UnhintedIterator<I: Iterator + Clone> {
300 iter: I,
301 }
302 impl<I: Iterator + Clone> Iterator for UnhintedIterator<I> {
303 type Item = I::Item;
304
305 fn next(&mut self) -> Option<Self::Item> {
306 self.iter.next()
307 }
308 }
309
310 #[derive(Clone)]
311 struct ChunkHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
312 iter: I,
313 chunk_remaining: usize,
314 chunk_size: usize,
315 hint_total_size: bool,
316 }
317 impl<I: ExactSizeIterator + Iterator + Clone> Iterator for ChunkHintedIterator<I> {
318 type Item = I::Item;
319
320 fn next(&mut self) -> Option<Self::Item> {
321 if self.chunk_remaining == 0 {
322 self.chunk_remaining = core::cmp::min(self.chunk_size, self.iter.len());
323 }
324 self.chunk_remaining = self.chunk_remaining.saturating_sub(1);
325
326 self.iter.next()
327 }
328
329 fn size_hint(&self) -> (usize, Option<usize>) {
330 (
331 self.chunk_remaining,
332 if self.hint_total_size {
333 Some(self.iter.len())
334 } else {
335 None
336 },
337 )
338 }
339 }
340
341 #[derive(Clone)]
342 struct WindowHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
343 iter: I,
344 window_size: usize,
345 hint_total_size: bool,
346 }
347 impl<I: ExactSizeIterator + Iterator + Clone> Iterator for WindowHintedIterator<I> {
348 type Item = I::Item;
349
350 fn next(&mut self) -> Option<Self::Item> {
351 self.iter.next()
352 }
353
354 fn size_hint(&self) -> (usize, Option<usize>) {
355 (
356 core::cmp::min(self.iter.len(), self.window_size),
357 if self.hint_total_size {
358 Some(self.iter.len())
359 } else {
360 None
361 },
362 )
363 }
364 }
365
366 #[test]
367 #[cfg_attr(miri, ignore)] fn test_iterator_choose() {
369 let r = &mut crate::test::rng(109);
370 fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item = usize> + Clone>(r: &mut R, iter: Iter) {
371 let mut chosen = [0i32; 9];
372 for _ in 0..1000 {
373 let picked = iter.clone().choose(r).unwrap();
374 chosen[picked] += 1;
375 }
376 for count in chosen.iter() {
377 assert!(
381 72 < *count && *count < 154,
382 "count not close to 1000/9: {}",
383 count
384 );
385 }
386 }
387
388 test_iter(r, 0..9);
389 test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned());
390 #[cfg(feature = "alloc")]
391 test_iter(r, (0..9).collect::<Vec<_>>().into_iter());
392 test_iter(r, UnhintedIterator { iter: 0..9 });
393 test_iter(
394 r,
395 ChunkHintedIterator {
396 iter: 0..9,
397 chunk_size: 4,
398 chunk_remaining: 4,
399 hint_total_size: false,
400 },
401 );
402 test_iter(
403 r,
404 ChunkHintedIterator {
405 iter: 0..9,
406 chunk_size: 4,
407 chunk_remaining: 4,
408 hint_total_size: true,
409 },
410 );
411 test_iter(
412 r,
413 WindowHintedIterator {
414 iter: 0..9,
415 window_size: 2,
416 hint_total_size: false,
417 },
418 );
419 test_iter(
420 r,
421 WindowHintedIterator {
422 iter: 0..9,
423 window_size: 2,
424 hint_total_size: true,
425 },
426 );
427
428 assert_eq!((0..0).choose(r), None);
429 assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None);
430 }
431
432 #[test]
433 #[cfg_attr(miri, ignore)] fn test_iterator_choose_stable() {
435 let r = &mut crate::test::rng(109);
436 fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item = usize> + Clone>(r: &mut R, iter: Iter) {
437 let mut chosen = [0i32; 9];
438 for _ in 0..1000 {
439 let picked = iter.clone().choose_stable(r).unwrap();
440 chosen[picked] += 1;
441 }
442 for count in chosen.iter() {
443 assert!(
447 72 < *count && *count < 154,
448 "count not close to 1000/9: {}",
449 count
450 );
451 }
452 }
453
454 test_iter(r, 0..9);
455 test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned());
456 #[cfg(feature = "alloc")]
457 test_iter(r, (0..9).collect::<Vec<_>>().into_iter());
458 test_iter(r, UnhintedIterator { iter: 0..9 });
459 test_iter(
460 r,
461 ChunkHintedIterator {
462 iter: 0..9,
463 chunk_size: 4,
464 chunk_remaining: 4,
465 hint_total_size: false,
466 },
467 );
468 test_iter(
469 r,
470 ChunkHintedIterator {
471 iter: 0..9,
472 chunk_size: 4,
473 chunk_remaining: 4,
474 hint_total_size: true,
475 },
476 );
477 test_iter(
478 r,
479 WindowHintedIterator {
480 iter: 0..9,
481 window_size: 2,
482 hint_total_size: false,
483 },
484 );
485 test_iter(
486 r,
487 WindowHintedIterator {
488 iter: 0..9,
489 window_size: 2,
490 hint_total_size: true,
491 },
492 );
493
494 assert_eq!((0..0).choose(r), None);
495 assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None);
496 }
497
498 #[test]
499 #[cfg_attr(miri, ignore)] fn test_iterator_choose_stable_stability() {
501 fn test_iter(iter: impl Iterator<Item = usize> + Clone) -> [i32; 9] {
502 let r = &mut crate::test::rng(109);
503 let mut chosen = [0i32; 9];
504 for _ in 0..1000 {
505 let picked = iter.clone().choose_stable(r).unwrap();
506 chosen[picked] += 1;
507 }
508 chosen
509 }
510
511 let reference = test_iter(0..9);
512 assert_eq!(
513 test_iter([0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()),
514 reference
515 );
516
517 #[cfg(feature = "alloc")]
518 assert_eq!(test_iter((0..9).collect::<Vec<_>>().into_iter()), reference);
519 assert_eq!(test_iter(UnhintedIterator { iter: 0..9 }), reference);
520 assert_eq!(
521 test_iter(ChunkHintedIterator {
522 iter: 0..9,
523 chunk_size: 4,
524 chunk_remaining: 4,
525 hint_total_size: false,
526 }),
527 reference
528 );
529 assert_eq!(
530 test_iter(ChunkHintedIterator {
531 iter: 0..9,
532 chunk_size: 4,
533 chunk_remaining: 4,
534 hint_total_size: true,
535 }),
536 reference
537 );
538 assert_eq!(
539 test_iter(WindowHintedIterator {
540 iter: 0..9,
541 window_size: 2,
542 hint_total_size: false,
543 }),
544 reference
545 );
546 assert_eq!(
547 test_iter(WindowHintedIterator {
548 iter: 0..9,
549 window_size: 2,
550 hint_total_size: true,
551 }),
552 reference
553 );
554 }
555
556 #[test]
557 #[cfg(feature = "alloc")]
558 fn test_sample_iter() {
559 let min_val = 1;
560 let max_val = 100;
561
562 let mut r = crate::test::rng(401);
563 let vals = (min_val..max_val).collect::<Vec<i32>>();
564 let small_sample = vals.iter().sample(&mut r, 5);
565 let large_sample = vals.iter().sample(&mut r, vals.len() + 5);
566
567 assert_eq!(small_sample.len(), 5);
568 assert_eq!(large_sample.len(), vals.len());
569 assert_eq!(large_sample, vals.iter().collect::<Vec<_>>());
571
572 assert!(
573 small_sample
574 .iter()
575 .all(|e| { **e >= min_val && **e <= max_val })
576 );
577 }
578
579 #[test]
580 fn value_stability_choose() {
581 fn choose<I: Iterator<Item = u32>>(iter: I) -> Option<u32> {
582 let mut rng = crate::test::rng(411);
583 iter.choose(&mut rng)
584 }
585
586 assert_eq!(choose([].iter().cloned()), None);
587 assert_eq!(choose(0..100), Some(33));
588 assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(27));
589 assert_eq!(
590 choose(ChunkHintedIterator {
591 iter: 0..100,
592 chunk_size: 32,
593 chunk_remaining: 32,
594 hint_total_size: false,
595 }),
596 Some(91)
597 );
598 assert_eq!(
599 choose(ChunkHintedIterator {
600 iter: 0..100,
601 chunk_size: 32,
602 chunk_remaining: 32,
603 hint_total_size: true,
604 }),
605 Some(91)
606 );
607 assert_eq!(
608 choose(WindowHintedIterator {
609 iter: 0..100,
610 window_size: 32,
611 hint_total_size: false,
612 }),
613 Some(34)
614 );
615 assert_eq!(
616 choose(WindowHintedIterator {
617 iter: 0..100,
618 window_size: 32,
619 hint_total_size: true,
620 }),
621 Some(34)
622 );
623 }
624
625 #[test]
626 fn value_stability_choose_stable() {
627 fn choose<I: Iterator<Item = u32>>(iter: I) -> Option<u32> {
628 let mut rng = crate::test::rng(411);
629 iter.choose_stable(&mut rng)
630 }
631
632 assert_eq!(choose([].iter().cloned()), None);
633 assert_eq!(choose(0..100), Some(27));
634 assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(27));
635 assert_eq!(
636 choose(ChunkHintedIterator {
637 iter: 0..100,
638 chunk_size: 32,
639 chunk_remaining: 32,
640 hint_total_size: false,
641 }),
642 Some(27)
643 );
644 assert_eq!(
645 choose(ChunkHintedIterator {
646 iter: 0..100,
647 chunk_size: 32,
648 chunk_remaining: 32,
649 hint_total_size: true,
650 }),
651 Some(27)
652 );
653 assert_eq!(
654 choose(WindowHintedIterator {
655 iter: 0..100,
656 window_size: 32,
657 hint_total_size: false,
658 }),
659 Some(27)
660 );
661 assert_eq!(
662 choose(WindowHintedIterator {
663 iter: 0..100,
664 window_size: 32,
665 hint_total_size: true,
666 }),
667 Some(27)
668 );
669 }
670
671 #[test]
672 fn value_stability_sample() {
673 fn do_test<I: Clone + Iterator<Item = u32>>(iter: I, v: &[u32]) {
674 let mut rng = crate::test::rng(412);
675 let mut buf = [0u32; 8];
676 assert_eq!(iter.clone().sample_fill(&mut rng, &mut buf), v.len());
677 assert_eq!(&buf[0..v.len()], v);
678
679 #[cfg(feature = "alloc")]
680 {
681 let mut rng = crate::test::rng(412);
682 assert_eq!(iter.sample(&mut rng, v.len()), v);
683 }
684 }
685
686 do_test(0..4, &[0, 1, 2, 3]);
687 do_test(0..8, &[0, 1, 2, 3, 4, 5, 6, 7]);
688 do_test(0..100, &[77, 95, 38, 23, 25, 8, 58, 40]);
689 }
690}