Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 28 additions & 23 deletions src/uu/shuf/src/shuf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use rand::{Rng, RngCore};
use std::collections::HashSet;
use std::fs::File;
use std::io::{stdin, stdout, BufReader, BufWriter, Error, Read, Write};
use std::ops::RangeInclusive;
use uucore::display::Quotable;
use uucore::error::{FromIo, UResult, USimpleError, UUsageError};
use uucore::{format_usage, help_about, help_usage};
Expand All @@ -21,7 +22,7 @@ mod rand_read_adapter;
enum Mode {
Default(String),
Echo(Vec<String>),
InputRange((usize, usize)),
InputRange(RangeInclusive<usize>),
}

static USAGE: &str = help_usage!("shuf.md");
Expand Down Expand Up @@ -119,8 +120,8 @@ pub fn uumain(args: impl uucore::Args) -> UResult<()> {
find_seps(&mut evec, options.sep);
shuf_exec(&mut evec, options)?;
}
Mode::InputRange((b, e)) => {
shuf_exec(&mut (b, e), options)?;
Mode::InputRange(mut range) => {
shuf_exec(&mut range, options)?;
}
Mode::Default(filename) => {
let fdata = read_input_file(&filename)?;
Expand Down Expand Up @@ -289,22 +290,21 @@ impl<'a> Shufable for Vec<&'a [u8]> {
}
}

impl Shufable for (usize, usize) {
impl Shufable for RangeInclusive<usize> {
type Item = usize;
fn is_empty(&self) -> bool {
// Note: This is an inclusive range, so equality means there is 1 element.
self.0 > self.1
self.is_empty()
}
fn choose(&self, rng: &mut WrappedRng) -> usize {
rng.gen_range(self.0..self.1)
rng.gen_range(self.clone())
}
type PartialShuffleIterator<'b> = NonrepeatingIterator<'b> where Self: 'b;
fn partial_shuffle<'b>(
&'b mut self,
rng: &'b mut WrappedRng,
amount: usize,
) -> Self::PartialShuffleIterator<'b> {
NonrepeatingIterator::new(self.0, self.1, rng, amount)
NonrepeatingIterator::new(self.clone(), rng, amount)
}
}

Expand All @@ -314,40 +314,39 @@ enum NumberSet {
}

struct NonrepeatingIterator<'a> {
begin: usize,
end: usize, // exclusive
range: RangeInclusive<usize>,
rng: &'a mut WrappedRng,
remaining_count: usize,
buf: NumberSet,
}

impl<'a> NonrepeatingIterator<'a> {
fn new(
begin: usize,
end: usize,
range: RangeInclusive<usize>,
rng: &'a mut WrappedRng,
amount: usize,
) -> NonrepeatingIterator {
let capped_amount = if begin > end {
let capped_amount = if range.start() > range.end() {
0
} else if *range.start() == 0 && *range.end() == std::usize::MAX {
amount
} else {
amount.min(end - begin)
amount.min(range.end() - range.start() + 1)
};
NonrepeatingIterator {
begin,
end,
range,
rng,
remaining_count: capped_amount,
buf: NumberSet::AlreadyListed(HashSet::default()),
}
}

fn produce(&mut self) -> usize {
debug_assert!(self.begin <= self.end);
debug_assert!(self.range.start() <= self.range.end());
match &mut self.buf {
NumberSet::AlreadyListed(already_listed) => {
let chosen = loop {
let guess = self.rng.gen_range(self.begin..self.end);
let guess = self.rng.gen_range(self.range.clone());
let newly_inserted = already_listed.insert(guess);
if newly_inserted {
break guess;
Expand All @@ -356,9 +355,11 @@ impl<'a> NonrepeatingIterator<'a> {
// Once a significant fraction of the interval has already been enumerated,
// the number of attempts to find a number that hasn't been chosen yet increases.
// Therefore, we need to switch at some point from "set of already returned values" to "list of remaining values".
let range_size = self.end - self.begin;
let range_size = (self.range.end() - self.range.start()).saturating_add(1);
if number_set_should_list_remaining(already_listed.len(), range_size) {
let mut remaining = (self.begin..self.end)
let mut remaining = self
.range
.clone()
.filter(|n| !already_listed.contains(n))
.collect::<Vec<_>>();
assert!(remaining.len() >= self.remaining_count);
Expand All @@ -381,7 +382,7 @@ impl<'a> Iterator for NonrepeatingIterator<'a> {
type Item = usize;

fn next(&mut self) -> Option<usize> {
if self.begin > self.end || self.remaining_count == 0 {
if self.range.is_empty() || self.remaining_count == 0 {
return None;
}
self.remaining_count -= 1;
Expand Down Expand Up @@ -462,15 +463,19 @@ fn shuf_exec(input: &mut impl Shufable, opts: Options) -> UResult<()> {
Ok(())
}

fn parse_range(input_range: &str) -> Result<(usize, usize), String> {
fn parse_range(input_range: &str) -> Result<RangeInclusive<usize>, String> {
if let Some((from, to)) = input_range.split_once('-') {
let begin = from
.parse::<usize>()
.map_err(|_| format!("invalid input range: {}", from.quote()))?;
let end = to
.parse::<usize>()
.map_err(|_| format!("invalid input range: {}", to.quote()))?;
Ok((begin, end + 1))
if begin <= end || begin == end + 1 {
Ok(begin..=end)
} else {
Err(format!("invalid input range: {}", input_range.quote()))
}
} else {
Err(format!("invalid input range: {}", input_range.quote()))
}
Expand Down
147 changes: 146 additions & 1 deletion tests/by-util/test_shuf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,99 @@ fn test_very_large_range_offset() {
);
}

#[test]
fn test_range_repeat_no_overflow_1_max() {
let upper_bound = std::usize::MAX;
let result = new_ucmd!()
.arg("-rn1")
.arg(&format!("-i1-{upper_bound}"))
.succeeds();
result.no_stderr();

let result_seq: Vec<usize> = result
.stdout_str()
.split('\n')
.filter(|x| !x.is_empty())
.map(|x| x.parse().unwrap())
.collect();
assert_eq!(result_seq.len(), 1, "Miscounted output length!");
}

#[test]
fn test_range_repeat_no_overflow_0_max_minus_1() {
let upper_bound = std::usize::MAX - 1;
let result = new_ucmd!()
.arg("-rn1")
.arg(&format!("-i0-{upper_bound}"))
.succeeds();
result.no_stderr();

let result_seq: Vec<usize> = result
.stdout_str()
.split('\n')
.filter(|x| !x.is_empty())
.map(|x| x.parse().unwrap())
.collect();
assert_eq!(result_seq.len(), 1, "Miscounted output length!");
}

#[test]
fn test_range_permute_no_overflow_1_max() {
let upper_bound = std::usize::MAX;
let result = new_ucmd!()
.arg("-n1")
.arg(&format!("-i1-{upper_bound}"))
.succeeds();
result.no_stderr();

let result_seq: Vec<usize> = result
.stdout_str()
.split('\n')
.filter(|x| !x.is_empty())
.map(|x| x.parse().unwrap())
.collect();
assert_eq!(result_seq.len(), 1, "Miscounted output length!");
}

#[test]
fn test_range_permute_no_overflow_0_max_minus_1() {
let upper_bound = std::usize::MAX - 1;
let result = new_ucmd!()
.arg("-n1")
.arg(&format!("-i0-{upper_bound}"))
.succeeds();
result.no_stderr();

let result_seq: Vec<usize> = result
.stdout_str()
.split('\n')
.filter(|x| !x.is_empty())
.map(|x| x.parse().unwrap())
.collect();
assert_eq!(result_seq.len(), 1, "Miscounted output length!");
}

#[test]
fn test_range_permute_no_overflow_0_max() {
// NOTE: This is different from GNU shuf!
// GNU shuf accepts -i0-MAX-1 and -i1-MAX, but not -i0-MAX.
// This feels like a bug in GNU shuf.
let upper_bound = std::usize::MAX;
let result = new_ucmd!()
.arg("-n1")
.arg(&format!("-i0-{upper_bound}"))
.succeeds();
result.no_stderr();

let result_seq: Vec<usize> = result
.stdout_str()
.split('\n')
.filter(|x| !x.is_empty())
.map(|x| x.parse().unwrap())
.collect();
assert_eq!(result_seq.len(), 1, "Miscounted output length!");
}

#[test]
fn test_very_high_range_full() {
let input_seq = vec![
Expand Down Expand Up @@ -626,7 +719,6 @@ fn test_shuf_multiple_input_line_count() {
}

#[test]
#[ignore = "known issue"]
fn test_shuf_repeat_empty_range() {
new_ucmd!()
.arg("-ri4-3")
Expand All @@ -653,3 +745,56 @@ fn test_shuf_repeat_empty_input() {
.no_stdout()
.stderr_only("shuf: no lines to repeat\n");
}

#[test]
fn test_range_one_elem() {
new_ucmd!()
.arg("-i5-5")
.succeeds()
.no_stderr()
.stdout_only("5\n");
}

#[test]
fn test_range_empty() {
new_ucmd!().arg("-i5-4").succeeds().no_output();
}

#[test]
fn test_range_empty_minus_one() {
new_ucmd!()
.arg("-i5-3")
.fails()
.no_stdout()
.stderr_only("shuf: invalid input range: '5-3'\n");
}

#[test]
fn test_range_repeat_one_elem() {
new_ucmd!()
.arg("-n1")
.arg("-ri5-5")
.succeeds()
.no_stderr()
.stdout_only("5\n");
}

#[test]
fn test_range_repeat_empty() {
new_ucmd!()
.arg("-n1")
.arg("-ri5-4")
.fails()
.no_stdout()
.stderr_only("shuf: no lines to repeat\n");
}

#[test]
fn test_range_repeat_empty_minus_one() {
new_ucmd!()
.arg("-n1")
.arg("-ri5-3")
.fails()
.no_stdout()
.stderr_only("shuf: invalid input range: '5-3'\n");
}