diff options
| -rw-r--r-- | rest_framework/filters.py | 24 | ||||
| -rw-r--r-- | rest_framework/pagination.py | 41 | ||||
| -rw-r--r-- | tests/test_pagination.py | 40 | 
3 files changed, 82 insertions, 23 deletions
| diff --git a/rest_framework/filters.py b/rest_framework/filters.py index d188a2d1..2bcf3699 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -114,7 +114,7 @@ class OrderingFilter(BaseFilterBackend):      ordering_param = api_settings.ORDERING_PARAM      ordering_fields = None -    def get_ordering(self, request): +    def get_ordering(self, request, queryset, view):          """          Ordering is set by a comma delimited ?ordering=... query parameter. @@ -124,7 +124,13 @@ class OrderingFilter(BaseFilterBackend):          """          params = request.query_params.get(self.ordering_param)          if params: -            return [param.strip() for param in params.split(',')] +            fields = [param.strip() for param in params.split(',')] +            ordering = self.remove_invalid_fields(queryset, fields, view) +            if ordering: +                return ordering + +        # No ordering was included, or all the ordering fields were invalid +        return self.get_default_ordering(view)      def get_default_ordering(self, view):          ordering = getattr(view, 'ordering', None) @@ -132,7 +138,7 @@ class OrderingFilter(BaseFilterBackend):              return (ordering,)          return ordering -    def remove_invalid_fields(self, queryset, ordering, view): +    def remove_invalid_fields(self, queryset, fields, view):          valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)          if valid_fields is None: @@ -152,18 +158,10 @@ class OrderingFilter(BaseFilterBackend):              valid_fields = [field.name for field in queryset.model._meta.fields]              valid_fields += queryset.query.aggregates.keys() -        return [term for term in ordering if term.lstrip('-') in valid_fields] +        return [term for term in fields if term.lstrip('-') in valid_fields]      def filter_queryset(self, request, queryset, view): -        ordering = self.get_ordering(request) - -        if ordering: -            # Skip any incorrect parameters -            ordering = self.remove_invalid_fields(queryset, ordering, view) - -        if not ordering: -            # Use 'ordering' attribute by default -            ordering = self.get_default_ordering(view) +        ordering = self.get_ordering(request, queryset, view)          if ordering:              return queryset.order_by(*ordering) diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index 7b28b47f..1b4174bc 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -427,8 +427,9 @@ class LimitOffsetPagination(BasePagination):  class CursorPagination(BasePagination): -    # Support usage with OrderingFilter -    # Determine how/if True, False and None positions work +    # Determine how/if True, False and None positions work - do the string +    # encodings work with Django queryset filters? +    # Consider a max offset cap.      cursor_query_param = 'cursor'      page_size = api_settings.PAGINATE_BY      invalid_cursor_message = _('Invalid cursor') @@ -436,7 +437,7 @@ class CursorPagination(BasePagination):      def paginate_queryset(self, queryset, request, view=None):          self.base_url = request.build_absolute_uri() -        self.ordering = self.get_ordering(view) +        self.ordering = self.get_ordering(request, queryset, view)          # Determine if we have a cursor, and if so then decode it.          encoded = request.query_params.get(self.cursor_query_param) @@ -600,16 +601,36 @@ class CursorPagination(BasePagination):          encoded = _encode_cursor(cursor)          return replace_query_param(self.base_url, self.cursor_query_param, encoded) -    def get_ordering(self, view): +    def get_ordering(self, request, queryset, view):          """          Return a tuple of strings, that may be used in an `order_by` method.          """ -        ordering = getattr(view, 'ordering', getattr(self, 'ordering', None)) +        ordering_filters = [ +            filter_cls for filter_cls in getattr(view, 'filter_backends', []) +            if hasattr(filter_cls, 'get_ordering') +        ] + +        if ordering_filters: +            # If a filter exists on the view that implements `get_ordering` +            # then we defer to that filter to determine the ordering. +            filter_cls = ordering_filters[0] +            filter_instance = filter_cls() +            ordering = filter_instance.get_ordering(request, queryset, view) +            assert ordering is not None, ( +                'Using cursor pagination, but filter class {filter_cls} ' +                'returned a `None` ordering.'.format( +                    filter_cls=filter_cls.__name__ +                ) +            ) +        else: +            # The default case is to check for an `ordering` attribute, +            # first on the view instance, and then on this pagination instance. +            ordering = getattr(view, 'ordering', getattr(self, 'ordering', None)) +            assert ordering is not None, ( +                'Using cursor pagination, but no ordering attribute was declared ' +                'on the view or on the pagination class.' +            ) -        assert ordering is not None, ( -            'Using cursor pagination, but no ordering attribute was declared ' -            'on the view or on the pagination class.' -        )          assert isinstance(ordering, (six.string_types, list, tuple)), (              'Invalid ordering. Expected string or tuple, but got {type}'.format(                  type=type(ordering).__name__ @@ -618,7 +639,7 @@ class CursorPagination(BasePagination):          if isinstance(ordering, six.string_types):              return (ordering,) -        return ordering +        return tuple(ordering)      def _get_position_from_instance(self, instance, ordering):          attr = getattr(instance, ordering[0]) diff --git a/tests/test_pagination.py b/tests/test_pagination.py index c05b4aba..338be610 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -77,6 +77,20 @@ class TestPaginationIntegration:              'count': 50          } +    def test_setting_page_size_to_zero(self): +        """ +        When page_size parameter is invalid it should return to the default. +        """ +        request = factory.get('/', {'page_size': 0}) +        response = self.view(request) +        assert response.status_code == status.HTTP_200_OK +        assert response.data == { +            'results': [2, 4, 6, 8, 10], +            'previous': None, +            'next': 'http://testserver/?page=2&page_size=0', +            'count': 50 +        } +      def test_additional_query_params_are_preserved(self):          request = factory.get('/', {'page': 2, 'filter': 'even'})          response = self.view(request) @@ -88,6 +102,14 @@ class TestPaginationIntegration:              'count': 50          } +    def test_404_not_found_for_zero_page(self): +        request = factory.get('/', {'page': '0'}) +        response = self.view(request) +        assert response.status_code == status.HTTP_404_NOT_FOUND +        assert response.data == { +            'detail': 'Invalid page "0": That page number is less than 1.' +        } +      def test_404_not_found_for_invalid_page(self):          request = factory.get('/', {'page': 'invalid'})          response = self.view(request) @@ -507,6 +529,24 @@ class TestCursorPagination:          with pytest.raises(exceptions.NotFound):              self.pagination.paginate_queryset(self.queryset, request) +    def test_use_with_ordering_filter(self): +        class MockView: +            filter_backends = (filters.OrderingFilter,) +            ordering_fields = ['username', 'created'] +            ordering = 'created' + +        request = Request(factory.get('/', {'ordering': 'username'})) +        ordering = self.pagination.get_ordering(request, [], MockView()) +        assert ordering == ('username',) + +        request = Request(factory.get('/', {'ordering': '-username'})) +        ordering = self.pagination.get_ordering(request, [], MockView()) +        assert ordering == ('-username',) + +        request = Request(factory.get('/', {'ordering': 'invalid'})) +        ordering = self.pagination.get_ordering(request, [], MockView()) +        assert ordering == ('created',) +      def test_cursor_pagination(self):          (previous, current, next, previous_url, next_url) = self.get_pages('/') | 
