diff options
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/generics.py | 9 | ||||
| -rw-r--r-- | rest_framework/mixins.py | 3 | ||||
| -rw-r--r-- | rest_framework/tests/test_pagination.py | 85 |
3 files changed, 93 insertions, 4 deletions
diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 4f134bce..6d204cf5 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -54,6 +54,7 @@ class GenericAPIView(views.APIView): # If you want to use object lookups other than pk, set this attribute. # For more complex lookup requirements override `get_object()`. lookup_field = 'pk' + lookup_url_kwarg = None # Pagination settings paginate_by = api_settings.PAGINATE_BY @@ -147,8 +148,8 @@ class GenericAPIView(views.APIView): page_query_param = self.request.QUERY_PARAMS.get(self.page_kwarg) page = page_kwarg or page_query_param or 1 try: - page_number = strict_positive_int(page) - except ValueError: + page_number = paginator.validate_number(page) + except InvalidPage: if page == 'last': page_number = paginator.num_pages else: @@ -278,9 +279,11 @@ class GenericAPIView(views.APIView): pass # Deprecation warning # Perform the lookup filtering. + # Note that `pk` and `slug` are deprecated styles of lookup filtering. + lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field + lookup = self.kwargs.get(lookup_url_kwarg, None) pk = self.kwargs.get(self.pk_url_kwarg, None) slug = self.kwargs.get(self.slug_url_kwarg, None) - lookup = self.kwargs.get(self.lookup_field, None) if lookup is not None: filter_kwargs = {self.lookup_field: lookup} diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 426865ff..4606c78b 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -158,7 +158,8 @@ class UpdateModelMixin(object): Set any attributes on the object that are implicit in the request. """ # pk and/or slug attributes are implicit in the URL. - lookup = self.kwargs.get(self.lookup_field, None) + lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field + lookup = self.kwargs.get(lookup_url_kwarg, None) pk = self.kwargs.get(self.pk_url_kwarg, None) slug = self.kwargs.get(self.slug_url_kwarg, None) slug_field = slug and self.slug_field or None diff --git a/rest_framework/tests/test_pagination.py b/rest_framework/tests/test_pagination.py index d6bc7895..cadb515f 100644 --- a/rest_framework/tests/test_pagination.py +++ b/rest_framework/tests/test_pagination.py @@ -430,3 +430,88 @@ class TestCustomPaginationSerializer(TestCase): 'objects': ['john', 'paul'] } self.assertEqual(serializer.data, expected) + + +class NonIntegerPage(object): + + def __init__(self, paginator, object_list, prev_token, token, next_token): + self.paginator = paginator + self.object_list = object_list + self.prev_token = prev_token + self.token = token + self.next_token = next_token + + def has_next(self): + return not not self.next_token + + def next_page_number(self): + return self.next_token + + def has_previous(self): + return not not self.prev_token + + def previous_page_number(self): + return self.prev_token + + +class NonIntegerPaginator(object): + + def __init__(self, object_list, per_page): + self.object_list = object_list + self.per_page = per_page + + def count(self): + # pretend like we don't know how many pages we have + return None + + def page(self, token=None): + if token: + try: + first = self.object_list.index(token) + except ValueError: + first = 0 + else: + first = 0 + n = len(self.object_list) + last = min(first + self.per_page, n) + prev_token = self.object_list[last - (2 * self.per_page)] if first else None + next_token = self.object_list[last] if last < n else None + return NonIntegerPage(self, self.object_list[first:last], prev_token, token, next_token) + + +class TestNonIntegerPagination(TestCase): + + + def test_custom_pagination_serializer(self): + objects = ['john', 'paul', 'george', 'ringo'] + paginator = NonIntegerPaginator(objects, 2) + + request = APIRequestFactory().get('/foobar') + serializer = CustomPaginationSerializer( + instance=paginator.page(), + context={'request': request} + ) + expected = { + 'links': { + 'next': 'http://testserver/foobar?page={0}'.format(objects[2]), + 'prev': None + }, + 'total_results': None, + 'objects': objects[:2] + } + self.assertEqual(serializer.data, expected) + + request = APIRequestFactory().get('/foobar') + serializer = CustomPaginationSerializer( + instance=paginator.page('george'), + context={'request': request} + ) + expected = { + 'links': { + 'next': None, + 'prev': 'http://testserver/foobar?page={0}'.format(objects[0]), + }, + 'total_results': None, + 'objects': objects[2:] + } + self.assertEqual(serializer.data, expected) |
