Skip to content

FilteredQuery

mike bayer edited this page Apr 11, 2021 · 8 revisions

note: - this recipe is for SQLAlchemy 1.3 and earlier. For 1.4, an all new system of augmenting queries that is more capable is introduced, which works fully in all cases; see the examples for do_orm_execute() in conjunction with with_loader_criteria at https://docs.sqlalchemy.org/en/14/orm/session_events.html#do-orm-execute-global-criteria as well as the new versions of the examples below at https://docs.sqlalchemy.org/en/14/orm/examples.html#examples-session-orm-events .

Filtered Query (SQLAlchemy 1.3 and earlier only)

Illustrates how to augment Query so that it will apply a pre-defined criteria to all SELECT statements. The event that applies the criteria can also respond to options applied to the query via the query.execution_options() method, as well as from using custom MapperOption objects.

This recipe is new as of November 2018 as an update to the "PreFilteredQuery" and "GlobalFilter" examples which were very old and out of date. Both examples now use the before_compile query event to affect the Query object using the same techniques.

The two techniques have some caveats, but the good news is that since the basic technique for both has been unified, that gives us a much better idea of where we can probably add some more event hooks in SQLAlchemy to make them fully possible.

Caveats:

  • The recipe as given does not work with joined eager loading -use the "selectinload" loader strategy instead, which is much more efficient for collections in any case. The recipe includes an on-load test to ensure that no objects were loaded with the wrong flag in the event that joined eager loading was used. A new event hook would need to be provided in SQLAlchemy for this to be more straightforward.

  • The recipe currently does not work when using a bulk update or delete from the Query, e.g. the query.update() or query.delete() methods. A new event hook would need to be provided in SQLAlchemy for this to be straightforward, which will be present in 1.2.17, see https://github.com/sqlalchemy/sqlalchemy/issues/4461. Otherwise, the Query object can be subclassed and the update() and delete() methods overridden to provide the criteria.

  • Lazy load queries are cached in the case that a custom MapperOption is not used, so to ensure this caching does not take place (at the expense of some performance), we set bake_queries=False on relationships. If a particular filter is not expected to change across loads, this option may be omitted. (This situation is to be improved in some upcoming releases)

Example One - the "public" flag. Only load rows that have a simple flag "public=True". A simple execution_option() can be used to load all rows.

from sqlalchemy import event
from sqlalchemy import inspect
from sqlalchemy.orm.query import Query
from sqlalchemy import Column, Boolean


@event.listens_for(Query, "before_compile", retval=True)
def before_compile(query):
    """A query compilation rule that will add limiting criteria for every
    subclass of HasPrivate"""

    if query._execution_options.get("include_private", False):
        return query

    for ent in query.column_descriptions:
        entity = ent['entity']
        if entity is None:
            continue
        insp = inspect(ent['entity'])
        mapper = getattr(insp, 'mapper', None)
        if mapper and issubclass(mapper.class_, HasPrivate):
            query = query.enable_assertions(False).filter(
                ent['entity'].public == True)

    return query


class HasPrivate(object):
    """Mixin that identifies a class as having private entities"""

    public = Column(Boolean, nullable=False)


# the recipe has a few holes in it, unfortunately, including that as given,
# it doesn't impact the JOIN added by joined eager loading.   As a guard
# against this and other potential scenarios, we can check every object as
# its loaded and refuse to continue if there's a problem
@event.listens_for(HasPrivate, "load", propagate=True)
def load(obj, context):
    if not obj.public and not \
            context.query._execution_options.get("include_private", False):
        raise TypeError(
            "private object %s was loaded, did you use "
            "joined eager loading?" % obj)

if __name__ == '__main__':

    from sqlalchemy import Integer, Column, String, ForeignKey, Boolean
    from sqlalchemy import create_engine
    from sqlalchemy.orm import relationship, sessionmaker
    from sqlalchemy.ext.declarative import declarative_base

    Base = declarative_base()

    class User(HasPrivate, Base):
        __tablename__ = 'user'

        id = Column(Integer, primary_key=True)
        name = Column(String)
        addresses = relationship("Address", back_populates="user", bake_queries=False)

    class Address(HasPrivate, Base):
        __tablename__ = 'address'

        id = Column(Integer, primary_key=True)
        email = Column(String)
        user_id = Column(Integer, ForeignKey('user.id'))

        user = relationship("User", back_populates="addresses", bake_queries=False)

    engine = create_engine('sqlite://', echo=True)
    Base.metadata.create_all(engine)

    Session = sessionmaker(bind=engine)

    sess = Session()

    sess.add_all([
        User(name='u1', public=True,
             addresses=[
                 Address(email='u1a1', public=True),
                 Address(email='u1a2', public=True)]),
        User(name='u2', public=True,
             addresses=[
                 Address(email='u2a1', public=False),
                 Address(email='u2a2', public=True)]),
        User(name='u3', public=False,
             addresses=[
                 Address(email='u3a1', public=False),
                 Address(email='u3a2', public=False)]),
        User(name='u4', public=False,
             addresses=[
                 Address(email='u4a1', public=False),
                 Address(email='u4a2', public=True)]),
        User(name='u5', public=True,
             addresses=[
                 Address(email='u5a1', public=True),
                 Address(email='u5a2', public=False)])
    ])

    sess.commit()

    # now querying Address or User objects only gives us the public ones
    for u1 in sess.query(User):
        assert u1.public

        # the addresses collection will also be "public only", which works
        # for all relationship loaders **except** for joinedload()
        for address in u1.addresses:
            assert address.public

    # works for columns too
    cols = sess.query(User.id, Address.id).join(User.addresses).\
        order_by(User.id, Address.id).all()
    assert cols == [(1, 1), (1, 2), (2, 4), (5, 9)]

    cols = sess.query(User.id, Address.id).join(User.addresses).\
        order_by(User.id, Address.id).execution_options(include_private=True).\
        all()
    assert cols == [
        (1, 1), (1, 2), (2, 3), (2, 4), (3, 5),
        (3, 6), (4, 7), (4, 8), (5, 9), (5, 10)]

    # count all public addresses
    assert sess.query(Address).count() == 5

    # count all addresses public and private
    assert sess.query(Address).execution_options(include_private=True).\
        count() == 10

    # load an Address that is public, but its parent User is private
    a1 = sess.query(Address).filter_by(email='u4a2').first()

    # assuming the User isn't already in the Session, it returns None
    assert a1.user is None

    # however, if that user is present in the session, then a many-to-one
    # does a simple get() and it will be present
    sess.expire(a1, ['user'])
    u1 = sess.query(User).filter_by(name='u4').\
        execution_options(include_private=True).first()
    assert a1.user is u1

Example Two - temporal range. A custom MapperOption can be added to the query that specifies some custom criteria that should be applied to selected entities in the query. This was previously the "GlobalFilter" example, here simplified to use the same before_compile() event as above. Note that by using a custom MapperOption, this also side-steps the lazyload query caching issue for now.

from sqlalchemy import event
from sqlalchemy import inspect
from sqlalchemy.orm.query import Query
from sqlalchemy.orm.interfaces import MapperOption
from sqlalchemy import Column, DateTime
import datetime


@event.listens_for(Query, "before_compile", retval=True)
def before_compile(query):
    """A query compilation rule that will add limiting criteria for every
    when the TemporalOption is used"""

    for opt in query._with_options:
        if isinstance(opt, TemporalOption):
            lower, upper = opt.range_lower, opt.range_upper

            # tell Query to definitely repopulate everything.
            # this also clears "sticky" MapperOption objects that may
            # be associated with existing objects
            query = query.populate_existing()

            break
    else:
        # no criteria located, do nothing
        return query

    for ent in query.column_descriptions:
        entity = ent['entity']
        if entity is None:
            continue
        insp = inspect(ent['entity'])
        mapper = getattr(insp, 'mapper', None)
        if mapper and issubclass(mapper.class_, HasTemporal):
            query = query.enable_assertions(False).filter(
                ent['entity'].timestamp.between(lower, upper))

    return query


class HasTemporal(object):
    """Mixin that identifies a class as having a timestamp column"""

    timestamp = Column(
        DateTime, default=datetime.datetime.utcnow, nullable=False)


class TemporalOption(MapperOption):
    """A MapperOption that specifes a range of dates to apply to a query."""
    propagate_to_loaders = True

    def __init__(self, range_lower, range_upper):
        self.range_lower = range_lower
        self.range_upper = range_upper


if __name__ == '__main__':

    from sqlalchemy import Integer, Column, ForeignKey
    from sqlalchemy import create_engine
    from sqlalchemy.orm import relationship, sessionmaker, selectinload
    from sqlalchemy.ext.declarative import declarative_base

    Base = declarative_base()

    class Parent(HasTemporal, Base):
        __tablename__ = 'parent'
        id = Column(Integer, primary_key=True)
        children = relationship("Child", bake_queries=False)

    class Child(HasTemporal, Base):
        __tablename__ = 'child'
        id = Column(Integer, primary_key=True)
        parent_id = Column(Integer, ForeignKey('parent.id'), nullable=False)

    engine = create_engine('sqlite://', echo=True)
    Base.metadata.create_all(engine)

    Session = sessionmaker(bind=engine)

    sess = Session()

    c1, c2, c3, c4, c5 = [
        Child(timestamp=datetime.datetime(2009, 10, 15, 12, 00, 00)),
        Child(timestamp=datetime.datetime(2009, 10, 17, 12, 00, 00)),
        Child(timestamp=datetime.datetime(2009, 10, 20, 12, 00, 00)),
        Child(timestamp=datetime.datetime(2009, 10, 12, 12, 00, 00)),
        Child(timestamp=datetime.datetime(2009, 10, 17, 12, 00, 00)),
    ]

    p1 = Parent(
        timestamp=datetime.datetime(2009, 10, 15, 12, 00, 00),
        children=[c1, c2, c3]
    )
    p2 = Parent(
        timestamp=datetime.datetime(2009, 10, 17, 12, 00, 00),
        children=[c4, c5]
    )

    sess.add_all([p1, p2])
    sess.commit()

    parents = sess.query(Parent).\
        options(
            TemporalOption(
                datetime.datetime(2009, 10, 16, 12, 00, 00),
                datetime.datetime(2009, 10, 18, 12, 00, 00))
    ).all()

    assert parents[0] == p2
    assert parents[0].children == [c5]

    sess.expire_all()

    # try it with eager load
    parents = sess.query(Parent).\
        options(
            TemporalOption(
                datetime.datetime(2009, 10, 16, 12, 00, 00),
                datetime.datetime(2009, 10, 18, 12, 00, 00))
    ).options(selectinload(Parent.children)).all()

    assert parents[0] == p2
    assert parents[0].children == [c5]

    sess.expire_all()

    parents = sess.query(Parent).\
        options(
            TemporalOption(
                datetime.datetime(2009, 10, 15, 11, 00, 00),
                datetime.datetime(2009, 10, 18, 12, 00, 00))
    ).\
        join(Parent.children).filter(Child.id == 2).\
        all()

    assert parents[0] == p1
    assert parents[0].children == [c1, c2]
Clone this wiki locally