-
-
Notifications
You must be signed in to change notification settings - Fork 1.5k
TrackQueryInfoOnOrmObjects
mike bayer edited this page Jun 28, 2023
·
2 revisions
For this use case we use the InstanceEvents event hooks to intercept newly loaded / refreshed / expired ORM objects, where we can then associate information about the query.
In the example below, extraction of the "for update" state is illustrated, which can indicate if the row for an object was pessimistically locked (e.g. using SELECT..FOR UPDATE).
from __future__ import annotations
from typing import cast
from typing import NamedTuple
from typing import TYPE_CHECKING
from sqlalchemy import create_engine
from sqlalchemy import event
from sqlalchemy import inspect
from sqlalchemy import select
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import Session
if TYPE_CHECKING:
from sqlalchemy.orm import InstanceState
class Base(DeclarativeBase):
pass
class LockInfo(NamedTuple):
locked: bool
read: bool
class LogLockingInfo:
"""mixin which defines a 'locking_info' attribute that exposes
ForUpdate information.
For typing support inside of the locking_info() descriptor, this class can
extend from Base or can extend from
sqlalchemy.inspection.Inspectable[sqlalchemy.orm.InstanceState]; this
is omitted above to avoid the impression that this is required at runtime.
"""
@property
def locking_info(self) -> LockInfo:
for_update = cast(
"InstanceState[LogLockingInfo]", inspect(self)
).info.get("locking_info", None)
return LockInfo(
locked=for_update is not None,
read=for_update is not None and for_update.read,
)
@event.listens_for(LogLockingInfo, "load", raw=True, propagate=True)
@event.listens_for(LogLockingInfo, "refresh", raw=True, propagate=True)
def _save_locking_info(state, context):
if context.query.is_select:
for_update_arg = context.query._for_update_arg
else:
for_update_arg = None
state.info["locking_info"] = for_update_arg
@event.listens_for(LogLockingInfo, "expire", raw=True, propagate=True)
def _clear_locking_info(state, attrs):
state.info.pop("locking_info")
if __name__ == "__main__":
class A(LogLockingInfo, Base):
__tablename__ = "a"
id: Mapped[int] = mapped_column(primary_key=True)
data: Mapped[str]
e = create_engine("postgresql://scott:tiger@localhost/test", echo=True)
Base.metadata.drop_all(e)
Base.metadata.create_all(e)
s = Session(e)
s.add_all([A(id=1, data="a1"), A(id=2, data="a2")])
s.commit()
s.close()
a1 = s.scalars(select(A).filter(A.id == 1)).one()
a2 = s.scalars(select(A).with_for_update().filter(A.id == 2)).one()
reload_a1 = s.scalars(select(A).filter(A.id == 1)).one()
print(a1.locking_info)
print(a2.locking_info)
s.expire_all()
print(a1.locking_info)
print(a2.locking_info)