Skip to content

TrackQueryInfoOnOrmObjects

mike bayer edited this page Jun 28, 2023 · 2 revisions

Tracking Query Information on ORM objects as they are loaded

For this use case we use the InstanceEvents event hooks to intercept newly loaded / refreshed / expired ORM objects, where we can then associate information about the query.

In the example below, extraction of the "for update" state is illustrated, which can indicate if the row for an object was pessimistically locked (e.g. using SELECT..FOR UPDATE).

from __future__ import annotations

from typing import cast
from typing import NamedTuple
from typing import TYPE_CHECKING

from sqlalchemy import create_engine
from sqlalchemy import event
from sqlalchemy import inspect
from sqlalchemy import select
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import Session

if TYPE_CHECKING:
    from sqlalchemy.orm import InstanceState


class Base(DeclarativeBase):
    pass


class LockInfo(NamedTuple):
    locked: bool
    read: bool


class LogLockingInfo:
    """mixin which defines a 'locking_info' attribute that exposes
    ForUpdate information.

    For typing support inside of the locking_info() descriptor, this class can
    extend from Base or can extend from
    sqlalchemy.inspection.Inspectable[sqlalchemy.orm.InstanceState]; this
    is omitted above to avoid the impression that this is required at runtime.

    """

    @property
    def locking_info(self) -> LockInfo:
        for_update = cast(
            "InstanceState[LogLockingInfo]", inspect(self)
        ).info.get("locking_info", None)
        return LockInfo(
            locked=for_update is not None,
            read=for_update is not None and for_update.read,
        )


@event.listens_for(LogLockingInfo, "load", raw=True, propagate=True)
@event.listens_for(LogLockingInfo, "refresh", raw=True, propagate=True)
def _save_locking_info(state, context):
    if context.query.is_select:
        for_update_arg = context.query._for_update_arg
    else:
        for_update_arg = None

    state.info["locking_info"] = for_update_arg


@event.listens_for(LogLockingInfo, "expire", raw=True, propagate=True)
def _clear_locking_info(state, attrs):
    state.info.pop("locking_info")


if __name__ == "__main__":

    class A(LogLockingInfo, Base):
        __tablename__ = "a"

        id: Mapped[int] = mapped_column(primary_key=True)
        data: Mapped[str]

    e = create_engine("postgresql://scott:tiger@localhost/test", echo=True)
    Base.metadata.drop_all(e)
    Base.metadata.create_all(e)

    s = Session(e)

    s.add_all([A(id=1, data="a1"), A(id=2, data="a2")])
    s.commit()
    s.close()

    a1 = s.scalars(select(A).filter(A.id == 1)).one()
    a2 = s.scalars(select(A).with_for_update().filter(A.id == 2)).one()

    reload_a1 = s.scalars(select(A).filter(A.id == 1)).one()

    print(a1.locking_info)
    print(a2.locking_info)

    s.expire_all()

    print(a1.locking_info)
    print(a2.locking_info)
Clone this wiki locally