diff options
Diffstat (limited to 'rest_framework/tests/pagination.py')
| -rw-r--r-- | rest_framework/tests/pagination.py | 246 |
1 files changed, 177 insertions, 69 deletions
diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py index 3b550877..1a2d68a6 100644 --- a/rest_framework/tests/pagination.py +++ b/rest_framework/tests/pagination.py @@ -1,5 +1,7 @@ +from __future__ import unicode_literals import datetime from decimal import Decimal +import django from django.core.paginator import Paginator from django.test import TestCase from django.test.client import RequestFactory @@ -19,21 +21,6 @@ class RootView(generics.ListCreateAPIView): paginate_by = 10 -if django_filters: - class DecimalFilter(django_filters.FilterSet): - decimal = django_filters.NumberFilter(lookup_type='lt') - - class Meta: - model = FilterableItem - fields = ['text', 'decimal', 'date'] - - class FilterFieldsRootView(generics.ListCreateAPIView): - model = FilterableItem - paginate_by = 10 - filter_class = DecimalFilter - filter_backend = filters.DjangoFilterBackend - - class DefaultPageSizeKwargView(generics.ListAPIView): """ View for testing default paginate_by_param usage @@ -72,28 +59,32 @@ class IntegrationTestPagination(TestCase): GET requests to paginated ListCreateAPIView should return paginated results. """ request = factory.get('/') - response = self.view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data['count'], 26) - self.assertEquals(response.data['results'], self.data[:10]) - self.assertNotEquals(response.data['next'], None) - self.assertEquals(response.data['previous'], None) + # Note: Database queries are a `SELECT COUNT`, and `SELECT <fields>` + with self.assertNumQueries(2): + response = self.view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['count'], 26) + self.assertEqual(response.data['results'], self.data[:10]) + self.assertNotEqual(response.data['next'], None) + self.assertEqual(response.data['previous'], None) request = factory.get(response.data['next']) - response = self.view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data['count'], 26) - self.assertEquals(response.data['results'], self.data[10:20]) - self.assertNotEquals(response.data['next'], None) - self.assertNotEquals(response.data['previous'], None) + with self.assertNumQueries(2): + response = self.view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['count'], 26) + self.assertEqual(response.data['results'], self.data[10:20]) + self.assertNotEqual(response.data['next'], None) + self.assertNotEqual(response.data['previous'], None) request = factory.get(response.data['next']) - response = self.view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data['count'], 26) - self.assertEquals(response.data['results'], self.data[20:]) - self.assertEquals(response.data['next'], None) - self.assertNotEquals(response.data['previous'], None) + with self.assertNumQueries(2): + response = self.view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['count'], 26) + self.assertEqual(response.data['results'], self.data[20:]) + self.assertEqual(response.data['next'], None) + self.assertNotEqual(response.data['previous'], None) class IntegrationTestPaginationAndFiltering(TestCase): @@ -111,41 +102,115 @@ class IntegrationTestPaginationAndFiltering(TestCase): self.objects = FilterableItem.objects self.data = [ - {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} - for obj in self.objects.all() + {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date.isoformat()} + for obj in self.objects.all() ] - self.view = FilterFieldsRootView.as_view() @unittest.skipUnless(django_filters, 'django-filters not installed') - def test_get_paginated_filtered_root_view(self): + def test_get_django_filter_paginated_filtered_root_view(self): """ GET requests to paginated filtered ListCreateAPIView should return paginated results. The next and previous links should preserve the filtered parameters. """ + class DecimalFilter(django_filters.FilterSet): + decimal = django_filters.NumberFilter(lookup_type='lt') + + class Meta: + model = FilterableItem + fields = ['text', 'decimal', 'date'] + + class FilterFieldsRootView(generics.ListCreateAPIView): + model = FilterableItem + paginate_by = 10 + filter_class = DecimalFilter + filter_backend = filters.DjangoFilterBackend + + view = FilterFieldsRootView.as_view() + + EXPECTED_NUM_QUERIES = 2 + if django.VERSION < (1, 4): + # On Django 1.3 we need to use django-filter 0.5.4 + # + # The filter objects there don't expose a `.count()` method, + # which means we only make a single query *but* it's a single + # query across *all* of the queryset, instead of a COUNT and then + # a SELECT with a LIMIT. + # + # Although this is fewer queries, it's actually a regression. + EXPECTED_NUM_QUERIES = 1 + request = factory.get('/?decimal=15.20') - response = self.view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data['count'], 15) - self.assertEquals(response.data['results'], self.data[:10]) - self.assertNotEquals(response.data['next'], None) - self.assertEquals(response.data['previous'], None) + with self.assertNumQueries(EXPECTED_NUM_QUERIES): + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['count'], 15) + self.assertEqual(response.data['results'], self.data[:10]) + self.assertNotEqual(response.data['next'], None) + self.assertEqual(response.data['previous'], None) request = factory.get(response.data['next']) - response = self.view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data['count'], 15) - self.assertEquals(response.data['results'], self.data[10:15]) - self.assertEquals(response.data['next'], None) - self.assertNotEquals(response.data['previous'], None) + with self.assertNumQueries(EXPECTED_NUM_QUERIES): + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['count'], 15) + self.assertEqual(response.data['results'], self.data[10:15]) + self.assertEqual(response.data['next'], None) + self.assertNotEqual(response.data['previous'], None) request = factory.get(response.data['previous']) - response = self.view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) - self.assertEquals(response.data['count'], 15) - self.assertEquals(response.data['results'], self.data[:10]) - self.assertNotEquals(response.data['next'], None) - self.assertEquals(response.data['previous'], None) + with self.assertNumQueries(EXPECTED_NUM_QUERIES): + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['count'], 15) + self.assertEqual(response.data['results'], self.data[:10]) + self.assertNotEqual(response.data['next'], None) + self.assertEqual(response.data['previous'], None) + + def test_get_basic_paginated_filtered_root_view(self): + """ + Same as `test_get_django_filter_paginated_filtered_root_view`, + except using a custom filter backend instead of the django-filter + backend, + """ + + class DecimalFilterBackend(filters.BaseFilterBackend): + def filter_queryset(self, request, queryset, view): + return queryset.filter(decimal__lt=Decimal(request.GET['decimal'])) + + class BasicFilterFieldsRootView(generics.ListCreateAPIView): + model = FilterableItem + paginate_by = 10 + filter_backend = DecimalFilterBackend + + view = BasicFilterFieldsRootView.as_view() + + request = factory.get('/?decimal=15.20') + with self.assertNumQueries(2): + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['count'], 15) + self.assertEqual(response.data['results'], self.data[:10]) + self.assertNotEqual(response.data['next'], None) + self.assertEqual(response.data['previous'], None) + + request = factory.get(response.data['next']) + with self.assertNumQueries(2): + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['count'], 15) + self.assertEqual(response.data['results'], self.data[10:15]) + self.assertEqual(response.data['next'], None) + self.assertNotEqual(response.data['previous'], None) + + request = factory.get(response.data['previous']) + with self.assertNumQueries(2): + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['count'], 15) + self.assertEqual(response.data['results'], self.data[:10]) + self.assertNotEqual(response.data['next'], None) + self.assertEqual(response.data['previous'], None) class PassOnContextPaginationSerializer(pagination.PaginationSerializer): @@ -166,16 +231,16 @@ class UnitTestPagination(TestCase): def test_native_pagination(self): serializer = pagination.PaginationSerializer(self.first_page) - self.assertEquals(serializer.data['count'], 26) - self.assertEquals(serializer.data['next'], '?page=2') - self.assertEquals(serializer.data['previous'], None) - self.assertEquals(serializer.data['results'], self.objects[:10]) + self.assertEqual(serializer.data['count'], 26) + self.assertEqual(serializer.data['next'], '?page=2') + self.assertEqual(serializer.data['previous'], None) + self.assertEqual(serializer.data['results'], self.objects[:10]) serializer = pagination.PaginationSerializer(self.last_page) - self.assertEquals(serializer.data['count'], 26) - self.assertEquals(serializer.data['next'], None) - self.assertEquals(serializer.data['previous'], '?page=2') - self.assertEquals(serializer.data['results'], self.objects[20:]) + self.assertEqual(serializer.data['count'], 26) + self.assertEqual(serializer.data['next'], None) + self.assertEqual(serializer.data['previous'], '?page=2') + self.assertEqual(serializer.data['results'], self.objects[20:]) def test_context_available_in_result(self): """ @@ -184,7 +249,7 @@ class UnitTestPagination(TestCase): serializer = PassOnContextPaginationSerializer(self.first_page, context={'foo': 'bar'}) serializer.data results = serializer.fields[serializer.results_field] - self.assertEquals(serializer.context, results.context) + self.assertEqual(serializer.context, results.context) class TestUnpaginated(TestCase): @@ -212,7 +277,7 @@ class TestUnpaginated(TestCase): """ request = factory.get('/') response = self.view(request) - self.assertEquals(response.data, self.data) + self.assertEqual(response.data, self.data) class TestCustomPaginateByParam(TestCase): @@ -240,7 +305,7 @@ class TestCustomPaginateByParam(TestCase): """ request = factory.get('/') response = self.view(request).render() - self.assertEquals(response.data, self.data) + self.assertEqual(response.data, self.data) def test_paginate_by_param(self): """ @@ -248,9 +313,11 @@ class TestCustomPaginateByParam(TestCase): """ request = factory.get('/?page_size=5') response = self.view(request).render() - self.assertEquals(response.data['count'], 13) - self.assertEquals(response.data['results'], self.data[:5]) + self.assertEqual(response.data['count'], 13) + self.assertEqual(response.data['results'], self.data[:5]) + +### Tests for context in pagination serializers class CustomField(serializers.Field): def to_native(self, value): @@ -262,6 +329,11 @@ class CustomField(serializers.Field): class BasicModelSerializer(serializers.Serializer): text = CustomField() + def __init__(self, *args, **kwargs): + super(BasicModelSerializer, self).__init__(*args, **kwargs) + if not 'view' in self.context: + raise RuntimeError("context isn't getting passed into serializer init") + class TestContextPassedToCustomField(TestCase): def setUp(self): @@ -277,5 +349,41 @@ class TestContextPassedToCustomField(TestCase): request = factory.get('/') response = self.view(request).render() - self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + +### Tests for custom pagination serializers + +class LinksSerializer(serializers.Serializer): + next = pagination.NextPageField(source='*') + prev = pagination.PreviousPageField(source='*') + +class CustomPaginationSerializer(pagination.BasePaginationSerializer): + links = LinksSerializer(source='*') # Takes the page object as the source + total_results = serializers.Field(source='paginator.count') + + results_field = 'objects' + + +class TestCustomPaginationSerializer(TestCase): + def setUp(self): + objects = ['john', 'paul', 'george', 'ringo'] + paginator = Paginator(objects, 2) + self.page = paginator.page(1) + + def test_custom_pagination_serializer(self): + request = RequestFactory().get('/foobar') + serializer = CustomPaginationSerializer( + instance=self.page, + context={'request': request} + ) + expected = { + 'links': { + 'next': 'http://testserver/foobar?page=2', + 'prev': None + }, + 'total_results': 4, + 'objects': ['john', 'paul'] + } + self.assertEqual(serializer.data, expected) |
