diff options
Diffstat (limited to 'tests/test_pagination.py')
| -rw-r--r-- | tests/test_pagination.py | 1048 |
1 files changed, 583 insertions, 465 deletions
diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 1fd9cf9c..13bfb627 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -1,553 +1,671 @@ +# coding: utf-8 from __future__ import unicode_literals -import datetime -from decimal import Decimal -from django.core.paginator import Paginator -from django.test import TestCase -from django.utils import unittest -from rest_framework import generics, serializers, status, pagination, filters -from rest_framework.compat import django_filters +from rest_framework import exceptions, generics, pagination, serializers, status, filters +from rest_framework.request import Request +from rest_framework.pagination import PageLink, PAGE_BREAK from rest_framework.test import APIRequestFactory -from .models import BasicModel, FilterableItem +import pytest factory = APIRequestFactory() -# Helper function to split arguments out of an url -def split_arguments_from_url(url): - if '?' not in url: - return url +class TestPaginationIntegration: + """ + Integration tests. + """ - path, args = url.split('?') - args = dict(r.split('=') for r in args.split('&')) - return path, args + def setup(self): + class PassThroughSerializer(serializers.BaseSerializer): + def to_representation(self, item): + return item + class EvenItemsOnly(filters.BaseFilterBackend): + def filter_queryset(self, request, queryset, view): + return [item for item in queryset if item % 2 == 0] + + class BasicPagination(pagination.PageNumberPagination): + paginate_by = 5 + paginate_by_param = 'page_size' + max_paginate_by = 20 + + self.view = generics.ListAPIView.as_view( + serializer_class=PassThroughSerializer, + queryset=range(1, 101), + filter_backends=[EvenItemsOnly], + pagination_class=BasicPagination + ) -class BasicSerializer(serializers.ModelSerializer): - class Meta: - model = BasicModel + def test_filtered_items_are_paginated(self): + request = factory.get('/', {'page': 2}) + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == { + 'results': [12, 14, 16, 18, 20], + 'previous': 'http://testserver/', + 'next': 'http://testserver/?page=3', + 'count': 50 + } + def test_setting_page_size(self): + """ + When 'paginate_by_param' is set, the client may choose a page size. + """ + request = factory.get('/', {'page_size': 10}) + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == { + 'results': [2, 4, 6, 8, 10, 12, 14, 16, 18, 20], + 'previous': None, + 'next': 'http://testserver/?page=2&page_size=10', + 'count': 50 + } -class FilterableItemSerializer(serializers.ModelSerializer): - class Meta: - model = FilterableItem + def test_setting_page_size_over_maximum(self): + """ + When page_size parameter exceeds maxiumum allowable, + then it should be capped to the maxiumum. + """ + request = factory.get('/', {'page_size': 1000}) + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == { + 'results': [ + 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, + 22, 24, 26, 28, 30, 32, 34, 36, 38, 40 + ], + 'previous': None, + 'next': 'http://testserver/?page=2&page_size=1000', + 'count': 50 + } + def test_setting_page_size_to_zero(self): + """ + When page_size parameter is invalid it should return to the default. + """ + request = factory.get('/', {'page_size': 0}) + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == { + 'results': [2, 4, 6, 8, 10], + 'previous': None, + 'next': 'http://testserver/?page=2&page_size=0', + 'count': 50 + } -class RootView(generics.ListCreateAPIView): - """ - Example description for OPTIONS. - """ - queryset = BasicModel.objects.all() - serializer_class = BasicSerializer - paginate_by = 10 + def test_additional_query_params_are_preserved(self): + request = factory.get('/', {'page': 2, 'filter': 'even'}) + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == { + 'results': [12, 14, 16, 18, 20], + 'previous': 'http://testserver/?filter=even', + 'next': 'http://testserver/?filter=even&page=3', + 'count': 50 + } + def test_404_not_found_for_zero_page(self): + request = factory.get('/', {'page': '0'}) + response = self.view(request) + assert response.status_code == status.HTTP_404_NOT_FOUND + assert response.data == { + 'detail': 'Invalid page "0": That page number is less than 1.' + } -class DefaultPageSizeKwargView(generics.ListAPIView): - """ - View for testing default paginate_by_param usage - """ - queryset = BasicModel.objects.all() - serializer_class = BasicSerializer + def test_404_not_found_for_invalid_page(self): + request = factory.get('/', {'page': 'invalid'}) + response = self.view(request) + assert response.status_code == status.HTTP_404_NOT_FOUND + assert response.data == { + 'detail': 'Invalid page "invalid": That page number is not an integer.' + } -class PaginateByParamView(generics.ListAPIView): +class TestPaginationDisabledIntegration: """ - View for testing custom paginate_by_param usage + Integration tests for disabled pagination. """ - queryset = BasicModel.objects.all() - serializer_class = BasicSerializer - paginate_by_param = 'page_size' + def setup(self): + class PassThroughSerializer(serializers.BaseSerializer): + def to_representation(self, item): + return item -class MaxPaginateByView(generics.ListAPIView): - """ - View for testing custom max_paginate_by usage - """ - queryset = BasicModel.objects.all() - serializer_class = BasicSerializer - paginate_by = 3 - max_paginate_by = 5 - paginate_by_param = 'page_size' + self.view = generics.ListAPIView.as_view( + serializer_class=PassThroughSerializer, + queryset=range(1, 101), + pagination_class=None + ) + + def test_unpaginated_list(self): + request = factory.get('/', {'page': 2}) + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == list(range(1, 101)) -class IntegrationTestPagination(TestCase): +class TestDeprecatedStylePagination: """ - Integration tests for paginated list views. + Integration tests for deprecated style of setting pagination + attributes on the view. """ - def setUp(self): - """ - Create 26 BasicModel instances. - """ - for char in 'abcdefghijklmnopqrstuvwxyz': - BasicModel(text=char * 3).save() - self.objects = BasicModel.objects - self.data = [ - {'id': obj.id, 'text': obj.text} - for obj in self.objects.all() - ] - self.view = RootView.as_view() - - def test_get_paginated_root_view(self): - """ - GET requests to paginated ListCreateAPIView should return paginated results. - """ - request = factory.get('/') - # Note: Database queries are a `SELECT COUNT`, and `SELECT <fields>` - with self.assertNumQueries(2): - response = self.view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data['count'], 26) - self.assertEqual(response.data['results'], self.data[:10]) - self.assertNotEqual(response.data['next'], None) - self.assertEqual(response.data['previous'], None) - - request = factory.get(*split_arguments_from_url(response.data['next'])) - with self.assertNumQueries(2): - response = self.view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data['count'], 26) - self.assertEqual(response.data['results'], self.data[10:20]) - self.assertNotEqual(response.data['next'], None) - self.assertNotEqual(response.data['previous'], None) - - request = factory.get(*split_arguments_from_url(response.data['next'])) - with self.assertNumQueries(2): - response = self.view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data['count'], 26) - self.assertEqual(response.data['results'], self.data[20:]) - self.assertEqual(response.data['next'], None) - self.assertNotEqual(response.data['previous'], None) - - -class IntegrationTestPaginationAndFiltering(TestCase): - - def setUp(self): - """ - Create 50 FilterableItem instances. - """ - base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8)) - for i in range(26): - text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc. - decimal = base_data[1] + i - date = base_data[2] - datetime.timedelta(days=i * 2) - FilterableItem(text=text, decimal=decimal, date=date).save() - - self.objects = FilterableItem.objects - self.data = [ - {'id': obj.id, 'text': obj.text, 'decimal': str(obj.decimal), 'date': obj.date.isoformat()} - for obj in self.objects.all() - ] - - @unittest.skipUnless(django_filters, 'django-filter not installed') - def test_get_django_filter_paginated_filtered_root_view(self): - """ - GET requests to paginated filtered ListCreateAPIView should return - paginated results. The next and previous links should preserve the - filtered parameters. - """ - class DecimalFilter(django_filters.FilterSet): - decimal = django_filters.NumberFilter(lookup_type='lt') - - class Meta: - model = FilterableItem - fields = ['text', 'decimal', 'date'] - - class FilterFieldsRootView(generics.ListCreateAPIView): - queryset = FilterableItem.objects.all() - serializer_class = FilterableItemSerializer - paginate_by = 10 - filter_class = DecimalFilter - filter_backends = (filters.DjangoFilterBackend,) - - view = FilterFieldsRootView.as_view() - - EXPECTED_NUM_QUERIES = 2 - - request = factory.get('/', {'decimal': '15.20'}) - with self.assertNumQueries(EXPECTED_NUM_QUERIES): - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data['count'], 15) - self.assertEqual(response.data['results'], self.data[:10]) - self.assertNotEqual(response.data['next'], None) - self.assertEqual(response.data['previous'], None) - - request = factory.get(*split_arguments_from_url(response.data['next'])) - with self.assertNumQueries(EXPECTED_NUM_QUERIES): - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data['count'], 15) - self.assertEqual(response.data['results'], self.data[10:15]) - self.assertEqual(response.data['next'], None) - self.assertNotEqual(response.data['previous'], None) - - request = factory.get(*split_arguments_from_url(response.data['previous'])) - with self.assertNumQueries(EXPECTED_NUM_QUERIES): - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data['count'], 15) - self.assertEqual(response.data['results'], self.data[:10]) - self.assertNotEqual(response.data['next'], None) - self.assertEqual(response.data['previous'], None) - - def test_get_basic_paginated_filtered_root_view(self): - """ - Same as `test_get_django_filter_paginated_filtered_root_view`, - except using a custom filter backend instead of the django-filter - backend, - """ + def setup(self): + class PassThroughSerializer(serializers.BaseSerializer): + def to_representation(self, item): + return item - class DecimalFilterBackend(filters.BaseFilterBackend): - def filter_queryset(self, request, queryset, view): - return queryset.filter(decimal__lt=Decimal(request.GET['decimal'])) - - class BasicFilterFieldsRootView(generics.ListCreateAPIView): - queryset = FilterableItem.objects.all() - serializer_class = FilterableItemSerializer - paginate_by = 10 - filter_backends = (DecimalFilterBackend,) - - view = BasicFilterFieldsRootView.as_view() - - request = factory.get('/', {'decimal': '15.20'}) - with self.assertNumQueries(2): - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data['count'], 15) - self.assertEqual(response.data['results'], self.data[:10]) - self.assertNotEqual(response.data['next'], None) - self.assertEqual(response.data['previous'], None) - - request = factory.get(*split_arguments_from_url(response.data['next'])) - with self.assertNumQueries(2): - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data['count'], 15) - self.assertEqual(response.data['results'], self.data[10:15]) - self.assertEqual(response.data['next'], None) - self.assertNotEqual(response.data['previous'], None) - - request = factory.get(*split_arguments_from_url(response.data['previous'])) - with self.assertNumQueries(2): - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data['count'], 15) - self.assertEqual(response.data['results'], self.data[:10]) - self.assertNotEqual(response.data['next'], None) - self.assertEqual(response.data['previous'], None) - - -class PassOnContextPaginationSerializer(pagination.PaginationSerializer): - class Meta: - object_serializer_class = serializers.Serializer - - -class UnitTestPagination(TestCase): - """ - Unit tests for pagination of primitive objects. - """ + class ExampleView(generics.ListAPIView): + serializer_class = PassThroughSerializer + queryset = range(1, 101) + pagination_class = pagination.PageNumberPagination + paginate_by = 20 + page_query_param = 'page_number' - def setUp(self): - self.objects = [char * 3 for char in 'abcdefghijklmnopqrstuvwxyz'] - paginator = Paginator(self.objects, 10) - self.first_page = paginator.page(1) - self.last_page = paginator.page(3) - - def test_native_pagination(self): - serializer = pagination.PaginationSerializer(self.first_page) - self.assertEqual(serializer.data['count'], 26) - self.assertEqual(serializer.data['next'], '?page=2') - self.assertEqual(serializer.data['previous'], None) - self.assertEqual(serializer.data['results'], self.objects[:10]) - - serializer = pagination.PaginationSerializer(self.last_page) - self.assertEqual(serializer.data['count'], 26) - self.assertEqual(serializer.data['next'], None) - self.assertEqual(serializer.data['previous'], '?page=2') - self.assertEqual(serializer.data['results'], self.objects[20:]) - - def test_context_available_in_result(self): - """ - Ensure context gets passed through to the object serializer. - """ - serializer = PassOnContextPaginationSerializer(self.first_page, context={'foo': 'bar'}) - serializer.data - results = serializer.fields[serializer.results_field] - self.assertEqual(serializer.context, results.context) + self.view = ExampleView.as_view() + + def test_paginate_by_attribute_on_view(self): + request = factory.get('/?page_number=2') + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == { + 'results': [ + 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + 31, 32, 33, 34, 35, 36, 37, 38, 39, 40 + ], + 'previous': 'http://testserver/', + 'next': 'http://testserver/?page_number=3', + 'count': 100 + } -class TestUnpaginated(TestCase): +class TestPageNumberPagination: """ - Tests for list views without pagination. + Unit tests for `pagination.PageNumberPagination`. """ - def setUp(self): - """ - Create 13 BasicModel instances. - """ - for i in range(13): - BasicModel(text=i).save() - self.objects = BasicModel.objects - self.data = [ - {'id': obj.id, 'text': obj.text} - for obj in self.objects.all() - ] - self.view = DefaultPageSizeKwargView.as_view() - - def test_unpaginated(self): - """ - Tests the default page size for this view. - no page size --> no limit --> no meta data - """ - request = factory.get('/') - response = self.view(request) - self.assertEqual(response.data, self.data) + def setup(self): + class ExamplePagination(pagination.PageNumberPagination): + paginate_by = 5 + self.pagination = ExamplePagination() + self.queryset = range(1, 101) + + def paginate_queryset(self, request): + return list(self.pagination.paginate_queryset(self.queryset, request)) + + def get_paginated_content(self, queryset): + response = self.pagination.get_paginated_response(queryset) + return response.data + + def get_html_context(self): + return self.pagination.get_html_context() + + def test_no_page_number(self): + request = Request(factory.get('/')) + queryset = self.paginate_queryset(request) + content = self.get_paginated_content(queryset) + context = self.get_html_context() + assert queryset == [1, 2, 3, 4, 5] + assert content == { + 'results': [1, 2, 3, 4, 5], + 'previous': None, + 'next': 'http://testserver/?page=2', + 'count': 100 + } + assert context == { + 'previous_url': None, + 'next_url': 'http://testserver/?page=2', + 'page_links': [ + PageLink('http://testserver/', 1, True, False), + PageLink('http://testserver/?page=2', 2, False, False), + PageLink('http://testserver/?page=3', 3, False, False), + PAGE_BREAK, + PageLink('http://testserver/?page=20', 20, False, False), + ] + } + assert self.pagination.display_page_controls + assert isinstance(self.pagination.to_html(), type('')) + + def test_second_page(self): + request = Request(factory.get('/', {'page': 2})) + queryset = self.paginate_queryset(request) + content = self.get_paginated_content(queryset) + context = self.get_html_context() + assert queryset == [6, 7, 8, 9, 10] + assert content == { + 'results': [6, 7, 8, 9, 10], + 'previous': 'http://testserver/', + 'next': 'http://testserver/?page=3', + 'count': 100 + } + assert context == { + 'previous_url': 'http://testserver/', + 'next_url': 'http://testserver/?page=3', + 'page_links': [ + PageLink('http://testserver/', 1, False, False), + PageLink('http://testserver/?page=2', 2, True, False), + PageLink('http://testserver/?page=3', 3, False, False), + PAGE_BREAK, + PageLink('http://testserver/?page=20', 20, False, False), + ] + } + + def test_last_page(self): + request = Request(factory.get('/', {'page': 'last'})) + queryset = self.paginate_queryset(request) + content = self.get_paginated_content(queryset) + context = self.get_html_context() + assert queryset == [96, 97, 98, 99, 100] + assert content == { + 'results': [96, 97, 98, 99, 100], + 'previous': 'http://testserver/?page=19', + 'next': None, + 'count': 100 + } + assert context == { + 'previous_url': 'http://testserver/?page=19', + 'next_url': None, + 'page_links': [ + PageLink('http://testserver/', 1, False, False), + PAGE_BREAK, + PageLink('http://testserver/?page=18', 18, False, False), + PageLink('http://testserver/?page=19', 19, False, False), + PageLink('http://testserver/?page=20', 20, True, False), + ] + } + + def test_invalid_page(self): + request = Request(factory.get('/', {'page': 'invalid'})) + with pytest.raises(exceptions.NotFound): + self.paginate_queryset(request) -class TestCustomPaginateByParam(TestCase): +class TestLimitOffset: """ - Tests for list views with default page size kwarg + Unit tests for `pagination.LimitOffsetPagination`. """ - def setUp(self): + def setup(self): + class ExamplePagination(pagination.LimitOffsetPagination): + default_limit = 10 + self.pagination = ExamplePagination() + self.queryset = range(1, 101) + + def paginate_queryset(self, request): + return list(self.pagination.paginate_queryset(self.queryset, request)) + + def get_paginated_content(self, queryset): + response = self.pagination.get_paginated_response(queryset) + return response.data + + def get_html_context(self): + return self.pagination.get_html_context() + + def test_no_offset(self): + request = Request(factory.get('/', {'limit': 5})) + queryset = self.paginate_queryset(request) + content = self.get_paginated_content(queryset) + context = self.get_html_context() + assert queryset == [1, 2, 3, 4, 5] + assert content == { + 'results': [1, 2, 3, 4, 5], + 'previous': None, + 'next': 'http://testserver/?limit=5&offset=5', + 'count': 100 + } + assert context == { + 'previous_url': None, + 'next_url': 'http://testserver/?limit=5&offset=5', + 'page_links': [ + PageLink('http://testserver/?limit=5', 1, True, False), + PageLink('http://testserver/?limit=5&offset=5', 2, False, False), + PageLink('http://testserver/?limit=5&offset=10', 3, False, False), + PAGE_BREAK, + PageLink('http://testserver/?limit=5&offset=95', 20, False, False), + ] + } + assert self.pagination.display_page_controls + assert isinstance(self.pagination.to_html(), type('')) + + def test_single_offset(self): """ - Create 13 BasicModel instances. + When the offset is not a multiple of the limit we get some edge cases: + * The first page should still be offset zero. + * We may end up displaying an extra page in the pagination control. """ - for i in range(13): - BasicModel(text=i).save() - self.objects = BasicModel.objects - self.data = [ - {'id': obj.id, 'text': obj.text} - for obj in self.objects.all() - ] - self.view = PaginateByParamView.as_view() - - def test_default_page_size(self): + request = Request(factory.get('/', {'limit': 5, 'offset': 1})) + queryset = self.paginate_queryset(request) + content = self.get_paginated_content(queryset) + context = self.get_html_context() + assert queryset == [2, 3, 4, 5, 6] + assert content == { + 'results': [2, 3, 4, 5, 6], + 'previous': 'http://testserver/?limit=5', + 'next': 'http://testserver/?limit=5&offset=6', + 'count': 100 + } + assert context == { + 'previous_url': 'http://testserver/?limit=5', + 'next_url': 'http://testserver/?limit=5&offset=6', + 'page_links': [ + PageLink('http://testserver/?limit=5', 1, False, False), + PageLink('http://testserver/?limit=5&offset=1', 2, True, False), + PageLink('http://testserver/?limit=5&offset=6', 3, False, False), + PAGE_BREAK, + PageLink('http://testserver/?limit=5&offset=96', 21, False, False), + ] + } + + def test_first_offset(self): + request = Request(factory.get('/', {'limit': 5, 'offset': 5})) + queryset = self.paginate_queryset(request) + content = self.get_paginated_content(queryset) + context = self.get_html_context() + assert queryset == [6, 7, 8, 9, 10] + assert content == { + 'results': [6, 7, 8, 9, 10], + 'previous': 'http://testserver/?limit=5', + 'next': 'http://testserver/?limit=5&offset=10', + 'count': 100 + } + assert context == { + 'previous_url': 'http://testserver/?limit=5', + 'next_url': 'http://testserver/?limit=5&offset=10', + 'page_links': [ + PageLink('http://testserver/?limit=5', 1, False, False), + PageLink('http://testserver/?limit=5&offset=5', 2, True, False), + PageLink('http://testserver/?limit=5&offset=10', 3, False, False), + PAGE_BREAK, + PageLink('http://testserver/?limit=5&offset=95', 20, False, False), + ] + } + + def test_middle_offset(self): + request = Request(factory.get('/', {'limit': 5, 'offset': 10})) + queryset = self.paginate_queryset(request) + content = self.get_paginated_content(queryset) + context = self.get_html_context() + assert queryset == [11, 12, 13, 14, 15] + assert content == { + 'results': [11, 12, 13, 14, 15], + 'previous': 'http://testserver/?limit=5&offset=5', + 'next': 'http://testserver/?limit=5&offset=15', + 'count': 100 + } + assert context == { + 'previous_url': 'http://testserver/?limit=5&offset=5', + 'next_url': 'http://testserver/?limit=5&offset=15', + 'page_links': [ + PageLink('http://testserver/?limit=5', 1, False, False), + PageLink('http://testserver/?limit=5&offset=5', 2, False, False), + PageLink('http://testserver/?limit=5&offset=10', 3, True, False), + PageLink('http://testserver/?limit=5&offset=15', 4, False, False), + PAGE_BREAK, + PageLink('http://testserver/?limit=5&offset=95', 20, False, False), + ] + } + + def test_ending_offset(self): + request = Request(factory.get('/', {'limit': 5, 'offset': 95})) + queryset = self.paginate_queryset(request) + content = self.get_paginated_content(queryset) + context = self.get_html_context() + assert queryset == [96, 97, 98, 99, 100] + assert content == { + 'results': [96, 97, 98, 99, 100], + 'previous': 'http://testserver/?limit=5&offset=90', + 'next': None, + 'count': 100 + } + assert context == { + 'previous_url': 'http://testserver/?limit=5&offset=90', + 'next_url': None, + 'page_links': [ + PageLink('http://testserver/?limit=5', 1, False, False), + PAGE_BREAK, + PageLink('http://testserver/?limit=5&offset=85', 18, False, False), + PageLink('http://testserver/?limit=5&offset=90', 19, False, False), + PageLink('http://testserver/?limit=5&offset=95', 20, True, False), + ] + } + + def test_invalid_offset(self): """ - Tests the default page size for this view. - no page size --> no limit --> no meta data + An invalid offset query param should be treated as 0. """ - request = factory.get('/') - response = self.view(request).render() - self.assertEqual(response.data, self.data) + request = Request(factory.get('/', {'limit': 5, 'offset': 'invalid'})) + queryset = self.paginate_queryset(request) + assert queryset == [1, 2, 3, 4, 5] - def test_paginate_by_param(self): + def test_invalid_limit(self): """ - If paginate_by_param is set, the new kwarg should limit per view requests. + An invalid limit query param should be ignored in favor of the default. """ - request = factory.get('/', {'page_size': 5}) - response = self.view(request).render() - self.assertEqual(response.data['count'], 13) - self.assertEqual(response.data['results'], self.data[:5]) + request = Request(factory.get('/', {'limit': 'invalid', 'offset': 0})) + queryset = self.paginate_queryset(request) + assert queryset == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] -class TestMaxPaginateByParam(TestCase): +class TestCursorPagination: """ - Tests for list views with max_paginate_by kwarg + Unit tests for `pagination.CursorPagination`. """ - def setUp(self): + def setup(self): + class MockObject(object): + def __init__(self, idx): + self.created = idx + + class MockQuerySet(object): + def __init__(self, items): + self.items = items + + def filter(self, created__gt=None, created__lt=None): + if created__gt is not None: + return MockQuerySet([ + item for item in self.items + if item.created > int(created__gt) + ]) + + assert created__lt is not None + return MockQuerySet([ + item for item in self.items + if item.created < int(created__lt) + ]) + + def order_by(self, *ordering): + if ordering[0].startswith('-'): + return MockQuerySet(list(reversed(self.items))) + return self + + def __getitem__(self, sliced): + return self.items[sliced] + + class ExamplePagination(pagination.CursorPagination): + page_size = 5 + ordering = 'created' + + self.pagination = ExamplePagination() + self.queryset = MockQuerySet([ + MockObject(idx) for idx in [ + 1, 1, 1, 1, 1, + 1, 2, 3, 4, 4, + 4, 4, 5, 6, 7, + 7, 7, 7, 7, 7, + 7, 7, 7, 8, 9, + 9, 9, 9, 9, 9 + ] + ]) + + def get_pages(self, url): """ - Create 13 BasicModel instances. - """ - for i in range(13): - BasicModel(text=i).save() - self.objects = BasicModel.objects - self.data = [ - {'id': obj.id, 'text': obj.text} - for obj in self.objects.all() - ] - self.view = MaxPaginateByView.as_view() - - def test_max_paginate_by(self): - """ - If max_paginate_by is set, it should limit page size for the view. - """ - request = factory.get('/', data={'page_size': 10}) - response = self.view(request).render() - self.assertEqual(response.data['count'], 13) - self.assertEqual(response.data['results'], self.data[:5]) + Given a URL return a tuple of: - def test_max_paginate_by_without_page_size_param(self): + (previous page, current page, next page, previous url, next url) """ - If max_paginate_by is set, but client does not specifiy page_size, - standard `paginate_by` behavior should be used. - """ - request = factory.get('/') - response = self.view(request).render() - self.assertEqual(response.data['results'], self.data[:3]) - - -# Tests for context in pagination serializers + request = Request(factory.get(url)) + queryset = self.pagination.paginate_queryset(self.queryset, request) + current = [item.created for item in queryset] -class CustomField(serializers.ReadOnlyField): - def to_native(self, value): - if 'view' not in self.context: - raise RuntimeError("context isn't getting passed into custom field") - return "value" + next_url = self.pagination.get_next_link() + previous_url = self.pagination.get_previous_link() + if next_url is not None: + request = Request(factory.get(next_url)) + queryset = self.pagination.paginate_queryset(self.queryset, request) + next = [item.created for item in queryset] + else: + next = None -class BasicModelSerializer(serializers.Serializer): - text = CustomField() - - def to_native(self, value): - if 'view' not in self.context: - raise RuntimeError("context isn't getting passed into serializer") - return super(BasicSerializer, self).to_native(value) + if previous_url is not None: + request = Request(factory.get(previous_url)) + queryset = self.pagination.paginate_queryset(self.queryset, request) + previous = [item.created for item in queryset] + else: + previous = None + return (previous, current, next, previous_url, next_url) -class TestContextPassedToCustomField(TestCase): - def setUp(self): - BasicModel.objects.create(text='ala ma kota') + def test_invalid_cursor(self): + request = Request(factory.get('/', {'cursor': '123'})) + with pytest.raises(exceptions.NotFound): + self.pagination.paginate_queryset(self.queryset, request) - def test_with_pagination(self): - class ListView(generics.ListCreateAPIView): - queryset = BasicModel.objects.all() - serializer_class = BasicModelSerializer - paginate_by = 1 + def test_use_with_ordering_filter(self): + class MockView: + filter_backends = (filters.OrderingFilter,) + ordering_fields = ['username', 'created'] + ordering = 'created' - self.view = ListView.as_view() - request = factory.get('/') - response = self.view(request).render() + request = Request(factory.get('/', {'ordering': 'username'})) + ordering = self.pagination.get_ordering(request, [], MockView()) + assert ordering == ('username',) - self.assertEqual(response.status_code, status.HTTP_200_OK) + request = Request(factory.get('/', {'ordering': '-username'})) + ordering = self.pagination.get_ordering(request, [], MockView()) + assert ordering == ('-username',) + request = Request(factory.get('/', {'ordering': 'invalid'})) + ordering = self.pagination.get_ordering(request, [], MockView()) + assert ordering == ('created',) -# Tests for custom pagination serializers + def test_cursor_pagination(self): + (previous, current, next, previous_url, next_url) = self.get_pages('/') -class LinksSerializer(serializers.Serializer): - next = pagination.NextPageField(source='*') - prev = pagination.PreviousPageField(source='*') + assert previous is None + assert current == [1, 1, 1, 1, 1] + assert next == [1, 2, 3, 4, 4] + (previous, current, next, previous_url, next_url) = self.get_pages(next_url) -class CustomPaginationSerializer(pagination.BasePaginationSerializer): - links = LinksSerializer(source='*') # Takes the page object as the source - total_results = serializers.ReadOnlyField(source='paginator.count') + assert previous == [1, 1, 1, 1, 1] + assert current == [1, 2, 3, 4, 4] + assert next == [4, 4, 5, 6, 7] - results_field = 'objects' + (previous, current, next, previous_url, next_url) = self.get_pages(next_url) + assert previous == [1, 2, 3, 4, 4] + assert current == [4, 4, 5, 6, 7] + assert next == [7, 7, 7, 7, 7] -class CustomFooSerializer(serializers.Serializer): - foo = serializers.CharField() + (previous, current, next, previous_url, next_url) = self.get_pages(next_url) + assert previous == [4, 4, 4, 5, 6] # Paging artifact + assert current == [7, 7, 7, 7, 7] + assert next == [7, 7, 7, 8, 9] -class CustomFooPaginationSerializer(pagination.PaginationSerializer): - class Meta: - object_serializer_class = CustomFooSerializer + (previous, current, next, previous_url, next_url) = self.get_pages(next_url) + assert previous == [7, 7, 7, 7, 7] + assert current == [7, 7, 7, 8, 9] + assert next == [9, 9, 9, 9, 9] -class TestCustomPaginationSerializer(TestCase): - def setUp(self): - objects = ['john', 'paul', 'george', 'ringo'] - paginator = Paginator(objects, 2) - self.page = paginator.page(1) + (previous, current, next, previous_url, next_url) = self.get_pages(next_url) - def test_custom_pagination_serializer(self): - request = APIRequestFactory().get('/foobar') - serializer = CustomPaginationSerializer( - instance=self.page, - context={'request': request} - ) - expected = { - 'links': { - 'next': 'http://testserver/foobar?page=2', - 'prev': None - }, - 'total_results': 4, - 'objects': ['john', 'paul'] - } - self.assertEqual(serializer.data, expected) + assert previous == [7, 7, 7, 8, 9] + assert current == [9, 9, 9, 9, 9] + assert next is None - def test_custom_pagination_serializer_with_custom_object_serializer(self): - objects = [ - {'foo': 'bar'}, - {'foo': 'spam'} - ] - paginator = Paginator(objects, 1) - page = paginator.page(1) - serializer = CustomFooPaginationSerializer(page) - serializer.data + (previous, current, next, previous_url, next_url) = self.get_pages(previous_url) + assert previous == [7, 7, 7, 7, 7] + assert current == [7, 7, 7, 8, 9] + assert next == [9, 9, 9, 9, 9] -class NonIntegerPage(object): + (previous, current, next, previous_url, next_url) = self.get_pages(previous_url) - def __init__(self, paginator, object_list, prev_token, token, next_token): - self.paginator = paginator - self.object_list = object_list - self.prev_token = prev_token - self.token = token - self.next_token = next_token + assert previous == [4, 4, 5, 6, 7] + assert current == [7, 7, 7, 7, 7] + assert next == [8, 9, 9, 9, 9] # Paging artifact - def has_next(self): - return not not self.next_token + (previous, current, next, previous_url, next_url) = self.get_pages(previous_url) - def next_page_number(self): - return self.next_token + assert previous == [1, 2, 3, 4, 4] + assert current == [4, 4, 5, 6, 7] + assert next == [7, 7, 7, 7, 7] - def has_previous(self): - return not not self.prev_token + (previous, current, next, previous_url, next_url) = self.get_pages(previous_url) - def previous_page_number(self): - return self.prev_token + assert previous == [1, 1, 1, 1, 1] + assert current == [1, 2, 3, 4, 4] + assert next == [4, 4, 5, 6, 7] + (previous, current, next, previous_url, next_url) = self.get_pages(previous_url) -class NonIntegerPaginator(object): + assert previous is None + assert current == [1, 1, 1, 1, 1] + assert next == [1, 2, 3, 4, 4] - def __init__(self, object_list, per_page): - self.object_list = object_list - self.per_page = per_page + assert isinstance(self.pagination.to_html(), type('')) - def count(self): - # pretend like we don't know how many pages we have - return None - def page(self, token=None): - if token: - try: - first = self.object_list.index(token) - except ValueError: - first = 0 - else: - first = 0 - n = len(self.object_list) - last = min(first + self.per_page, n) - prev_token = self.object_list[last - (2 * self.per_page)] if first else None - next_token = self.object_list[last] if last < n else None - return NonIntegerPage(self, self.object_list[first:last], prev_token, token, next_token) - - -class TestNonIntegerPagination(TestCase): - def test_custom_pagination_serializer(self): - objects = ['john', 'paul', 'george', 'ringo'] - paginator = NonIntegerPaginator(objects, 2) - - request = APIRequestFactory().get('/foobar') - serializer = CustomPaginationSerializer( - instance=paginator.page(), - context={'request': request} - ) - expected = { - 'links': { - 'next': 'http://testserver/foobar?page={0}'.format(objects[2]), - 'prev': None - }, - 'total_results': None, - 'objects': objects[:2] - } - self.assertEqual(serializer.data, expected) +def test_get_displayed_page_numbers(): + """ + Test our contextual page display function. - request = APIRequestFactory().get('/foobar') - serializer = CustomPaginationSerializer( - instance=paginator.page('george'), - context={'request': request} - ) - expected = { - 'links': { - 'next': None, - 'prev': 'http://testserver/foobar?page={0}'.format(objects[0]), - }, - 'total_results': None, - 'objects': objects[2:] - } - self.assertEqual(serializer.data, expected) + This determines which pages to display in a pagination control, + given the current page and the last page. + """ + displayed_page_numbers = pagination._get_displayed_page_numbers + + # At five pages or less, all pages are displayed, always. + assert displayed_page_numbers(1, 5) == [1, 2, 3, 4, 5] + assert displayed_page_numbers(2, 5) == [1, 2, 3, 4, 5] + assert displayed_page_numbers(3, 5) == [1, 2, 3, 4, 5] + assert displayed_page_numbers(4, 5) == [1, 2, 3, 4, 5] + assert displayed_page_numbers(5, 5) == [1, 2, 3, 4, 5] + + # Between six and either pages we may have a single page break. + assert displayed_page_numbers(1, 6) == [1, 2, 3, None, 6] + assert displayed_page_numbers(2, 6) == [1, 2, 3, None, 6] + assert displayed_page_numbers(3, 6) == [1, 2, 3, 4, 5, 6] + assert displayed_page_numbers(4, 6) == [1, 2, 3, 4, 5, 6] + assert displayed_page_numbers(5, 6) == [1, None, 4, 5, 6] + assert displayed_page_numbers(6, 6) == [1, None, 4, 5, 6] + + assert displayed_page_numbers(1, 7) == [1, 2, 3, None, 7] + assert displayed_page_numbers(2, 7) == [1, 2, 3, None, 7] + assert displayed_page_numbers(3, 7) == [1, 2, 3, 4, None, 7] + assert displayed_page_numbers(4, 7) == [1, 2, 3, 4, 5, 6, 7] + assert displayed_page_numbers(5, 7) == [1, None, 4, 5, 6, 7] + assert displayed_page_numbers(6, 7) == [1, None, 5, 6, 7] + assert displayed_page_numbers(7, 7) == [1, None, 5, 6, 7] + + assert displayed_page_numbers(1, 8) == [1, 2, 3, None, 8] + assert displayed_page_numbers(2, 8) == [1, 2, 3, None, 8] + assert displayed_page_numbers(3, 8) == [1, 2, 3, 4, None, 8] + assert displayed_page_numbers(4, 8) == [1, 2, 3, 4, 5, None, 8] + assert displayed_page_numbers(5, 8) == [1, None, 4, 5, 6, 7, 8] + assert displayed_page_numbers(6, 8) == [1, None, 5, 6, 7, 8] + assert displayed_page_numbers(7, 8) == [1, None, 6, 7, 8] + assert displayed_page_numbers(8, 8) == [1, None, 6, 7, 8] + + # At nine or more pages we may have two page breaks, one on each side. + assert displayed_page_numbers(1, 9) == [1, 2, 3, None, 9] + assert displayed_page_numbers(2, 9) == [1, 2, 3, None, 9] + assert displayed_page_numbers(3, 9) == [1, 2, 3, 4, None, 9] + assert displayed_page_numbers(4, 9) == [1, 2, 3, 4, 5, None, 9] + assert displayed_page_numbers(5, 9) == [1, None, 4, 5, 6, None, 9] + assert displayed_page_numbers(6, 9) == [1, None, 5, 6, 7, 8, 9] + assert displayed_page_numbers(7, 9) == [1, None, 6, 7, 8, 9] + assert displayed_page_numbers(8, 9) == [1, None, 7, 8, 9] + assert displayed_page_numbers(9, 9) == [1, None, 7, 8, 9] |
