diff options
| author | Ross McFarland | 2013-10-19 20:43:23 -0700 | 
|---|---|---|
| committer | Ross McFarland | 2013-10-19 21:11:27 -0700 | 
| commit | 63e6a3b4925bf54e80ae63502a0353136e846b31 (patch) | |
| tree | d3f95c8d4f1a8926c749a81e2ff07451778354bb /rest_framework | |
| parent | c3aeb16557f2cbb1c1218b5af7bab646e4958234 (diff) | |
| download | django-rest-framework-63e6a3b4925bf54e80ae63502a0353136e846b31.tar.bz2 | |
paginator should validate page and provide default
- use the standard paginator.validate_number method rather
  strict_postive_int.
- support optional paginator method, default_page_number, to get the default
  page number rather than hard-coding it to 1
- this allows supporting non-integer based pagination which can be an
  important performance tweak on extermely large datasets or high request
  loads
- relatively thorough unit tests of the changes
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/generics.py | 14 | ||||
| -rw-r--r-- | rest_framework/tests/test_pagination.py | 88 | 
2 files changed, 99 insertions, 3 deletions
| diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 4f134bce..6b42a1d5 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -145,10 +145,18 @@ class GenericAPIView(views.APIView):                                           allow_empty_first_page=self.allow_empty)          page_kwarg = self.kwargs.get(self.page_kwarg)          page_query_param = self.request.QUERY_PARAMS.get(self.page_kwarg) -        page = page_kwarg or page_query_param or 1 +        page = page_kwarg or page_query_param +        if not page: +            # we didn't recieve a page +            if hasattr(paginator, 'default_page_number'): +                # our paginator has a method that will provide a default +                page = paginator.default_page_number() +            else: +                # fall back on the base default value +                page = 1          try: -            page_number = strict_positive_int(page) -        except ValueError: +            page_number = paginator.validate_number(page) +        except InvalidPage:              if page == 'last':                  page_number = paginator.num_pages              else: diff --git a/rest_framework/tests/test_pagination.py b/rest_framework/tests/test_pagination.py index d6bc7895..a1118f1e 100644 --- a/rest_framework/tests/test_pagination.py +++ b/rest_framework/tests/test_pagination.py @@ -430,3 +430,91 @@ class TestCustomPaginationSerializer(TestCase):              'objects': ['john', 'paul']          }          self.assertEqual(serializer.data, expected) + + +class NonIntegerPage(object): + +    def __init__(self, paginator, object_list, prev_token, token, next_token): +        self.paginator = paginator +        self.object_list = object_list +        self.prev_token = prev_token +        self.token = token +        self.next_token = next_token + +    def has_next(self): +        return not not self.next_token + +    def next_page_number(self): +        return self.next_token + +    def has_previous(self): +        return not not self.prev_token + +    def previous_page_number(self): +        return self.prev_token + + +class NonIntegerPaginator(object): + +    def __init__(self, object_list, per_page): +        self.object_list = object_list +        self.per_page = per_page + +    def count(self): +        # pretend like we don't know how many pages we have +        return None + +    def default_page_token(self): +        return None + +    def page(self, token=None): +        if token: +            try: +                first = self.object_list.index(token) +            except ValueError: +                first = 0 +        else: +            first = 0 +        n = len(self.object_list) +        last = min(first + self.per_page, n) +        prev_token = self.object_list[last - (2 * self.per_page)] if first else None +        next_token = self.object_list[last] if last < n else None +        return NonIntegerPage(self, self.object_list[first:last], prev_token, token, next_token) + + +class TestNonIntegerPagination(TestCase): + + +    def test_custom_pagination_serializer(self): +        objects = ['john', 'paul', 'george', 'ringo'] +        paginator = NonIntegerPaginator(objects, 2) + +        request = APIRequestFactory().get('/foobar') +        serializer = CustomPaginationSerializer( +            instance=paginator.page(), +            context={'request': request} +        ) +        expected = { +            'links': { +                'next': 'http://testserver/foobar?page={0}'.format(objects[2]), +                'prev': None +            }, +            'total_results': None, +            'objects': objects[:2] +        } +        self.assertEqual(serializer.data, expected) + +        request = APIRequestFactory().get('/foobar') +        serializer = CustomPaginationSerializer( +            instance=paginator.page('george'), +            context={'request': request} +        ) +        expected = { +            'links': { +                'next': None, +                'prev': 'http://testserver/foobar?page={0}'.format(objects[0]), +            }, +            'total_results': None, +            'objects': objects[2:] +        } +        self.assertEqual(serializer.data, expected) | 
