Skip to content

Demo and fix broken pagination #7889

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 59 additions & 16 deletions rest_framework/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
Pagination serializers determine the structure of the output that should
be used for paginated responses.
"""
import operator
from base64 import b64decode, b64encode
from collections import OrderedDict, namedtuple
from functools import reduce
from urllib import parse

from django.core.paginator import InvalidPage
from django.core.paginator import Paginator as DjangoPaginator
from django.db.models.query import Q
from django.template import loader
from django.utils.encoding import force_str
from django.utils.translation import gettext_lazy as _
Expand All @@ -16,6 +19,7 @@
from rest_framework.exceptions import NotFound
from rest_framework.response import Response
from rest_framework.settings import api_settings
from rest_framework.utils import json
from rest_framework.utils.urls import remove_query_param, replace_query_param


Expand Down Expand Up @@ -624,17 +628,43 @@ def paginate_queryset(self, queryset, request, view=None):

# If we have a cursor with a fixed position then filter by that.
if current_position is not None:
order = self.ordering[0]
is_reversed = order.startswith('-')
order_attr = order.lstrip('-')
current_position_list = json.loads(current_position)

# Test for: (cursor reversed) XOR (queryset reversed)
if self.cursor.reverse != is_reversed:
kwargs = {order_attr + '__lt': current_position}
else:
kwargs = {order_attr + '__gt': current_position}
q_objects_equals = {}
q_objects_compare = {}

for order, position in zip(self.ordering, current_position_list):
is_reversed = order.startswith("-")
order_attr = order.lstrip("-")

q_objects_equals[order] = Q(**{order_attr: position})

# Test for: (cursor reversed) XOR (queryset reversed)
if self.cursor.reverse != is_reversed:
q_objects_compare[order] = Q(
**{(order_attr + "__lt"): position}
)
else:
q_objects_compare[order] = Q(
**{(order_attr + "__gt"): position}
)

filter_list = [q_objects_compare[self.ordering[0]]]

ordering = self.ordering

queryset = queryset.filter(**kwargs)
# starting with the second field
for i in range(len(ordering)):
# The first operands need to be equals
# the last operands need to be gt
equals = list(ordering[:i + 2])
greater_than_q = q_objects_compare[equals.pop()]
sub_filters = [q_objects_equals[e] for e in equals]
sub_filters.append(greater_than_q)
filter_list.append(reduce(operator.and_, sub_filters))

q_object = reduce(operator.or_, filter_list)
queryset = queryset.filter(q_object)

# If we have an offset cursor then offset the entire page by that amount.
# We also always fetch an extra item in order to determine if there is a
Expand Down Expand Up @@ -839,7 +869,14 @@ def get_ordering(self, request, queryset, view):
)

if isinstance(ordering, str):
return (ordering,)
ordering = (ordering,)

pk_name = queryset.model._meta.pk.name

# Always include a unique key to order by
if not {"-{}".format(pk_name), pk_name, "pk", "-pk"} & set(ordering):
ordering = tuple(ordering) + (pk_name,)

return tuple(ordering)

def decode_cursor(self, request):
Expand Down Expand Up @@ -884,12 +921,18 @@ def encode_cursor(self, cursor):
return replace_query_param(self.base_url, self.cursor_query_param, encoded)

def _get_position_from_instance(self, instance, ordering):
field_name = ordering[0].lstrip('-')
if isinstance(instance, dict):
attr = instance[field_name]
else:
attr = getattr(instance, field_name)
return str(attr)
fields = []

for o in ordering:
field_name = o.lstrip("-")
if isinstance(instance, dict):
attr = instance[field_name]
else:
attr = getattr(instance, field_name)

fields.append(str(attr))

return json.dumps(fields)

def get_paginated_response(self, data):
return Response(OrderedDict([
Expand Down
10 changes: 9 additions & 1 deletion tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class RESTFrameworkModel(models.Model):
"""

class Meta:
app_label = 'tests'
app_label = "tests"
abstract = True


Expand Down Expand Up @@ -119,3 +119,11 @@ class OneToOnePKSource(RESTFrameworkModel):
target = models.OneToOneField(
OneToOneTarget, primary_key=True,
related_name='required_source', on_delete=models.CASCADE)


class ExamplePaginationModel(models.Model):
# Don't use an auto field because we can't reset
# sequences and that's needed for this test
id = models.IntegerField(primary_key=True)
field = models.IntegerField()
timestamp = models.IntegerField()
197 changes: 197 additions & 0 deletions tests/test_cursor_pagination.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
import base64
import itertools
from base64 import b64encode
from urllib import parse

import pytest

from rest_framework import generics
from rest_framework.filters import OrderingFilter
from rest_framework.pagination import Cursor, CursorPagination
from rest_framework.permissions import AllowAny
from rest_framework.serializers import ModelSerializer
from rest_framework.test import APIRequestFactory

from .models import ExamplePaginationModel

factory = APIRequestFactory()


class SerializerCls(ModelSerializer):
class Meta:
model = ExamplePaginationModel
fields = "__all__"


def create_cursor(offset, reverse, position):
# Taken from rest_framework.pagination
cursor = Cursor(offset=offset, reverse=reverse, position=position)

tokens = {}
if cursor.offset != 0:
tokens["o"] = str(cursor.offset)
if cursor.reverse:
tokens["r"] = "1"
if cursor.position is not None:
tokens["p"] = cursor.position

querystring = parse.urlencode(tokens, doseq=True)
return b64encode(querystring.encode("ascii")).decode("ascii")


def decode_cursor(response):
links = {
'next': response.data.get('next'),
'prev': response.data.get('prev'),
}

cursors = {}

for rel, link in links.items():
if link:
# Don't hate my laziness - copied from an IPDB prompt
cursor_dict = dict(
parse.parse_qsl(
base64.decodebytes(
(parse.parse_qs(parse.urlparse(link).query)["cursor"][0]).encode()
)
)
)

offset = cursor_dict.get(b"o", 0)
if offset:
offset = int(offset)

reverse = cursor_dict.get(b"r", False)
if reverse:
reverse = int(reverse)

position = cursor_dict.get(b"p", None)

cursors[rel] = Cursor(
offset=offset,
reverse=reverse,
position=position,
)

return type(
"prev_next_stuct",
(object,),
{"next": cursors.get("next"), "prev": cursors.get("previous")},
)


@pytest.mark.django_db
@pytest.mark.parametrize(
"page_size,offset",
[
(6, 2), (2, 6), (5, 3), (3, 5), (5, 5)
],
ids=[
'page_size_divisor_of_offset',
'page_size_multiple_of_offset',
'page_size_uneven_divisor_of_offset',
'page_size_uneven_multiple_of_offset',
'page_size_same_as_offset',
]
)
def test_filtered_items_are_paginated(page_size, offset):

PaginationCls = type('PaginationCls', (CursorPagination,), dict(
page_size=page_size,
offset_cutoff=offset,
max_page_size=20,
))

example_models = []

for id_, (field_1, field_2) in enumerate(
itertools.product(range(1, 11), range(1, 3))
):
# field_1 is a unique range from 1-10 inclusive
# field_2 is the 'timestamp' field. 1 or 2
example_models.append(
ExamplePaginationModel(
# manual primary key
id=id_ + 1,
field=field_1,
timestamp=field_2,
)
)

ExamplePaginationModel.objects.bulk_create(example_models)

view = generics.ListAPIView.as_view(
serializer_class=SerializerCls,
queryset=ExamplePaginationModel.objects.all(),
pagination_class=PaginationCls,
permission_classes=(AllowAny,),
filter_backends=[OrderingFilter],
)

def _request(offset, reverse, position):
return view(
factory.get(
"/",
{
PaginationCls.cursor_query_param: create_cursor(
offset, reverse, position
),
"ordering": "timestamp,id",
},
)
)

# This is the result we would expect
expected_result = list(
ExamplePaginationModel.objects.order_by("timestamp", "id").values(
"timestamp",
"id",
"field",
)
)
assert expected_result == [
{"field": 1, "id": 1, "timestamp": 1},
{"field": 2, "id": 3, "timestamp": 1},
{"field": 3, "id": 5, "timestamp": 1},
{"field": 4, "id": 7, "timestamp": 1},
{"field": 5, "id": 9, "timestamp": 1},
{"field": 6, "id": 11, "timestamp": 1},
{"field": 7, "id": 13, "timestamp": 1},
{"field": 8, "id": 15, "timestamp": 1},
{"field": 9, "id": 17, "timestamp": 1},
{"field": 10, "id": 19, "timestamp": 1},
{"field": 1, "id": 2, "timestamp": 2},
{"field": 2, "id": 4, "timestamp": 2},
{"field": 3, "id": 6, "timestamp": 2},
{"field": 4, "id": 8, "timestamp": 2},
{"field": 5, "id": 10, "timestamp": 2},
{"field": 6, "id": 12, "timestamp": 2},
{"field": 7, "id": 14, "timestamp": 2},
{"field": 8, "id": 16, "timestamp": 2},
{"field": 9, "id": 18, "timestamp": 2},
{"field": 10, "id": 20, "timestamp": 2},
]

response = _request(0, False, None)
next_cursor = decode_cursor(response).next
position = 0

while next_cursor:
assert (
expected_result[position: position + len(response.data['results'])] == response.data['results']
)
position += len(response.data['results'])
response = _request(*next_cursor)
next_cursor = decode_cursor(response).next

prev_cursor = decode_cursor(response).prev
position = 20

while prev_cursor:
assert (
expected_result[position - len(response.data['results']): position] == response.data['results']
)
position -= len(response.data['results'])
response = _request(*prev_cursor)
prev_cursor = decode_cursor(response).prev
Loading