From 0822c9e55820f8e4737329e38abc2e21718af9e5 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 22 Jan 2015 16:12:05 +0000 Subject: Cursor pagination now works with OrderingFilter --- rest_framework/filters.py | 24 +++++++++++------------- rest_framework/pagination.py | 41 +++++++++++++++++++++++++++++++---------- 2 files changed, 42 insertions(+), 23 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/filters.py b/rest_framework/filters.py index d188a2d1..2bcf3699 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -114,7 +114,7 @@ class OrderingFilter(BaseFilterBackend): ordering_param = api_settings.ORDERING_PARAM ordering_fields = None - def get_ordering(self, request): + def get_ordering(self, request, queryset, view): """ Ordering is set by a comma delimited ?ordering=... query parameter. @@ -124,7 +124,13 @@ class OrderingFilter(BaseFilterBackend): """ params = request.query_params.get(self.ordering_param) if params: - return [param.strip() for param in params.split(',')] + fields = [param.strip() for param in params.split(',')] + ordering = self.remove_invalid_fields(queryset, fields, view) + if ordering: + return ordering + + # No ordering was included, or all the ordering fields were invalid + return self.get_default_ordering(view) def get_default_ordering(self, view): ordering = getattr(view, 'ordering', None) @@ -132,7 +138,7 @@ class OrderingFilter(BaseFilterBackend): return (ordering,) return ordering - def remove_invalid_fields(self, queryset, ordering, view): + def remove_invalid_fields(self, queryset, fields, view): valid_fields = getattr(view, 'ordering_fields', self.ordering_fields) if valid_fields is None: @@ -152,18 +158,10 @@ class OrderingFilter(BaseFilterBackend): valid_fields = [field.name for field in queryset.model._meta.fields] valid_fields += queryset.query.aggregates.keys() - return [term for term in ordering if term.lstrip('-') in valid_fields] + return [term for term in fields if term.lstrip('-') in valid_fields] def filter_queryset(self, request, queryset, view): - ordering = self.get_ordering(request) - - if ordering: - # Skip any incorrect parameters - ordering = self.remove_invalid_fields(queryset, ordering, view) - - if not ordering: - # Use 'ordering' attribute by default - ordering = self.get_default_ordering(view) + ordering = self.get_ordering(request, queryset, view) if ordering: return queryset.order_by(*ordering) diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index 7b28b47f..1b4174bc 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -427,8 +427,9 @@ class LimitOffsetPagination(BasePagination): class CursorPagination(BasePagination): - # Support usage with OrderingFilter - # Determine how/if True, False and None positions work + # Determine how/if True, False and None positions work - do the string + # encodings work with Django queryset filters? + # Consider a max offset cap. cursor_query_param = 'cursor' page_size = api_settings.PAGINATE_BY invalid_cursor_message = _('Invalid cursor') @@ -436,7 +437,7 @@ class CursorPagination(BasePagination): def paginate_queryset(self, queryset, request, view=None): self.base_url = request.build_absolute_uri() - self.ordering = self.get_ordering(view) + self.ordering = self.get_ordering(request, queryset, view) # Determine if we have a cursor, and if so then decode it. encoded = request.query_params.get(self.cursor_query_param) @@ -600,16 +601,36 @@ class CursorPagination(BasePagination): encoded = _encode_cursor(cursor) return replace_query_param(self.base_url, self.cursor_query_param, encoded) - def get_ordering(self, view): + def get_ordering(self, request, queryset, view): """ Return a tuple of strings, that may be used in an `order_by` method. """ - ordering = getattr(view, 'ordering', getattr(self, 'ordering', None)) + ordering_filters = [ + filter_cls for filter_cls in getattr(view, 'filter_backends', []) + if hasattr(filter_cls, 'get_ordering') + ] + + if ordering_filters: + # If a filter exists on the view that implements `get_ordering` + # then we defer to that filter to determine the ordering. + filter_cls = ordering_filters[0] + filter_instance = filter_cls() + ordering = filter_instance.get_ordering(request, queryset, view) + assert ordering is not None, ( + 'Using cursor pagination, but filter class {filter_cls} ' + 'returned a `None` ordering.'.format( + filter_cls=filter_cls.__name__ + ) + ) + else: + # The default case is to check for an `ordering` attribute, + # first on the view instance, and then on this pagination instance. + ordering = getattr(view, 'ordering', getattr(self, 'ordering', None)) + assert ordering is not None, ( + 'Using cursor pagination, but no ordering attribute was declared ' + 'on the view or on the pagination class.' + ) - assert ordering is not None, ( - 'Using cursor pagination, but no ordering attribute was declared ' - 'on the view or on the pagination class.' - ) assert isinstance(ordering, (six.string_types, list, tuple)), ( 'Invalid ordering. Expected string or tuple, but got {type}'.format( type=type(ordering).__name__ @@ -618,7 +639,7 @@ class CursorPagination(BasePagination): if isinstance(ordering, six.string_types): return (ordering,) - return ordering + return tuple(ordering) def _get_position_from_instance(self, instance, ordering): attr = getattr(instance, ordering[0]) -- cgit v1.2.3