Skip to content

INTPYTHON-635 Improve join performance by pushing simple filter conditions to $lookup #356

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions django_mongodb_backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,16 @@
from django.db.models.expressions import Case, Col, OrderBy, Ref, Value, When
from django.db.models.functions.comparison import Coalesce
from django.db.models.functions.math import Power
from django.db.models.lookups import IsNull
from django.db.models.lookups import IsNull, Lookup
from django.db.models.sql import compiler
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, MULTI, SINGLE
from django.db.models.sql.datastructures import BaseTable
from django.db.models.sql.where import AND, WhereNode
from django.utils.functional import cached_property
from pymongo import ASCENDING, DESCENDING

from .query import MongoQuery, wrap_database_errors
from .query_utils import is_direct_value


class SQLCompiler(compiler.SQLCompiler):
Expand Down Expand Up @@ -548,10 +550,26 @@ def get_combinator_queries(self):

def get_lookup_pipeline(self):
result = []
# To improve join performance, push conditions (filters) from the
# WHERE ($match) clause to the JOIN ($lookup) clause.
where = self.get_where()
pushed_filters = defaultdict(list)
for expr in where.children if where and where.connector == AND else ():
# Push only basic lookups; no subqueries or complex conditions.
# To avoid duplication across subqueries, only use the LHS target
# table.
if (
isinstance(expr, Lookup)
and isinstance(expr.lhs, Col)
and (is_direct_value(expr.rhs) or isinstance(expr.rhs, Value | Col))
):
pushed_filters[expr.lhs.alias].append(expr)
for alias in tuple(self.query.alias_map):
if not self.query.alias_refcount[alias] or self.collection_name == alias:
continue
result += self.query.alias_map[alias].as_mql(self, self.connection)
result += self.query.alias_map[alias].as_mql(
self, self.connection, WhereNode(pushed_filters[alias], connector=AND)
)
return result

def _get_aggregate_expressions(self, expr):
Expand Down
78 changes: 54 additions & 24 deletions django_mongodb_backend/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,25 +123,21 @@ def extra_where(self, compiler, connection): # noqa: ARG001
raise NotSupportedError("QuerySet.extra() is not supported on MongoDB.")


def join(self, compiler, connection):
lookup_pipeline = []
lhs_fields = []
rhs_fields = []
# Add a join condition for each pair of joining fields.
def join(self, compiler, connection, pushed_filter_expression=None):
"""
Generate a MongoDB $lookup stage for a join.
`pushed_filter_expression` is a Where expression involving fields from the
joined collection which can be pushed from the WHERE ($match) clause to the
JOIN ($lookup) clause to improve performance.
"""
parent_template = "parent__field__"
for lhs, rhs in self.join_fields:
lhs, rhs = connection.ops.prepare_join_on_clause(
self.parent_alias, lhs, compiler.collection_name, rhs
)
lhs_fields.append(lhs.as_mql(compiler, connection))
# In the lookup stage, the reference to this column doesn't include
# the collection name.
rhs_fields.append(rhs.as_mql(compiler, connection))
# Handle any join conditions besides matching field pairs.
extra = self.join_field.get_extra_restriction(self.table_alias, self.parent_alias)
if extra:

def _get_reroot_replacements(expression):
if not expression:
return None
columns = []
for expr in extra.leaves():
for expr in expression.leaves():
# Determine whether the column needs to be transformed or rerouted
# as part of the subquery.
for hand_side in ["lhs", "rhs"]:
Expand All @@ -151,27 +147,61 @@ def join(self, compiler, connection):
# lhs_fields.
if hand_side_value.alias != self.table_alias:
pos = len(lhs_fields)
lhs_fields.append(expr.lhs.as_mql(compiler, connection))
lhs_fields.append(hand_side_value.as_mql(compiler, connection))
else:
pos = None
columns.append((hand_side_value, pos))
# Replace columns in the extra conditions with new column references
# based on their rerouted positions in the join pipeline.
replacements = {}
for col, parent_pos in columns:
column_target = Col(compiler.collection_name, expr.output_field.__class__())
target = col.target.clone()
target.remote_field = col.target.remote_field
column_target = Col(compiler.collection_name, target)
if parent_pos is not None:
target_col = f"${parent_template}{parent_pos}"
column_target.target.db_column = target_col
column_target.target.set_attributes_from_name(target_col)
else:
column_target.target = col.target
replacements[col] = column_target
# Apply the transformed expressions in the extra condition.
extra_condition = [extra.replace_expressions(replacements).as_mql(compiler, connection)]
else:
extra_condition = []
return replacements

lookup_pipeline = []
lhs_fields = []
rhs_fields = []
# Add a join condition for each pair of joining fields.
for lhs, rhs in self.join_fields:
lhs, rhs = connection.ops.prepare_join_on_clause(
self.parent_alias, lhs, compiler.collection_name, rhs
)
lhs_fields.append(lhs.as_mql(compiler, connection))
# In the lookup stage, the reference to this column doesn't include the
# collection name.
rhs_fields.append(rhs.as_mql(compiler, connection))
# Handle any join conditions besides matching field pairs.
extra = self.join_field.get_extra_restriction(self.table_alias, self.parent_alias)
extra_conditions = []
if extra:
replacements = _get_reroot_replacements(extra)
extra_conditions.append(
extra.replace_expressions(replacements).as_mql(compiler, connection)
)
# pushed_filter_expression is a Where expression from the outer WHERE
# clause that involves fields from the joined (right-hand) table and
# possibly the outer (left-hand) table. If it can be safely evaluated
# within the $lookup pipeline (e.g., field comparisons like
# right.status = left.id), it is "pushed" into the join's $match stage to
# reduce the volume of joined documents. This only applies to INNER JOINs,
# as pushing filters into a LEFT JOIN can change the semantics of the
# result. LEFT JOINs may rely on null checks to detect missing RHS.
if pushed_filter_expression and self.join_type == INNER:
rerooted_replacement = _get_reroot_replacements(pushed_filter_expression)
extra_conditions.append(
pushed_filter_expression.replace_expressions(rerooted_replacement).as_mql(
compiler, connection
)
)
lookup_pipeline = [
{
"$lookup": {
Expand All @@ -197,7 +227,7 @@ def join(self, compiler, connection):
{"$eq": [f"$${parent_template}{i}", field]}
for i, field in enumerate(rhs_fields)
]
+ extra_condition
+ extra_conditions
}
}
}
Expand Down
6 changes: 6 additions & 0 deletions docs/source/releases/5.2.x.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,13 @@ Bug fixes
databases.
- :meth:`QuerySet.explain() <django.db.models.query.QuerySet.explain>` now
:ref:`returns a string that can be parsed as JSON <queryset-explain>`.

Performance improvements
------------------------

- Improved ``QuerySet`` performance by removing low limit on server-side chunking.
- Improved ``QuerySet`` join (``$lookup``) performance by pushing some simple
conditions from the ``WHERE`` (``$match``) clause to the ``$lookup`` stage.

5.2.0 beta 1
============
Expand Down
15 changes: 15 additions & 0 deletions tests/queries_/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,18 @@ class Meta:

def __str__(self):
return str(self.pk)


class Reader(models.Model):
name = models.CharField(max_length=20)

def __str__(self):
return self.name


class Library(models.Model):
name = models.CharField(max_length=20)
readers = models.ManyToManyField(Reader, related_name="libraries")

def __str__(self):
return self.name
Loading
Loading