Skip to content

Commit 6bcd2f8

Browse files
committed
Fix broken pagination
1 parent 2375f6c commit 6bcd2f8

File tree

1 file changed

+66
-22
lines changed

1 file changed

+66
-22
lines changed

rest_framework/pagination.py

Lines changed: 66 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,20 @@
22
Pagination serializers determine the structure of the output that should
33
be used for paginated responses.
44
"""
5+
import json
6+
import operator
7+
58
from base64 import b64decode, b64encode
69
from collections import OrderedDict, namedtuple
710
from urllib import parse
11+
from functools import reduce
812

913
from django.core.paginator import InvalidPage
1014
from django.core.paginator import Paginator as DjangoPaginator
1115
from django.template import loader
1216
from django.utils.encoding import force_str
1317
from django.utils.translation import gettext_lazy as _
18+
from django.db.models.query import Q
1419

1520
from rest_framework.compat import coreapi, coreschema
1621
from rest_framework.exceptions import NotFound
@@ -616,25 +621,41 @@ def paginate_queryset(self, queryset, request, view=None):
616621
else:
617622
(offset, reverse, current_position) = self.cursor
618623

619-
# Cursor pagination always enforces an ordering.
620-
if reverse:
621-
queryset = queryset.order_by(*_reverse_ordering(self.ordering))
622-
else:
623-
queryset = queryset.order_by(*self.ordering)
624-
625624
# If we have a cursor with a fixed position then filter by that.
626625
if current_position is not None:
627-
order = self.ordering[0]
628-
is_reversed = order.startswith('-')
629-
order_attr = order.lstrip('-')
626+
current_position_list = json.loads(current_position)
630627

631-
# Test for: (cursor reversed) XOR (queryset reversed)
632-
if self.cursor.reverse != is_reversed:
633-
kwargs = {order_attr + '__lt': current_position}
634-
else:
635-
kwargs = {order_attr + '__gt': current_position}
628+
q_objects_equals = {}
629+
q_objects_compare = {}
630+
631+
for order, position in zip(self.ordering, current_position_list):
632+
is_reversed = order.startswith("-")
633+
order_attr = order.lstrip("-")
634+
635+
q_objects_equals[order] = Q(**{order_attr: position})
636+
637+
# Test for: (cursor reversed) XOR (queryset reversed)
638+
if self.cursor.reverse != is_reversed:
639+
q_objects_compare[order] = Q(
640+
**{(order_attr + "__lt"): position}
641+
)
642+
else:
643+
q_objects_compare[order] = Q(
644+
**{(order_attr + "__gt"): position}
645+
)
646+
647+
filter_list = []
648+
# starting with the second field
649+
for i in range(len(self.ordering)):
650+
# The first operands need to be equals
651+
# the last operands need to be gt
652+
equals = list(self.ordering[:i+2])
653+
greater_than_q = q_objects_compare[equals.pop()]
654+
sub_filters = [q_objects_equals[e] for e in equals]
655+
sub_filters.append(greater_than_q)
656+
filter_list.append(reduce(operator.and_, sub_filters))
636657

637-
queryset = queryset.filter(**kwargs)
658+
queryset = queryset.filter(reduce(operator.or_, filter_list))
638659

639660
# If we have an offset cursor then offset the entire page by that amount.
640661
# We also always fetch an extra item in order to determine if there is a
@@ -839,7 +860,14 @@ def get_ordering(self, request, queryset, view):
839860
)
840861

841862
if isinstance(ordering, str):
842-
return (ordering,)
863+
ordering = (ordering,)
864+
865+
pk_name = queryset.model._meta.pk.name
866+
867+
# Always include a unique key to order by
868+
if not {f"-{pk_name}", pk_name, "pk", "-pk"} & set(ordering):
869+
ordering = ordering + (pk_name)
870+
843871
return tuple(ordering)
844872

845873
def decode_cursor(self, request):
@@ -884,12 +912,28 @@ def encode_cursor(self, cursor):
884912
return replace_query_param(self.base_url, self.cursor_query_param, encoded)
885913

886914
def _get_position_from_instance(self, instance, ordering):
887-
field_name = ordering[0].lstrip('-')
888-
if isinstance(instance, dict):
889-
attr = instance[field_name]
890-
else:
891-
attr = getattr(instance, field_name)
892-
return str(attr)
915+
"""
916+
Overriden from the base class.
917+
This encodes the list data structure that's decoded
918+
on line 154 of this file.
919+
The old method simply return getattr(instnace, ordering[0]).
920+
This only works if the value in ordering[0] is unique.
921+
The value is json encoded here because it's an easy way to
922+
escape and serialize a list. This is then encoded as base64
923+
by encode_cursor, which calls this function.
924+
"""
925+
fields = []
926+
927+
for o in ordering:
928+
field_name = o.lstrip("-")
929+
if isinstance(instance, dict):
930+
attr = instance[field_name]
931+
else:
932+
attr = getattr(instance, field_name)
933+
934+
fields.append(str(attr))
935+
936+
return json.dumps(fields).encode()
893937

894938
def get_paginated_response(self, data):
895939
return Response(OrderedDict([

0 commit comments

Comments
 (0)