-
-
Notifications
You must be signed in to change notification settings - Fork 1.5k
FilteredQuery
note: - this recipe is for SQLAlchemy 1.3 and earlier. For 1.4, an all new system of augmenting queries that
is more capable is introduced, which works fully in all cases; see the examples for do_orm_execute()
in conjunction with with_loader_criteria
at https://docs.sqlalchemy.org/en/14/orm/session_events.html#do-orm-execute-global-criteria as well as the new versions of the examples below at https://docs.sqlalchemy.org/en/14/orm/examples.html#examples-session-orm-events .
Illustrates how to augment Query so that it will apply a pre-defined criteria
to all SELECT statements. The event that applies the criteria can also respond to options applied to the query via the query.execution_options() method, as well as from using custom MapperOption
objects.
This recipe is new as of November 2018 as an update to the "PreFilteredQuery" and "GlobalFilter" examples which were very old and out of date. Both examples now use the before_compile query event to affect the Query object using the same techniques.
The two techniques have some caveats, but the good news is that since the basic technique for both has been unified, that gives us a much better idea of where we can probably add some more event hooks in SQLAlchemy to make them fully possible.
Caveats:
-
The recipe as given does not work with joined eager loading -use the "selectinload" loader strategy instead, which is much more efficient for collections in any case. The recipe includes an on-load test to ensure that no objects were loaded with the wrong flag in the event that joined eager loading was used. A new event hook would need to be provided in SQLAlchemy for this to be more straightforward.
-
The recipe currently does not work when using a bulk update or delete from the Query, e.g. the query.update() or query.delete() methods. A new event hook would need to be provided in SQLAlchemy for this to be straightforward, which will be present in 1.2.17, see https://github.com/sqlalchemy/sqlalchemy/issues/4461. Otherwise, the Query object can be subclassed and the update() and delete() methods overridden to provide the criteria.
-
Lazy load queries are cached in the case that a custom MapperOption is not used, so to ensure this caching does not take place (at the expense of some performance), we set
bake_queries=False
on relationships. If a particular filter is not expected to change across loads, this option may be omitted. (This situation is to be improved in some upcoming releases)
Example One - the "public" flag. Only load rows that have a simple flag "public=True". A simple execution_option() can be used to load all rows.
from sqlalchemy import event
from sqlalchemy import inspect
from sqlalchemy.orm.query import Query
from sqlalchemy import Column, Boolean
@event.listens_for(Query, "before_compile", retval=True)
def before_compile(query):
"""A query compilation rule that will add limiting criteria for every
subclass of HasPrivate"""
if query._execution_options.get("include_private", False):
return query
for ent in query.column_descriptions:
entity = ent['entity']
if entity is None:
continue
insp = inspect(ent['entity'])
mapper = getattr(insp, 'mapper', None)
if mapper and issubclass(mapper.class_, HasPrivate):
query = query.enable_assertions(False).filter(
ent['entity'].public == True)
return query
class HasPrivate(object):
"""Mixin that identifies a class as having private entities"""
public = Column(Boolean, nullable=False)
# the recipe has a few holes in it, unfortunately, including that as given,
# it doesn't impact the JOIN added by joined eager loading. As a guard
# against this and other potential scenarios, we can check every object as
# its loaded and refuse to continue if there's a problem
@event.listens_for(HasPrivate, "load", propagate=True)
def load(obj, context):
if not obj.public and not \
context.query._execution_options.get("include_private", False):
raise TypeError(
"private object %s was loaded, did you use "
"joined eager loading?" % obj)
if __name__ == '__main__':
from sqlalchemy import Integer, Column, String, ForeignKey, Boolean
from sqlalchemy import create_engine
from sqlalchemy.orm import relationship, sessionmaker
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
class User(HasPrivate, Base):
__tablename__ = 'user'
id = Column(Integer, primary_key=True)
name = Column(String)
addresses = relationship("Address", back_populates="user", bake_queries=False)
class Address(HasPrivate, Base):
__tablename__ = 'address'
id = Column(Integer, primary_key=True)
email = Column(String)
user_id = Column(Integer, ForeignKey('user.id'))
user = relationship("User", back_populates="addresses", bake_queries=False)
engine = create_engine('sqlite://', echo=True)
Base.metadata.create_all(engine)
Session = sessionmaker(bind=engine)
sess = Session()
sess.add_all([
User(name='u1', public=True,
addresses=[
Address(email='u1a1', public=True),
Address(email='u1a2', public=True)]),
User(name='u2', public=True,
addresses=[
Address(email='u2a1', public=False),
Address(email='u2a2', public=True)]),
User(name='u3', public=False,
addresses=[
Address(email='u3a1', public=False),
Address(email='u3a2', public=False)]),
User(name='u4', public=False,
addresses=[
Address(email='u4a1', public=False),
Address(email='u4a2', public=True)]),
User(name='u5', public=True,
addresses=[
Address(email='u5a1', public=True),
Address(email='u5a2', public=False)])
])
sess.commit()
# now querying Address or User objects only gives us the public ones
for u1 in sess.query(User):
assert u1.public
# the addresses collection will also be "public only", which works
# for all relationship loaders **except** for joinedload()
for address in u1.addresses:
assert address.public
# works for columns too
cols = sess.query(User.id, Address.id).join(User.addresses).\
order_by(User.id, Address.id).all()
assert cols == [(1, 1), (1, 2), (2, 4), (5, 9)]
cols = sess.query(User.id, Address.id).join(User.addresses).\
order_by(User.id, Address.id).execution_options(include_private=True).\
all()
assert cols == [
(1, 1), (1, 2), (2, 3), (2, 4), (3, 5),
(3, 6), (4, 7), (4, 8), (5, 9), (5, 10)]
# count all public addresses
assert sess.query(Address).count() == 5
# count all addresses public and private
assert sess.query(Address).execution_options(include_private=True).\
count() == 10
# load an Address that is public, but its parent User is private
a1 = sess.query(Address).filter_by(email='u4a2').first()
# assuming the User isn't already in the Session, it returns None
assert a1.user is None
# however, if that user is present in the session, then a many-to-one
# does a simple get() and it will be present
sess.expire(a1, ['user'])
u1 = sess.query(User).filter_by(name='u4').\
execution_options(include_private=True).first()
assert a1.user is u1
Example Two - temporal range. A custom MapperOption
can be added to the query that specifies some custom criteria that should be applied to selected entities in the query. This was previously the "GlobalFilter" example, here simplified to use the same before_compile()
event as above. Note that by using a custom MapperOption
, this also side-steps the lazyload query caching issue for now.
from sqlalchemy import event
from sqlalchemy import inspect
from sqlalchemy.orm.query import Query
from sqlalchemy.orm.interfaces import MapperOption
from sqlalchemy import Column, DateTime
import datetime
@event.listens_for(Query, "before_compile", retval=True)
def before_compile(query):
"""A query compilation rule that will add limiting criteria for every
when the TemporalOption is used"""
for opt in query._with_options:
if isinstance(opt, TemporalOption):
lower, upper = opt.range_lower, opt.range_upper
# tell Query to definitely repopulate everything.
# this also clears "sticky" MapperOption objects that may
# be associated with existing objects
query = query.populate_existing()
break
else:
# no criteria located, do nothing
return query
for ent in query.column_descriptions:
entity = ent['entity']
if entity is None:
continue
insp = inspect(ent['entity'])
mapper = getattr(insp, 'mapper', None)
if mapper and issubclass(mapper.class_, HasTemporal):
query = query.enable_assertions(False).filter(
ent['entity'].timestamp.between(lower, upper))
return query
class HasTemporal(object):
"""Mixin that identifies a class as having a timestamp column"""
timestamp = Column(
DateTime, default=datetime.datetime.utcnow, nullable=False)
class TemporalOption(MapperOption):
"""A MapperOption that specifes a range of dates to apply to a query."""
propagate_to_loaders = True
def __init__(self, range_lower, range_upper):
self.range_lower = range_lower
self.range_upper = range_upper
if __name__ == '__main__':
from sqlalchemy import Integer, Column, ForeignKey
from sqlalchemy import create_engine
from sqlalchemy.orm import relationship, sessionmaker, selectinload
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
class Parent(HasTemporal, Base):
__tablename__ = 'parent'
id = Column(Integer, primary_key=True)
children = relationship("Child", bake_queries=False)
class Child(HasTemporal, Base):
__tablename__ = 'child'
id = Column(Integer, primary_key=True)
parent_id = Column(Integer, ForeignKey('parent.id'), nullable=False)
engine = create_engine('sqlite://', echo=True)
Base.metadata.create_all(engine)
Session = sessionmaker(bind=engine)
sess = Session()
c1, c2, c3, c4, c5 = [
Child(timestamp=datetime.datetime(2009, 10, 15, 12, 00, 00)),
Child(timestamp=datetime.datetime(2009, 10, 17, 12, 00, 00)),
Child(timestamp=datetime.datetime(2009, 10, 20, 12, 00, 00)),
Child(timestamp=datetime.datetime(2009, 10, 12, 12, 00, 00)),
Child(timestamp=datetime.datetime(2009, 10, 17, 12, 00, 00)),
]
p1 = Parent(
timestamp=datetime.datetime(2009, 10, 15, 12, 00, 00),
children=[c1, c2, c3]
)
p2 = Parent(
timestamp=datetime.datetime(2009, 10, 17, 12, 00, 00),
children=[c4, c5]
)
sess.add_all([p1, p2])
sess.commit()
parents = sess.query(Parent).\
options(
TemporalOption(
datetime.datetime(2009, 10, 16, 12, 00, 00),
datetime.datetime(2009, 10, 18, 12, 00, 00))
).all()
assert parents[0] == p2
assert parents[0].children == [c5]
sess.expire_all()
# try it with eager load
parents = sess.query(Parent).\
options(
TemporalOption(
datetime.datetime(2009, 10, 16, 12, 00, 00),
datetime.datetime(2009, 10, 18, 12, 00, 00))
).options(selectinload(Parent.children)).all()
assert parents[0] == p2
assert parents[0].children == [c5]
sess.expire_all()
parents = sess.query(Parent).\
options(
TemporalOption(
datetime.datetime(2009, 10, 15, 11, 00, 00),
datetime.datetime(2009, 10, 18, 12, 00, 00))
).\
join(Parent.children).filter(Child.id == 2).\
all()
assert parents[0] == p1
assert parents[0].children == [c1, c2]