Skip to content

feat: TopK optimizer&planner&executor #288

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 6, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
The topk algorithm is implemented to optimize the use of both order b…
…y and limit
  • Loading branch information
wszhdshys committed Aug 4, 2025
commit e5c67a59cd1b3d4ee96725044c2774a4df2fb30e
5 changes: 5 additions & 0 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,11 @@ impl<S: Storage> State<S> {
NormalizationRuleImpl::CombineFilter,
],
)
.batch(
"TopK".to_string(),
HepBatchStrategy::once_topdown(),
vec![NormalizationRuleImpl::TopK],
)
.batch(
"Expression Remapper".to_string(),
HepBatchStrategy::once_topdown(),
Expand Down
1 change: 1 addition & 0 deletions src/execution/dql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub(crate) mod seq_scan;
pub(crate) mod show_table;
pub(crate) mod show_view;
pub(crate) mod sort;
pub(crate) mod top_k;
pub(crate) mod union;
pub(crate) mod values;

Expand Down
26 changes: 18 additions & 8 deletions src/execution/dql/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use std::pin::Pin;
pub(crate) type BumpVec<'bump, T> = bumpalo::collections::Vec<'bump, T>;

#[derive(Clone)]
pub(crate) struct NullableVec<'a, T>(BumpVec<'a, Option<T>>);
pub(crate) struct NullableVec<'a, T>(pub(crate) BumpVec<'a, Option<T>>);

impl<'a, T> NullableVec<'a, T> {
#[inline]
Expand Down Expand Up @@ -49,17 +49,31 @@ impl<'a, T> NullableVec<'a, T> {
}
}

struct RemappingIterator<'a> {
pub struct RemappingIterator<'a> {
pos: usize,
tuples: NullableVec<'a, (usize, Tuple)>,
indices: BumpVec<'a, usize>,
}

impl RemappingIterator<'_> {
pub fn new<'a>(
pos: usize,
tuples: NullableVec<'a, (usize, Tuple)>,
indices: BumpVec<'a, usize>,
) -> RemappingIterator<'a> {
RemappingIterator {
pos,
tuples,
indices,
}
}
}

impl Iterator for RemappingIterator<'_> {
type Item = Tuple;

fn next(&mut self) -> Option<Self::Item> {
if self.pos > self.tuples.len() - 1 {
if self.pos > self.indices.len() - 1 {
return None;
}
let (_, tuple) = self.tuples.take(self.indices[self.pos]);
Expand Down Expand Up @@ -147,11 +161,7 @@ impl SortBy {
}
let indices = radix_sort(sort_keys, arena);

Ok(Box::new(RemappingIterator {
pos: 0,
tuples,
indices,
}))
Ok(Box::new(RemappingIterator::new(0, tuples, indices)))
}
SortBy::Fast => {
let fn_nulls_first = |nulls_first: bool| {
Expand Down
153 changes: 153 additions & 0 deletions src/execution/dql/top_k.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
use crate::errors::DatabaseError;
use crate::execution::dql::sort::{BumpVec, NullableVec, RemappingIterator};
use crate::execution::{build_read, Executor, ReadExecutor};
use crate::planner::operator::sort::SortField;
use crate::planner::operator::top_k::TopKOperator;
use crate::planner::LogicalPlan;
use crate::storage::table_codec::BumpBytes;
use crate::storage::{StatisticsMetaCache, TableCache, Transaction, ViewCache};
use crate::throw;
use crate::types::tuple::{Schema, Tuple};
use bumpalo::Bump;
use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::ops::Coroutine;
use std::ops::CoroutineState;
use std::pin::Pin;

fn top_sort<'a>(
arena: &'a Bump,
schema: &Schema,
sort_fields: &[SortField],
tuples: NullableVec<'a, (usize, Tuple)>,
limit: Option<usize>,
offset: Option<usize>,
) -> Result<Box<dyn Iterator<Item = Tuple> + 'a>, DatabaseError> {
let mut sort_keys = BumpVec::with_capacity_in(tuples.len(), arena);
for (i, tuple) in tuples.0.iter().enumerate() {
let mut full_key = BumpVec::new_in(arena);
for SortField {
expr,
nulls_first,
asc,
} in sort_fields
{
let mut key = BumpBytes::new_in(arena);
let tuple = tuple.as_ref().map(|(_, tuple)| tuple).unwrap();
expr.eval(Some((tuple, &**schema)))?
.memcomparable_encode(&mut key)?;
if *asc {
for byte in key.iter_mut() {
*byte ^= 0xFF;
}
}
key.push(if *nulls_first { u8::MIN } else { u8::MAX });
full_key.extend(key);
}
//full_key.extend_from_slice(&(i as u64).to_be_bytes());
sort_keys.push((i, full_key))
}

let keep_count = offset.unwrap_or(0) + limit.unwrap_or(sort_keys.len());

let mut heap: BinaryHeap<Reverse<(&[u8], usize)>> = BinaryHeap::with_capacity(keep_count);
for (i, key) in sort_keys.iter() {
let key = key.as_slice();
if heap.len() < keep_count {
heap.push(Reverse((key, *i)));
} else if let Some(&Reverse((min_key, _))) = heap.peek() {
if key > min_key {
heap.pop();
heap.push(Reverse((key, *i)));
}
}
}

let mut topk: Vec<(Vec<u8>, usize)> = heap
.into_iter()
.map(|Reverse((key, i))| (key.to_vec(), i))
.collect();
topk.sort_by(|(k1, i1), (k2, i2)| k1.cmp(k2).then_with(|| i1.cmp(i2).reverse()));
topk.reverse();

let mut bumped_indices =
BumpVec::with_capacity_in(topk.len().saturating_sub(offset.unwrap_or(0)), arena);
for (_, idx) in topk.into_iter().skip(offset.unwrap_or(0)) {
bumped_indices.push(idx);
}
Ok(Box::new(RemappingIterator::new(0, tuples, bumped_indices)))
}

pub struct TopK {
arena: Bump,
sort_fields: Vec<SortField>,
limit: Option<usize>,
offset: Option<usize>,
input: LogicalPlan,
}

impl From<(TopKOperator, LogicalPlan)> for TopK {
fn from(
(
TopKOperator {
sort_fields,
limit,
offset,
},
input,
): (TopKOperator, LogicalPlan),
) -> Self {
TopK {
arena: Default::default(),
sort_fields,
limit,
offset,
input,
}
}
}

impl<'a, T: Transaction + 'a> ReadExecutor<'a, T> for TopK {
fn execute(
self,
cache: (&'a TableCache, &'a ViewCache, &'a StatisticsMetaCache),
transaction: *mut T,
) -> Executor<'a> {
Box::new(
#[coroutine]
move || {
let TopK {
arena,
sort_fields,
limit,
offset,
mut input,
} = self;

let arena: *const Bump = &arena;

let mut tuples = NullableVec::new(unsafe { &*arena });
let schema = input.output_schema().clone();
let mut tuple_offset = 0;

let mut coroutine = build_read(input, cache, transaction);

while let CoroutineState::Yielded(tuple) = Pin::new(&mut coroutine).resume(()) {
tuples.put((tuple_offset, throw!(tuple)));
tuple_offset += 1;
}

for tuple in throw!(top_sort(
unsafe { &*arena },
&schema,
&sort_fields,
tuples,
limit,
offset
)) {
yield Ok(tuple)
}
},
)
}
}
6 changes: 6 additions & 0 deletions src/execution/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ use crate::execution::dql::seq_scan::SeqScan;
use crate::execution::dql::show_table::ShowTables;
use crate::execution::dql::show_view::ShowViews;
use crate::execution::dql::sort::Sort;
use crate::execution::dql::top_k::TopK;
use crate::execution::dql::union::Union;
use crate::execution::dql::values::Values;
use crate::planner::operator::join::JoinCondition;
Expand Down Expand Up @@ -133,6 +134,11 @@ pub fn build_read<'a, T: Transaction + 'a>(

Limit::from((op, input)).execute(cache, transaction)
}
Operator::TopK(op) => {
let input = childrens.pop_only();

TopK::from((op, input)).execute(cache, transaction)
}
Operator::Values(op) => Values::from(op).execute(cache, transaction),
Operator::ShowTable => ShowTables.execute(cache, transaction),
Operator::ShowView => ShowViews.execute(cache, transaction),
Expand Down
3 changes: 2 additions & 1 deletion src/optimizer/rule/normalization/column_pruning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ impl ColumnPruning {
| Operator::Join(_)
| Operator::Filter(_)
| Operator::Union(_)
| Operator::Except(_) => {
| Operator::Except(_)
| Operator::TopK(_) => {
let temp_columns = operator.referenced_columns(false);
// why?
let mut column_references = column_references;
Expand Down
10 changes: 10 additions & 0 deletions src/optimizer/rule/normalization/compilation_in_advance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ impl ExpressionRemapper {
TryReference::new(output_exprs).visit(&mut sort_field.expr)?;
}
}
Operator::TopK(op) => {
for sort_field in op.sort_fields.iter_mut() {
TryReference::new(output_exprs).visit(&mut sort_field.expr)?;
}
}
Operator::FunctionScan(op) => {
for expr in op.table_function.args.iter_mut() {
TryReference::new(output_exprs).visit(expr)?;
Expand Down Expand Up @@ -186,6 +191,11 @@ impl EvaluatorBind {
BindEvaluator.visit(&mut sort_field.expr)?;
}
}
Operator::TopK(op) => {
for sort_field in op.sort_fields.iter_mut() {
BindEvaluator.visit(&mut sort_field.expr)?;
}
}
Operator::FunctionScan(op) => {
for expr in op.table_function.args.iter_mut() {
BindEvaluator.visit(expr)?;
Expand Down
7 changes: 6 additions & 1 deletion src/optimizer/rule/normalization/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,22 @@ use crate::optimizer::rule::normalization::combine_operators::{
use crate::optimizer::rule::normalization::compilation_in_advance::{
EvaluatorBind, ExpressionRemapper,
};

use crate::optimizer::rule::normalization::pushdown_limit::{
LimitProjectTranspose, PushLimitIntoScan, PushLimitThroughJoin,
};
use crate::optimizer::rule::normalization::pushdown_predicates::PushPredicateIntoScan;
use crate::optimizer::rule::normalization::pushdown_predicates::PushPredicateThroughJoin;
use crate::optimizer::rule::normalization::simplification::ConstantCalculation;
use crate::optimizer::rule::normalization::simplification::SimplifyFilter;

use crate::optimizer::rule::normalization::top_k::TopK;
mod column_pruning;
mod combine_operators;
mod compilation_in_advance;
mod pushdown_limit;
mod pushdown_predicates;
mod simplification;
mod top_k;

#[derive(Debug, Copy, Clone)]
pub enum NormalizationRuleImpl {
Expand All @@ -46,6 +48,7 @@ pub enum NormalizationRuleImpl {
// CompilationInAdvance
ExpressionRemapper,
EvaluatorBind,
TopK,
}

impl MatchPattern for NormalizationRuleImpl {
Expand All @@ -64,6 +67,7 @@ impl MatchPattern for NormalizationRuleImpl {
NormalizationRuleImpl::ConstantCalculation => ConstantCalculation.pattern(),
NormalizationRuleImpl::ExpressionRemapper => ExpressionRemapper.pattern(),
NormalizationRuleImpl::EvaluatorBind => EvaluatorBind.pattern(),
NormalizationRuleImpl::TopK => TopK.pattern(),
}
}
}
Expand Down Expand Up @@ -94,6 +98,7 @@ impl NormalizationRule for NormalizationRuleImpl {
NormalizationRuleImpl::ConstantCalculation => ConstantCalculation.apply(node_id, graph),
NormalizationRuleImpl::ExpressionRemapper => ExpressionRemapper.apply(node_id, graph),
NormalizationRuleImpl::EvaluatorBind => EvaluatorBind.apply(node_id, graph),
NormalizationRuleImpl::TopK => TopK.apply(node_id, graph),
}
}
}
Expand Down
46 changes: 46 additions & 0 deletions src/optimizer/rule/normalization/top_k.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use crate::errors::DatabaseError;
use crate::optimizer::core::pattern::Pattern;
use crate::optimizer::core::pattern::PatternChildrenPredicate;
use crate::optimizer::core::rule::{MatchPattern, NormalizationRule};
use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId};
use crate::planner::operator::top_k::TopKOperator;
use crate::planner::operator::Operator;
use std::sync::LazyLock;

static TOP_K_RULE: LazyLock<Pattern> = LazyLock::new(|| Pattern {
predicate: |op| matches!(op, Operator::Limit(_)),
children: PatternChildrenPredicate::Predicate(vec![Pattern {
predicate: |op| matches!(op, Operator::Sort(_)),
children: PatternChildrenPredicate::None,
}]),
});

pub struct TopK;

impl MatchPattern for TopK {
fn pattern(&self) -> &Pattern {
&TOP_K_RULE
}
}

impl NormalizationRule for TopK {
fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), DatabaseError> {
if let Operator::Limit(op) = graph.operator(node_id) {
if let Some(child_id) = graph.eldest_child_at(node_id) {
if let Operator::Sort(child_op) = graph.operator(child_id) {
graph.replace_node(
node_id,
Operator::TopK(TopKOperator {
sort_fields: child_op.sort_fields.clone(),
limit: op.limit,
offset: op.offset,
}),
);
graph.remove_node(child_id, false);
}
}
}

Ok(())
}
}
2 changes: 1 addition & 1 deletion src/planner/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ impl LogicalPlan {
mut childrens_iter: ChildrensIter,
) -> SchemaOutput {
match operator {
Operator::Filter(_) | Operator::Sort(_) | Operator::Limit(_) => {
Operator::Filter(_) | Operator::Sort(_) | Operator::Limit(_) | Operator::TopK(_) => {
childrens_iter.next().unwrap().output_schema_direct()
}
Operator::Aggregate(op) => SchemaOutput::Schema(
Expand Down
Loading
Loading