diff options
| author | Xavier Ordoquy | 2014-04-17 09:53:44 +0200 | 
|---|---|---|
| committer | Xavier Ordoquy | 2014-04-17 09:53:44 +0200 | 
| commit | 1797a74e828c9fdfbfb46bb4de049100b18db875 (patch) | |
| tree | a91e8f4ff22e8eb63a990a439995632e9198d1c3 /tests/test_pagination.py | |
| parent | 1d404874b3f0f5b16f3b38ba322f31a18c41aad6 (diff) | |
| parent | 971578ca345c3d3bae7fd93b87c41d43483b6f05 (diff) | |
| download | django-rest-framework-1797a74e828c9fdfbfb46bb4de049100b18db875.tar.bz2 | |
Merge remote-tracking branch 'pelme/pytest' into feature/pytest
Conflicts:
	.travis.yml
	rest_framework/runtests/runtests.py
	tests/test_filters.py
	tests/test_pagination.py
	tox.ini
Diffstat (limited to 'tests/test_pagination.py')
| -rw-r--r-- | tests/test_pagination.py | 520 | 
1 files changed, 520 insertions, 0 deletions
| diff --git a/tests/test_pagination.py b/tests/test_pagination.py new file mode 100644 index 00000000..293146c0 --- /dev/null +++ b/tests/test_pagination.py @@ -0,0 +1,520 @@ +from __future__ import unicode_literals +import datetime +from decimal import Decimal +from django.db import models +from django.core.paginator import Paginator +from django.test import TestCase +from django.utils import unittest +from rest_framework import generics, status, pagination, filters, serializers +from rest_framework.compat import django_filters +from rest_framework.test import APIRequestFactory +from .models import BasicModel, FilterableItem + +factory = APIRequestFactory() + +# Helper function to split arguments out of an url +def split_arguments_from_url(url): +    if '?' not in url: +        return url + +    path, args = url.split('?') +    args = dict(r.split('=') for r in args.split('&')) +    return path, args + + +class RootView(generics.ListCreateAPIView): +    """ +    Example description for OPTIONS. +    """ +    model = BasicModel +    paginate_by = 10 + + +class DefaultPageSizeKwargView(generics.ListAPIView): +    """ +    View for testing default paginate_by_param usage +    """ +    model = BasicModel + + +class PaginateByParamView(generics.ListAPIView): +    """ +    View for testing custom paginate_by_param usage +    """ +    model = BasicModel +    paginate_by_param = 'page_size' + + +class MaxPaginateByView(generics.ListAPIView): +    """ +    View for testing custom max_paginate_by usage +    """ +    model = BasicModel +    paginate_by = 3 +    max_paginate_by = 5 +    paginate_by_param = 'page_size' + + +class IntegrationTestPagination(TestCase): +    """ +    Integration tests for paginated list views. +    """ + +    def setUp(self): +        """ +        Create 26 BasicModel instances. +        """ +        for char in 'abcdefghijklmnopqrstuvwxyz': +            BasicModel(text=char * 3).save() +        self.objects = BasicModel.objects +        self.data = [ +            {'id': obj.id, 'text': obj.text} +            for obj in self.objects.all() +        ] +        self.view = RootView.as_view() + +    def test_get_paginated_root_view(self): +        """ +        GET requests to paginated ListCreateAPIView should return paginated results. +        """ +        request = factory.get('/') +        # Note: Database queries are a `SELECT COUNT`, and `SELECT <fields>` +        with self.assertNumQueries(2): +            response = self.view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data['count'], 26) +        self.assertEqual(response.data['results'], self.data[:10]) +        self.assertNotEqual(response.data['next'], None) +        self.assertEqual(response.data['previous'], None) + +        request = factory.get(*split_arguments_from_url(response.data['next'])) +        with self.assertNumQueries(2): +            response = self.view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data['count'], 26) +        self.assertEqual(response.data['results'], self.data[10:20]) +        self.assertNotEqual(response.data['next'], None) +        self.assertNotEqual(response.data['previous'], None) + +        request = factory.get(*split_arguments_from_url(response.data['next'])) +        with self.assertNumQueries(2): +            response = self.view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data['count'], 26) +        self.assertEqual(response.data['results'], self.data[20:]) +        self.assertEqual(response.data['next'], None) +        self.assertNotEqual(response.data['previous'], None) + + +class IntegrationTestPaginationAndFiltering(TestCase): + +    def setUp(self): +        """ +        Create 50 FilterableItem instances. +        """ +        base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8)) +        for i in range(26): +            text = chr(i + ord(base_data[0])) * 3  # Produces string 'aaa', 'bbb', etc. +            decimal = base_data[1] + i +            date = base_data[2] - datetime.timedelta(days=i * 2) +            FilterableItem(text=text, decimal=decimal, date=date).save() + +        self.objects = FilterableItem.objects +        self.data = [ +            {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} +            for obj in self.objects.all() +        ] + +    @unittest.skipUnless(django_filters, 'django-filter not installed') +    def test_get_django_filter_paginated_filtered_root_view(self): +        """ +        GET requests to paginated filtered ListCreateAPIView should return +        paginated results. The next and previous links should preserve the +        filtered parameters. +        """ +        class DecimalFilter(django_filters.FilterSet): +            decimal = django_filters.NumberFilter(lookup_type='lt') + +            class Meta: +                model = FilterableItem +                fields = ['text', 'decimal', 'date'] + +        class FilterFieldsRootView(generics.ListCreateAPIView): +            model = FilterableItem +            paginate_by = 10 +            filter_class = DecimalFilter +            filter_backends = (filters.DjangoFilterBackend,) + +        view = FilterFieldsRootView.as_view() + +        EXPECTED_NUM_QUERIES = 2 + +        request = factory.get('/', {'decimal': '15.20'}) +        with self.assertNumQueries(EXPECTED_NUM_QUERIES): +            response = view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data['count'], 15) +        self.assertEqual(response.data['results'], self.data[:10]) +        self.assertNotEqual(response.data['next'], None) +        self.assertEqual(response.data['previous'], None) + +        request = factory.get(*split_arguments_from_url(response.data['next'])) +        with self.assertNumQueries(EXPECTED_NUM_QUERIES): +            response = view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data['count'], 15) +        self.assertEqual(response.data['results'], self.data[10:15]) +        self.assertEqual(response.data['next'], None) +        self.assertNotEqual(response.data['previous'], None) + +        request = factory.get(*split_arguments_from_url(response.data['previous'])) +        with self.assertNumQueries(EXPECTED_NUM_QUERIES): +            response = view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data['count'], 15) +        self.assertEqual(response.data['results'], self.data[:10]) +        self.assertNotEqual(response.data['next'], None) +        self.assertEqual(response.data['previous'], None) + +    def test_get_basic_paginated_filtered_root_view(self): +        """ +        Same as `test_get_django_filter_paginated_filtered_root_view`, +        except using a custom filter backend instead of the django-filter +        backend, +        """ + +        class DecimalFilterBackend(filters.BaseFilterBackend): +            def filter_queryset(self, request, queryset, view): +                return queryset.filter(decimal__lt=Decimal(request.GET['decimal'])) + +        class BasicFilterFieldsRootView(generics.ListCreateAPIView): +            model = FilterableItem +            paginate_by = 10 +            filter_backends = (DecimalFilterBackend,) + +        view = BasicFilterFieldsRootView.as_view() + +        request = factory.get('/', {'decimal': '15.20'}) +        with self.assertNumQueries(2): +            response = view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data['count'], 15) +        self.assertEqual(response.data['results'], self.data[:10]) +        self.assertNotEqual(response.data['next'], None) +        self.assertEqual(response.data['previous'], None) + +        request = factory.get(*split_arguments_from_url(response.data['next'])) +        with self.assertNumQueries(2): +            response = view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data['count'], 15) +        self.assertEqual(response.data['results'], self.data[10:15]) +        self.assertEqual(response.data['next'], None) +        self.assertNotEqual(response.data['previous'], None) + +        request = factory.get(*split_arguments_from_url(response.data['previous'])) +        with self.assertNumQueries(2): +            response = view(request).render() +        self.assertEqual(response.status_code, status.HTTP_200_OK) +        self.assertEqual(response.data['count'], 15) +        self.assertEqual(response.data['results'], self.data[:10]) +        self.assertNotEqual(response.data['next'], None) +        self.assertEqual(response.data['previous'], None) + + +class PassOnContextPaginationSerializer(pagination.PaginationSerializer): +    class Meta: +        object_serializer_class = serializers.Serializer + + +class UnitTestPagination(TestCase): +    """ +    Unit tests for pagination of primitive objects. +    """ + +    def setUp(self): +        self.objects = [char * 3 for char in 'abcdefghijklmnopqrstuvwxyz'] +        paginator = Paginator(self.objects, 10) +        self.first_page = paginator.page(1) +        self.last_page = paginator.page(3) + +    def test_native_pagination(self): +        serializer = pagination.PaginationSerializer(self.first_page) +        self.assertEqual(serializer.data['count'], 26) +        self.assertEqual(serializer.data['next'], '?page=2') +        self.assertEqual(serializer.data['previous'], None) +        self.assertEqual(serializer.data['results'], self.objects[:10]) + +        serializer = pagination.PaginationSerializer(self.last_page) +        self.assertEqual(serializer.data['count'], 26) +        self.assertEqual(serializer.data['next'], None) +        self.assertEqual(serializer.data['previous'], '?page=2') +        self.assertEqual(serializer.data['results'], self.objects[20:]) + +    def test_context_available_in_result(self): +        """ +        Ensure context gets passed through to the object serializer. +        """ +        serializer = PassOnContextPaginationSerializer(self.first_page, context={'foo': 'bar'}) +        serializer.data +        results = serializer.fields[serializer.results_field] +        self.assertEqual(serializer.context, results.context) + + +class TestUnpaginated(TestCase): +    """ +    Tests for list views without pagination. +    """ + +    def setUp(self): +        """ +        Create 13 BasicModel instances. +        """ +        for i in range(13): +            BasicModel(text=i).save() +        self.objects = BasicModel.objects +        self.data = [ +        {'id': obj.id, 'text': obj.text} +        for obj in self.objects.all() +        ] +        self.view = DefaultPageSizeKwargView.as_view() + +    def test_unpaginated(self): +        """ +        Tests the default page size for this view. +        no page size --> no limit --> no meta data +        """ +        request = factory.get('/') +        response = self.view(request) +        self.assertEqual(response.data, self.data) + + +class TestCustomPaginateByParam(TestCase): +    """ +    Tests for list views with default page size kwarg +    """ + +    def setUp(self): +        """ +        Create 13 BasicModel instances. +        """ +        for i in range(13): +            BasicModel(text=i).save() +        self.objects = BasicModel.objects +        self.data = [ +        {'id': obj.id, 'text': obj.text} +        for obj in self.objects.all() +        ] +        self.view = PaginateByParamView.as_view() + +    def test_default_page_size(self): +        """ +        Tests the default page size for this view. +        no page size --> no limit --> no meta data +        """ +        request = factory.get('/') +        response = self.view(request).render() +        self.assertEqual(response.data, self.data) + +    def test_paginate_by_param(self): +        """ +        If paginate_by_param is set, the new kwarg should limit per view requests. +        """ +        request = factory.get('/', {'page_size': 5}) +        response = self.view(request).render() +        self.assertEqual(response.data['count'], 13) +        self.assertEqual(response.data['results'], self.data[:5]) + + +class TestMaxPaginateByParam(TestCase): +    """ +    Tests for list views with max_paginate_by kwarg +    """ + +    def setUp(self): +        """ +        Create 13 BasicModel instances. +        """ +        for i in range(13): +            BasicModel(text=i).save() +        self.objects = BasicModel.objects +        self.data = [ +            {'id': obj.id, 'text': obj.text} +            for obj in self.objects.all() +        ] +        self.view = MaxPaginateByView.as_view() + +    def test_max_paginate_by(self): +        """ +        If max_paginate_by is set, it should limit page size for the view. +        """ +        request = factory.get('/', data={'page_size': 10}) +        response = self.view(request).render() +        self.assertEqual(response.data['count'], 13) +        self.assertEqual(response.data['results'], self.data[:5]) + +    def test_max_paginate_by_without_page_size_param(self): +        """ +        If max_paginate_by is set, but client does not specifiy page_size, +        standard `paginate_by` behavior should be used. +        """ +        request = factory.get('/') +        response = self.view(request).render() +        self.assertEqual(response.data['results'], self.data[:3]) + + +### Tests for context in pagination serializers + +class CustomField(serializers.Field): +    def to_native(self, value): +        if not 'view' in self.context: +            raise RuntimeError("context isn't getting passed into custom field") +        return "value" + + +class BasicModelSerializer(serializers.Serializer): +    text = CustomField() + +    def __init__(self, *args, **kwargs): +        super(BasicModelSerializer, self).__init__(*args, **kwargs) +        if not 'view' in self.context: +            raise RuntimeError("context isn't getting passed into serializer init") + + +class TestContextPassedToCustomField(TestCase): +    def setUp(self): +        BasicModel.objects.create(text='ala ma kota') + +    def test_with_pagination(self): +        class ListView(generics.ListCreateAPIView): +            model = BasicModel +            serializer_class = BasicModelSerializer +            paginate_by = 1 + +        self.view = ListView.as_view() +        request = factory.get('/') +        response = self.view(request).render() + +        self.assertEqual(response.status_code, status.HTTP_200_OK) + + +### Tests for custom pagination serializers + +class LinksSerializer(serializers.Serializer): +    next = pagination.NextPageField(source='*') +    prev = pagination.PreviousPageField(source='*') + + +class CustomPaginationSerializer(pagination.BasePaginationSerializer): +    links = LinksSerializer(source='*')  # Takes the page object as the source +    total_results = serializers.Field(source='paginator.count') + +    results_field = 'objects' + + +class TestCustomPaginationSerializer(TestCase): +    def setUp(self): +        objects = ['john', 'paul', 'george', 'ringo'] +        paginator = Paginator(objects, 2) +        self.page = paginator.page(1) + +    def test_custom_pagination_serializer(self): +        request = APIRequestFactory().get('/foobar') +        serializer = CustomPaginationSerializer( +            instance=self.page, +            context={'request': request} +        ) +        expected = { +            'links': { +                'next': 'http://testserver/foobar?page=2', +                'prev': None +            }, +            'total_results': 4, +            'objects': ['john', 'paul'] +        } +        self.assertEqual(serializer.data, expected) + + +class NonIntegerPage(object): + +    def __init__(self, paginator, object_list, prev_token, token, next_token): +        self.paginator = paginator +        self.object_list = object_list +        self.prev_token = prev_token +        self.token = token +        self.next_token = next_token + +    def has_next(self): +        return not not self.next_token + +    def next_page_number(self): +        return self.next_token + +    def has_previous(self): +        return not not self.prev_token + +    def previous_page_number(self): +        return self.prev_token + + +class NonIntegerPaginator(object): + +    def __init__(self, object_list, per_page): +        self.object_list = object_list +        self.per_page = per_page + +    def count(self): +        # pretend like we don't know how many pages we have +        return None + +    def page(self, token=None): +        if token: +            try: +                first = self.object_list.index(token) +            except ValueError: +                first = 0 +        else: +            first = 0 +        n = len(self.object_list) +        last = min(first + self.per_page, n) +        prev_token = self.object_list[last - (2 * self.per_page)] if first else None +        next_token = self.object_list[last] if last < n else None +        return NonIntegerPage(self, self.object_list[first:last], prev_token, token, next_token) + + +class TestNonIntegerPagination(TestCase): + + +    def test_custom_pagination_serializer(self): +        objects = ['john', 'paul', 'george', 'ringo'] +        paginator = NonIntegerPaginator(objects, 2) + +        request = APIRequestFactory().get('/foobar') +        serializer = CustomPaginationSerializer( +            instance=paginator.page(), +            context={'request': request} +        ) +        expected = { +            'links': { +                'next': 'http://testserver/foobar?page={0}'.format(objects[2]), +                'prev': None +            }, +            'total_results': None, +            'objects': objects[:2] +        } +        self.assertEqual(serializer.data, expected) + +        request = APIRequestFactory().get('/foobar') +        serializer = CustomPaginationSerializer( +            instance=paginator.page('george'), +            context={'request': request} +        ) +        expected = { +            'links': { +                'next': None, +                'prev': 'http://testserver/foobar?page={0}'.format(objects[0]), +            }, +            'total_results': None, +            'objects': objects[2:] +        } +        self.assertEqual(serializer.data, expected) | 
