aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTom Christie2015-01-22 16:12:05 +0000
committerTom Christie2015-01-22 16:12:05 +0000
commit0822c9e55820f8e4737329e38abc2e21718af9e5 (patch)
tree5c1a5f1a80e7f28ff1bdb4c94bfa595a7b58cbfd
parent408261ee02b176732b7f840f7042e7c24f3ecd27 (diff)
downloaddjango-rest-framework-0822c9e55820f8e4737329e38abc2e21718af9e5.tar.bz2
Cursor pagination now works with OrderingFilter
-rw-r--r--rest_framework/filters.py24
-rw-r--r--rest_framework/pagination.py41
-rw-r--r--tests/test_pagination.py40
3 files changed, 82 insertions, 23 deletions
diff --git a/rest_framework/filters.py b/rest_framework/filters.py
index d188a2d1..2bcf3699 100644
--- a/rest_framework/filters.py
+++ b/rest_framework/filters.py
@@ -114,7 +114,7 @@ class OrderingFilter(BaseFilterBackend):
ordering_param = api_settings.ORDERING_PARAM
ordering_fields = None
- def get_ordering(self, request):
+ def get_ordering(self, request, queryset, view):
"""
Ordering is set by a comma delimited ?ordering=... query parameter.
@@ -124,7 +124,13 @@ class OrderingFilter(BaseFilterBackend):
"""
params = request.query_params.get(self.ordering_param)
if params:
- return [param.strip() for param in params.split(',')]
+ fields = [param.strip() for param in params.split(',')]
+ ordering = self.remove_invalid_fields(queryset, fields, view)
+ if ordering:
+ return ordering
+
+ # No ordering was included, or all the ordering fields were invalid
+ return self.get_default_ordering(view)
def get_default_ordering(self, view):
ordering = getattr(view, 'ordering', None)
@@ -132,7 +138,7 @@ class OrderingFilter(BaseFilterBackend):
return (ordering,)
return ordering
- def remove_invalid_fields(self, queryset, ordering, view):
+ def remove_invalid_fields(self, queryset, fields, view):
valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)
if valid_fields is None:
@@ -152,18 +158,10 @@ class OrderingFilter(BaseFilterBackend):
valid_fields = [field.name for field in queryset.model._meta.fields]
valid_fields += queryset.query.aggregates.keys()
- return [term for term in ordering if term.lstrip('-') in valid_fields]
+ return [term for term in fields if term.lstrip('-') in valid_fields]
def filter_queryset(self, request, queryset, view):
- ordering = self.get_ordering(request)
-
- if ordering:
- # Skip any incorrect parameters
- ordering = self.remove_invalid_fields(queryset, ordering, view)
-
- if not ordering:
- # Use 'ordering' attribute by default
- ordering = self.get_default_ordering(view)
+ ordering = self.get_ordering(request, queryset, view)
if ordering:
return queryset.order_by(*ordering)
diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py
index 7b28b47f..1b4174bc 100644
--- a/rest_framework/pagination.py
+++ b/rest_framework/pagination.py
@@ -427,8 +427,9 @@ class LimitOffsetPagination(BasePagination):
class CursorPagination(BasePagination):
- # Support usage with OrderingFilter
- # Determine how/if True, False and None positions work
+ # Determine how/if True, False and None positions work - do the string
+ # encodings work with Django queryset filters?
+ # Consider a max offset cap.
cursor_query_param = 'cursor'
page_size = api_settings.PAGINATE_BY
invalid_cursor_message = _('Invalid cursor')
@@ -436,7 +437,7 @@ class CursorPagination(BasePagination):
def paginate_queryset(self, queryset, request, view=None):
self.base_url = request.build_absolute_uri()
- self.ordering = self.get_ordering(view)
+ self.ordering = self.get_ordering(request, queryset, view)
# Determine if we have a cursor, and if so then decode it.
encoded = request.query_params.get(self.cursor_query_param)
@@ -600,16 +601,36 @@ class CursorPagination(BasePagination):
encoded = _encode_cursor(cursor)
return replace_query_param(self.base_url, self.cursor_query_param, encoded)
- def get_ordering(self, view):
+ def get_ordering(self, request, queryset, view):
"""
Return a tuple of strings, that may be used in an `order_by` method.
"""
- ordering = getattr(view, 'ordering', getattr(self, 'ordering', None))
+ ordering_filters = [
+ filter_cls for filter_cls in getattr(view, 'filter_backends', [])
+ if hasattr(filter_cls, 'get_ordering')
+ ]
+
+ if ordering_filters:
+ # If a filter exists on the view that implements `get_ordering`
+ # then we defer to that filter to determine the ordering.
+ filter_cls = ordering_filters[0]
+ filter_instance = filter_cls()
+ ordering = filter_instance.get_ordering(request, queryset, view)
+ assert ordering is not None, (
+ 'Using cursor pagination, but filter class {filter_cls} '
+ 'returned a `None` ordering.'.format(
+ filter_cls=filter_cls.__name__
+ )
+ )
+ else:
+ # The default case is to check for an `ordering` attribute,
+ # first on the view instance, and then on this pagination instance.
+ ordering = getattr(view, 'ordering', getattr(self, 'ordering', None))
+ assert ordering is not None, (
+ 'Using cursor pagination, but no ordering attribute was declared '
+ 'on the view or on the pagination class.'
+ )
- assert ordering is not None, (
- 'Using cursor pagination, but no ordering attribute was declared '
- 'on the view or on the pagination class.'
- )
assert isinstance(ordering, (six.string_types, list, tuple)), (
'Invalid ordering. Expected string or tuple, but got {type}'.format(
type=type(ordering).__name__
@@ -618,7 +639,7 @@ class CursorPagination(BasePagination):
if isinstance(ordering, six.string_types):
return (ordering,)
- return ordering
+ return tuple(ordering)
def _get_position_from_instance(self, instance, ordering):
attr = getattr(instance, ordering[0])
diff --git a/tests/test_pagination.py b/tests/test_pagination.py
index c05b4aba..338be610 100644
--- a/tests/test_pagination.py
+++ b/tests/test_pagination.py
@@ -77,6 +77,20 @@ class TestPaginationIntegration:
'count': 50
}
+ def test_setting_page_size_to_zero(self):
+ """
+ When page_size parameter is invalid it should return to the default.
+ """
+ request = factory.get('/', {'page_size': 0})
+ response = self.view(request)
+ assert response.status_code == status.HTTP_200_OK
+ assert response.data == {
+ 'results': [2, 4, 6, 8, 10],
+ 'previous': None,
+ 'next': 'http://testserver/?page=2&page_size=0',
+ 'count': 50
+ }
+
def test_additional_query_params_are_preserved(self):
request = factory.get('/', {'page': 2, 'filter': 'even'})
response = self.view(request)
@@ -88,6 +102,14 @@ class TestPaginationIntegration:
'count': 50
}
+ def test_404_not_found_for_zero_page(self):
+ request = factory.get('/', {'page': '0'})
+ response = self.view(request)
+ assert response.status_code == status.HTTP_404_NOT_FOUND
+ assert response.data == {
+ 'detail': 'Invalid page "0": That page number is less than 1.'
+ }
+
def test_404_not_found_for_invalid_page(self):
request = factory.get('/', {'page': 'invalid'})
response = self.view(request)
@@ -507,6 +529,24 @@ class TestCursorPagination:
with pytest.raises(exceptions.NotFound):
self.pagination.paginate_queryset(self.queryset, request)
+ def test_use_with_ordering_filter(self):
+ class MockView:
+ filter_backends = (filters.OrderingFilter,)
+ ordering_fields = ['username', 'created']
+ ordering = 'created'
+
+ request = Request(factory.get('/', {'ordering': 'username'}))
+ ordering = self.pagination.get_ordering(request, [], MockView())
+ assert ordering == ('username',)
+
+ request = Request(factory.get('/', {'ordering': '-username'}))
+ ordering = self.pagination.get_ordering(request, [], MockView())
+ assert ordering == ('-username',)
+
+ request = Request(factory.get('/', {'ordering': 'invalid'}))
+ ordering = self.pagination.get_ordering(request, [], MockView())
+ assert ordering == ('created',)
+
def test_cursor_pagination(self):
(previous, current, next, previous_url, next_url) = self.get_pages('/')