Skip to content

RangeQuery and WindowedRangeQuery

mike bayer edited this page Sep 9, 2023 · 6 revisions

RangeQuery and WindowedRangeQuery

Updated for SQLAlchemy 2.0

The goal is to select through a very large number of rows that's too large to fetch all at once. Many DBAPIs pre-buffer result sets fully, and otherwise it can be difficult to keep an active cursor when using an option like psycopg2's server side cursors. The usual alternative, i.e. to page through the results using LIMIT/OFFSET, has the downside that the OFFSET will scan through all the previous rows each time in order to get to the requested row. To overcome this, there are two approaches to page through results without using OFFSET.

The simplest is to order the results by a particular unique column (usually primary key), then fetch chunks using LIMIT only, adding a WHERE clause that will ensure we only fetch rows greater than the last one we fetched). This will work for basically any database backend and is illustrated below for MySQL. The potential downside is that the database needs to sort the full set of remaining rows for each chunk, which may inefficient, even though both recipes presented here assume the sort column is indexed. However, the approach is very simple and can likely work for most ordinary use cases for a primary key column on a database that does not support window functions.

from __future__ import annotations

import random
import typing
from typing import Any
from typing import Iterator

from sqlalchemy import create_engine
from sqlalchemy import select
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import Session

if typing.TYPE_CHECKING:
    from sqlalchemy import FrozenResult
    from sqlalchemy import Result
    from sqlalchemy import Select
    from sqlalchemy import SQLColumnExpression


def windowed_query(
    session: Session,
    stmt: Select[Any],
    column: SQLColumnExpression[Any],
    windowsize: int,
) -> Iterator[Result[Any]]:
    """Given a Session and Select() object, organize and execute the statement
    such that it is invoked for ordered chunks of the total result.   yield
    out individual Result objects for each chunk.

    """

    # add the column we will window / sort on to the statement
    stmt = stmt.add_columns(column).order_by(column)
    last_id = None

    while True:
        subq = stmt

        # filter the statement on the previous "last id" we got, if any
        if last_id is not None:
            subq = subq.filter(column > last_id)

        # execute the query
        result: Result = session.execute(subq.limit(windowsize))

        # turn the Result into a FrozenResult that we can peek at the data
        # first, then spin off new Result objects
        frozen_result: FrozenResult = result.freeze()

        # get the raw data
        chunk = frozen_result().all()

        if not chunk:
            break

        # count how many columns we have and also get the "last id" fetched
        result_width = len(chunk[0])
        last_id = chunk[-1][-1]

        # get a new, unconsumed Result back from the FrozenResult
        yield_result: Result = frozen_result()

        # split off the last column (Result could use a slice method here)
        yield_result = yield_result.columns(*list(range(0, result_width - 1)))

        # yield it out
        yield yield_result


if __name__ == "__main__":

    class Base(DeclarativeBase):
        pass

    class Widget(Base):
        __tablename__ = "widget"
        id: Mapped[int] = mapped_column(primary_key=True)
        data: Mapped[int]

    e = create_engine("mysql://scott:tiger@localhost/test", echo=True)

    Base.metadata.drop_all(e)
    Base.metadata.create_all(e)

    # get some random list of unique values
    data = set([random.randint(1, 1000000) for i in range(10000)])

    s = Session(e)
    s.add_all([Widget(id=i, data=j) for i, j in enumerate(data, 1)])
    s.commit()

    q = select(Widget)

    for result in windowed_query(s, q, Widget.data, 1000):
        for widget in result.scalars():
            print("data:", widget.data)

A more elaborate way to do this, which allows that the table rows are fully sorted only once, is to use a window function in order to establish the exact range for each "chunk" ahead of time, and then to yield chunks as rows selected within that range. This works only on databases that support windows functions. This recipe has been on the SQLAlchemy Wiki for a long time but it's not clear how much advantage it has over the previous simpler approach; both approaches should be evaluated for efficiency for a given use case.

from __future__ import annotations

import random
import typing
from typing import Any
from typing import Iterator

from sqlalchemy import and_
from sqlalchemy import create_engine
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import Session

if typing.TYPE_CHECKING:
    from sqlalchemy import Result
    from sqlalchemy import Select
    from sqlalchemy import SQLColumnExpression


def column_windows(
    session: Session,
    stmt: Select[Any],
    column: SQLColumnExpression[Any],
    windowsize: int,
) -> Iterator[SQLColumnExpression[bool]]:
    """Return a series of WHERE clauses against
    a given column that break it into windows.

    Result is an iterable of WHERE clauses that are packaged with
    the individual ranges to select from.

    Requires a database that supports window functions.

    """

    rownum = func.row_number().over(order_by=column).label("rownum")

    subq = stmt.add_columns(rownum).subquery()
    subq_column = list(subq.columns)[-1]

    target_column = subq.corresponding_column(column)  # type: ignore  # will be fixed by #10326
    new_stmt = select(target_column)

    if windowsize > 1:
        new_stmt = new_stmt.filter(subq_column % windowsize == 1)

    """
    # the SQL statement here is intended to give us a list of ranges,
    # and looks like:

    SELECT anon_1.data
    FROM (SELECT widget.id AS id, widget.data AS data, row_number() OVER (ORDER BY widget.data) AS rownum
    FROM widget) AS anon_1
    WHERE anon_1.rownum %% %(rownum_1)s = %(param_1)s

    """

    intervals = list(session.scalars(new_stmt))

    # yield out WHERE clauses for each range
    while intervals:
        start = intervals.pop(0)
        if intervals:
            yield and_(column >= start, column < intervals[0])
        else:
            yield column >= start


def windowed_query(
    session: Session,
    stmt: Select[Any],
    column: SQLColumnExpression[Any],
    windowsize: int,
) -> Iterator[Result[Any]]:
    """Given a Session and Select() object, organize and execute the statement
    such that it is invoked for ordered chunks of the total result.   yield
    out individual Result objects for each chunk.

    """

    for whereclause in column_windows(session, stmt, column, windowsize):
        yield session.execute(stmt.filter(whereclause).order_by(column))


if __name__ == "__main__":

    class Base(DeclarativeBase):
        pass

    class Widget(Base):
        __tablename__ = "widget"
        id: Mapped[int] = mapped_column(primary_key=True)
        data: Mapped[int]

    e = create_engine("postgresql://scott:tiger@localhost/test", echo=True)

    Base.metadata.drop_all(e)
    Base.metadata.create_all(e)

    # get some random list of unique values
    data = set([random.randint(1, 1000000) for i in range(10000)])

    with Session(e) as s:
        s.add_all([Widget(id=i, data=j) for i, j in enumerate(data, 1)])
        s.commit()

    with Session(e) as s:
        q = select(Widget)

        for result in windowed_query(s, q, Widget.data, 1000):
            for widget in result.scalars():
                print("data:", widget.data)

Here's an example of the kind of SQL this emits:

-- first, it gets a list of ranges, with 1000 values in each bucket
SELECT anon_1.widget_data AS anon_1_widget_data
FROM (SELECT widget.data AS widget_data, row_number() OVER (ORDER BY widget.data) AS rownum
FROM widget) AS anon_1
WHERE rownum %% 1000=1

Col ('anon_1_widget_data',)
Row (4,)
Row (100107,)
Row (200004,)
Row (299526,)
Row (397664,)
Row (502373,)
Row (597853,)
Row (695306,)
Row (798335,)
Row (899000,)

-- then, the original query is run for each window, adding in
-- the extra range criterion

SELECT widget.id AS widget_id, widget.data AS widget_data
FROM widget
WHERE widget.data >= %(data_1)s AND widget.data < %(data_2)s ORDER BY widget.data
-- values: {'data_2': 100107, 'data_1': 4}
Col ('widget_id', 'widget_data')
Row (1, 4)
Row (64, 211)
Row (5415, 554)
Row (168, 568)
Row (203, 672)
Row (275, 914)
Row (343, 1124)
Row (344, 1132)
...


SELECT widget.id AS widget_id, widget.data AS widget_data
FROM widget
WHERE widget.data >= %(data_1)s AND widget.data < %(data_2)s ORDER BY widget.data
-- values: {'data_2': 200004, 'data_1': 100107}
Col ('widget_id', 'widget_data')
Row (544, 100107)
Row (549, 100120)
Row (583, 100225)
Row (3564, 100235)
Row (588, 100241)
Row (594, 100258)
Row (599, 100274)

...
Clone this wiki locally