Skip to content

shuf: Add --random-seed, make --random-source GNU-compatible, report write failures, optimize #7585

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
3 changes: 3 additions & 0 deletions .vscode/cspell.dictionaries/jargon.wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ fileio
filesystem
filesystems
flamegraph
footgun
fsxattr
fullblock
getfacl
Expand Down Expand Up @@ -76,6 +77,7 @@ mergeable
microbenchmark
microbenchmarks
microbenchmarking
monomorphized
multibyte
multicall
nmerge
Expand All @@ -89,6 +91,7 @@ nolinks
nonblock
nonportable
nonprinting
nonrepeating
nonseekable
notrunc
noxfer
Expand Down
3 changes: 3 additions & 0 deletions .vscode/cspell.dictionaries/people.wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ Boden Garman
Chirag B Jadwani
Chirag
Jadwani
Daniel Lemire
Daniel
Lemire
Derek Chiang
Derek
Chiang
Expand Down
1 change: 1 addition & 0 deletions .vscode/cspell.dictionaries/workspace.wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ getrandom
globset
indicatif
itertools
itoa
lscolors
mdbook
memchr
Expand Down
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ hostname = "0.4"
iana-time-zone = "0.1.57"
indicatif = "0.17.8"
itertools = "0.14.0"
itoa = "1.0.15"
libc = "0.2.153"
linux-raw-sys = "0.9"
lscolors = { version = "0.20.0", default-features = false, features = [
Expand All @@ -322,6 +323,7 @@ phf_codegen = "0.11.2"
platform-info = "2.0.3"
quick-error = "2.0.1"
rand = { version = "0.9.0", features = ["small_rng"] }
rand_chacha = { version = "0.9.0" }
rand_core = "0.9.0"
rayon = "1.10"
regex = "1.10.4"
Expand Down
3 changes: 3 additions & 0 deletions src/uu/shuf/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ path = "src/shuf.rs"

[dependencies]
clap = { workspace = true }
itoa = { workspace = true }
rand = { workspace = true }
rand_chacha = { workspace = true }
rand_core = { workspace = true }
sha3 = { workspace = true }
uucore = { workspace = true }

[[bin]]
Expand Down
119 changes: 119 additions & 0 deletions src/uu/shuf/src/compat_random_source.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// This file is part of the uutils coreutils package.
//
// For the full copyright and license information, please view the LICENSE
// file that was distributed with this source code.

use std::{io::BufRead, ops::RangeInclusive};

use uucore::error::{FromIo, UResult, USimpleError};

/// A uniform integer generator that tries to exactly match GNU shuf's --random-source.
///
/// It's not particularly efficient and possibly not quite uniform. It should *only* be
/// used for compatibility with GNU: other modes shouldn't touch this code.
///
/// All the logic here was black box reverse engineered. It might not match up in all edge
/// cases but it gives identical results on many different large and small inputs.
///
/// It seems that GNU uses fairly textbook rejection sampling to generate integers, reading
/// one byte at a time until it has enough entropy, and recycling leftover entropy after
/// accepting or rejecting a value.
///
/// To do your own experiments, start with commands like these:
///
/// printf '\x01\x02\x03\x04' | shuf -i0-255 -r --random-source=/dev/stdin
///
/// Then vary the integer range and the input and the input length. It can be useful to
/// see when exactly shuf crashes with an "end of file" error.
///
/// To spot small inconsistencies it's useful to run:
///
/// diff -y <(my_shuf ...) <(shuf -i0-{MAX} -r --random-source={INPUT}) | head -n 50
pub struct RandomSourceAdapter<R> {
reader: R,
state: u64,
entropy: u64,
}

impl<R> RandomSourceAdapter<R> {
pub fn new(reader: R) -> Self {
Self {
reader,
state: 0,
entropy: 0,
}
}
}

impl<R: BufRead> RandomSourceAdapter<R> {
fn generate_at_most(&mut self, at_most: u64) -> UResult<u64> {
while self.entropy < at_most {
let buf = self
.reader
.fill_buf()
.map_err_context(|| "reading random bytes failed".into())?;
let Some(&byte) = buf.first() else {
return Err(USimpleError::new(1, "end of random source"));
};
self.reader.consume(1);
// Is overflow OK here? Won't it cause bias? (Seems to work out...)
self.state = self.state.wrapping_mul(256).wrapping_add(byte as u64);
self.entropy = self.entropy.wrapping_mul(256).wrapping_add(255);
}

if at_most == u64::MAX {
// at_most + 1 would overflow but this case is easy.
let val = self.state;
self.entropy = 0;
self.state = 0;
return Ok(val);
}

let num_possibilities = at_most + 1;

// If the generated number falls within this margin at the upper end of the
// range then we retry to avoid modulo bias.
let margin = ((self.entropy as u128 + 1) % num_possibilities as u128) as u64;
let safe_zone = self.entropy - margin;

if self.state <= safe_zone {
let val = self.state % num_possibilities;
// Reuse the rest of the state.
self.state /= num_possibilities;
// We need this subtraction, otherwise we consume new input slightly more
// slowly than GNU. Not sure if it checks out mathematically.
self.entropy -= at_most;
self.entropy /= num_possibilities;
Ok(val)
} else {
self.state %= num_possibilities;
self.entropy %= num_possibilities;
// I sure hope the compiler optimizes this tail call.
self.generate_at_most(at_most)
}
}

pub fn choose_from_range(&mut self, range: RangeInclusive<u64>) -> UResult<u64> {
let offset = self.generate_at_most(*range.end() - *range.start())?;
Ok(*range.start() + offset)
}

pub fn choose_from_slice<T: Copy>(&mut self, vals: &[T]) -> UResult<T> {
assert!(!vals.is_empty());
let idx = self.generate_at_most(vals.len() as u64 - 1)? as usize;
Ok(vals[idx])
}

pub fn shuffle<'a, T>(&mut self, vals: &'a mut [T], amount: usize) -> UResult<&'a mut [T]> {
// Fisher-Yates shuffle.
// TODO: GNU does something different if amount <= vals.len() and the input is stdin.
// The order changes completely and depends on --head-count.
// No clue what they might do differently and why.
let amount = amount.min(vals.len());
for idx in 0..amount {
let other_idx = self.generate_at_most((vals.len() - idx - 1) as u64)? as usize + idx;
vals.swap(idx, other_idx);
}
Ok(&mut vals[..amount])
}
}
111 changes: 111 additions & 0 deletions src/uu/shuf/src/nonrepeating_iterator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
use std::collections::HashMap;
use std::ops::RangeInclusive;

use uucore::error::UResult;

use crate::WrappedRng;

/// An iterator that samples from an integer range without repetition.
///
/// This is based on Fisher-Yates, and it's required for backward compatibility
/// that it behaves exactly like Fisher-Yates if --random-source or --random-seed
/// is used. But we have a few tricks:
///
/// - In the beginning we use a hash table instead of an array. This way we lazily
/// keep track of swaps without allocating the entire range upfront.
///
/// - When the hash table starts to get big relative to the remaining items
/// we switch over to an array.
///
/// - We store the array backwards so that we can shrink it as we go and free excess
/// memory every now and then.
///
/// Both the hash table and the array give the same output.
///
/// There's room for optimization:
///
/// - Switching over from the hash table to the array is costly. If we happen to know
/// (through --head-count) that only few draws remain then it would be better not
/// to switch.
///
/// - If the entire range gets used then we might as well allocate an array to start
/// with. But if the user e.g. pipes through `head` rather than using --head-count
/// we can't know whether that's the case, so there's a tradeoff.
///
/// GNU decides the other way: --head-count is noticeably faster than | head.
pub(crate) struct NonrepeatingIterator<'a> {
rng: &'a mut WrappedRng,
values: Values,
}

enum Values {
Full(Vec<u64>),
Sparse(RangeInclusive<u64>, HashMap<u64, u64>),
}

impl<'a> NonrepeatingIterator<'a> {
pub(crate) fn new(range: RangeInclusive<u64>, rng: &'a mut WrappedRng) -> Self {
let values = Values::Sparse(range, HashMap::default());
NonrepeatingIterator { rng, values }
}

fn produce(&mut self) -> UResult<u64> {
match &mut self.values {
Values::Full(items) => {
let this_idx = items.len() - 1;

let other_idx = self.rng.choose_from_range(0..=items.len() as u64 - 1)? as usize;
// Flip the index to pretend we're going left-to-right
let other_idx = items.len() - other_idx - 1;

items.swap(this_idx, other_idx);

let val = items.pop().unwrap();
if items.len().is_power_of_two() && items.len() >= 512 {
items.shrink_to_fit();
}
Ok(val)
}
Values::Sparse(range, items) => {
let this_idx = *range.start();
let this_val = items.remove(&this_idx).unwrap_or(this_idx);

let other_idx = self.rng.choose_from_range(range.clone())?;

let val = if this_idx != other_idx {
items.insert(other_idx, this_val).unwrap_or(other_idx)
} else {
this_val
};
*range = *range.start() + 1..=*range.end();

Ok(val)
}
}
}
}

impl Iterator for NonrepeatingIterator<'_> {
type Item = UResult<u64>;

fn next(&mut self) -> Option<Self::Item> {
match &self.values {
Values::Full(items) if items.is_empty() => return None,
Values::Full(_) => (),
Values::Sparse(range, _) if range.is_empty() => return None,
Values::Sparse(range, items) => {
let range_len = range.size_hint().0 as u64;
if items.len() as u64 >= range_len / 8 {
self.values = Values::Full(hashmap_to_vec(range.clone(), items));
}
}
}

Some(self.produce())
}
}

fn hashmap_to_vec(range: RangeInclusive<u64>, map: &HashMap<u64, u64>) -> Vec<u64> {
let lookup = |idx| *map.get(&idx).unwrap_or(&idx);
range.rev().map(lookup).collect()
}
Loading