From 0163a0901c578fd3cc8e48bace23e7271c89c7c6 Mon Sep 17 00:00:00 2001 From: Jure Bajic Date: Mon, 11 Aug 2025 16:33:01 +0200 Subject: [PATCH 1/3] Enable filter node in optimizer rule --- .../Rule/OptimizerRuleVectorIndex.cpp | 53 +++++++++++++++++-- 1 file changed, 48 insertions(+), 5 deletions(-) diff --git a/arangod/Aql/Optimizer/Rule/OptimizerRuleVectorIndex.cpp b/arangod/Aql/Optimizer/Rule/OptimizerRuleVectorIndex.cpp index d110c4aef6cf..b5bec7b49105 100644 --- a/arangod/Aql/Optimizer/Rule/OptimizerRuleVectorIndex.cpp +++ b/arangod/Aql/Optimizer/Rule/OptimizerRuleVectorIndex.cpp @@ -32,6 +32,7 @@ #include "Aql/ExecutionNode/CalculationNode.h" #include "Aql/ExecutionNode/MaterializeRocksDBNode.h" #include "Aql/ExecutionNode/LimitNode.h" +#include "Aql/ExecutionNode/FilterNode.h" #include "Aql/ExecutionNode/SortNode.h" #include "Aql/Optimizer.h" #include "Aql/OptimizerRules.h" @@ -46,7 +47,7 @@ using namespace arangodb; using namespace arangodb::aql; using EN = arangodb::aql::ExecutionNode; -#define LOG_RULE_ENABLED false +#define LOG_RULE_ENABLED true #define LOG_RULE_IF(cond) LOG_DEVEL_IF((LOG_RULE_ENABLED) && (cond)) #define LOG_RULE LOG_RULE_IF(true) @@ -256,11 +257,27 @@ void arangodb::aql::useVectorIndexRule(Optimizer* opt, // check that enumerateColNode has both sort and limit auto* currentNode = enumerateCollectionNode->getFirstParent(); - // skip over some calculation nodes - while (currentNode != nullptr && - currentNode->getType() == EN::CALCULATION) { + const auto skipOverCalculationNodes = [¤tNode] { + while (currentNode != nullptr && + currentNode->getType() == EN::CALCULATION) { + currentNode = currentNode->getFirstParent(); + } + }; + skipOverCalculationNodes(); + + // We tolerate post filtering + ExecutionNode* maybeFilterNode{nullptr}; + if (currentNode != nullptr && currentNode->getType() == EN::FILTER) { + // TODO move this in searchParameters detection + LOG_TOPIC("b15ac", WARN, Logger::AQL) + << "When using filtering with vector index it is recommended to " + "enable " + "iterative mode, e.g. APPROX_NEAR_L2(..., ..., {iterative: true, " + "maxNProbe: 10})"; + maybeFilterNode = currentNode; currentNode = currentNode->getFirstParent(); } + skipOverCalculationNodes(); if (currentNode == nullptr || currentNode->getType() != EN::SORT) { LOG_RULE << "DID NOT FIND SORT NODE, but instead " @@ -313,17 +330,41 @@ void arangodb::aql::useVectorIndexRule(Optimizer* opt, auto limit = limitNode->limit(); auto* inVariable = plan->getAst()->variables()->createTemporaryVariable(); + LOG_DEVEL << "SHOW PLAN"; + plan->show(); auto* queryPointCalculationNode = plan->createNode( plan.get(), plan->nextId(), std::make_unique(plan->getAst(), approximatedAttributeExpression), inVariable); + // Remove filter node + // Take expression from CalculationNode and move it into + // EnumerateNearVectorNode + Expression* filterExpression{nullptr}; + if (maybeFilterNode) { + // auto filterNode = ExecutionNode::castTo(maybeFilterNode); + auto* calculationNode = maybeFilterNode->getFirstParent(); + // This is wrong I need to check that the input variable of FilterNode + // is same as output of CalculationNode + TRI_ASSERT(calculationNode != nullptr && + calculationNode->getType() == EN::CALCULATION) + << "Can we ever have a FilterNode which does not depend on " + "CalculationNode?"; + auto const* calcNode = + ExecutionNode::castTo(calculationNode); + LOG_DEVEL << "Expression calcNode: " << calcNode->expression(); + filterExpression = calcNode->expression(); + + plan->unlinkNode(maybeFilterNode); + plan->unlinkNode(calculationNode); + } auto* enumerateNear = plan->createNode( plan.get(), plan->nextId(), inVariable, oldDocumentVariable, documentIdVariable, distanceVariable, limit, ascending, limitNode->offset(), std::move(searchParameters), - enumerateCollectionNode->collection(), index); + enumerateCollectionNode->collection(), index, filterExpression); auto* materializer = plan->createNode( @@ -342,6 +383,8 @@ void arangodb::aql::useVectorIndexRule(Optimizer* opt, plan->unlinkNode(distanceCalculationNode); modified = true; + LOG_DEVEL << "After optimizing"; + plan->show(); break; } } From 4881e80f8fae833e6e8cf4f1583201e173ae0d8a Mon Sep 17 00:00:00 2001 From: Jure Bajic Date: Mon, 11 Aug 2025 16:33:56 +0200 Subject: [PATCH 2/3] Pass filterExpression to EnumerateNearVectorNode --- arangod/Aql/ExecutionNode/EnumerateNearVectorNode.cpp | 9 +++++---- arangod/Aql/ExecutionNode/EnumerateNearVectorNode.h | 6 +++++- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/arangod/Aql/ExecutionNode/EnumerateNearVectorNode.cpp b/arangod/Aql/ExecutionNode/EnumerateNearVectorNode.cpp index 0b56ff79e262..8834068de929 100644 --- a/arangod/Aql/ExecutionNode/EnumerateNearVectorNode.cpp +++ b/arangod/Aql/ExecutionNode/EnumerateNearVectorNode.cpp @@ -56,7 +56,7 @@ EnumerateNearVectorNode::EnumerateNearVectorNode( Variable const* documentOutVariable, Variable const* distanceOutVariable, std::size_t limit, bool ascending, std::size_t offset, SearchParameters searchParameters, aql::Collection const* collection, - transaction::Methods::IndexHandle indexHandle) + transaction::Methods::IndexHandle indexHandle, Expression* filterExpression) : ExecutionNode(plan, id), CollectionAccessingNode(collection), _inVariable(inVariable), @@ -67,7 +67,8 @@ EnumerateNearVectorNode::EnumerateNearVectorNode( _ascending(ascending), _offset(offset), _searchParameters(std::move(searchParameters)), - _index(std::move(indexHandle)) {} + _index(std::move(indexHandle)), + _filterExpression(filterExpression) {} ExecutionNode::NodeType EnumerateNearVectorNode::getType() const { return ENUMERATE_NEAR_VECTORS; @@ -110,7 +111,7 @@ std::unique_ptr EnumerateNearVectorNode::createBlock( auto executorInfos = EnumerateNearVectorsExecutorInfos( inNmDocIdRegId, outDocumentRegId, outDistanceRegId, _index, engine.getQuery(), _collectionAccess.collection(), _limit, _offset, - _searchParameters); + _searchParameters, _filterExpression); auto registerInfos = createRegisterInfos(std::move(readableInputRegisters), std::move(writableOutputRegisters)); @@ -123,7 +124,7 @@ ExecutionNode* EnumerateNearVectorNode::clone(ExecutionPlan* plan, auto c = std::make_unique( plan, _id, _inVariable, _oldDocumentVariable, _documentOutVariable, _distanceOutVariable, _limit, _ascending, _offset, _searchParameters, - collection(), _index); + collection(), _index, _filterExpression); CollectionAccessingNode::cloneInto(*c); return cloneHelper(std::move(c), withDependencies); } diff --git a/arangod/Aql/ExecutionNode/EnumerateNearVectorNode.h b/arangod/Aql/ExecutionNode/EnumerateNearVectorNode.h index a3055d15ac53..fc3b2d97a96b 100644 --- a/arangod/Aql/ExecutionNode/EnumerateNearVectorNode.h +++ b/arangod/Aql/ExecutionNode/EnumerateNearVectorNode.h @@ -51,7 +51,8 @@ class EnumerateNearVectorNode : public ExecutionNode, std::size_t limit, bool ascending, std::size_t offset, SearchParameters searchParameters, aql::Collection const* collection, - transaction::Methods::IndexHandle indexHandle); + transaction::Methods::IndexHandle indexHandle, + Expression* filterExpression); EnumerateNearVectorNode(ExecutionPlan*, arangodb::velocypack::Slice base); @@ -109,5 +110,8 @@ class EnumerateNearVectorNode : public ExecutionNode, /// @brief selected index for vector search transaction::Methods::IndexHandle _index; + + // @brief if filter was set this is the filtering expression + Expression* _filterExpression; }; } // namespace arangodb::aql From e40d2b58ce2d7e3322181709c2d3e07a6b725ac4 Mon Sep 17 00:00:00 2001 From: Jure Bajic Date: Mon, 11 Aug 2025 16:34:09 +0200 Subject: [PATCH 3/3] Pass filterExpression to EnumerateNearVectorExecutor --- arangod/Aql/Executor/EnumerateNearVectorExecutor.cpp | 3 ++- arangod/Aql/Executor/EnumerateNearVectorExecutor.h | 7 +++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/arangod/Aql/Executor/EnumerateNearVectorExecutor.cpp b/arangod/Aql/Executor/EnumerateNearVectorExecutor.cpp index fd2c9f8e1bc8..51673e0bd843 100644 --- a/arangod/Aql/Executor/EnumerateNearVectorExecutor.cpp +++ b/arangod/Aql/Executor/EnumerateNearVectorExecutor.cpp @@ -102,7 +102,8 @@ void EnumerateNearVectorsExecutor::searchResults() { std::tie(_labels, _distances) = vectorIndex->readBatch( _inputRowConverted, _infos.searchParameters, mthds, &_trx, - _collection->getCollection(), 1, _infos.getNumberOfResults()); + _collection->getCollection(), 1, _infos.getNumberOfResults(), + _infos.filterExpression); _currentProcessedResultCount = 0; TRI_ASSERT(hasResults()); LOG_INTERNAL << "Results: " << _labels << " and distances: " << _distances; diff --git a/arangod/Aql/Executor/EnumerateNearVectorExecutor.h b/arangod/Aql/Executor/EnumerateNearVectorExecutor.h index 6df93146ec0e..edad730d8942 100644 --- a/arangod/Aql/Executor/EnumerateNearVectorExecutor.h +++ b/arangod/Aql/Executor/EnumerateNearVectorExecutor.h @@ -23,6 +23,7 @@ #pragma once +#include "Aql/Expression.h" #include "Aql/QueryContext.h" #include "Aql/SingleRowFetcher.h" #include "Aql/ExecutionBlock.h" @@ -47,7 +48,7 @@ struct EnumerateNearVectorsExecutorInfos { RegisterId inNmDocId, RegisterId outDocRegId, RegisterId outDistanceRegId, transaction::Methods::IndexHandle index, QueryContext& queryContext, aql::Collection const* collection, std::size_t topK, std::size_t offset, - SearchParameters searchParameters) + SearchParameters searchParameters, Expression* filterExpression) : inputReg(inNmDocId), outDocumentIdReg(outDocRegId), outDistancesReg(outDistanceRegId), @@ -56,7 +57,8 @@ struct EnumerateNearVectorsExecutorInfos { collection(collection), topK(topK), offset(offset), - searchParameters(searchParameters) {} + searchParameters(searchParameters), + filterExpression(filterExpression) {} EnumerateNearVectorsExecutorInfos() = delete; EnumerateNearVectorsExecutorInfos(EnumerateNearVectorsExecutorInfos&&) = @@ -81,6 +83,7 @@ struct EnumerateNearVectorsExecutorInfos { std::size_t topK; std::size_t offset; SearchParameters searchParameters; + Expression* filterExpression; }; class EnumerateNearVectorsExecutor {