Skip to content

Implement a first 4-ary trie to represent posting lists #37

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Optimize implementations and make blocks fully configurable.
Fastest results are achieved with 32 bytes per block which results in a 12.5% overhead.
64 bytes per block seems like a reasonable compromise between speed and memory overhead (6.25%).
  • Loading branch information
aneubeck committed Oct 31, 2024
commit 27118092a5c099e765b692ed523f7ebfbd0a2f18
168 changes: 93 additions & 75 deletions crates/quaternary_trie/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,86 +219,103 @@ impl QuarternaryTrie {
if level == 1 {
self.recurse2(node, value, results);
} else {
let n = self.data.get_nibble(node);
let value = value * 4;
let mut n = self.data.get_nibble(node);
let mut value = value * 4;
if n == 0 {
results.extend((0..4 << (level * 2)).map(|i| (value << (level * 2)) + i));
return;
}
let mut r = self.data.rank(node * 4) as usize;
if n & 1 != 0 {
while n > 0 {
let delta = n.trailing_zeros();
r += 1;
self.recurse(r, level - 1, value, results);
}
if n & 2 != 0 {
r += 1;
self.recurse(r, level - 1, value + 1, results);
}
if n & 4 != 0 {
r += 1;
self.recurse(r, level - 1, value + 2, results);
}
if n & 8 != 0 {
r += 1;
self.recurse(r, level - 1, value + 3, results);
self.recurse(r, level - 1, value + delta, results);
value += delta + 1;
n >>= delta + 1;
}
}
}

#[inline(always)]
fn recurse2(&self, node: usize, value: u32, results: &mut Vec<u32>) {
let n = self.data.get_nibble(node);
let value = value * 4;
let mut n = self.data.get_nibble(node);
let mut value = value * 4;
if n == 0 {
results.extend((0..16).map(|i| (value << 2) + i));
return;
}
let mut r = self.data.rank(node * 4) as usize;
if n & 1 != 0 {
r += 1;
self.recurse0(r, value, results);
}
if n & 2 != 0 {
r += 1;
self.recurse0(r, value + 1, results);
}
if n & 4 != 0 {
r += 1;
self.recurse0(r, value + 2, results);
}
if n & 8 != 0 {
while n > 0 {
let delta = n.trailing_zeros();
r += 1;
self.recurse0(r, value + 3, results);
self.recurse0(r, value + delta, results);
value += delta + 1;
n >>= delta + 1;
}
}

#[inline(always)]
fn recurse0(&self, node: usize, value: u32, results: &mut Vec<u32>) {
let n = self.data.get_nibble(node);
let value = value * 4;
let mut n = self.data.get_nibble(node);
let mut value = value * 4;
if n == 0 {
results.extend((0..4).map(|i| value + i));
return;
}
if n & 1 != 0 {
results.push(value);
}
if n & 2 != 0 {
results.push(value + 1);
}
if n & 4 != 0 {
results.push(value + 2);
}
if n & 8 != 0 {
results.push(value + 3);
while n > 0 {
let delta = n.trailing_zeros();
results.push(value + delta);
value += delta + 1;
n >>= delta + 1;
}
}

pub fn collect(&self) -> Vec<u32> {
// This is the "slow" implementation which computes at every level the rank and extract the corresponding nibble.
pub fn collect2(&self) -> Vec<u32> {
let mut results = Vec::with_capacity(self.level_idx[0]);
self.recurse(0, MAX_LEVEL - 1, 0, &mut results);
results
}

// This is the "fastest" implementation, since it doesn't use rank information at all during the traversal.
// This is possible, since it iterates through ALL nodes and thus we can simply increment the positions by 1.
// We only need the rank information to initialize the positions.
// The only remaining "expensive" part here is the lookup of the nibble with every iteration.
// This lookup requires the slightly complicated conversion from position into block pointer (either via the virtual mapping or via some math).
// The math would be trivial if we wouldn't store the counters within the same page...
// Instead one could try to cache a u64 value and keep shifting until the end is reached. Or working with the pointer into the bitrank array.
pub fn collect(&mut self) -> Vec<u32> {
self.level_idx[MAX_LEVEL - 1] = 0;
for level in (1..MAX_LEVEL).into_iter().rev() {
self.level_idx[level - 1] = self.data.rank(self.level_idx[level] * 4) as usize + 1;
}
let mut results = Vec::new();
self.fast_collect_inner(MAX_LEVEL - 1, 0, &mut results);
results
}

fn fast_collect_inner(&mut self, level: usize, value: u32, results: &mut Vec<u32>) {
let mut nibble = self.data.get_nibble(self.level_idx[level]);
self.level_idx[level] += 1;
if nibble == 0 {
results.extend((0..4 << (level * 2)).map(|i| (value << (level * 2)) + i));
return;
}
let mut value = value * 4;
if level == 0 {
while nibble > 0 {
let delta = nibble.trailing_zeros();
results.push(value + delta);
value += delta + 1;
nibble >>= delta + 1;
}
} else {
while nibble > 0 {
let delta = nibble.trailing_zeros();
self.fast_collect_inner(level - 1, value + delta, results);
value += delta + 1;
nibble >>= delta + 1;
}
}
}
}

pub trait TrieIteratorTrait {
Expand Down Expand Up @@ -328,53 +345,54 @@ impl TrieIteratorTrait for TrieTraversal<'_> {
fn down(&mut self, level: usize, child: u32) {
let index = self.pos[level] * 4 + child;
let new_index = self.trie.data.rank(index as usize + 1);
self.pos[level - 1] = new_index as u32;
self.pos[level - 1] = new_index;
}
}

pub struct TrieIterator<T> {
trie: T,
level: usize,
item: u32,
nibbles: [u32; MAX_LEVEL],
}

impl<T: TrieIteratorTrait> TrieIterator<T> {
pub fn new(trie: T) -> Self {
let mut result = Self {
Self {
trie,
level: MAX_LEVEL - 1,
item: 0,
nibbles: [0; MAX_LEVEL],
};
result.nibbles[result.level] = result.trie.get(result.level);
result
}
}
}

impl<'a, T: TrieIteratorTrait> Iterator for TrieIterator<T> {
type Item = u32;

fn next(&mut self) -> Option<u32> {
while self.level < MAX_LEVEL {
let child = (self.item >> (2 * self.level)) & 3;
let nibble = self.nibbles[self.level] >> child;
let mut level = if self.item == 0 {
self.nibbles[MAX_LEVEL - 1] = self.trie.get(MAX_LEVEL - 1);
MAX_LEVEL - 1
} else {
(self.item.trailing_zeros() / 2) as usize
};
while level < MAX_LEVEL {
let child = (self.item >> (2 * level)) & 3;
let nibble = self.nibbles[level] >> child;
if nibble != 0 {
let delta = nibble.trailing_zeros();
if self.level == 0 {
if level == 0 {
let res = self.item + delta;
self.item = res + 1;
self.level = (self.item.trailing_zeros() / 2) as usize;
return Some(res);
}
self.item += delta << (2 * self.level);
self.trie.down(self.level, child + delta);
self.level -= 1;
self.nibbles[self.level] = self.trie.get(self.level);
self.item += delta << (2 * level);
self.trie.down(level, child + delta);
level -= 1;
self.nibbles[level] = self.trie.get(level);
} else {
self.item |= 3 << (self.level * 2);
self.item += 1 << (self.level * 2);
self.level = (self.item.trailing_zeros() / 2) as usize;
self.item |= 3 << (level * 2);
self.item += 1 << (level * 2);
level = (self.item.trailing_zeros() / 2) as usize;
}
}
None
Expand Down Expand Up @@ -413,13 +431,13 @@ mod tests {
use crate::{Intersection, Layout, QuarternaryTrie, TrieIterator, TrieTraversal};

#[test]
fn test_bpt() {
fn test_trie() {
let values = vec![3, 6, 7, 10];
let trie = QuarternaryTrie::new(&values, Layout::VanEmdeBoas);
let mut trie = QuarternaryTrie::new(&values, Layout::VanEmdeBoas);
assert_eq!(trie.collect(), values);

let values: Vec<_> = (1..63).collect();
let trie = QuarternaryTrie::new(&values, Layout::VanEmdeBoas);
let mut trie = QuarternaryTrie::new(&values, Layout::VanEmdeBoas);
assert_eq!(trie.collect(), values);
}

Expand All @@ -432,7 +450,7 @@ mod tests {
values.dedup();

let start = Instant::now();
let trie = QuarternaryTrie::new(&values, Layout::VanEmdeBoas);
let mut trie = QuarternaryTrie::new(&values, Layout::VanEmdeBoas);
println!("construction {:?}", start.elapsed() / values.len() as u32);

let start = Instant::now();
Expand All @@ -450,21 +468,21 @@ mod tests {
#[test]
fn test_van_emde_boas_layout() {
let values: Vec<_> = (0..64).collect();
let trie = QuarternaryTrie::new(&values, Layout::VanEmdeBoas);
let mut trie = QuarternaryTrie::new(&values, Layout::VanEmdeBoas);
assert_eq!(trie.collect(), values);
}

#[test]
fn test_intersection() {
let mut page_counts = [0, 0, 0];
for _ in 0..3 {
let mut values: Vec<_> = (0..100000)
let mut values: Vec<_> = (0..10000000)
.map(|_| thread_rng().gen_range(0..100000000))
.collect();
values.sort();
values.dedup();

let mut values2: Vec<_> = (0..100000000)
let mut values2: Vec<_> = (0..10000000)
.map(|_| thread_rng().gen_range(0..100000000))
.collect();
values2.sort();
Expand Down Expand Up @@ -495,7 +513,7 @@ mod tests {
let result: Vec<_> = iter.collect();
let count = trie.page_count();
let count2 = trie2.page_count();
page_counts[i] += count.0 + count.1;
page_counts[i] += count.0 + count2.0;
println!(
"trie intersection {:?} {}",
start.elapsed() / values.len() as u32,
Expand Down
33 changes: 18 additions & 15 deletions crates/quaternary_trie/src/virtual_bitrank.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,22 @@

use std::cell::RefCell;

const BLOCK_BYTES: usize = 128;
type Word = u64;

const BLOCK_BYTES: usize = 64;
const BLOCK_BITS: usize = BLOCK_BYTES * 8;
const PAGE_BYTES: usize = 4096;
const PAGE_BITS: usize = PAGE_BYTES * 8;
const BLOCKS_PER_PAGE: usize = PAGE_BYTES / BLOCK_BYTES;
const WORD_BITS: usize = 64;
const WORD_BYTES: usize = WORD_BITS / 8;
const BLOCKS_PER_PAGE: usize = BLOCK_BYTES / 4;
const WORD_BITS: usize = WORD_BYTES * 8;
const WORD_BYTES: usize = std::mem::size_of::<Word>();
const WORDS_PER_BLOCK: usize = BLOCK_BYTES / WORD_BYTES;
const PAGE_BYTES: usize = BLOCKS_PER_PAGE * BLOCK_BYTES;
const PAGE_BITS: usize = PAGE_BYTES * 8;
const SUPER_PAGE_BITS: usize = 4096 * 8;

#[repr(C, align(128))]
#[derive(Default, Clone)]
struct Block {
words: [u64; WORDS_PER_BLOCK],
words: [Word; WORDS_PER_BLOCK],
}

#[derive(Default)]
Expand All @@ -77,10 +80,7 @@ impl VirtualBitRank {
}

pub(crate) fn reset_stats(&mut self) {
self.stats = vec![
RefCell::new(0);
((self.blocks.len() + BLOCKS_PER_PAGE - 1) / BLOCKS_PER_PAGE + 63) / 64
];
self.stats = vec![RefCell::new(0); self.blocks.len() * BLOCK_BITS / SUPER_PAGE_BITS + 1];
}

pub(crate) fn page_count(&self) -> (usize, usize) {
Expand All @@ -94,10 +94,13 @@ impl VirtualBitRank {
}

fn bit_to_block(&self, bit: usize) -> usize {
//let block = bit / BLOCK_BITS;
//let result2 = block + (block / (BLOCKS_PER_PAGE - 1)) + 1;
let result = self.block_mapping[bit / BLOCK_BITS] as usize;
/*if let Some(v) = self.stats.get(result / PAGE_BITS / 64) {
//assert_eq!(result2, result);
if let Some(v) = self.stats.get(result * BLOCK_BITS / SUPER_PAGE_BITS / 64) {
*v.borrow_mut() += 1 << (result % 64);
}*/
}
result
}

Expand Down Expand Up @@ -143,7 +146,7 @@ impl VirtualBitRank {
}
}

fn get_word_mut(&mut self, bit: usize) -> &mut u64 {
fn get_word_mut(&mut self, bit: usize) -> &mut Word {
let block = bit / BLOCK_BITS;
if block >= self.block_mapping.len() {
self.block_mapping.resize(block + 1, 0);
Expand All @@ -168,7 +171,7 @@ impl VirtualBitRank {
let bit_idx = nibble_idx * 4;
// clear all bits...
// *self.get_word(bit_idx) &= !(15 << (bit_idx & (WORD_BITS - 1)));
*self.get_word_mut(bit_idx) |= (nibble_value as u64) << (bit_idx & (WORD_BITS - 1));
*self.get_word_mut(bit_idx) |= (nibble_value as Word) << (bit_idx & (WORD_BITS - 1));
}

pub(crate) fn build(&mut self) {
Expand Down
Loading