Skip to content

Commit e89aaf8

Browse files
committed
shuf: correctness: Use Fisher-Yates for nonrepeating integers
We used to use a clever homegrown way to sample integers. But GNU shuf with --random-source observably uses Fisher-Yates, and the output of the old version depended on a heuristic (making it dangerous for --random-seed). So now we do Fisher-Yates here, just like we do for other inputs. In deterministic modes the output for --input-range is identical that for piping `seq` into `shuf`. We imitate the old algorithm's method for keeping the resource use in check. The performance of the new version is very close to that of the old version: I haven't found any cases where it's much faster or much slower.
1 parent 163e6f9 commit e89aaf8

File tree

4 files changed

+85
-155
lines changed

4 files changed

+85
-155
lines changed
+82-148
Original file line numberDiff line numberDiff line change
@@ -1,74 +1,85 @@
1-
// spell-checker:ignore nonrepeating
2-
3-
// TODO: this iterator is not compatible with GNU when --random-source is used
4-
5-
use std::{collections::HashSet, ops::RangeInclusive};
1+
use std::collections::HashMap;
2+
use std::ops::RangeInclusive;
63

74
use uucore::error::UResult;
85

96
use crate::WrappedRng;
107

11-
enum NumberSet {
12-
AlreadyListed(HashSet<u64>),
13-
Remaining(Vec<u64>),
14-
}
15-
8+
/// An iterator that samples from an integer range without repetition.
9+
///
10+
/// This is based on Fisher-Yates, and it's required for backward compatibility
11+
/// that it behaves exactly like Fisher-Yates if --random-source or --random-seed
12+
/// is used. But we have a few tricks:
13+
///
14+
/// - In the beginning we use a hash table instead of an array. This way we lazily
15+
/// keep track of swaps without allocating the entire range upfront.
16+
///
17+
/// - When the hash table starts to get big relative to the remaining items
18+
/// we switch over to an array.
19+
///
20+
/// - We store the array backwards so that we can shrink it as we go and free excess
21+
/// memory every now and then.
22+
///
23+
/// Both the hash table and the array give the same output.
24+
///
25+
/// There's room for optimization:
26+
///
27+
/// - Switching over from the hash table to the array is costly. If we happen to know
28+
/// (through --head-count) that only few draws remain then it would be better not
29+
/// to switch.
30+
///
31+
/// - If the entire range gets used then we might as well allocate an array to start
32+
/// with. But if the user e.g. pipes through `head` rather than using --head-count
33+
/// we can't know whether that's the case, so there's a tradeoff.
34+
///
35+
/// GNU decides the other way: --head-count is noticeably faster than | head.
1636
pub(crate) struct NonrepeatingIterator<'a> {
17-
range: RangeInclusive<u64>,
1837
rng: &'a mut WrappedRng,
19-
remaining_count: u64,
20-
buf: NumberSet,
38+
values: Values,
39+
}
40+
41+
enum Values {
42+
Full(Vec<u64>),
43+
Sparse(RangeInclusive<u64>, HashMap<u64, u64>),
2144
}
2245

2346
impl<'a> NonrepeatingIterator<'a> {
24-
pub(crate) fn new(range: RangeInclusive<u64>, rng: &'a mut WrappedRng, amount: u64) -> Self {
25-
let capped_amount = if range.start() > range.end() {
26-
0
27-
} else if range == (0..=u64::MAX) {
28-
amount
29-
} else {
30-
amount.min(range.end() - range.start() + 1)
31-
};
32-
NonrepeatingIterator {
33-
range,
34-
rng,
35-
remaining_count: capped_amount,
36-
buf: NumberSet::AlreadyListed(HashSet::default()),
37-
}
47+
pub(crate) fn new(range: RangeInclusive<u64>, rng: &'a mut WrappedRng) -> Self {
48+
let values = Values::Sparse(range, HashMap::default());
49+
NonrepeatingIterator { rng, values }
3850
}
3951

4052
fn produce(&mut self) -> UResult<u64> {
41-
debug_assert!(self.range.start() <= self.range.end());
42-
match &mut self.buf {
43-
NumberSet::AlreadyListed(already_listed) => {
44-
let chosen = loop {
45-
let guess = self.rng.choose_from_range(self.range.clone())?;
46-
let newly_inserted = already_listed.insert(guess);
47-
if newly_inserted {
48-
break guess;
49-
}
50-
};
51-
// Once a significant fraction of the interval has already been enumerated,
52-
// the number of attempts to find a number that hasn't been chosen yet increases.
53-
// Therefore, we need to switch at some point from "set of already returned values" to "list of remaining values".
54-
let range_size = (self.range.end() - self.range.start()).saturating_add(1);
55-
if number_set_should_list_remaining(already_listed.len() as u64, range_size) {
56-
let mut remaining = self
57-
.range
58-
.clone()
59-
.filter(|n| !already_listed.contains(n))
60-
.collect::<Vec<_>>();
61-
assert!(remaining.len() as u64 >= self.remaining_count);
62-
remaining.truncate(self.remaining_count as usize);
63-
self.rng.shuffle(&mut remaining, usize::MAX)?;
64-
self.buf = NumberSet::Remaining(remaining);
53+
match &mut self.values {
54+
Values::Full(items) => {
55+
let this_idx = items.len() - 1;
56+
57+
let other_idx = self.rng.choose_from_range(0..=items.len() as u64 - 1)? as usize;
58+
// Flip the index to pretend we're going left-to-right
59+
let other_idx = items.len() - other_idx - 1;
60+
61+
items.swap(this_idx, other_idx);
62+
63+
let val = items.pop().unwrap();
64+
if items.len().is_power_of_two() && items.len() >= 512 {
65+
items.shrink_to_fit();
6566
}
66-
Ok(chosen)
67+
Ok(val)
6768
}
68-
NumberSet::Remaining(remaining_numbers) => {
69-
debug_assert!(!remaining_numbers.is_empty());
70-
// We only enter produce() when there is at least one actual element remaining, so popping must always return an element.
71-
Ok(remaining_numbers.pop().unwrap())
69+
Values::Sparse(range, items) => {
70+
let this_idx = *range.start();
71+
let this_val = items.remove(&this_idx).unwrap_or(this_idx);
72+
73+
let other_idx = self.rng.choose_from_range(range.clone())?;
74+
75+
let val = if this_idx != other_idx {
76+
items.insert(other_idx, this_val).unwrap_or(other_idx)
77+
} else {
78+
this_val
79+
};
80+
*range = *range.start() + 1..=*range.end();
81+
82+
Ok(val)
7283
}
7384
}
7485
}
@@ -77,101 +88,24 @@ impl<'a> NonrepeatingIterator<'a> {
7788
impl Iterator for NonrepeatingIterator<'_> {
7889
type Item = UResult<u64>;
7990

80-
fn next(&mut self) -> Option<UResult<u64>> {
81-
if self.range.is_empty() || self.remaining_count == 0 {
82-
return None;
91+
fn next(&mut self) -> Option<Self::Item> {
92+
match &self.values {
93+
Values::Full(items) if items.is_empty() => return None,
94+
Values::Full(_) => (),
95+
Values::Sparse(range, _) if range.is_empty() => return None,
96+
Values::Sparse(range, items) => {
97+
let range_len = range.size_hint().0 as u64;
98+
if range_len > 16 && items.len() as u64 >= range_len / 8 {
99+
self.values = Values::Full(hashmap_to_vec(range.clone(), items));
100+
}
101+
}
83102
}
84-
self.remaining_count -= 1;
103+
85104
Some(self.produce())
86105
}
87106
}
88107

89-
// This could be a method, but it is much easier to test as a stand-alone function.
90-
fn number_set_should_list_remaining(listed_count: u64, range_size: u64) -> bool {
91-
// Arbitrarily determine the switchover point to be around 25%. This is because:
92-
// - HashSet has a large space overhead for the hash table load factor.
93-
// - This means that somewhere between 25-40%, the memory required for a "positive" HashSet and a "negative" Vec should be the same.
94-
// - HashSet has a small but non-negligible overhead for each lookup, so we have a slight preference for Vec anyway.
95-
// - At 25%, on average 1.33 attempts are needed to find a number that hasn't been taken yet.
96-
// - Finally, "24%" is computationally the simplest:
97-
listed_count >= range_size / 4
98-
}
99-
100-
#[cfg(test)]
101-
// Since the computed value is a bool, it is more readable to write the expected value out:
102-
#[allow(clippy::bool_assert_comparison)]
103-
mod test_number_set_decision {
104-
use super::number_set_should_list_remaining;
105-
106-
#[test]
107-
fn test_stay_positive_large_remaining_first() {
108-
assert_eq!(false, number_set_should_list_remaining(0, u64::MAX));
109-
}
110-
111-
#[test]
112-
fn test_stay_positive_large_remaining_second() {
113-
assert_eq!(false, number_set_should_list_remaining(1, u64::MAX));
114-
}
115-
116-
#[test]
117-
fn test_stay_positive_large_remaining_tenth() {
118-
assert_eq!(false, number_set_should_list_remaining(9, u64::MAX));
119-
}
120-
121-
#[test]
122-
fn test_stay_positive_smallish_range_first() {
123-
assert_eq!(false, number_set_should_list_remaining(0, 12345));
124-
}
125-
126-
#[test]
127-
fn test_stay_positive_smallish_range_second() {
128-
assert_eq!(false, number_set_should_list_remaining(1, 12345));
129-
}
130-
131-
#[test]
132-
fn test_stay_positive_smallish_range_tenth() {
133-
assert_eq!(false, number_set_should_list_remaining(9, 12345));
134-
}
135-
136-
#[test]
137-
fn test_stay_positive_small_range_not_too_early() {
138-
assert_eq!(false, number_set_should_list_remaining(1, 10));
139-
}
140-
141-
// Don't want to test close to the border, in case we decide to change the threshold.
142-
// However, at 50% coverage, we absolutely should switch:
143-
#[test]
144-
fn test_switch_half() {
145-
assert_eq!(true, number_set_should_list_remaining(1234, 2468));
146-
}
147-
148-
// Ensure that the decision is monotonous:
149-
#[test]
150-
fn test_switch_late1() {
151-
assert_eq!(true, number_set_should_list_remaining(12340, 12345));
152-
}
153-
154-
#[test]
155-
fn test_switch_late2() {
156-
assert_eq!(true, number_set_should_list_remaining(12344, 12345));
157-
}
158-
159-
// Ensure that we are overflow-free:
160-
#[test]
161-
fn test_no_crash_exceed_max_size1() {
162-
assert_eq!(false, number_set_should_list_remaining(12345, u64::MAX));
163-
}
164-
165-
#[test]
166-
fn test_no_crash_exceed_max_size2() {
167-
assert_eq!(
168-
true,
169-
number_set_should_list_remaining(u64::MAX - 1, u64::MAX)
170-
);
171-
}
172-
173-
#[test]
174-
fn test_no_crash_exceed_max_size3() {
175-
assert_eq!(true, number_set_should_list_remaining(u64::MAX, u64::MAX));
176-
}
108+
fn hashmap_to_vec(range: RangeInclusive<u64>, map: &HashMap<u64, u64>) -> Vec<u64> {
109+
let lookup = |idx| *map.get(&idx).unwrap_or(&idx);
110+
range.rev().map(lookup).collect()
177111
}

src/uu/shuf/src/random_seed.rs

-3
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,6 @@ use sha3::{Digest as _, Sha3_256};
3333
///
3434
/// - Without --repeat, use these to do left-to-right modern Fisher-Yates.
3535
///
36-
/// - Or for --input-range without --repeat, do whatever NonrepeatingIterator does.
37-
/// (We may want to change that. Watch this space.)
38-
///
3936
/// # Why it works like this
4037
///
4138
/// - Unicode string: Greatest common denominator between platforms. Windows doesn't

src/uu/shuf/src/shuf.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,8 @@ impl Shufable for RangeInclusive<u64> {
346346
rng: &'b mut WrappedRng,
347347
amount: u64,
348348
) -> UResult<impl Iterator<Item = UResult<Self::Item>>> {
349-
Ok(NonrepeatingIterator::new(self.clone(), rng, amount))
349+
let amount = usize::try_from(amount).unwrap_or(usize::MAX);
350+
Ok(NonrepeatingIterator::new(self.clone(), rng).take(amount))
350351
}
351352
}
352353

tests/by-util/test_shuf.rs

+1-3
Original file line numberDiff line numberDiff line change
@@ -1001,8 +1001,6 @@ fn test_gnu_compat_limited_from_stdin() {
10011001
.stdout_is("6\n5\n1\n3\n2\n7\n4\n");
10021002
}
10031003

1004-
// We haven't reverse-engineered GNU's nonrepeating integer sampling yet.
1005-
#[ignore = "disabled until fixed"]
10061004
#[test]
10071005
fn test_gnu_compat_range_no_repeat() {
10081006
let (at, mut ucmd) = at_and_ucmd!();
@@ -1064,5 +1062,5 @@ fn test_seed_range_no_repeat() {
10641062
.arg("-i1-10")
10651063
.succeeds()
10661064
.no_stderr()
1067-
.stdout_is("8\n9\n5\n10\n1\n2\n4\n7\n3\n6\n");
1065+
.stdout_is("8\n9\n1\n5\n2\n6\n4\n3\n10\n7\n");
10681066
}

0 commit comments

Comments
 (0)