Skip to content

Commit 8301a8e

Browse files
authored
Merge pull request #5980 from BenWiederhake/dev-shuf-number-speed
shuf: Fix OOM crash for huge number ranges
2 parents 17174ab + f25b210 commit 8301a8e

File tree

3 files changed

+368
-20
lines changed

3 files changed

+368
-20
lines changed

src/uu/shuf/BENCHMARKING.md

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@ a range of numbers to randomly sample from. An example of a command that works
2828
well for testing:
2929

3030
```shell
31-
hyperfine --warmup 10 "target/release/shuf -i 0-10000000"
31+
hyperfine --warmup 10 "target/release/shuf -i 0-10000000 > /dev/null"
3232
```
3333

3434
To measure the time taken by shuffling an input file, the following command can
35-
be used::
35+
be used:
3636

3737
```shell
3838
hyperfine --warmup 10 "target/release/shuf input.txt > /dev/null"
@@ -49,5 +49,14 @@ should be benchmarked separately. In this case, we have to pass the `-n` flag or
4949
the command will run forever. An example of a hyperfine command is
5050

5151
```shell
52-
hyperfine --warmup 10 "target/release/shuf -r -n 10000000 -i 0-1000"
52+
hyperfine --warmup 10 "target/release/shuf -r -n 10000000 -i 0-1000 > /dev/null"
53+
```
54+
55+
## With huge interval ranges
56+
57+
When `shuf` runs with huge interval ranges, special care must be taken, so it
58+
should be benchmarked separately also. An example of a hyperfine command is
59+
60+
```shell
61+
hyperfine --warmup 10 "target/release/shuf -n 100 -i 1000-2000000000 > /dev/null"
5362
```

src/uu/shuf/src/shuf.rs

Lines changed: 263 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
// For the full copyright and license information, please view the LICENSE
44
// file that was distributed with this source code.
55

6-
// spell-checker:ignore (ToDO) cmdline evec seps rvec fdata
6+
// spell-checker:ignore (ToDO) cmdline evec nonrepeating seps shufable rvec fdata
77

88
use clap::{crate_version, Arg, ArgAction, Command};
99
use memchr::memchr_iter;
1010
use rand::prelude::SliceRandom;
11-
use rand::RngCore;
11+
use rand::{Rng, RngCore};
12+
use std::collections::HashSet;
1213
use std::fs::File;
13-
use std::io::{stdin, stdout, BufReader, BufWriter, Read, Write};
14+
use std::io::{stdin, stdout, BufReader, BufWriter, Error, Read, Write};
1415
use uucore::display::Quotable;
1516
use uucore::error::{FromIo, UResult, USimpleError, UUsageError};
1617
use uucore::{format_usage, help_about, help_usage};
@@ -116,18 +117,16 @@ pub fn uumain(args: impl uucore::Args) -> UResult<()> {
116117
Mode::Echo(args) => {
117118
let mut evec = args.iter().map(String::as_bytes).collect::<Vec<_>>();
118119
find_seps(&mut evec, options.sep);
119-
shuf_bytes(&mut evec, options)?;
120+
shuf_exec(&mut evec, options)?;
120121
}
121122
Mode::InputRange((b, e)) => {
122-
let rvec = (b..e).map(|x| format!("{x}")).collect::<Vec<String>>();
123-
let mut rvec = rvec.iter().map(String::as_bytes).collect::<Vec<&[u8]>>();
124-
shuf_bytes(&mut rvec, options)?;
123+
shuf_exec(&mut (b, e), options)?;
125124
}
126125
Mode::Default(filename) => {
127126
let fdata = read_input_file(&filename)?;
128127
let mut fdata = vec![&fdata[..]];
129128
find_seps(&mut fdata, options.sep);
130-
shuf_bytes(&mut fdata, options)?;
129+
shuf_exec(&mut fdata, options)?;
131130
}
132131
}
133132

@@ -251,7 +250,173 @@ fn find_seps(data: &mut Vec<&[u8]>, sep: u8) {
251250
}
252251
}
253252

254-
fn shuf_bytes(input: &mut Vec<&[u8]>, opts: Options) -> UResult<()> {
253+
trait Shufable {
254+
type Item: Writable;
255+
fn is_empty(&self) -> bool;
256+
fn choose(&self, rng: &mut WrappedRng) -> Self::Item;
257+
// This type shouldn't even be known. However, because we want to support
258+
// Rust 1.70, it is not possible to return "impl Iterator".
259+
// TODO: When the MSRV is raised, rewrite this to return "impl Iterator".
260+
type PartialShuffleIterator<'b>: Iterator<Item = Self::Item>
261+
where
262+
Self: 'b;
263+
fn partial_shuffle<'b>(
264+
&'b mut self,
265+
rng: &'b mut WrappedRng,
266+
amount: usize,
267+
) -> Self::PartialShuffleIterator<'b>;
268+
}
269+
270+
impl<'a> Shufable for Vec<&'a [u8]> {
271+
type Item = &'a [u8];
272+
fn is_empty(&self) -> bool {
273+
(**self).is_empty()
274+
}
275+
fn choose(&self, rng: &mut WrappedRng) -> Self::Item {
276+
// Note: "copied()" only copies the reference, not the entire [u8].
277+
// Returns None if the slice is empty. We checked this before, so
278+
// this is safe.
279+
(**self).choose(rng).unwrap()
280+
}
281+
type PartialShuffleIterator<'b> = std::iter::Copied<std::slice::Iter<'b, &'a [u8]>> where Self: 'b;
282+
fn partial_shuffle<'b>(
283+
&'b mut self,
284+
rng: &'b mut WrappedRng,
285+
amount: usize,
286+
) -> Self::PartialShuffleIterator<'b> {
287+
// Note: "copied()" only copies the reference, not the entire [u8].
288+
(**self).partial_shuffle(rng, amount).0.iter().copied()
289+
}
290+
}
291+
292+
impl Shufable for (usize, usize) {
293+
type Item = usize;
294+
fn is_empty(&self) -> bool {
295+
// Note: This is an inclusive range, so equality means there is 1 element.
296+
self.0 > self.1
297+
}
298+
fn choose(&self, rng: &mut WrappedRng) -> usize {
299+
rng.gen_range(self.0..self.1)
300+
}
301+
type PartialShuffleIterator<'b> = NonrepeatingIterator<'b> where Self: 'b;
302+
fn partial_shuffle<'b>(
303+
&'b mut self,
304+
rng: &'b mut WrappedRng,
305+
amount: usize,
306+
) -> Self::PartialShuffleIterator<'b> {
307+
NonrepeatingIterator::new(self.0, self.1, rng, amount)
308+
}
309+
}
310+
311+
enum NumberSet {
312+
AlreadyListed(HashSet<usize>),
313+
Remaining(Vec<usize>),
314+
}
315+
316+
struct NonrepeatingIterator<'a> {
317+
begin: usize,
318+
end: usize, // exclusive
319+
rng: &'a mut WrappedRng,
320+
remaining_count: usize,
321+
buf: NumberSet,
322+
}
323+
324+
impl<'a> NonrepeatingIterator<'a> {
325+
fn new(
326+
begin: usize,
327+
end: usize,
328+
rng: &'a mut WrappedRng,
329+
amount: usize,
330+
) -> NonrepeatingIterator {
331+
let capped_amount = if begin > end {
332+
0
333+
} else {
334+
amount.min(end - begin)
335+
};
336+
NonrepeatingIterator {
337+
begin,
338+
end,
339+
rng,
340+
remaining_count: capped_amount,
341+
buf: NumberSet::AlreadyListed(HashSet::default()),
342+
}
343+
}
344+
345+
fn produce(&mut self) -> usize {
346+
debug_assert!(self.begin <= self.end);
347+
match &mut self.buf {
348+
NumberSet::AlreadyListed(already_listed) => {
349+
let chosen = loop {
350+
let guess = self.rng.gen_range(self.begin..self.end);
351+
let newly_inserted = already_listed.insert(guess);
352+
if newly_inserted {
353+
break guess;
354+
}
355+
};
356+
// Once a significant fraction of the interval has already been enumerated,
357+
// the number of attempts to find a number that hasn't been chosen yet increases.
358+
// Therefore, we need to switch at some point from "set of already returned values" to "list of remaining values".
359+
let range_size = self.end - self.begin;
360+
if number_set_should_list_remaining(already_listed.len(), range_size) {
361+
let mut remaining = (self.begin..self.end)
362+
.filter(|n| !already_listed.contains(n))
363+
.collect::<Vec<_>>();
364+
assert!(remaining.len() >= self.remaining_count);
365+
remaining.partial_shuffle(&mut self.rng, self.remaining_count);
366+
remaining.truncate(self.remaining_count);
367+
self.buf = NumberSet::Remaining(remaining);
368+
}
369+
chosen
370+
}
371+
NumberSet::Remaining(remaining_numbers) => {
372+
debug_assert!(!remaining_numbers.is_empty());
373+
// We only enter produce() when there is at least one actual element remaining, so popping must always return an element.
374+
remaining_numbers.pop().unwrap()
375+
}
376+
}
377+
}
378+
}
379+
380+
impl<'a> Iterator for NonrepeatingIterator<'a> {
381+
type Item = usize;
382+
383+
fn next(&mut self) -> Option<usize> {
384+
if self.begin > self.end || self.remaining_count == 0 {
385+
return None;
386+
}
387+
self.remaining_count -= 1;
388+
Some(self.produce())
389+
}
390+
}
391+
392+
// This could be a method, but it is much easier to test as a stand-alone function.
393+
fn number_set_should_list_remaining(listed_count: usize, range_size: usize) -> bool {
394+
// Arbitrarily determine the switchover point to be around 25%. This is because:
395+
// - HashSet has a large space overhead for the hash table load factor.
396+
// - This means that somewhere between 25-40%, the memory required for a "positive" HashSet and a "negative" Vec should be the same.
397+
// - HashSet has a small but non-negligible overhead for each lookup, so we have a slight preference for Vec anyway.
398+
// - At 25%, on average 1.33 attempts are needed to find a number that hasn't been taken yet.
399+
// - Finally, "24%" is computationally the simplest:
400+
listed_count >= range_size / 4
401+
}
402+
403+
trait Writable {
404+
fn write_all_to(&self, output: &mut impl Write) -> Result<(), Error>;
405+
}
406+
407+
impl<'a> Writable for &'a [u8] {
408+
fn write_all_to(&self, output: &mut impl Write) -> Result<(), Error> {
409+
output.write_all(self)
410+
}
411+
}
412+
413+
impl Writable for usize {
414+
fn write_all_to(&self, output: &mut impl Write) -> Result<(), Error> {
415+
output.write_all(format!("{self}").as_bytes())
416+
}
417+
}
418+
419+
fn shuf_exec(input: &mut impl Shufable, opts: Options) -> UResult<()> {
255420
let mut output = BufWriter::new(match opts.output {
256421
None => Box::new(stdout()) as Box<dyn Write>,
257422
Some(s) => {
@@ -276,22 +441,18 @@ fn shuf_bytes(input: &mut Vec<&[u8]>, opts: Options) -> UResult<()> {
276441

277442
if opts.repeat {
278443
for _ in 0..opts.head_count {
279-
// Returns None is the slice is empty. We checked this before, so
280-
// this is safe.
281-
let r = input.choose(&mut rng).unwrap();
444+
let r = input.choose(&mut rng);
282445

283-
output
284-
.write_all(r)
446+
r.write_all_to(&mut output)
285447
.map_err_context(|| "write failed".to_string())?;
286448
output
287449
.write_all(&[opts.sep])
288450
.map_err_context(|| "write failed".to_string())?;
289451
}
290452
} else {
291-
let (shuffled, _) = input.partial_shuffle(&mut rng, opts.head_count);
453+
let shuffled = input.partial_shuffle(&mut rng, opts.head_count);
292454
for r in shuffled {
293-
output
294-
.write_all(r)
455+
r.write_all_to(&mut output)
295456
.map_err_context(|| "write failed".to_string())?;
296457
output
297458
.write_all(&[opts.sep])
@@ -361,3 +522,88 @@ impl RngCore for WrappedRng {
361522
}
362523
}
363524
}
525+
526+
#[cfg(test)]
527+
// Since the computed value is a bool, it is more readable to write the expected value out:
528+
#[allow(clippy::bool_assert_comparison)]
529+
mod test_number_set_decision {
530+
use super::number_set_should_list_remaining;
531+
532+
#[test]
533+
fn test_stay_positive_large_remaining_first() {
534+
assert_eq!(false, number_set_should_list_remaining(0, std::usize::MAX));
535+
}
536+
537+
#[test]
538+
fn test_stay_positive_large_remaining_second() {
539+
assert_eq!(false, number_set_should_list_remaining(1, std::usize::MAX));
540+
}
541+
542+
#[test]
543+
fn test_stay_positive_large_remaining_tenth() {
544+
assert_eq!(false, number_set_should_list_remaining(9, std::usize::MAX));
545+
}
546+
547+
#[test]
548+
fn test_stay_positive_smallish_range_first() {
549+
assert_eq!(false, number_set_should_list_remaining(0, 12345));
550+
}
551+
552+
#[test]
553+
fn test_stay_positive_smallish_range_second() {
554+
assert_eq!(false, number_set_should_list_remaining(1, 12345));
555+
}
556+
557+
#[test]
558+
fn test_stay_positive_smallish_range_tenth() {
559+
assert_eq!(false, number_set_should_list_remaining(9, 12345));
560+
}
561+
562+
#[test]
563+
fn test_stay_positive_small_range_not_too_early() {
564+
assert_eq!(false, number_set_should_list_remaining(1, 10));
565+
}
566+
567+
// Don't want to test close to the border, in case we decide to change the threshold.
568+
// However, at 50% coverage, we absolutely should switch:
569+
#[test]
570+
fn test_switch_half() {
571+
assert_eq!(true, number_set_should_list_remaining(1234, 2468));
572+
}
573+
574+
// Ensure that the decision is monotonous:
575+
#[test]
576+
fn test_switch_late1() {
577+
assert_eq!(true, number_set_should_list_remaining(12340, 12345));
578+
}
579+
580+
#[test]
581+
fn test_switch_late2() {
582+
assert_eq!(true, number_set_should_list_remaining(12344, 12345));
583+
}
584+
585+
// Ensure that we are overflow-free:
586+
#[test]
587+
fn test_no_crash_exceed_max_size1() {
588+
assert_eq!(
589+
false,
590+
number_set_should_list_remaining(12345, std::usize::MAX)
591+
);
592+
}
593+
594+
#[test]
595+
fn test_no_crash_exceed_max_size2() {
596+
assert_eq!(
597+
true,
598+
number_set_should_list_remaining(std::usize::MAX - 1, std::usize::MAX)
599+
);
600+
}
601+
602+
#[test]
603+
fn test_no_crash_exceed_max_size3() {
604+
assert_eq!(
605+
true,
606+
number_set_should_list_remaining(std::usize::MAX, std::usize::MAX)
607+
);
608+
}
609+
}

0 commit comments

Comments
 (0)