diff options
| -rw-r--r-- | rest_framework/generics.py | 14 | ||||
| -rw-r--r-- | rest_framework/tests/test_pagination.py | 88 | 
2 files changed, 99 insertions, 3 deletions
| diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 4f134bce..6b42a1d5 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -145,10 +145,18 @@ class GenericAPIView(views.APIView):                                           allow_empty_first_page=self.allow_empty)          page_kwarg = self.kwargs.get(self.page_kwarg)          page_query_param = self.request.QUERY_PARAMS.get(self.page_kwarg) -        page = page_kwarg or page_query_param or 1 +        page = page_kwarg or page_query_param +        if not page: +            # we didn't recieve a page +            if hasattr(paginator, 'default_page_number'): +                # our paginator has a method that will provide a default +                page = paginator.default_page_number() +            else: +                # fall back on the base default value +                page = 1          try: -            page_number = strict_positive_int(page) -        except ValueError: +            page_number = paginator.validate_number(page) +        except InvalidPage:              if page == 'last':                  page_number = paginator.num_pages              else: diff --git a/rest_framework/tests/test_pagination.py b/rest_framework/tests/test_pagination.py index d6bc7895..a1118f1e 100644 --- a/rest_framework/tests/test_pagination.py +++ b/rest_framework/tests/test_pagination.py @@ -430,3 +430,91 @@ class TestCustomPaginationSerializer(TestCase):              'objects': ['john', 'paul']          }          self.assertEqual(serializer.data, expected) + + +class NonIntegerPage(object): + +    def __init__(self, paginator, object_list, prev_token, token, next_token): +        self.paginator = paginator +        self.object_list = object_list +        self.prev_token = prev_token +        self.token = token +        self.next_token = next_token + +    def has_next(self): +        return not not self.next_token + +    def next_page_number(self): +        return self.next_token + +    def has_previous(self): +        return not not self.prev_token + +    def previous_page_number(self): +        return self.prev_token + + +class NonIntegerPaginator(object): + +    def __init__(self, object_list, per_page): +        self.object_list = object_list +        self.per_page = per_page + +    def count(self): +        # pretend like we don't know how many pages we have +        return None + +    def default_page_token(self): +        return None + +    def page(self, token=None): +        if token: +            try: +                first = self.object_list.index(token) +            except ValueError: +                first = 0 +        else: +            first = 0 +        n = len(self.object_list) +        last = min(first + self.per_page, n) +        prev_token = self.object_list[last - (2 * self.per_page)] if first else None +        next_token = self.object_list[last] if last < n else None +        return NonIntegerPage(self, self.object_list[first:last], prev_token, token, next_token) + + +class TestNonIntegerPagination(TestCase): + + +    def test_custom_pagination_serializer(self): +        objects = ['john', 'paul', 'george', 'ringo'] +        paginator = NonIntegerPaginator(objects, 2) + +        request = APIRequestFactory().get('/foobar') +        serializer = CustomPaginationSerializer( +            instance=paginator.page(), +            context={'request': request} +        ) +        expected = { +            'links': { +                'next': 'http://testserver/foobar?page={0}'.format(objects[2]), +                'prev': None +            }, +            'total_results': None, +            'objects': objects[:2] +        } +        self.assertEqual(serializer.data, expected) + +        request = APIRequestFactory().get('/foobar') +        serializer = CustomPaginationSerializer( +            instance=paginator.page('george'), +            context={'request': request} +        ) +        expected = { +            'links': { +                'next': None, +                'prev': 'http://testserver/foobar?page={0}'.format(objects[0]), +            }, +            'total_results': None, +            'objects': objects[2:] +        } +        self.assertEqual(serializer.data, expected) | 
