diff options
| author | Tom Christie | 2015-01-22 16:12:05 +0000 | 
|---|---|---|
| committer | Tom Christie | 2015-01-22 16:12:05 +0000 | 
| commit | 0822c9e55820f8e4737329e38abc2e21718af9e5 (patch) | |
| tree | 5c1a5f1a80e7f28ff1bdb4c94bfa595a7b58cbfd /rest_framework | |
| parent | 408261ee02b176732b7f840f7042e7c24f3ecd27 (diff) | |
| download | django-rest-framework-0822c9e55820f8e4737329e38abc2e21718af9e5.tar.bz2 | |
Cursor pagination now works with OrderingFilter
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/filters.py | 24 | ||||
| -rw-r--r-- | rest_framework/pagination.py | 41 | 
2 files changed, 42 insertions, 23 deletions
| diff --git a/rest_framework/filters.py b/rest_framework/filters.py index d188a2d1..2bcf3699 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -114,7 +114,7 @@ class OrderingFilter(BaseFilterBackend):      ordering_param = api_settings.ORDERING_PARAM      ordering_fields = None -    def get_ordering(self, request): +    def get_ordering(self, request, queryset, view):          """          Ordering is set by a comma delimited ?ordering=... query parameter. @@ -124,7 +124,13 @@ class OrderingFilter(BaseFilterBackend):          """          params = request.query_params.get(self.ordering_param)          if params: -            return [param.strip() for param in params.split(',')] +            fields = [param.strip() for param in params.split(',')] +            ordering = self.remove_invalid_fields(queryset, fields, view) +            if ordering: +                return ordering + +        # No ordering was included, or all the ordering fields were invalid +        return self.get_default_ordering(view)      def get_default_ordering(self, view):          ordering = getattr(view, 'ordering', None) @@ -132,7 +138,7 @@ class OrderingFilter(BaseFilterBackend):              return (ordering,)          return ordering -    def remove_invalid_fields(self, queryset, ordering, view): +    def remove_invalid_fields(self, queryset, fields, view):          valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)          if valid_fields is None: @@ -152,18 +158,10 @@ class OrderingFilter(BaseFilterBackend):              valid_fields = [field.name for field in queryset.model._meta.fields]              valid_fields += queryset.query.aggregates.keys() -        return [term for term in ordering if term.lstrip('-') in valid_fields] +        return [term for term in fields if term.lstrip('-') in valid_fields]      def filter_queryset(self, request, queryset, view): -        ordering = self.get_ordering(request) - -        if ordering: -            # Skip any incorrect parameters -            ordering = self.remove_invalid_fields(queryset, ordering, view) - -        if not ordering: -            # Use 'ordering' attribute by default -            ordering = self.get_default_ordering(view) +        ordering = self.get_ordering(request, queryset, view)          if ordering:              return queryset.order_by(*ordering) diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index 7b28b47f..1b4174bc 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -427,8 +427,9 @@ class LimitOffsetPagination(BasePagination):  class CursorPagination(BasePagination): -    # Support usage with OrderingFilter -    # Determine how/if True, False and None positions work +    # Determine how/if True, False and None positions work - do the string +    # encodings work with Django queryset filters? +    # Consider a max offset cap.      cursor_query_param = 'cursor'      page_size = api_settings.PAGINATE_BY      invalid_cursor_message = _('Invalid cursor') @@ -436,7 +437,7 @@ class CursorPagination(BasePagination):      def paginate_queryset(self, queryset, request, view=None):          self.base_url = request.build_absolute_uri() -        self.ordering = self.get_ordering(view) +        self.ordering = self.get_ordering(request, queryset, view)          # Determine if we have a cursor, and if so then decode it.          encoded = request.query_params.get(self.cursor_query_param) @@ -600,16 +601,36 @@ class CursorPagination(BasePagination):          encoded = _encode_cursor(cursor)          return replace_query_param(self.base_url, self.cursor_query_param, encoded) -    def get_ordering(self, view): +    def get_ordering(self, request, queryset, view):          """          Return a tuple of strings, that may be used in an `order_by` method.          """ -        ordering = getattr(view, 'ordering', getattr(self, 'ordering', None)) +        ordering_filters = [ +            filter_cls for filter_cls in getattr(view, 'filter_backends', []) +            if hasattr(filter_cls, 'get_ordering') +        ] + +        if ordering_filters: +            # If a filter exists on the view that implements `get_ordering` +            # then we defer to that filter to determine the ordering. +            filter_cls = ordering_filters[0] +            filter_instance = filter_cls() +            ordering = filter_instance.get_ordering(request, queryset, view) +            assert ordering is not None, ( +                'Using cursor pagination, but filter class {filter_cls} ' +                'returned a `None` ordering.'.format( +                    filter_cls=filter_cls.__name__ +                ) +            ) +        else: +            # The default case is to check for an `ordering` attribute, +            # first on the view instance, and then on this pagination instance. +            ordering = getattr(view, 'ordering', getattr(self, 'ordering', None)) +            assert ordering is not None, ( +                'Using cursor pagination, but no ordering attribute was declared ' +                'on the view or on the pagination class.' +            ) -        assert ordering is not None, ( -            'Using cursor pagination, but no ordering attribute was declared ' -            'on the view or on the pagination class.' -        )          assert isinstance(ordering, (six.string_types, list, tuple)), (              'Invalid ordering. Expected string or tuple, but got {type}'.format(                  type=type(ordering).__name__ @@ -618,7 +639,7 @@ class CursorPagination(BasePagination):          if isinstance(ordering, six.string_types):              return (ordering,) -        return ordering +        return tuple(ordering)      def _get_position_from_instance(self, instance, ordering):          attr = getattr(instance, ordering[0]) | 
