diff options
Diffstat (limited to 'tests/test_pagination.py')
| -rw-r--r-- | tests/test_pagination.py | 926 | 
1 files changed, 424 insertions, 502 deletions
| diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 1fd9cf9c..7cc92347 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -1,553 +1,475 @@  from __future__ import unicode_literals -import datetime -from decimal import Decimal -from django.core.paginator import Paginator -from django.test import TestCase -from django.utils import unittest -from rest_framework import generics, serializers, status, pagination, filters -from rest_framework.compat import django_filters +from rest_framework import exceptions, generics, pagination, serializers, status, filters +from rest_framework.request import Request +from rest_framework.pagination import PageLink, PAGE_BREAK  from rest_framework.test import APIRequestFactory -from .models import BasicModel, FilterableItem +import pytest  factory = APIRequestFactory() -# Helper function to split arguments out of an url -def split_arguments_from_url(url): -    if '?' not in url: -        return url - -    path, args = url.split('?') -    args = dict(r.split('=') for r in args.split('&')) -    return path, args - - -class BasicSerializer(serializers.ModelSerializer): -    class Meta: -        model = BasicModel - - -class FilterableItemSerializer(serializers.ModelSerializer): -    class Meta: -        model = FilterableItem - - -class RootView(generics.ListCreateAPIView): +class TestPaginationIntegration:      """ -    Example description for OPTIONS. +    Integration tests.      """ -    queryset = BasicModel.objects.all() -    serializer_class = BasicSerializer -    paginate_by = 10 +    def setup(self): +        class PassThroughSerializer(serializers.BaseSerializer): +            def to_representation(self, item): +                return item -class DefaultPageSizeKwargView(generics.ListAPIView): -    """ -    View for testing default paginate_by_param usage -    """ -    queryset = BasicModel.objects.all() -    serializer_class = BasicSerializer - - -class PaginateByParamView(generics.ListAPIView): -    """ -    View for testing custom paginate_by_param usage -    """ -    queryset = BasicModel.objects.all() -    serializer_class = BasicSerializer -    paginate_by_param = 'page_size' - - -class MaxPaginateByView(generics.ListAPIView): -    """ -    View for testing custom max_paginate_by usage -    """ -    queryset = BasicModel.objects.all() -    serializer_class = BasicSerializer -    paginate_by = 3 -    max_paginate_by = 5 -    paginate_by_param = 'page_size' - +        class EvenItemsOnly(filters.BaseFilterBackend): +            def filter_queryset(self, request, queryset, view): +                return [item for item in queryset if item % 2 == 0] + +        class BasicPagination(pagination.PageNumberPagination): +            paginate_by = 5 +            paginate_by_param = 'page_size' +            max_paginate_by = 20 + +        self.view = generics.ListAPIView.as_view( +            serializer_class=PassThroughSerializer, +            queryset=range(1, 101), +            filter_backends=[EvenItemsOnly], +            pagination_class=BasicPagination +        ) -class IntegrationTestPagination(TestCase): -    """ -    Integration tests for paginated list views. -    """ +    def test_filtered_items_are_paginated(self): +        request = factory.get('/', {'page': 2}) +        response = self.view(request) +        assert response.status_code == status.HTTP_200_OK +        assert response.data == { +            'results': [12, 14, 16, 18, 20], +            'previous': 'http://testserver/', +            'next': 'http://testserver/?page=3', +            'count': 50 +        } -    def setUp(self): -        """ -        Create 26 BasicModel instances. -        """ -        for char in 'abcdefghijklmnopqrstuvwxyz': -            BasicModel(text=char * 3).save() -        self.objects = BasicModel.objects -        self.data = [ -            {'id': obj.id, 'text': obj.text} -            for obj in self.objects.all() -        ] -        self.view = RootView.as_view() - -    def test_get_paginated_root_view(self): -        """ -        GET requests to paginated ListCreateAPIView should return paginated results. +    def test_setting_page_size(self):          """ -        request = factory.get('/') -        # Note: Database queries are a `SELECT COUNT`, and `SELECT <fields>` -        with self.assertNumQueries(2): -            response = self.view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data['count'], 26) -        self.assertEqual(response.data['results'], self.data[:10]) -        self.assertNotEqual(response.data['next'], None) -        self.assertEqual(response.data['previous'], None) - -        request = factory.get(*split_arguments_from_url(response.data['next'])) -        with self.assertNumQueries(2): -            response = self.view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data['count'], 26) -        self.assertEqual(response.data['results'], self.data[10:20]) -        self.assertNotEqual(response.data['next'], None) -        self.assertNotEqual(response.data['previous'], None) - -        request = factory.get(*split_arguments_from_url(response.data['next'])) -        with self.assertNumQueries(2): -            response = self.view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data['count'], 26) -        self.assertEqual(response.data['results'], self.data[20:]) -        self.assertEqual(response.data['next'], None) -        self.assertNotEqual(response.data['previous'], None) - - -class IntegrationTestPaginationAndFiltering(TestCase): - -    def setUp(self): +        When 'paginate_by_param' is set, the client may choose a page size.          """ -        Create 50 FilterableItem instances. -        """ -        base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8)) -        for i in range(26): -            text = chr(i + ord(base_data[0])) * 3  # Produces string 'aaa', 'bbb', etc. -            decimal = base_data[1] + i -            date = base_data[2] - datetime.timedelta(days=i * 2) -            FilterableItem(text=text, decimal=decimal, date=date).save() - -        self.objects = FilterableItem.objects -        self.data = [ -            {'id': obj.id, 'text': obj.text, 'decimal': str(obj.decimal), 'date': obj.date.isoformat()} -            for obj in self.objects.all() -        ] - -    @unittest.skipUnless(django_filters, 'django-filter not installed') -    def test_get_django_filter_paginated_filtered_root_view(self): -        """ -        GET requests to paginated filtered ListCreateAPIView should return -        paginated results. The next and previous links should preserve the -        filtered parameters. -        """ -        class DecimalFilter(django_filters.FilterSet): -            decimal = django_filters.NumberFilter(lookup_type='lt') - -            class Meta: -                model = FilterableItem -                fields = ['text', 'decimal', 'date'] - -        class FilterFieldsRootView(generics.ListCreateAPIView): -            queryset = FilterableItem.objects.all() -            serializer_class = FilterableItemSerializer -            paginate_by = 10 -            filter_class = DecimalFilter -            filter_backends = (filters.DjangoFilterBackend,) - -        view = FilterFieldsRootView.as_view() - -        EXPECTED_NUM_QUERIES = 2 - -        request = factory.get('/', {'decimal': '15.20'}) -        with self.assertNumQueries(EXPECTED_NUM_QUERIES): -            response = view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data['count'], 15) -        self.assertEqual(response.data['results'], self.data[:10]) -        self.assertNotEqual(response.data['next'], None) -        self.assertEqual(response.data['previous'], None) - -        request = factory.get(*split_arguments_from_url(response.data['next'])) -        with self.assertNumQueries(EXPECTED_NUM_QUERIES): -            response = view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data['count'], 15) -        self.assertEqual(response.data['results'], self.data[10:15]) -        self.assertEqual(response.data['next'], None) -        self.assertNotEqual(response.data['previous'], None) - -        request = factory.get(*split_arguments_from_url(response.data['previous'])) -        with self.assertNumQueries(EXPECTED_NUM_QUERIES): -            response = view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data['count'], 15) -        self.assertEqual(response.data['results'], self.data[:10]) -        self.assertNotEqual(response.data['next'], None) -        self.assertEqual(response.data['previous'], None) - -    def test_get_basic_paginated_filtered_root_view(self): -        """ -        Same as `test_get_django_filter_paginated_filtered_root_view`, -        except using a custom filter backend instead of the django-filter -        backend, -        """ - -        class DecimalFilterBackend(filters.BaseFilterBackend): -            def filter_queryset(self, request, queryset, view): -                return queryset.filter(decimal__lt=Decimal(request.GET['decimal'])) - -        class BasicFilterFieldsRootView(generics.ListCreateAPIView): -            queryset = FilterableItem.objects.all() -            serializer_class = FilterableItemSerializer -            paginate_by = 10 -            filter_backends = (DecimalFilterBackend,) - -        view = BasicFilterFieldsRootView.as_view() - -        request = factory.get('/', {'decimal': '15.20'}) -        with self.assertNumQueries(2): -            response = view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data['count'], 15) -        self.assertEqual(response.data['results'], self.data[:10]) -        self.assertNotEqual(response.data['next'], None) -        self.assertEqual(response.data['previous'], None) - -        request = factory.get(*split_arguments_from_url(response.data['next'])) -        with self.assertNumQueries(2): -            response = view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data['count'], 15) -        self.assertEqual(response.data['results'], self.data[10:15]) -        self.assertEqual(response.data['next'], None) -        self.assertNotEqual(response.data['previous'], None) - -        request = factory.get(*split_arguments_from_url(response.data['previous'])) -        with self.assertNumQueries(2): -            response = view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data['count'], 15) -        self.assertEqual(response.data['results'], self.data[:10]) -        self.assertNotEqual(response.data['next'], None) -        self.assertEqual(response.data['previous'], None) - - -class PassOnContextPaginationSerializer(pagination.PaginationSerializer): -    class Meta: -        object_serializer_class = serializers.Serializer - - -class UnitTestPagination(TestCase): -    """ -    Unit tests for pagination of primitive objects. -    """ +        request = factory.get('/', {'page_size': 10}) +        response = self.view(request) +        assert response.status_code == status.HTTP_200_OK +        assert response.data == { +            'results': [2, 4, 6, 8, 10, 12, 14, 16, 18, 20], +            'previous': None, +            'next': 'http://testserver/?page=2&page_size=10', +            'count': 50 +        } -    def setUp(self): -        self.objects = [char * 3 for char in 'abcdefghijklmnopqrstuvwxyz'] -        paginator = Paginator(self.objects, 10) -        self.first_page = paginator.page(1) -        self.last_page = paginator.page(3) - -    def test_native_pagination(self): -        serializer = pagination.PaginationSerializer(self.first_page) -        self.assertEqual(serializer.data['count'], 26) -        self.assertEqual(serializer.data['next'], '?page=2') -        self.assertEqual(serializer.data['previous'], None) -        self.assertEqual(serializer.data['results'], self.objects[:10]) - -        serializer = pagination.PaginationSerializer(self.last_page) -        self.assertEqual(serializer.data['count'], 26) -        self.assertEqual(serializer.data['next'], None) -        self.assertEqual(serializer.data['previous'], '?page=2') -        self.assertEqual(serializer.data['results'], self.objects[20:]) - -    def test_context_available_in_result(self): +    def test_setting_page_size_over_maximum(self):          """ -        Ensure context gets passed through to the object serializer. +        When page_size parameter exceeds maxiumum allowable, +        then it should be capped to the maxiumum.          """ -        serializer = PassOnContextPaginationSerializer(self.first_page, context={'foo': 'bar'}) -        serializer.data -        results = serializer.fields[serializer.results_field] -        self.assertEqual(serializer.context, results.context) - +        request = factory.get('/', {'page_size': 1000}) +        response = self.view(request) +        assert response.status_code == status.HTTP_200_OK +        assert response.data == { +            'results': [ +                2, 4, 6, 8, 10, 12, 14, 16, 18, 20, +                22, 24, 26, 28, 30, 32, 34, 36, 38, 40 +            ], +            'previous': None, +            'next': 'http://testserver/?page=2&page_size=1000', +            'count': 50 +        } -class TestUnpaginated(TestCase): -    """ -    Tests for list views without pagination. -    """ +    def test_additional_query_params_are_preserved(self): +        request = factory.get('/', {'page': 2, 'filter': 'even'}) +        response = self.view(request) +        assert response.status_code == status.HTTP_200_OK +        assert response.data == { +            'results': [12, 14, 16, 18, 20], +            'previous': 'http://testserver/?filter=even', +            'next': 'http://testserver/?filter=even&page=3', +            'count': 50 +        } -    def setUp(self): -        """ -        Create 13 BasicModel instances. -        """ -        for i in range(13): -            BasicModel(text=i).save() -        self.objects = BasicModel.objects -        self.data = [ -            {'id': obj.id, 'text': obj.text} -            for obj in self.objects.all() -        ] -        self.view = DefaultPageSizeKwargView.as_view() - -    def test_unpaginated(self): -        """ -        Tests the default page size for this view. -        no page size --> no limit --> no meta data -        """ -        request = factory.get('/') +    def test_404_not_found_for_invalid_page(self): +        request = factory.get('/', {'page': 'invalid'})          response = self.view(request) -        self.assertEqual(response.data, self.data) +        assert response.status_code == status.HTTP_404_NOT_FOUND +        assert response.data == { +            'detail': 'Invalid page "invalid": That page number is not an integer.' +        } -class TestCustomPaginateByParam(TestCase): +class TestPaginationDisabledIntegration:      """ -    Tests for list views with default page size kwarg +    Integration tests for disabled pagination.      """ -    def setUp(self): -        """ -        Create 13 BasicModel instances. -        """ -        for i in range(13): -            BasicModel(text=i).save() -        self.objects = BasicModel.objects -        self.data = [ -            {'id': obj.id, 'text': obj.text} -            for obj in self.objects.all() -        ] -        self.view = PaginateByParamView.as_view() - -    def test_default_page_size(self): -        """ -        Tests the default page size for this view. -        no page size --> no limit --> no meta data -        """ -        request = factory.get('/') -        response = self.view(request).render() -        self.assertEqual(response.data, self.data) +    def setup(self): +        class PassThroughSerializer(serializers.BaseSerializer): +            def to_representation(self, item): +                return item -    def test_paginate_by_param(self): -        """ -        If paginate_by_param is set, the new kwarg should limit per view requests. -        """ -        request = factory.get('/', {'page_size': 5}) -        response = self.view(request).render() -        self.assertEqual(response.data['count'], 13) -        self.assertEqual(response.data['results'], self.data[:5]) +        self.view = generics.ListAPIView.as_view( +            serializer_class=PassThroughSerializer, +            queryset=range(1, 101), +            pagination_class=None +        ) +    def test_unpaginated_list(self): +        request = factory.get('/', {'page': 2}) +        response = self.view(request) +        assert response.status_code == status.HTTP_200_OK +        assert response.data == list(range(1, 101)) -class TestMaxPaginateByParam(TestCase): + +class TestDeprecatedStylePagination:      """ -    Tests for list views with max_paginate_by kwarg +    Integration tests for deprecated style of setting pagination +    attributes on the view.      """ -    def setUp(self): -        """ -        Create 13 BasicModel instances. -        """ -        for i in range(13): -            BasicModel(text=i).save() -        self.objects = BasicModel.objects -        self.data = [ -            {'id': obj.id, 'text': obj.text} -            for obj in self.objects.all() -        ] -        self.view = MaxPaginateByView.as_view() - -    def test_max_paginate_by(self): -        """ -        If max_paginate_by is set, it should limit page size for the view. -        """ -        request = factory.get('/', data={'page_size': 10}) -        response = self.view(request).render() -        self.assertEqual(response.data['count'], 13) -        self.assertEqual(response.data['results'], self.data[:5]) - -    def test_max_paginate_by_without_page_size_param(self): -        """ -        If max_paginate_by is set, but client does not specifiy page_size, -        standard `paginate_by` behavior should be used. -        """ -        request = factory.get('/') -        response = self.view(request).render() -        self.assertEqual(response.data['results'], self.data[:3]) - - -# Tests for context in pagination serializers +    def setup(self): +        class PassThroughSerializer(serializers.BaseSerializer): +            def to_representation(self, item): +                return item -class CustomField(serializers.ReadOnlyField): -    def to_native(self, value): -        if 'view' not in self.context: -            raise RuntimeError("context isn't getting passed into custom field") -        return "value" +        class ExampleView(generics.ListAPIView): +            serializer_class = PassThroughSerializer +            queryset = range(1, 101) +            pagination_class = pagination.PageNumberPagination +            paginate_by = 20 +            page_query_param = 'page_number' +        self.view = ExampleView.as_view() -class BasicModelSerializer(serializers.Serializer): -    text = CustomField() - -    def to_native(self, value): -        if 'view' not in self.context: -            raise RuntimeError("context isn't getting passed into serializer") -        return super(BasicSerializer, self).to_native(value) - - -class TestContextPassedToCustomField(TestCase): -    def setUp(self): -        BasicModel.objects.create(text='ala ma kota') +    def test_paginate_by_attribute_on_view(self): +        request = factory.get('/?page_number=2') +        response = self.view(request) +        assert response.status_code == status.HTTP_200_OK +        assert response.data == { +            'results': [ +                21, 22, 23, 24, 25, 26, 27, 28, 29, 30, +                31, 32, 33, 34, 35, 36, 37, 38, 39, 40 +            ], +            'previous': 'http://testserver/', +            'next': 'http://testserver/?page_number=3', +            'count': 100 +        } -    def test_with_pagination(self): -        class ListView(generics.ListCreateAPIView): -            queryset = BasicModel.objects.all() -            serializer_class = BasicModelSerializer -            paginate_by = 1 -        self.view = ListView.as_view() -        request = factory.get('/') -        response = self.view(request).render() +class TestPageNumberPagination: +    """ +    Unit tests for `pagination.PageNumberPagination`. +    """ -        self.assertEqual(response.status_code, status.HTTP_200_OK) +    def setup(self): +        class ExamplePagination(pagination.PageNumberPagination): +            paginate_by = 5 +        self.pagination = ExamplePagination() +        self.queryset = range(1, 101) + +    def paginate_queryset(self, request): +        return list(self.pagination.paginate_queryset(self.queryset, request)) + +    def get_paginated_content(self, queryset): +        response = self.pagination.get_paginated_response(queryset) +        return response.data + +    def get_html_context(self): +        return self.pagination.get_html_context() + +    def test_no_page_number(self): +        request = Request(factory.get('/')) +        queryset = self.paginate_queryset(request) +        content = self.get_paginated_content(queryset) +        context = self.get_html_context() +        assert queryset == [1, 2, 3, 4, 5] +        assert content == { +            'results': [1, 2, 3, 4, 5], +            'previous': None, +            'next': 'http://testserver/?page=2', +            'count': 100 +        } +        assert context == { +            'previous_url': None, +            'next_url': 'http://testserver/?page=2', +            'page_links': [ +                PageLink('http://testserver/', 1, True, False), +                PageLink('http://testserver/?page=2', 2, False, False), +                PageLink('http://testserver/?page=3', 3, False, False), +                PAGE_BREAK, +                PageLink('http://testserver/?page=20', 20, False, False), +            ] +        } +        assert self.pagination.display_page_controls +        assert isinstance(self.pagination.to_html(), type('')) + +    def test_second_page(self): +        request = Request(factory.get('/', {'page': 2})) +        queryset = self.paginate_queryset(request) +        content = self.get_paginated_content(queryset) +        context = self.get_html_context() +        assert queryset == [6, 7, 8, 9, 10] +        assert content == { +            'results': [6, 7, 8, 9, 10], +            'previous': 'http://testserver/', +            'next': 'http://testserver/?page=3', +            'count': 100 +        } +        assert context == { +            'previous_url': 'http://testserver/', +            'next_url': 'http://testserver/?page=3', +            'page_links': [ +                PageLink('http://testserver/', 1, False, False), +                PageLink('http://testserver/?page=2', 2, True, False), +                PageLink('http://testserver/?page=3', 3, False, False), +                PAGE_BREAK, +                PageLink('http://testserver/?page=20', 20, False, False), +            ] +        } +    def test_last_page(self): +        request = Request(factory.get('/', {'page': 'last'})) +        queryset = self.paginate_queryset(request) +        content = self.get_paginated_content(queryset) +        context = self.get_html_context() +        assert queryset == [96, 97, 98, 99, 100] +        assert content == { +            'results': [96, 97, 98, 99, 100], +            'previous': 'http://testserver/?page=19', +            'next': None, +            'count': 100 +        } +        assert context == { +            'previous_url': 'http://testserver/?page=19', +            'next_url': None, +            'page_links': [ +                PageLink('http://testserver/', 1, False, False), +                PAGE_BREAK, +                PageLink('http://testserver/?page=18', 18, False, False), +                PageLink('http://testserver/?page=19', 19, False, False), +                PageLink('http://testserver/?page=20', 20, True, False), +            ] +        } -# Tests for custom pagination serializers +    def test_invalid_page(self): +        request = Request(factory.get('/', {'page': 'invalid'})) +        with pytest.raises(exceptions.NotFound): +            self.paginate_queryset(request) -class LinksSerializer(serializers.Serializer): -    next = pagination.NextPageField(source='*') -    prev = pagination.PreviousPageField(source='*') +class TestLimitOffset: +    """ +    Unit tests for `pagination.LimitOffsetPagination`. +    """ -class CustomPaginationSerializer(pagination.BasePaginationSerializer): -    links = LinksSerializer(source='*')  # Takes the page object as the source -    total_results = serializers.ReadOnlyField(source='paginator.count') +    def setup(self): +        class ExamplePagination(pagination.LimitOffsetPagination): +            default_limit = 10 +        self.pagination = ExamplePagination() +        self.queryset = range(1, 101) + +    def paginate_queryset(self, request): +        return list(self.pagination.paginate_queryset(self.queryset, request)) + +    def get_paginated_content(self, queryset): +        response = self.pagination.get_paginated_response(queryset) +        return response.data + +    def get_html_context(self): +        return self.pagination.get_html_context() + +    def test_no_offset(self): +        request = Request(factory.get('/', {'limit': 5})) +        queryset = self.paginate_queryset(request) +        content = self.get_paginated_content(queryset) +        context = self.get_html_context() +        assert queryset == [1, 2, 3, 4, 5] +        assert content == { +            'results': [1, 2, 3, 4, 5], +            'previous': None, +            'next': 'http://testserver/?limit=5&offset=5', +            'count': 100 +        } +        assert context == { +            'previous_url': None, +            'next_url': 'http://testserver/?limit=5&offset=5', +            'page_links': [ +                PageLink('http://testserver/?limit=5', 1, True, False), +                PageLink('http://testserver/?limit=5&offset=5', 2, False, False), +                PageLink('http://testserver/?limit=5&offset=10', 3, False, False), +                PAGE_BREAK, +                PageLink('http://testserver/?limit=5&offset=95', 20, False, False), +            ] +        } +        assert self.pagination.display_page_controls +        assert isinstance(self.pagination.to_html(), type('')) -    results_field = 'objects' +    def test_single_offset(self): +        """ +        When the offset is not a multiple of the limit we get some edge cases: +        * The first page should still be offset zero. +        * We may end up displaying an extra page in the pagination control. +        """ +        request = Request(factory.get('/', {'limit': 5, 'offset': 1})) +        queryset = self.paginate_queryset(request) +        content = self.get_paginated_content(queryset) +        context = self.get_html_context() +        assert queryset == [2, 3, 4, 5, 6] +        assert content == { +            'results': [2, 3, 4, 5, 6], +            'previous': 'http://testserver/?limit=5', +            'next': 'http://testserver/?limit=5&offset=6', +            'count': 100 +        } +        assert context == { +            'previous_url': 'http://testserver/?limit=5', +            'next_url': 'http://testserver/?limit=5&offset=6', +            'page_links': [ +                PageLink('http://testserver/?limit=5', 1, False, False), +                PageLink('http://testserver/?limit=5&offset=1', 2, True, False), +                PageLink('http://testserver/?limit=5&offset=6', 3, False, False), +                PAGE_BREAK, +                PageLink('http://testserver/?limit=5&offset=96', 21, False, False), +            ] +        } +    def test_first_offset(self): +        request = Request(factory.get('/', {'limit': 5, 'offset': 5})) +        queryset = self.paginate_queryset(request) +        content = self.get_paginated_content(queryset) +        context = self.get_html_context() +        assert queryset == [6, 7, 8, 9, 10] +        assert content == { +            'results': [6, 7, 8, 9, 10], +            'previous': 'http://testserver/?limit=5', +            'next': 'http://testserver/?limit=5&offset=10', +            'count': 100 +        } +        assert context == { +            'previous_url': 'http://testserver/?limit=5', +            'next_url': 'http://testserver/?limit=5&offset=10', +            'page_links': [ +                PageLink('http://testserver/?limit=5', 1, False, False), +                PageLink('http://testserver/?limit=5&offset=5', 2, True, False), +                PageLink('http://testserver/?limit=5&offset=10', 3, False, False), +                PAGE_BREAK, +                PageLink('http://testserver/?limit=5&offset=95', 20, False, False), +            ] +        } -class CustomFooSerializer(serializers.Serializer): -    foo = serializers.CharField() +    def test_middle_offset(self): +        request = Request(factory.get('/', {'limit': 5, 'offset': 10})) +        queryset = self.paginate_queryset(request) +        content = self.get_paginated_content(queryset) +        context = self.get_html_context() +        assert queryset == [11, 12, 13, 14, 15] +        assert content == { +            'results': [11, 12, 13, 14, 15], +            'previous': 'http://testserver/?limit=5&offset=5', +            'next': 'http://testserver/?limit=5&offset=15', +            'count': 100 +        } +        assert context == { +            'previous_url': 'http://testserver/?limit=5&offset=5', +            'next_url': 'http://testserver/?limit=5&offset=15', +            'page_links': [ +                PageLink('http://testserver/?limit=5', 1, False, False), +                PageLink('http://testserver/?limit=5&offset=5', 2, False, False), +                PageLink('http://testserver/?limit=5&offset=10', 3, True, False), +                PageLink('http://testserver/?limit=5&offset=15', 4, False, False), +                PAGE_BREAK, +                PageLink('http://testserver/?limit=5&offset=95', 20, False, False), +            ] +        } +    def test_ending_offset(self): +        request = Request(factory.get('/', {'limit': 5, 'offset': 95})) +        queryset = self.paginate_queryset(request) +        content = self.get_paginated_content(queryset) +        context = self.get_html_context() +        assert queryset == [96, 97, 98, 99, 100] +        assert content == { +            'results': [96, 97, 98, 99, 100], +            'previous': 'http://testserver/?limit=5&offset=90', +            'next': None, +            'count': 100 +        } +        assert context == { +            'previous_url': 'http://testserver/?limit=5&offset=90', +            'next_url': None, +            'page_links': [ +                PageLink('http://testserver/?limit=5', 1, False, False), +                PAGE_BREAK, +                PageLink('http://testserver/?limit=5&offset=85', 18, False, False), +                PageLink('http://testserver/?limit=5&offset=90', 19, False, False), +                PageLink('http://testserver/?limit=5&offset=95', 20, True, False), +            ] +        } -class CustomFooPaginationSerializer(pagination.PaginationSerializer): -    class Meta: -        object_serializer_class = CustomFooSerializer +    def test_invalid_offset(self): +        """ +        An invalid offset query param should be treated as 0. +        """ +        request = Request(factory.get('/', {'limit': 5, 'offset': 'invalid'})) +        queryset = self.paginate_queryset(request) +        assert queryset == [1, 2, 3, 4, 5] +    def test_invalid_limit(self): +        """ +        An invalid limit query param should be ignored in favor of the default. +        """ +        request = Request(factory.get('/', {'limit': 'invalid', 'offset': 0})) +        queryset = self.paginate_queryset(request) +        assert queryset == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] -class TestCustomPaginationSerializer(TestCase): -    def setUp(self): -        objects = ['john', 'paul', 'george', 'ringo'] -        paginator = Paginator(objects, 2) -        self.page = paginator.page(1) -    def test_custom_pagination_serializer(self): -        request = APIRequestFactory().get('/foobar') -        serializer = CustomPaginationSerializer( -            instance=self.page, -            context={'request': request} -        ) -        expected = { -            'links': { -                'next': 'http://testserver/foobar?page=2', -                'prev': None -            }, -            'total_results': 4, -            'objects': ['john', 'paul'] -        } -        self.assertEqual(serializer.data, expected) - -    def test_custom_pagination_serializer_with_custom_object_serializer(self): -        objects = [ -            {'foo': 'bar'}, -            {'foo': 'spam'} -        ] -        paginator = Paginator(objects, 1) -        page = paginator.page(1) -        serializer = CustomFooPaginationSerializer(page) -        serializer.data - - -class NonIntegerPage(object): - -    def __init__(self, paginator, object_list, prev_token, token, next_token): -        self.paginator = paginator -        self.object_list = object_list -        self.prev_token = prev_token -        self.token = token -        self.next_token = next_token - -    def has_next(self): -        return not not self.next_token - -    def next_page_number(self): -        return self.next_token - -    def has_previous(self): -        return not not self.prev_token - -    def previous_page_number(self): -        return self.prev_token - - -class NonIntegerPaginator(object): - -    def __init__(self, object_list, per_page): -        self.object_list = object_list -        self.per_page = per_page - -    def count(self): -        # pretend like we don't know how many pages we have -        return None - -    def page(self, token=None): -        if token: -            try: -                first = self.object_list.index(token) -            except ValueError: -                first = 0 -        else: -            first = 0 -        n = len(self.object_list) -        last = min(first + self.per_page, n) -        prev_token = self.object_list[last - (2 * self.per_page)] if first else None -        next_token = self.object_list[last] if last < n else None -        return NonIntegerPage(self, self.object_list[first:last], prev_token, token, next_token) - - -class TestNonIntegerPagination(TestCase): -    def test_custom_pagination_serializer(self): -        objects = ['john', 'paul', 'george', 'ringo'] -        paginator = NonIntegerPaginator(objects, 2) - -        request = APIRequestFactory().get('/foobar') -        serializer = CustomPaginationSerializer( -            instance=paginator.page(), -            context={'request': request} -        ) -        expected = { -            'links': { -                'next': 'http://testserver/foobar?page={0}'.format(objects[2]), -                'prev': None -            }, -            'total_results': None, -            'objects': objects[:2] -        } -        self.assertEqual(serializer.data, expected) +def test_get_displayed_page_numbers(): +    """ +    Test our contextual page display function. -        request = APIRequestFactory().get('/foobar') -        serializer = CustomPaginationSerializer( -            instance=paginator.page('george'), -            context={'request': request} -        ) -        expected = { -            'links': { -                'next': None, -                'prev': 'http://testserver/foobar?page={0}'.format(objects[0]), -            }, -            'total_results': None, -            'objects': objects[2:] -        } -        self.assertEqual(serializer.data, expected) +    This determines which pages to display in a pagination control, +    given the current page and the last page. +    """ +    displayed_page_numbers = pagination._get_displayed_page_numbers + +    # At five pages or less, all pages are displayed, always. +    assert displayed_page_numbers(1, 5) == [1, 2, 3, 4, 5] +    assert displayed_page_numbers(2, 5) == [1, 2, 3, 4, 5] +    assert displayed_page_numbers(3, 5) == [1, 2, 3, 4, 5] +    assert displayed_page_numbers(4, 5) == [1, 2, 3, 4, 5] +    assert displayed_page_numbers(5, 5) == [1, 2, 3, 4, 5] + +    # Between six and either pages we may have a single page break. +    assert displayed_page_numbers(1, 6) == [1, 2, 3, None, 6] +    assert displayed_page_numbers(2, 6) == [1, 2, 3, None, 6] +    assert displayed_page_numbers(3, 6) == [1, 2, 3, 4, 5, 6] +    assert displayed_page_numbers(4, 6) == [1, 2, 3, 4, 5, 6] +    assert displayed_page_numbers(5, 6) == [1, None, 4, 5, 6] +    assert displayed_page_numbers(6, 6) == [1, None, 4, 5, 6] + +    assert displayed_page_numbers(1, 7) == [1, 2, 3, None, 7] +    assert displayed_page_numbers(2, 7) == [1, 2, 3, None, 7] +    assert displayed_page_numbers(3, 7) == [1, 2, 3, 4, None, 7] +    assert displayed_page_numbers(4, 7) == [1, 2, 3, 4, 5, 6, 7] +    assert displayed_page_numbers(5, 7) == [1, None, 4, 5, 6, 7] +    assert displayed_page_numbers(6, 7) == [1, None, 5, 6, 7] +    assert displayed_page_numbers(7, 7) == [1, None, 5, 6, 7] + +    assert displayed_page_numbers(1, 8) == [1, 2, 3, None, 8] +    assert displayed_page_numbers(2, 8) == [1, 2, 3, None, 8] +    assert displayed_page_numbers(3, 8) == [1, 2, 3, 4, None, 8] +    assert displayed_page_numbers(4, 8) == [1, 2, 3, 4, 5, None, 8] +    assert displayed_page_numbers(5, 8) == [1, None, 4, 5, 6, 7, 8] +    assert displayed_page_numbers(6, 8) == [1, None, 5, 6, 7, 8] +    assert displayed_page_numbers(7, 8) == [1, None, 6, 7, 8] +    assert displayed_page_numbers(8, 8) == [1, None, 6, 7, 8] + +    # At nine or more pages we may have two page breaks, one on each side. +    assert displayed_page_numbers(1, 9) == [1, 2, 3, None, 9] +    assert displayed_page_numbers(2, 9) == [1, 2, 3, None, 9] +    assert displayed_page_numbers(3, 9) == [1, 2, 3, 4, None, 9] +    assert displayed_page_numbers(4, 9) == [1, 2, 3, 4, 5, None, 9] +    assert displayed_page_numbers(5, 9) == [1, None, 4, 5, 6, None, 9] +    assert displayed_page_numbers(6, 9) == [1, None, 5, 6, 7, 8, 9] +    assert displayed_page_numbers(7, 9) == [1, None, 6, 7, 8, 9] +    assert displayed_page_numbers(8, 9) == [1, None, 7, 8, 9] +    assert displayed_page_numbers(9, 9) == [1, None, 7, 8, 9] | 
