diff options
| author | Tom Christie | 2015-02-06 14:35:06 +0000 |
|---|---|---|
| committer | Tom Christie | 2015-02-06 14:35:06 +0000 |
| commit | 3dff9a4fe2952cf632ca7f4cd9ecf4221059ca91 (patch) | |
| tree | 0649d42b20b875e97cb551b987644b61e7860e84 /tests | |
| parent | c06a82d0531f4cb290baacee196829c770913eaa (diff) | |
| parent | 1f996128458570a909d13f15c3d739fb12111984 (diff) | |
| download | django-rest-framework-model-serializer-caching.tar.bz2 | |
Resolve merge conflictmodel-serializer-caching
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/browsable_api/auth_urls.py | 1 | ||||
| -rw-r--r-- | tests/test_fields.py | 129 | ||||
| -rw-r--r-- | tests/test_filters.py | 17 | ||||
| -rw-r--r-- | tests/test_generics.py | 8 | ||||
| -rw-r--r-- | tests/test_htmlrenderer.py | 6 | ||||
| -rw-r--r-- | tests/test_metadata.py | 60 | ||||
| -rw-r--r-- | tests/test_model_serializer.py | 54 | ||||
| -rw-r--r-- | tests/test_multitable_inheritance.py | 4 | ||||
| -rw-r--r-- | tests/test_pagination.py | 1048 | ||||
| -rw-r--r-- | tests/test_parsers.py | 4 | ||||
| -rw-r--r-- | tests/test_relations.py | 35 | ||||
| -rw-r--r-- | tests/test_relations_hyperlink.py | 7 | ||||
| -rw-r--r-- | tests/test_renderers.py | 107 | ||||
| -rw-r--r-- | tests/test_routers.py | 151 | ||||
| -rw-r--r-- | tests/test_serializer.py | 79 | ||||
| -rw-r--r-- | tests/test_serializer_bulk_update.py | 4 | ||||
| -rw-r--r-- | tests/test_versioning.py | 62 | ||||
| -rw-r--r-- | tests/utils.py | 35 |
18 files changed, 1115 insertions, 696 deletions
diff --git a/tests/browsable_api/auth_urls.py b/tests/browsable_api/auth_urls.py index bce7dcf9..97bc1036 100644 --- a/tests/browsable_api/auth_urls.py +++ b/tests/browsable_api/auth_urls.py @@ -3,6 +3,7 @@ from django.conf.urls import patterns, url, include from .views import MockView + urlpatterns = patterns( '', (r'^$', MockView.as_view()), diff --git a/tests/test_fields.py b/tests/test_fields.py index 04c721d3..48ada780 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -4,6 +4,7 @@ from rest_framework import serializers import datetime import django import pytest +import uuid # Tests for field keyword arguments and core functionality. @@ -223,8 +224,8 @@ class MockHTMLDict(dict): getlist = None -class TestCharHTMLInput: - def test_empty_html_checkbox(self): +class TestHTMLInput: + def test_empty_html_charfield(self): class TestSerializer(serializers.Serializer): message = serializers.CharField(default='happy') @@ -232,23 +233,31 @@ class TestCharHTMLInput: assert serializer.is_valid() assert serializer.validated_data == {'message': 'happy'} - def test_empty_html_checkbox_allow_null(self): + def test_empty_html_charfield_allow_null(self): class TestSerializer(serializers.Serializer): message = serializers.CharField(allow_null=True) - serializer = TestSerializer(data=MockHTMLDict()) + serializer = TestSerializer(data=MockHTMLDict({'message': ''})) assert serializer.is_valid() assert serializer.validated_data == {'message': None} - def test_empty_html_checkbox_allow_null_allow_blank(self): + def test_empty_html_datefield_allow_null(self): + class TestSerializer(serializers.Serializer): + expiry = serializers.DateField(allow_null=True) + + serializer = TestSerializer(data=MockHTMLDict({'expiry': ''})) + assert serializer.is_valid() + assert serializer.validated_data == {'expiry': None} + + def test_empty_html_charfield_allow_null_allow_blank(self): class TestSerializer(serializers.Serializer): message = serializers.CharField(allow_null=True, allow_blank=True) - serializer = TestSerializer(data=MockHTMLDict({})) + serializer = TestSerializer(data=MockHTMLDict({'message': ''})) assert serializer.is_valid() assert serializer.validated_data == {'message': ''} - def test_empty_html_required_false(self): + def test_empty_html_charfield_required_false(self): class TestSerializer(serializers.Serializer): message = serializers.CharField(required=False) @@ -338,7 +347,7 @@ class TestBooleanField(FieldValues): False: False, } invalid_inputs = { - 'foo': ['`foo` is not a valid boolean.'], + 'foo': ['"foo" is not a valid boolean.'], None: ['This field may not be null.'] } outputs = { @@ -368,7 +377,7 @@ class TestNullBooleanField(FieldValues): None: None } invalid_inputs = { - 'foo': ['`foo` is not a valid boolean.'], + 'foo': ['"foo" is not a valid boolean.'], } outputs = { 'true': True, @@ -439,7 +448,7 @@ class TestSlugField(FieldValues): 'slug-99': 'slug-99', } invalid_inputs = { - 'slug 99': ["Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens."] + 'slug 99': ['Enter a valid "slug" consisting of letters, numbers, underscores or hyphens.'] } outputs = {} field = serializers.SlugField() @@ -459,6 +468,23 @@ class TestURLField(FieldValues): field = serializers.URLField() +class TestUUIDField(FieldValues): + """ + Valid and invalid values for `UUIDField`. + """ + valid_inputs = { + '825d7aeb-05a9-45b5-a5b7-05df87923cda': uuid.UUID('825d7aeb-05a9-45b5-a5b7-05df87923cda'), + '825d7aeb05a945b5a5b705df87923cda': uuid.UUID('825d7aeb-05a9-45b5-a5b7-05df87923cda') + } + invalid_inputs = { + '825d7aeb-05a9-45b5-a5b7': ['"825d7aeb-05a9-45b5-a5b7" is not a valid UUID.'] + } + outputs = { + uuid.UUID('825d7aeb-05a9-45b5-a5b7-05df87923cda'): '825d7aeb-05a9-45b5-a5b7-05df87923cda' + } + field = serializers.UUIDField() + + # Number types... class TestIntegerField(FieldValues): @@ -640,8 +666,8 @@ class TestDateField(FieldValues): datetime.date(2001, 1, 1): datetime.date(2001, 1, 1), } invalid_inputs = { - 'abc': ['Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]'], - '2001-99-99': ['Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]'], + 'abc': ['Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]].'], + '2001-99-99': ['Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]].'], datetime.datetime(2001, 1, 1, 12, 00): ['Expected a date but got a datetime.'], } outputs = { @@ -658,7 +684,7 @@ class TestCustomInputFormatDateField(FieldValues): '1 Jan 2001': datetime.date(2001, 1, 1), } invalid_inputs = { - '2001-01-01': ['Date has wrong format. Use one of these formats instead: DD [Jan-Dec] YYYY'] + '2001-01-01': ['Date has wrong format. Use one of these formats instead: DD [Jan-Dec] YYYY.'] } outputs = {} field = serializers.DateField(input_formats=['%d %b %Y']) @@ -702,8 +728,8 @@ class TestDateTimeField(FieldValues): '2001-01-01T14:00+01:00' if (django.VERSION > (1, 4)) else '2001-01-01T13:00Z': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()) } invalid_inputs = { - 'abc': ['Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]'], - '2001-99-99T99:00': ['Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]'], + 'abc': ['Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z].'], + '2001-99-99T99:00': ['Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z].'], datetime.date(2001, 1, 1): ['Expected a datetime but got a date.'], } outputs = { @@ -721,7 +747,7 @@ class TestCustomInputFormatDateTimeField(FieldValues): '1:35pm, 1 Jan 2001': datetime.datetime(2001, 1, 1, 13, 35, tzinfo=timezone.UTC()), } invalid_inputs = { - '2001-01-01T20:50': ['Datetime has wrong format. Use one of these formats instead: hh:mm[AM|PM], DD [Jan-Dec] YYYY'] + '2001-01-01T20:50': ['Datetime has wrong format. Use one of these formats instead: hh:mm[AM|PM], DD [Jan-Dec] YYYY.'] } outputs = {} field = serializers.DateTimeField(default_timezone=timezone.UTC(), input_formats=['%I:%M%p, %d %b %Y']) @@ -773,8 +799,8 @@ class TestTimeField(FieldValues): datetime.time(13, 00): datetime.time(13, 00), } invalid_inputs = { - 'abc': ['Time has wrong format. Use one of these formats instead: hh:mm[:ss[.uuuuuu]]'], - '99:99': ['Time has wrong format. Use one of these formats instead: hh:mm[:ss[.uuuuuu]]'], + 'abc': ['Time has wrong format. Use one of these formats instead: hh:mm[:ss[.uuuuuu]].'], + '99:99': ['Time has wrong format. Use one of these formats instead: hh:mm[:ss[.uuuuuu]].'], } outputs = { datetime.time(13, 00): '13:00:00' @@ -790,7 +816,7 @@ class TestCustomInputFormatTimeField(FieldValues): '1:00pm': datetime.time(13, 00), } invalid_inputs = { - '13:00': ['Time has wrong format. Use one of these formats instead: hh:mm[AM|PM]'], + '13:00': ['Time has wrong format. Use one of these formats instead: hh:mm[AM|PM].'], } outputs = {} field = serializers.TimeField(input_formats=['%I:%M%p']) @@ -832,7 +858,7 @@ class TestChoiceField(FieldValues): 'good': 'good', } invalid_inputs = { - 'amazing': ['`amazing` is not a valid choice.'] + 'amazing': ['"amazing" is not a valid choice.'] } outputs = { 'good': 'good', @@ -872,8 +898,8 @@ class TestChoiceFieldWithType(FieldValues): 3: 3, } invalid_inputs = { - 5: ['`5` is not a valid choice.'], - 'abc': ['`abc` is not a valid choice.'] + 5: ['"5" is not a valid choice.'], + 'abc': ['"abc" is not a valid choice.'] } outputs = { '1': 1, @@ -899,7 +925,7 @@ class TestChoiceFieldWithListChoices(FieldValues): 'good': 'good', } invalid_inputs = { - 'awful': ['`awful` is not a valid choice.'] + 'awful': ['"awful" is not a valid choice.'] } outputs = { 'good': 'good' @@ -917,8 +943,8 @@ class TestMultipleChoiceField(FieldValues): ('aircon', 'manual'): set(['aircon', 'manual']), } invalid_inputs = { - 'abc': ['Expected a list of items but got type `str`.'], - ('aircon', 'incorrect'): ['`incorrect` is not a valid choice.'] + 'abc': ['Expected a list of items but got type "str".'], + ('aircon', 'incorrect'): ['"incorrect" is not a valid choice.'] } outputs = [ (['aircon', 'manual'], set(['aircon', 'manual'])) @@ -1021,14 +1047,14 @@ class TestValidImageField(FieldValues): class TestListField(FieldValues): """ - Values for `ListField`. + Values for `ListField` with IntegerField as child. """ valid_inputs = [ ([1, 2, 3], [1, 2, 3]), (['1', '2', '3'], [1, 2, 3]) ] invalid_inputs = [ - ('not a list', ['Expected a list of items but got type `str`']), + ('not a list', ['Expected a list of items but got type "str".']), ([1, 2, 'error'], ['A valid integer is required.']) ] outputs = [ @@ -1038,6 +1064,55 @@ class TestListField(FieldValues): field = serializers.ListField(child=serializers.IntegerField()) +class TestUnvalidatedListField(FieldValues): + """ + Values for `ListField` with no `child` argument. + """ + valid_inputs = [ + ([1, '2', True, [4, 5, 6]], [1, '2', True, [4, 5, 6]]), + ] + invalid_inputs = [ + ('not a list', ['Expected a list of items but got type "str".']), + ] + outputs = [ + ([1, '2', True, [4, 5, 6]], [1, '2', True, [4, 5, 6]]), + ] + field = serializers.ListField() + + +class TestDictField(FieldValues): + """ + Values for `ListField` with CharField as child. + """ + valid_inputs = [ + ({'a': 1, 'b': '2', 3: 3}, {'a': '1', 'b': '2', '3': '3'}), + ] + invalid_inputs = [ + ({'a': 1, 'b': None}, ['This field may not be null.']), + ('not a dict', ['Expected a dictionary of items but got type "str".']), + ] + outputs = [ + ({'a': 1, 'b': '2', 3: 3}, {'a': '1', 'b': '2', '3': '3'}), + ] + field = serializers.DictField(child=serializers.CharField()) + + +class TestUnvalidatedDictField(FieldValues): + """ + Values for `ListField` with no `child` argument. + """ + valid_inputs = [ + ({'a': 1, 'b': [4, 5, 6], 1: 123}, {'a': 1, 'b': [4, 5, 6], '1': 123}), + ] + invalid_inputs = [ + ('not a dict', ['Expected a dictionary of items but got type "str".']), + ] + outputs = [ + ({'a': 1, 'b': [4, 5, 6]}, {'a': 1, 'b': [4, 5, 6]}), + ] + field = serializers.DictField() + + # Tests for FieldField. # --------------------- diff --git a/tests/test_filters.py b/tests/test_filters.py index dc84dcbd..355f02ce 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -5,13 +5,15 @@ from django.db import models from django.conf.urls import patterns, url from django.core.urlresolvers import reverse from django.test import TestCase +from django.test.utils import override_settings from django.utils import unittest from django.utils.dateparse import parse_date +from django.utils.six.moves import reload_module from rest_framework import generics, serializers, status, filters from rest_framework.compat import django_filters from rest_framework.test import APIRequestFactory from .models import BaseFilterableItem, FilterableItem, BasicModel -from .utils import temporary_setting + factory = APIRequestFactory() @@ -404,7 +406,9 @@ class SearchFilterTests(TestCase): ) def test_search_with_nonstandard_search_param(self): - with temporary_setting('SEARCH_PARAM', 'query', module=filters): + with override_settings(REST_FRAMEWORK={'SEARCH_PARAM': 'query'}): + reload_module(filters) + class SearchListView(generics.ListAPIView): queryset = SearchFilterModel.objects.all() serializer_class = SearchFilterSerializer @@ -422,6 +426,8 @@ class SearchFilterTests(TestCase): ] ) + reload_module(filters) + class OrderingFilterModel(models.Model): title = models.CharField(max_length=20) @@ -467,6 +473,7 @@ class DjangoFilterOrderingTests(TestCase): for d in data: DjangoFilterOrderingModel.objects.create(**d) + @unittest.skipUnless(django_filters, 'django-filter not installed') def test_default_ordering(self): class DjangoFilterOrderingView(generics.ListAPIView): serializer_class = DjangoFilterOrderingSerializer @@ -641,7 +648,9 @@ class OrderingFilterTests(TestCase): ) def test_ordering_with_nonstandard_ordering_param(self): - with temporary_setting('ORDERING_PARAM', 'order', filters): + with override_settings(REST_FRAMEWORK={'ORDERING_PARAM': 'order'}): + reload_module(filters) + class OrderingListView(generics.ListAPIView): queryset = OrderingFilterModel.objects.all() serializer_class = OrderingFilterSerializer @@ -661,6 +670,8 @@ class OrderingFilterTests(TestCase): ] ) + reload_module(filters) + class SensitiveOrderingFilterModel(models.Model): username = models.CharField(max_length=20) diff --git a/tests/test_generics.py b/tests/test_generics.py index 94023c30..88e792ce 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -117,7 +117,7 @@ class TestRootView(TestCase): with self.assertNumQueries(0): response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) - self.assertEqual(response.data, {"detail": "Method 'PUT' not allowed."}) + self.assertEqual(response.data, {"detail": 'Method "PUT" not allowed.'}) def test_delete_root_view(self): """ @@ -127,7 +127,7 @@ class TestRootView(TestCase): with self.assertNumQueries(0): response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) - self.assertEqual(response.data, {"detail": "Method 'DELETE' not allowed."}) + self.assertEqual(response.data, {"detail": 'Method "DELETE" not allowed.'}) def test_post_cannot_set_id(self): """ @@ -181,7 +181,7 @@ class TestInstanceView(TestCase): with self.assertNumQueries(0): response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) - self.assertEqual(response.data, {"detail": "Method 'POST' not allowed."}) + self.assertEqual(response.data, {"detail": 'Method "POST" not allowed.'}) def test_put_instance_view(self): """ @@ -483,7 +483,7 @@ class TestFilterBackendAppliedToViews(TestCase): request = factory.get('/1') response = instance_view(request, pk=1).render() self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) - self.assertEqual(response.data, {'detail': 'Not found'}) + self.assertEqual(response.data, {'detail': 'Not found.'}) def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(self): """ diff --git a/tests/test_htmlrenderer.py b/tests/test_htmlrenderer.py index 2edc6b4b..a33b832f 100644 --- a/tests/test_htmlrenderer.py +++ b/tests/test_htmlrenderer.py @@ -56,7 +56,13 @@ class TemplateHTMLRendererTests(TestCase): return Template("example: {{ object }}") raise TemplateDoesNotExist(template_name) + def select_template(template_name_list, dirs=None, using=None): + if template_name_list == ['example.html']: + return Template("example: {{ object }}") + raise TemplateDoesNotExist(template_name_list[0]) + django.template.loader.get_template = get_template + django.template.loader.select_template = select_template def tearDown(self): """ diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 5ff59c72..5031c0f3 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -1,9 +1,8 @@ from __future__ import unicode_literals - -from rest_framework import exceptions, serializers, views +from rest_framework import exceptions, serializers, status, views, versioning from rest_framework.request import Request +from rest_framework.renderers import BrowsableAPIRenderer from rest_framework.test import APIRequestFactory -import pytest request = Request(APIRequestFactory().options('/')) @@ -17,7 +16,8 @@ class TestMetadata: """Example view.""" pass - response = ExampleView().options(request=request) + view = ExampleView.as_view() + response = view(request=request) expected = { 'name': 'Example', 'description': 'Example view.', @@ -31,7 +31,7 @@ class TestMetadata: 'multipart/form-data' ] } - assert response.status_code == 200 + assert response.status_code == status.HTTP_200_OK assert response.data == expected def test_none_metadata(self): @@ -42,8 +42,10 @@ class TestMetadata: class ExampleView(views.APIView): metadata_class = None - with pytest.raises(exceptions.MethodNotAllowed): - ExampleView().options(request=request) + view = ExampleView.as_view() + response = view(request=request) + assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED + assert response.data == {'detail': 'Method "OPTIONS" not allowed.'} def test_actions(self): """ @@ -63,7 +65,8 @@ class TestMetadata: def get_serializer(self): return ExampleSerializer() - response = ExampleView().options(request=request) + view = ExampleView.as_view() + response = view(request=request) expected = { 'name': 'Example', 'description': 'Example view.', @@ -104,7 +107,7 @@ class TestMetadata: } } } - assert response.status_code == 200 + assert response.status_code == status.HTTP_200_OK assert response.data == expected def test_global_permissions(self): @@ -132,8 +135,9 @@ class TestMetadata: if request.method == 'POST': raise exceptions.PermissionDenied() - response = ExampleView().options(request=request) - assert response.status_code == 200 + view = ExampleView.as_view() + response = view(request=request) + assert response.status_code == status.HTTP_200_OK assert list(response.data['actions'].keys()) == ['PUT'] def test_object_permissions(self): @@ -161,6 +165,36 @@ class TestMetadata: if self.request.method == 'PUT': raise exceptions.PermissionDenied() - response = ExampleView().options(request=request) - assert response.status_code == 200 + view = ExampleView.as_view() + response = view(request=request) + assert response.status_code == status.HTTP_200_OK assert list(response.data['actions'].keys()) == ['POST'] + + def test_bug_2455_clone_request(self): + class ExampleView(views.APIView): + renderer_classes = (BrowsableAPIRenderer,) + + def post(self, request): + pass + + def get_serializer(self): + assert hasattr(self.request, 'version') + return serializers.Serializer() + + view = ExampleView.as_view() + view(request=request) + + def test_bug_2477_clone_request(self): + class ExampleView(views.APIView): + renderer_classes = (BrowsableAPIRenderer,) + + def post(self, request): + pass + + def get_serializer(self): + assert hasattr(self.request, 'versioning_scheme') + return serializers.Serializer() + + scheme = versioning.QueryParameterVersioning + view = ExampleView.as_view(versioning_class=scheme) + view(request=request) diff --git a/tests/test_model_serializer.py b/tests/test_model_serializer.py index 5c56c8db..bce2008a 100644 --- a/tests/test_model_serializer.py +++ b/tests/test_model_serializer.py @@ -5,11 +5,14 @@ shortcuts for automatically creating serializers based on a given model class. These tests deal with ensuring that we correctly map the model fields onto an appropriate set of serializer fields for each case. """ +from __future__ import unicode_literals from django.core.exceptions import ImproperlyConfigured from django.core.validators import MaxValueValidator, MinValueValidator, MinLengthValidator from django.db import models from django.test import TestCase +from django.utils import six from rest_framework import serializers +from rest_framework.compat import unicode_repr def dedent(blocktext): @@ -119,12 +122,12 @@ class TestRegularFieldMappings(TestCase): positive_small_integer_field = IntegerField() slug_field = SlugField(max_length=100) small_integer_field = IntegerField() - text_field = CharField(style={'type': 'textarea'}) + text_field = CharField(style={'base_template': 'textarea.html'}) time_field = TimeField() url_field = URLField(max_length=100) custom_field = ModelField(model_field=<tests.test_model_serializer.CustomField: custom_field>) """) - self.assertEqual(repr(TestSerializer()), expected) + self.assertEqual(unicode_repr(TestSerializer()), expected) def test_field_options(self): class TestSerializer(serializers.ModelSerializer): @@ -142,7 +145,14 @@ class TestRegularFieldMappings(TestCase): descriptive_field = IntegerField(help_text='Some help text', label='A label') choices_field = ChoiceField(choices=[('red', 'Red'), ('blue', 'Blue'), ('green', 'Green')]) """) - self.assertEqual(repr(TestSerializer()), expected) + if six.PY2: + # This particular case is too awkward to resolve fully across + # both py2 and py3. + expected = expected.replace( + "('red', 'Red'), ('blue', 'Blue'), ('green', 'Green')", + "(u'red', u'Red'), (u'blue', u'Blue'), (u'green', u'Green')" + ) + self.assertEqual(unicode_repr(TestSerializer()), expected) def test_method_field(self): """ @@ -206,7 +216,7 @@ class TestRegularFieldMappings(TestCase): with self.assertRaises(ImproperlyConfigured) as excinfo: TestSerializer().fields - expected = 'Field name `invalid` is not valid for model `ModelBase`.' + expected = 'Field name `invalid` is not valid for model `RegularFieldsModel`.' assert str(excinfo.exception) == expected def test_missing_field(self): @@ -229,6 +239,26 @@ class TestRegularFieldMappings(TestCase): ) assert str(excinfo.exception) == expected + def test_missing_superclass_field(self): + """ + Fields that have been declared on a parent of the serializer class may + be excluded from the `Meta.fields` option. + """ + class TestSerializer(serializers.ModelSerializer): + missing = serializers.ReadOnlyField() + + class Meta: + model = RegularFieldsModel + + class ChildSerializer(TestSerializer): + missing = serializers.ReadOnlyField() + + class Meta: + model = RegularFieldsModel + fields = ('auto_field',) + + ChildSerializer().fields + # Tests for relational field mappings. # ------------------------------------ @@ -276,7 +306,7 @@ class TestRelationalFieldMappings(TestCase): many_to_many = PrimaryKeyRelatedField(many=True, queryset=ManyToManyTargetModel.objects.all()) through = PrimaryKeyRelatedField(many=True, read_only=True) """) - self.assertEqual(repr(TestSerializer()), expected) + self.assertEqual(unicode_repr(TestSerializer()), expected) def test_nested_relations(self): class TestSerializer(serializers.ModelSerializer): @@ -300,7 +330,7 @@ class TestRelationalFieldMappings(TestCase): id = IntegerField(label='ID', read_only=True) name = CharField(max_length=100) """) - self.assertEqual(repr(TestSerializer()), expected) + self.assertEqual(unicode_repr(TestSerializer()), expected) def test_hyperlinked_relations(self): class TestSerializer(serializers.HyperlinkedModelSerializer): @@ -315,7 +345,7 @@ class TestRelationalFieldMappings(TestCase): many_to_many = HyperlinkedRelatedField(many=True, queryset=ManyToManyTargetModel.objects.all(), view_name='manytomanytargetmodel-detail') through = HyperlinkedRelatedField(many=True, read_only=True, view_name='throughtargetmodel-detail') """) - self.assertEqual(repr(TestSerializer()), expected) + self.assertEqual(unicode_repr(TestSerializer()), expected) def test_nested_hyperlinked_relations(self): class TestSerializer(serializers.HyperlinkedModelSerializer): @@ -339,7 +369,7 @@ class TestRelationalFieldMappings(TestCase): url = HyperlinkedIdentityField(view_name='throughtargetmodel-detail') name = CharField(max_length=100) """) - self.assertEqual(repr(TestSerializer()), expected) + self.assertEqual(unicode_repr(TestSerializer()), expected) def test_pk_reverse_foreign_key(self): class TestSerializer(serializers.ModelSerializer): @@ -353,7 +383,7 @@ class TestRelationalFieldMappings(TestCase): name = CharField(max_length=100) reverse_foreign_key = PrimaryKeyRelatedField(many=True, queryset=RelationalModel.objects.all()) """) - self.assertEqual(repr(TestSerializer()), expected) + self.assertEqual(unicode_repr(TestSerializer()), expected) def test_pk_reverse_one_to_one(self): class TestSerializer(serializers.ModelSerializer): @@ -367,7 +397,7 @@ class TestRelationalFieldMappings(TestCase): name = CharField(max_length=100) reverse_one_to_one = PrimaryKeyRelatedField(queryset=RelationalModel.objects.all()) """) - self.assertEqual(repr(TestSerializer()), expected) + self.assertEqual(unicode_repr(TestSerializer()), expected) def test_pk_reverse_many_to_many(self): class TestSerializer(serializers.ModelSerializer): @@ -381,7 +411,7 @@ class TestRelationalFieldMappings(TestCase): name = CharField(max_length=100) reverse_many_to_many = PrimaryKeyRelatedField(many=True, queryset=RelationalModel.objects.all()) """) - self.assertEqual(repr(TestSerializer()), expected) + self.assertEqual(unicode_repr(TestSerializer()), expected) def test_pk_reverse_through(self): class TestSerializer(serializers.ModelSerializer): @@ -395,7 +425,7 @@ class TestRelationalFieldMappings(TestCase): name = CharField(max_length=100) reverse_through = PrimaryKeyRelatedField(many=True, read_only=True) """) - self.assertEqual(repr(TestSerializer()), expected) + self.assertEqual(unicode_repr(TestSerializer()), expected) class TestIntegration(TestCase): diff --git a/tests/test_multitable_inheritance.py b/tests/test_multitable_inheritance.py index e1b40cc7..15627e1d 100644 --- a/tests/test_multitable_inheritance.py +++ b/tests/test_multitable_inheritance.py @@ -48,8 +48,8 @@ class InheritedModelSerializationTests(TestCase): Assert that a model with a onetoone field that is the primary key is not treated like a derived model """ - parent = ParentModel(name1='parent name') - associate = AssociatedModel(name='hello', ref=parent) + parent = ParentModel.objects.create(name1='parent name') + associate = AssociatedModel.objects.create(name='hello', ref=parent) serializer = AssociatedModelSerializer(associate) self.assertEqual(set(serializer.data.keys()), set(['name', 'ref'])) diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 1fd9cf9c..13bfb627 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -1,553 +1,671 @@ +# coding: utf-8 from __future__ import unicode_literals -import datetime -from decimal import Decimal -from django.core.paginator import Paginator -from django.test import TestCase -from django.utils import unittest -from rest_framework import generics, serializers, status, pagination, filters -from rest_framework.compat import django_filters +from rest_framework import exceptions, generics, pagination, serializers, status, filters +from rest_framework.request import Request +from rest_framework.pagination import PageLink, PAGE_BREAK from rest_framework.test import APIRequestFactory -from .models import BasicModel, FilterableItem +import pytest factory = APIRequestFactory() -# Helper function to split arguments out of an url -def split_arguments_from_url(url): - if '?' not in url: - return url +class TestPaginationIntegration: + """ + Integration tests. + """ - path, args = url.split('?') - args = dict(r.split('=') for r in args.split('&')) - return path, args + def setup(self): + class PassThroughSerializer(serializers.BaseSerializer): + def to_representation(self, item): + return item + class EvenItemsOnly(filters.BaseFilterBackend): + def filter_queryset(self, request, queryset, view): + return [item for item in queryset if item % 2 == 0] + + class BasicPagination(pagination.PageNumberPagination): + paginate_by = 5 + paginate_by_param = 'page_size' + max_paginate_by = 20 + + self.view = generics.ListAPIView.as_view( + serializer_class=PassThroughSerializer, + queryset=range(1, 101), + filter_backends=[EvenItemsOnly], + pagination_class=BasicPagination + ) -class BasicSerializer(serializers.ModelSerializer): - class Meta: - model = BasicModel + def test_filtered_items_are_paginated(self): + request = factory.get('/', {'page': 2}) + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == { + 'results': [12, 14, 16, 18, 20], + 'previous': 'http://testserver/', + 'next': 'http://testserver/?page=3', + 'count': 50 + } + def test_setting_page_size(self): + """ + When 'paginate_by_param' is set, the client may choose a page size. + """ + request = factory.get('/', {'page_size': 10}) + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == { + 'results': [2, 4, 6, 8, 10, 12, 14, 16, 18, 20], + 'previous': None, + 'next': 'http://testserver/?page=2&page_size=10', + 'count': 50 + } -class FilterableItemSerializer(serializers.ModelSerializer): - class Meta: - model = FilterableItem + def test_setting_page_size_over_maximum(self): + """ + When page_size parameter exceeds maxiumum allowable, + then it should be capped to the maxiumum. + """ + request = factory.get('/', {'page_size': 1000}) + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == { + 'results': [ + 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, + 22, 24, 26, 28, 30, 32, 34, 36, 38, 40 + ], + 'previous': None, + 'next': 'http://testserver/?page=2&page_size=1000', + 'count': 50 + } + def test_setting_page_size_to_zero(self): + """ + When page_size parameter is invalid it should return to the default. + """ + request = factory.get('/', {'page_size': 0}) + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == { + 'results': [2, 4, 6, 8, 10], + 'previous': None, + 'next': 'http://testserver/?page=2&page_size=0', + 'count': 50 + } -class RootView(generics.ListCreateAPIView): - """ - Example description for OPTIONS. - """ - queryset = BasicModel.objects.all() - serializer_class = BasicSerializer - paginate_by = 10 + def test_additional_query_params_are_preserved(self): + request = factory.get('/', {'page': 2, 'filter': 'even'}) + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == { + 'results': [12, 14, 16, 18, 20], + 'previous': 'http://testserver/?filter=even', + 'next': 'http://testserver/?filter=even&page=3', + 'count': 50 + } + def test_404_not_found_for_zero_page(self): + request = factory.get('/', {'page': '0'}) + response = self.view(request) + assert response.status_code == status.HTTP_404_NOT_FOUND + assert response.data == { + 'detail': 'Invalid page "0": That page number is less than 1.' + } -class DefaultPageSizeKwargView(generics.ListAPIView): - """ - View for testing default paginate_by_param usage - """ - queryset = BasicModel.objects.all() - serializer_class = BasicSerializer + def test_404_not_found_for_invalid_page(self): + request = factory.get('/', {'page': 'invalid'}) + response = self.view(request) + assert response.status_code == status.HTTP_404_NOT_FOUND + assert response.data == { + 'detail': 'Invalid page "invalid": That page number is not an integer.' + } -class PaginateByParamView(generics.ListAPIView): +class TestPaginationDisabledIntegration: """ - View for testing custom paginate_by_param usage + Integration tests for disabled pagination. """ - queryset = BasicModel.objects.all() - serializer_class = BasicSerializer - paginate_by_param = 'page_size' + def setup(self): + class PassThroughSerializer(serializers.BaseSerializer): + def to_representation(self, item): + return item -class MaxPaginateByView(generics.ListAPIView): - """ - View for testing custom max_paginate_by usage - """ - queryset = BasicModel.objects.all() - serializer_class = BasicSerializer - paginate_by = 3 - max_paginate_by = 5 - paginate_by_param = 'page_size' + self.view = generics.ListAPIView.as_view( + serializer_class=PassThroughSerializer, + queryset=range(1, 101), + pagination_class=None + ) + + def test_unpaginated_list(self): + request = factory.get('/', {'page': 2}) + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == list(range(1, 101)) -class IntegrationTestPagination(TestCase): +class TestDeprecatedStylePagination: """ - Integration tests for paginated list views. + Integration tests for deprecated style of setting pagination + attributes on the view. """ - def setUp(self): - """ - Create 26 BasicModel instances. - """ - for char in 'abcdefghijklmnopqrstuvwxyz': - BasicModel(text=char * 3).save() - self.objects = BasicModel.objects - self.data = [ - {'id': obj.id, 'text': obj.text} - for obj in self.objects.all() - ] - self.view = RootView.as_view() - - def test_get_paginated_root_view(self): - """ - GET requests to paginated ListCreateAPIView should return paginated results. - """ - request = factory.get('/') - # Note: Database queries are a `SELECT COUNT`, and `SELECT <fields>` - with self.assertNumQueries(2): - response = self.view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data['count'], 26) - self.assertEqual(response.data['results'], self.data[:10]) - self.assertNotEqual(response.data['next'], None) - self.assertEqual(response.data['previous'], None) - - request = factory.get(*split_arguments_from_url(response.data['next'])) - with self.assertNumQueries(2): - response = self.view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data['count'], 26) - self.assertEqual(response.data['results'], self.data[10:20]) - self.assertNotEqual(response.data['next'], None) - self.assertNotEqual(response.data['previous'], None) - - request = factory.get(*split_arguments_from_url(response.data['next'])) - with self.assertNumQueries(2): - response = self.view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data['count'], 26) - self.assertEqual(response.data['results'], self.data[20:]) - self.assertEqual(response.data['next'], None) - self.assertNotEqual(response.data['previous'], None) - - -class IntegrationTestPaginationAndFiltering(TestCase): - - def setUp(self): - """ - Create 50 FilterableItem instances. - """ - base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8)) - for i in range(26): - text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc. - decimal = base_data[1] + i - date = base_data[2] - datetime.timedelta(days=i * 2) - FilterableItem(text=text, decimal=decimal, date=date).save() - - self.objects = FilterableItem.objects - self.data = [ - {'id': obj.id, 'text': obj.text, 'decimal': str(obj.decimal), 'date': obj.date.isoformat()} - for obj in self.objects.all() - ] - - @unittest.skipUnless(django_filters, 'django-filter not installed') - def test_get_django_filter_paginated_filtered_root_view(self): - """ - GET requests to paginated filtered ListCreateAPIView should return - paginated results. The next and previous links should preserve the - filtered parameters. - """ - class DecimalFilter(django_filters.FilterSet): - decimal = django_filters.NumberFilter(lookup_type='lt') - - class Meta: - model = FilterableItem - fields = ['text', 'decimal', 'date'] - - class FilterFieldsRootView(generics.ListCreateAPIView): - queryset = FilterableItem.objects.all() - serializer_class = FilterableItemSerializer - paginate_by = 10 - filter_class = DecimalFilter - filter_backends = (filters.DjangoFilterBackend,) - - view = FilterFieldsRootView.as_view() - - EXPECTED_NUM_QUERIES = 2 - - request = factory.get('/', {'decimal': '15.20'}) - with self.assertNumQueries(EXPECTED_NUM_QUERIES): - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data['count'], 15) - self.assertEqual(response.data['results'], self.data[:10]) - self.assertNotEqual(response.data['next'], None) - self.assertEqual(response.data['previous'], None) - - request = factory.get(*split_arguments_from_url(response.data['next'])) - with self.assertNumQueries(EXPECTED_NUM_QUERIES): - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data['count'], 15) - self.assertEqual(response.data['results'], self.data[10:15]) - self.assertEqual(response.data['next'], None) - self.assertNotEqual(response.data['previous'], None) - - request = factory.get(*split_arguments_from_url(response.data['previous'])) - with self.assertNumQueries(EXPECTED_NUM_QUERIES): - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data['count'], 15) - self.assertEqual(response.data['results'], self.data[:10]) - self.assertNotEqual(response.data['next'], None) - self.assertEqual(response.data['previous'], None) - - def test_get_basic_paginated_filtered_root_view(self): - """ - Same as `test_get_django_filter_paginated_filtered_root_view`, - except using a custom filter backend instead of the django-filter - backend, - """ + def setup(self): + class PassThroughSerializer(serializers.BaseSerializer): + def to_representation(self, item): + return item - class DecimalFilterBackend(filters.BaseFilterBackend): - def filter_queryset(self, request, queryset, view): - return queryset.filter(decimal__lt=Decimal(request.GET['decimal'])) - - class BasicFilterFieldsRootView(generics.ListCreateAPIView): - queryset = FilterableItem.objects.all() - serializer_class = FilterableItemSerializer - paginate_by = 10 - filter_backends = (DecimalFilterBackend,) - - view = BasicFilterFieldsRootView.as_view() - - request = factory.get('/', {'decimal': '15.20'}) - with self.assertNumQueries(2): - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data['count'], 15) - self.assertEqual(response.data['results'], self.data[:10]) - self.assertNotEqual(response.data['next'], None) - self.assertEqual(response.data['previous'], None) - - request = factory.get(*split_arguments_from_url(response.data['next'])) - with self.assertNumQueries(2): - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data['count'], 15) - self.assertEqual(response.data['results'], self.data[10:15]) - self.assertEqual(response.data['next'], None) - self.assertNotEqual(response.data['previous'], None) - - request = factory.get(*split_arguments_from_url(response.data['previous'])) - with self.assertNumQueries(2): - response = view(request).render() - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data['count'], 15) - self.assertEqual(response.data['results'], self.data[:10]) - self.assertNotEqual(response.data['next'], None) - self.assertEqual(response.data['previous'], None) - - -class PassOnContextPaginationSerializer(pagination.PaginationSerializer): - class Meta: - object_serializer_class = serializers.Serializer - - -class UnitTestPagination(TestCase): - """ - Unit tests for pagination of primitive objects. - """ + class ExampleView(generics.ListAPIView): + serializer_class = PassThroughSerializer + queryset = range(1, 101) + pagination_class = pagination.PageNumberPagination + paginate_by = 20 + page_query_param = 'page_number' - def setUp(self): - self.objects = [char * 3 for char in 'abcdefghijklmnopqrstuvwxyz'] - paginator = Paginator(self.objects, 10) - self.first_page = paginator.page(1) - self.last_page = paginator.page(3) - - def test_native_pagination(self): - serializer = pagination.PaginationSerializer(self.first_page) - self.assertEqual(serializer.data['count'], 26) - self.assertEqual(serializer.data['next'], '?page=2') - self.assertEqual(serializer.data['previous'], None) - self.assertEqual(serializer.data['results'], self.objects[:10]) - - serializer = pagination.PaginationSerializer(self.last_page) - self.assertEqual(serializer.data['count'], 26) - self.assertEqual(serializer.data['next'], None) - self.assertEqual(serializer.data['previous'], '?page=2') - self.assertEqual(serializer.data['results'], self.objects[20:]) - - def test_context_available_in_result(self): - """ - Ensure context gets passed through to the object serializer. - """ - serializer = PassOnContextPaginationSerializer(self.first_page, context={'foo': 'bar'}) - serializer.data - results = serializer.fields[serializer.results_field] - self.assertEqual(serializer.context, results.context) + self.view = ExampleView.as_view() + + def test_paginate_by_attribute_on_view(self): + request = factory.get('/?page_number=2') + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == { + 'results': [ + 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + 31, 32, 33, 34, 35, 36, 37, 38, 39, 40 + ], + 'previous': 'http://testserver/', + 'next': 'http://testserver/?page_number=3', + 'count': 100 + } -class TestUnpaginated(TestCase): +class TestPageNumberPagination: """ - Tests for list views without pagination. + Unit tests for `pagination.PageNumberPagination`. """ - def setUp(self): - """ - Create 13 BasicModel instances. - """ - for i in range(13): - BasicModel(text=i).save() - self.objects = BasicModel.objects - self.data = [ - {'id': obj.id, 'text': obj.text} - for obj in self.objects.all() - ] - self.view = DefaultPageSizeKwargView.as_view() - - def test_unpaginated(self): - """ - Tests the default page size for this view. - no page size --> no limit --> no meta data - """ - request = factory.get('/') - response = self.view(request) - self.assertEqual(response.data, self.data) + def setup(self): + class ExamplePagination(pagination.PageNumberPagination): + paginate_by = 5 + self.pagination = ExamplePagination() + self.queryset = range(1, 101) + + def paginate_queryset(self, request): + return list(self.pagination.paginate_queryset(self.queryset, request)) + + def get_paginated_content(self, queryset): + response = self.pagination.get_paginated_response(queryset) + return response.data + + def get_html_context(self): + return self.pagination.get_html_context() + + def test_no_page_number(self): + request = Request(factory.get('/')) + queryset = self.paginate_queryset(request) + content = self.get_paginated_content(queryset) + context = self.get_html_context() + assert queryset == [1, 2, 3, 4, 5] + assert content == { + 'results': [1, 2, 3, 4, 5], + 'previous': None, + 'next': 'http://testserver/?page=2', + 'count': 100 + } + assert context == { + 'previous_url': None, + 'next_url': 'http://testserver/?page=2', + 'page_links': [ + PageLink('http://testserver/', 1, True, False), + PageLink('http://testserver/?page=2', 2, False, False), + PageLink('http://testserver/?page=3', 3, False, False), + PAGE_BREAK, + PageLink('http://testserver/?page=20', 20, False, False), + ] + } + assert self.pagination.display_page_controls + assert isinstance(self.pagination.to_html(), type('')) + + def test_second_page(self): + request = Request(factory.get('/', {'page': 2})) + queryset = self.paginate_queryset(request) + content = self.get_paginated_content(queryset) + context = self.get_html_context() + assert queryset == [6, 7, 8, 9, 10] + assert content == { + 'results': [6, 7, 8, 9, 10], + 'previous': 'http://testserver/', + 'next': 'http://testserver/?page=3', + 'count': 100 + } + assert context == { + 'previous_url': 'http://testserver/', + 'next_url': 'http://testserver/?page=3', + 'page_links': [ + PageLink('http://testserver/', 1, False, False), + PageLink('http://testserver/?page=2', 2, True, False), + PageLink('http://testserver/?page=3', 3, False, False), + PAGE_BREAK, + PageLink('http://testserver/?page=20', 20, False, False), + ] + } + + def test_last_page(self): + request = Request(factory.get('/', {'page': 'last'})) + queryset = self.paginate_queryset(request) + content = self.get_paginated_content(queryset) + context = self.get_html_context() + assert queryset == [96, 97, 98, 99, 100] + assert content == { + 'results': [96, 97, 98, 99, 100], + 'previous': 'http://testserver/?page=19', + 'next': None, + 'count': 100 + } + assert context == { + 'previous_url': 'http://testserver/?page=19', + 'next_url': None, + 'page_links': [ + PageLink('http://testserver/', 1, False, False), + PAGE_BREAK, + PageLink('http://testserver/?page=18', 18, False, False), + PageLink('http://testserver/?page=19', 19, False, False), + PageLink('http://testserver/?page=20', 20, True, False), + ] + } + + def test_invalid_page(self): + request = Request(factory.get('/', {'page': 'invalid'})) + with pytest.raises(exceptions.NotFound): + self.paginate_queryset(request) -class TestCustomPaginateByParam(TestCase): +class TestLimitOffset: """ - Tests for list views with default page size kwarg + Unit tests for `pagination.LimitOffsetPagination`. """ - def setUp(self): + def setup(self): + class ExamplePagination(pagination.LimitOffsetPagination): + default_limit = 10 + self.pagination = ExamplePagination() + self.queryset = range(1, 101) + + def paginate_queryset(self, request): + return list(self.pagination.paginate_queryset(self.queryset, request)) + + def get_paginated_content(self, queryset): + response = self.pagination.get_paginated_response(queryset) + return response.data + + def get_html_context(self): + return self.pagination.get_html_context() + + def test_no_offset(self): + request = Request(factory.get('/', {'limit': 5})) + queryset = self.paginate_queryset(request) + content = self.get_paginated_content(queryset) + context = self.get_html_context() + assert queryset == [1, 2, 3, 4, 5] + assert content == { + 'results': [1, 2, 3, 4, 5], + 'previous': None, + 'next': 'http://testserver/?limit=5&offset=5', + 'count': 100 + } + assert context == { + 'previous_url': None, + 'next_url': 'http://testserver/?limit=5&offset=5', + 'page_links': [ + PageLink('http://testserver/?limit=5', 1, True, False), + PageLink('http://testserver/?limit=5&offset=5', 2, False, False), + PageLink('http://testserver/?limit=5&offset=10', 3, False, False), + PAGE_BREAK, + PageLink('http://testserver/?limit=5&offset=95', 20, False, False), + ] + } + assert self.pagination.display_page_controls + assert isinstance(self.pagination.to_html(), type('')) + + def test_single_offset(self): """ - Create 13 BasicModel instances. + When the offset is not a multiple of the limit we get some edge cases: + * The first page should still be offset zero. + * We may end up displaying an extra page in the pagination control. """ - for i in range(13): - BasicModel(text=i).save() - self.objects = BasicModel.objects - self.data = [ - {'id': obj.id, 'text': obj.text} - for obj in self.objects.all() - ] - self.view = PaginateByParamView.as_view() - - def test_default_page_size(self): + request = Request(factory.get('/', {'limit': 5, 'offset': 1})) + queryset = self.paginate_queryset(request) + content = self.get_paginated_content(queryset) + context = self.get_html_context() + assert queryset == [2, 3, 4, 5, 6] + assert content == { + 'results': [2, 3, 4, 5, 6], + 'previous': 'http://testserver/?limit=5', + 'next': 'http://testserver/?limit=5&offset=6', + 'count': 100 + } + assert context == { + 'previous_url': 'http://testserver/?limit=5', + 'next_url': 'http://testserver/?limit=5&offset=6', + 'page_links': [ + PageLink('http://testserver/?limit=5', 1, False, False), + PageLink('http://testserver/?limit=5&offset=1', 2, True, False), + PageLink('http://testserver/?limit=5&offset=6', 3, False, False), + PAGE_BREAK, + PageLink('http://testserver/?limit=5&offset=96', 21, False, False), + ] + } + + def test_first_offset(self): + request = Request(factory.get('/', {'limit': 5, 'offset': 5})) + queryset = self.paginate_queryset(request) + content = self.get_paginated_content(queryset) + context = self.get_html_context() + assert queryset == [6, 7, 8, 9, 10] + assert content == { + 'results': [6, 7, 8, 9, 10], + 'previous': 'http://testserver/?limit=5', + 'next': 'http://testserver/?limit=5&offset=10', + 'count': 100 + } + assert context == { + 'previous_url': 'http://testserver/?limit=5', + 'next_url': 'http://testserver/?limit=5&offset=10', + 'page_links': [ + PageLink('http://testserver/?limit=5', 1, False, False), + PageLink('http://testserver/?limit=5&offset=5', 2, True, False), + PageLink('http://testserver/?limit=5&offset=10', 3, False, False), + PAGE_BREAK, + PageLink('http://testserver/?limit=5&offset=95', 20, False, False), + ] + } + + def test_middle_offset(self): + request = Request(factory.get('/', {'limit': 5, 'offset': 10})) + queryset = self.paginate_queryset(request) + content = self.get_paginated_content(queryset) + context = self.get_html_context() + assert queryset == [11, 12, 13, 14, 15] + assert content == { + 'results': [11, 12, 13, 14, 15], + 'previous': 'http://testserver/?limit=5&offset=5', + 'next': 'http://testserver/?limit=5&offset=15', + 'count': 100 + } + assert context == { + 'previous_url': 'http://testserver/?limit=5&offset=5', + 'next_url': 'http://testserver/?limit=5&offset=15', + 'page_links': [ + PageLink('http://testserver/?limit=5', 1, False, False), + PageLink('http://testserver/?limit=5&offset=5', 2, False, False), + PageLink('http://testserver/?limit=5&offset=10', 3, True, False), + PageLink('http://testserver/?limit=5&offset=15', 4, False, False), + PAGE_BREAK, + PageLink('http://testserver/?limit=5&offset=95', 20, False, False), + ] + } + + def test_ending_offset(self): + request = Request(factory.get('/', {'limit': 5, 'offset': 95})) + queryset = self.paginate_queryset(request) + content = self.get_paginated_content(queryset) + context = self.get_html_context() + assert queryset == [96, 97, 98, 99, 100] + assert content == { + 'results': [96, 97, 98, 99, 100], + 'previous': 'http://testserver/?limit=5&offset=90', + 'next': None, + 'count': 100 + } + assert context == { + 'previous_url': 'http://testserver/?limit=5&offset=90', + 'next_url': None, + 'page_links': [ + PageLink('http://testserver/?limit=5', 1, False, False), + PAGE_BREAK, + PageLink('http://testserver/?limit=5&offset=85', 18, False, False), + PageLink('http://testserver/?limit=5&offset=90', 19, False, False), + PageLink('http://testserver/?limit=5&offset=95', 20, True, False), + ] + } + + def test_invalid_offset(self): """ - Tests the default page size for this view. - no page size --> no limit --> no meta data + An invalid offset query param should be treated as 0. """ - request = factory.get('/') - response = self.view(request).render() - self.assertEqual(response.data, self.data) + request = Request(factory.get('/', {'limit': 5, 'offset': 'invalid'})) + queryset = self.paginate_queryset(request) + assert queryset == [1, 2, 3, 4, 5] - def test_paginate_by_param(self): + def test_invalid_limit(self): """ - If paginate_by_param is set, the new kwarg should limit per view requests. + An invalid limit query param should be ignored in favor of the default. """ - request = factory.get('/', {'page_size': 5}) - response = self.view(request).render() - self.assertEqual(response.data['count'], 13) - self.assertEqual(response.data['results'], self.data[:5]) + request = Request(factory.get('/', {'limit': 'invalid', 'offset': 0})) + queryset = self.paginate_queryset(request) + assert queryset == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] -class TestMaxPaginateByParam(TestCase): +class TestCursorPagination: """ - Tests for list views with max_paginate_by kwarg + Unit tests for `pagination.CursorPagination`. """ - def setUp(self): + def setup(self): + class MockObject(object): + def __init__(self, idx): + self.created = idx + + class MockQuerySet(object): + def __init__(self, items): + self.items = items + + def filter(self, created__gt=None, created__lt=None): + if created__gt is not None: + return MockQuerySet([ + item for item in self.items + if item.created > int(created__gt) + ]) + + assert created__lt is not None + return MockQuerySet([ + item for item in self.items + if item.created < int(created__lt) + ]) + + def order_by(self, *ordering): + if ordering[0].startswith('-'): + return MockQuerySet(list(reversed(self.items))) + return self + + def __getitem__(self, sliced): + return self.items[sliced] + + class ExamplePagination(pagination.CursorPagination): + page_size = 5 + ordering = 'created' + + self.pagination = ExamplePagination() + self.queryset = MockQuerySet([ + MockObject(idx) for idx in [ + 1, 1, 1, 1, 1, + 1, 2, 3, 4, 4, + 4, 4, 5, 6, 7, + 7, 7, 7, 7, 7, + 7, 7, 7, 8, 9, + 9, 9, 9, 9, 9 + ] + ]) + + def get_pages(self, url): """ - Create 13 BasicModel instances. - """ - for i in range(13): - BasicModel(text=i).save() - self.objects = BasicModel.objects - self.data = [ - {'id': obj.id, 'text': obj.text} - for obj in self.objects.all() - ] - self.view = MaxPaginateByView.as_view() - - def test_max_paginate_by(self): - """ - If max_paginate_by is set, it should limit page size for the view. - """ - request = factory.get('/', data={'page_size': 10}) - response = self.view(request).render() - self.assertEqual(response.data['count'], 13) - self.assertEqual(response.data['results'], self.data[:5]) + Given a URL return a tuple of: - def test_max_paginate_by_without_page_size_param(self): + (previous page, current page, next page, previous url, next url) """ - If max_paginate_by is set, but client does not specifiy page_size, - standard `paginate_by` behavior should be used. - """ - request = factory.get('/') - response = self.view(request).render() - self.assertEqual(response.data['results'], self.data[:3]) - - -# Tests for context in pagination serializers + request = Request(factory.get(url)) + queryset = self.pagination.paginate_queryset(self.queryset, request) + current = [item.created for item in queryset] -class CustomField(serializers.ReadOnlyField): - def to_native(self, value): - if 'view' not in self.context: - raise RuntimeError("context isn't getting passed into custom field") - return "value" + next_url = self.pagination.get_next_link() + previous_url = self.pagination.get_previous_link() + if next_url is not None: + request = Request(factory.get(next_url)) + queryset = self.pagination.paginate_queryset(self.queryset, request) + next = [item.created for item in queryset] + else: + next = None -class BasicModelSerializer(serializers.Serializer): - text = CustomField() - - def to_native(self, value): - if 'view' not in self.context: - raise RuntimeError("context isn't getting passed into serializer") - return super(BasicSerializer, self).to_native(value) + if previous_url is not None: + request = Request(factory.get(previous_url)) + queryset = self.pagination.paginate_queryset(self.queryset, request) + previous = [item.created for item in queryset] + else: + previous = None + return (previous, current, next, previous_url, next_url) -class TestContextPassedToCustomField(TestCase): - def setUp(self): - BasicModel.objects.create(text='ala ma kota') + def test_invalid_cursor(self): + request = Request(factory.get('/', {'cursor': '123'})) + with pytest.raises(exceptions.NotFound): + self.pagination.paginate_queryset(self.queryset, request) - def test_with_pagination(self): - class ListView(generics.ListCreateAPIView): - queryset = BasicModel.objects.all() - serializer_class = BasicModelSerializer - paginate_by = 1 + def test_use_with_ordering_filter(self): + class MockView: + filter_backends = (filters.OrderingFilter,) + ordering_fields = ['username', 'created'] + ordering = 'created' - self.view = ListView.as_view() - request = factory.get('/') - response = self.view(request).render() + request = Request(factory.get('/', {'ordering': 'username'})) + ordering = self.pagination.get_ordering(request, [], MockView()) + assert ordering == ('username',) - self.assertEqual(response.status_code, status.HTTP_200_OK) + request = Request(factory.get('/', {'ordering': '-username'})) + ordering = self.pagination.get_ordering(request, [], MockView()) + assert ordering == ('-username',) + request = Request(factory.get('/', {'ordering': 'invalid'})) + ordering = self.pagination.get_ordering(request, [], MockView()) + assert ordering == ('created',) -# Tests for custom pagination serializers + def test_cursor_pagination(self): + (previous, current, next, previous_url, next_url) = self.get_pages('/') -class LinksSerializer(serializers.Serializer): - next = pagination.NextPageField(source='*') - prev = pagination.PreviousPageField(source='*') + assert previous is None + assert current == [1, 1, 1, 1, 1] + assert next == [1, 2, 3, 4, 4] + (previous, current, next, previous_url, next_url) = self.get_pages(next_url) -class CustomPaginationSerializer(pagination.BasePaginationSerializer): - links = LinksSerializer(source='*') # Takes the page object as the source - total_results = serializers.ReadOnlyField(source='paginator.count') + assert previous == [1, 1, 1, 1, 1] + assert current == [1, 2, 3, 4, 4] + assert next == [4, 4, 5, 6, 7] - results_field = 'objects' + (previous, current, next, previous_url, next_url) = self.get_pages(next_url) + assert previous == [1, 2, 3, 4, 4] + assert current == [4, 4, 5, 6, 7] + assert next == [7, 7, 7, 7, 7] -class CustomFooSerializer(serializers.Serializer): - foo = serializers.CharField() + (previous, current, next, previous_url, next_url) = self.get_pages(next_url) + assert previous == [4, 4, 4, 5, 6] # Paging artifact + assert current == [7, 7, 7, 7, 7] + assert next == [7, 7, 7, 8, 9] -class CustomFooPaginationSerializer(pagination.PaginationSerializer): - class Meta: - object_serializer_class = CustomFooSerializer + (previous, current, next, previous_url, next_url) = self.get_pages(next_url) + assert previous == [7, 7, 7, 7, 7] + assert current == [7, 7, 7, 8, 9] + assert next == [9, 9, 9, 9, 9] -class TestCustomPaginationSerializer(TestCase): - def setUp(self): - objects = ['john', 'paul', 'george', 'ringo'] - paginator = Paginator(objects, 2) - self.page = paginator.page(1) + (previous, current, next, previous_url, next_url) = self.get_pages(next_url) - def test_custom_pagination_serializer(self): - request = APIRequestFactory().get('/foobar') - serializer = CustomPaginationSerializer( - instance=self.page, - context={'request': request} - ) - expected = { - 'links': { - 'next': 'http://testserver/foobar?page=2', - 'prev': None - }, - 'total_results': 4, - 'objects': ['john', 'paul'] - } - self.assertEqual(serializer.data, expected) + assert previous == [7, 7, 7, 8, 9] + assert current == [9, 9, 9, 9, 9] + assert next is None - def test_custom_pagination_serializer_with_custom_object_serializer(self): - objects = [ - {'foo': 'bar'}, - {'foo': 'spam'} - ] - paginator = Paginator(objects, 1) - page = paginator.page(1) - serializer = CustomFooPaginationSerializer(page) - serializer.data + (previous, current, next, previous_url, next_url) = self.get_pages(previous_url) + assert previous == [7, 7, 7, 7, 7] + assert current == [7, 7, 7, 8, 9] + assert next == [9, 9, 9, 9, 9] -class NonIntegerPage(object): + (previous, current, next, previous_url, next_url) = self.get_pages(previous_url) - def __init__(self, paginator, object_list, prev_token, token, next_token): - self.paginator = paginator - self.object_list = object_list - self.prev_token = prev_token - self.token = token - self.next_token = next_token + assert previous == [4, 4, 5, 6, 7] + assert current == [7, 7, 7, 7, 7] + assert next == [8, 9, 9, 9, 9] # Paging artifact - def has_next(self): - return not not self.next_token + (previous, current, next, previous_url, next_url) = self.get_pages(previous_url) - def next_page_number(self): - return self.next_token + assert previous == [1, 2, 3, 4, 4] + assert current == [4, 4, 5, 6, 7] + assert next == [7, 7, 7, 7, 7] - def has_previous(self): - return not not self.prev_token + (previous, current, next, previous_url, next_url) = self.get_pages(previous_url) - def previous_page_number(self): - return self.prev_token + assert previous == [1, 1, 1, 1, 1] + assert current == [1, 2, 3, 4, 4] + assert next == [4, 4, 5, 6, 7] + (previous, current, next, previous_url, next_url) = self.get_pages(previous_url) -class NonIntegerPaginator(object): + assert previous is None + assert current == [1, 1, 1, 1, 1] + assert next == [1, 2, 3, 4, 4] - def __init__(self, object_list, per_page): - self.object_list = object_list - self.per_page = per_page + assert isinstance(self.pagination.to_html(), type('')) - def count(self): - # pretend like we don't know how many pages we have - return None - def page(self, token=None): - if token: - try: - first = self.object_list.index(token) - except ValueError: - first = 0 - else: - first = 0 - n = len(self.object_list) - last = min(first + self.per_page, n) - prev_token = self.object_list[last - (2 * self.per_page)] if first else None - next_token = self.object_list[last] if last < n else None - return NonIntegerPage(self, self.object_list[first:last], prev_token, token, next_token) - - -class TestNonIntegerPagination(TestCase): - def test_custom_pagination_serializer(self): - objects = ['john', 'paul', 'george', 'ringo'] - paginator = NonIntegerPaginator(objects, 2) - - request = APIRequestFactory().get('/foobar') - serializer = CustomPaginationSerializer( - instance=paginator.page(), - context={'request': request} - ) - expected = { - 'links': { - 'next': 'http://testserver/foobar?page={0}'.format(objects[2]), - 'prev': None - }, - 'total_results': None, - 'objects': objects[:2] - } - self.assertEqual(serializer.data, expected) +def test_get_displayed_page_numbers(): + """ + Test our contextual page display function. - request = APIRequestFactory().get('/foobar') - serializer = CustomPaginationSerializer( - instance=paginator.page('george'), - context={'request': request} - ) - expected = { - 'links': { - 'next': None, - 'prev': 'http://testserver/foobar?page={0}'.format(objects[0]), - }, - 'total_results': None, - 'objects': objects[2:] - } - self.assertEqual(serializer.data, expected) + This determines which pages to display in a pagination control, + given the current page and the last page. + """ + displayed_page_numbers = pagination._get_displayed_page_numbers + + # At five pages or less, all pages are displayed, always. + assert displayed_page_numbers(1, 5) == [1, 2, 3, 4, 5] + assert displayed_page_numbers(2, 5) == [1, 2, 3, 4, 5] + assert displayed_page_numbers(3, 5) == [1, 2, 3, 4, 5] + assert displayed_page_numbers(4, 5) == [1, 2, 3, 4, 5] + assert displayed_page_numbers(5, 5) == [1, 2, 3, 4, 5] + + # Between six and either pages we may have a single page break. + assert displayed_page_numbers(1, 6) == [1, 2, 3, None, 6] + assert displayed_page_numbers(2, 6) == [1, 2, 3, None, 6] + assert displayed_page_numbers(3, 6) == [1, 2, 3, 4, 5, 6] + assert displayed_page_numbers(4, 6) == [1, 2, 3, 4, 5, 6] + assert displayed_page_numbers(5, 6) == [1, None, 4, 5, 6] + assert displayed_page_numbers(6, 6) == [1, None, 4, 5, 6] + + assert displayed_page_numbers(1, 7) == [1, 2, 3, None, 7] + assert displayed_page_numbers(2, 7) == [1, 2, 3, None, 7] + assert displayed_page_numbers(3, 7) == [1, 2, 3, 4, None, 7] + assert displayed_page_numbers(4, 7) == [1, 2, 3, 4, 5, 6, 7] + assert displayed_page_numbers(5, 7) == [1, None, 4, 5, 6, 7] + assert displayed_page_numbers(6, 7) == [1, None, 5, 6, 7] + assert displayed_page_numbers(7, 7) == [1, None, 5, 6, 7] + + assert displayed_page_numbers(1, 8) == [1, 2, 3, None, 8] + assert displayed_page_numbers(2, 8) == [1, 2, 3, None, 8] + assert displayed_page_numbers(3, 8) == [1, 2, 3, 4, None, 8] + assert displayed_page_numbers(4, 8) == [1, 2, 3, 4, 5, None, 8] + assert displayed_page_numbers(5, 8) == [1, None, 4, 5, 6, 7, 8] + assert displayed_page_numbers(6, 8) == [1, None, 5, 6, 7, 8] + assert displayed_page_numbers(7, 8) == [1, None, 6, 7, 8] + assert displayed_page_numbers(8, 8) == [1, None, 6, 7, 8] + + # At nine or more pages we may have two page breaks, one on each side. + assert displayed_page_numbers(1, 9) == [1, 2, 3, None, 9] + assert displayed_page_numbers(2, 9) == [1, 2, 3, None, 9] + assert displayed_page_numbers(3, 9) == [1, 2, 3, 4, None, 9] + assert displayed_page_numbers(4, 9) == [1, 2, 3, 4, 5, None, 9] + assert displayed_page_numbers(5, 9) == [1, None, 4, 5, 6, None, 9] + assert displayed_page_numbers(6, 9) == [1, None, 5, 6, 7, 8, 9] + assert displayed_page_numbers(7, 9) == [1, None, 6, 7, 8, 9] + assert displayed_page_numbers(8, 9) == [1, None, 7, 8, 9] + assert displayed_page_numbers(9, 9) == [1, None, 7, 8, 9] diff --git a/tests/test_parsers.py b/tests/test_parsers.py index 54455cf6..8816065a 100644 --- a/tests/test_parsers.py +++ b/tests/test_parsers.py @@ -101,7 +101,9 @@ class TestFileUploadParser(TestCase): self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8--ÀĥƦ.txt') filename = parser.get_filename(self.stream, None, self.parser_context) - self.assertEqual(filename, 'fallback.txt') + # Malformed. Either None or 'fallback.txt' will be acceptable. + # See also https://code.djangoproject.com/ticket/24209 + self.assertIn(filename, ('fallback.txt', None)) def __replace_content_disposition(self, disposition): self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = disposition diff --git a/tests/test_relations.py b/tests/test_relations.py index 62353dc2..fbe176e2 100644 --- a/tests/test_relations.py +++ b/tests/test_relations.py @@ -1,6 +1,8 @@ from .utils import mock_reverse, fail_reverse, BadType, MockObject, MockQueryset from django.core.exceptions import ImproperlyConfigured +from django.utils.datastructures import MultiValueDict from rest_framework import serializers +from rest_framework.fields import empty from rest_framework.test import APISimpleTestCase import pytest @@ -33,7 +35,7 @@ class TestPrimaryKeyRelatedField(APISimpleTestCase): with pytest.raises(serializers.ValidationError) as excinfo: self.field.to_internal_value(4) msg = excinfo.value.detail[0] - assert msg == "Invalid pk '4' - object does not exist." + assert msg == 'Invalid pk "4" - object does not exist.' def test_pk_related_lookup_invalid_type(self): with pytest.raises(serializers.ValidationError) as excinfo: @@ -134,3 +136,34 @@ class TestSlugRelatedField(APISimpleTestCase): def test_representation(self): representation = self.field.to_representation(self.instance) assert representation == self.instance.name + + +class TestManyRelatedField(APISimpleTestCase): + def setUp(self): + self.instance = MockObject(pk=1, name='foo') + self.field = serializers.StringRelatedField(many=True) + self.field.field_name = 'foo' + + def test_get_value_regular_dictionary_full(self): + assert 'bar' == self.field.get_value({'foo': 'bar'}) + assert empty == self.field.get_value({'baz': 'bar'}) + + def test_get_value_regular_dictionary_partial(self): + setattr(self.field.root, 'partial', True) + assert 'bar' == self.field.get_value({'foo': 'bar'}) + assert empty == self.field.get_value({'baz': 'bar'}) + + def test_get_value_multi_dictionary_full(self): + mvd = MultiValueDict({'foo': ['bar1', 'bar2']}) + assert ['bar1', 'bar2'] == self.field.get_value(mvd) + + mvd = MultiValueDict({'baz': ['bar1', 'bar2']}) + assert [] == self.field.get_value(mvd) + + def test_get_value_multi_dictionary_partial(self): + setattr(self.field.root, 'partial', True) + mvd = MultiValueDict({'foo': ['bar1', 'bar2']}) + assert ['bar1', 'bar2'] == self.field.get_value(mvd) + + mvd = MultiValueDict({'baz': ['bar1', 'bar2']}) + assert empty == self.field.get_value(mvd) diff --git a/tests/test_relations_hyperlink.py b/tests/test_relations_hyperlink.py index f1b882ed..aede61d2 100644 --- a/tests/test_relations_hyperlink.py +++ b/tests/test_relations_hyperlink.py @@ -1,5 +1,5 @@ from __future__ import unicode_literals -from django.conf.urls import patterns, url +from django.conf.urls import url from django.test import TestCase from rest_framework import serializers from rest_framework.test import APIRequestFactory @@ -14,8 +14,7 @@ request = factory.get('/') # Just to ensure we have a request in the serializer dummy_view = lambda request, pk: None -urlpatterns = patterns( - '', +urlpatterns = [ url(r'^dummyurl/(?P<pk>[0-9]+)/$', dummy_view, name='dummy-url'), url(r'^manytomanysource/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanysource-detail'), url(r'^manytomanytarget/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanytarget-detail'), @@ -24,7 +23,7 @@ urlpatterns = patterns( url(r'^nullableforeignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableforeignkeysource-detail'), url(r'^onetoonetarget/(?P<pk>[0-9]+)/$', dummy_view, name='onetoonetarget-detail'), url(r'^nullableonetoonesource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableonetoonesource-detail'), -) +] # ManyToMany diff --git a/tests/test_renderers.py b/tests/test_renderers.py index 7b78f7ba..f68405f0 100644 --- a/tests/test_renderers.py +++ b/tests/test_renderers.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- from __future__ import unicode_literals - from django.conf.urls import patterns, url, include from django.core.cache import cache from django.db import models @@ -8,6 +7,7 @@ from django.test import TestCase from django.utils import six from django.utils.translation import ugettext_lazy as _ from rest_framework import status, permissions +from rest_framework.compat import OrderedDict from rest_framework.response import Response from rest_framework.views import APIView from rest_framework.renderers import BaseRenderer, JSONRenderer, BrowsableAPIRenderer @@ -15,7 +15,6 @@ from rest_framework.settings import api_settings from rest_framework.test import APIRequestFactory from collections import MutableMapping import json -import pickle import re @@ -408,84 +407,46 @@ class CacheRenderTest(TestCase): urls = 'tests.test_renderers' - cache_key = 'just_a_cache_key' - - @classmethod - def _get_pickling_errors(cls, obj, seen=None): - """ Return any errors that would be raised if `obj' is pickled - Courtesy of koffie @ http://stackoverflow.com/a/7218986/109897 - """ - if seen is None: - seen = [] - try: - state = obj.__getstate__() - except AttributeError: - return - if state is None: - return - if isinstance(state, tuple): - if not isinstance(state[0], dict): - state = state[1] - else: - state = state[0].update(state[1]) - result = {} - for i in state: - try: - pickle.dumps(state[i], protocol=2) - except pickle.PicklingError: - if not state[i] in seen: - seen.append(state[i]) - result[i] = cls._get_pickling_errors(state[i], seen) - return result - - def http_resp(self, http_method, url): - """ - Simple wrapper for Client http requests - Removes the `client' and `request' attributes from as they are - added by django.test.client.Client and not part of caching - responses outside of tests. - """ - method = getattr(self.client, http_method) - resp = method(url) - resp._closable_objects = [] - del resp.client, resp.request - try: - del resp.wsgi_request - except AttributeError: - pass - return resp - - def test_obj_pickling(self): - """ - Test that responses are properly pickled - """ - resp = self.http_resp('get', '/cache') - - # Make sure that no pickling errors occurred - self.assertEqual(self._get_pickling_errors(resp), {}) - - # Unfortunately LocMem backend doesn't raise PickleErrors but returns - # None instead. - cache.set(self.cache_key, resp) - self.assertTrue(cache.get(self.cache_key) is not None) - def test_head_caching(self): """ Test caching of HEAD requests """ - resp = self.http_resp('head', '/cache') - cache.set(self.cache_key, resp) - - cached_resp = cache.get(self.cache_key) - self.assertIsInstance(cached_resp, Response) + response = self.client.head('/cache') + cache.set('key', response) + cached_response = cache.get('key') + assert isinstance(cached_response, Response) + assert cached_response.content == response.content + assert cached_response.status_code == response.status_code def test_get_caching(self): """ Test caching of GET requests """ - resp = self.http_resp('get', '/cache') - cache.set(self.cache_key, resp) + response = self.client.get('/cache') + cache.set('key', response) + cached_response = cache.get('key') + assert isinstance(cached_response, Response) + assert cached_response.content == response.content + assert cached_response.status_code == response.status_code + + +class TestJSONIndentationStyles: + def test_indented(self): + renderer = JSONRenderer() + data = OrderedDict([('a', 1), ('b', 2)]) + assert renderer.render(data) == b'{"a":1,"b":2}' - cached_resp = cache.get(self.cache_key) - self.assertIsInstance(cached_resp, Response) - self.assertEqual(cached_resp.content, resp.content) + def test_compact(self): + renderer = JSONRenderer() + data = OrderedDict([('a', 1), ('b', 2)]) + context = {'indent': 4} + assert ( + renderer.render(data, renderer_context=context) == + b'{\n "a": 1,\n "b": 2\n}' + ) + + def test_long_form(self): + renderer = JSONRenderer() + renderer.compact = False + data = OrderedDict([('a', 1), ('b', 2)]) + assert renderer.render(data) == b'{"a": 1, "b": 2}' diff --git a/tests/test_routers.py b/tests/test_routers.py index 34306146..948c69bb 100644 --- a/tests/test_routers.py +++ b/tests/test_routers.py @@ -1,17 +1,53 @@ from __future__ import unicode_literals -from django.conf.urls import patterns, url, include +from django.conf.urls import url, include from django.db import models from django.test import TestCase from django.core.exceptions import ImproperlyConfigured -from rest_framework import serializers, viewsets, mixins, permissions +from rest_framework import serializers, viewsets, permissions from rest_framework.decorators import detail_route, list_route from rest_framework.response import Response from rest_framework.routers import SimpleRouter, DefaultRouter from rest_framework.test import APIRequestFactory +from collections import namedtuple factory = APIRequestFactory() -urlpatterns = patterns('',) + +class RouterTestModel(models.Model): + uuid = models.CharField(max_length=20) + text = models.CharField(max_length=200) + + +class NoteSerializer(serializers.HyperlinkedModelSerializer): + url = serializers.HyperlinkedIdentityField(view_name='routertestmodel-detail', lookup_field='uuid') + + class Meta: + model = RouterTestModel + fields = ('url', 'uuid', 'text') + + +class NoteViewSet(viewsets.ModelViewSet): + queryset = RouterTestModel.objects.all() + serializer_class = NoteSerializer + lookup_field = 'uuid' + + +class MockViewSet(viewsets.ModelViewSet): + queryset = None + serializer_class = None + + +notes_router = SimpleRouter() +notes_router.register(r'notes', NoteViewSet) + +namespaced_router = DefaultRouter() +namespaced_router.register(r'example', MockViewSet, base_name='example') + +urlpatterns = [ + url(r'^non-namespaced/', include(namespaced_router.urls)), + url(r'^namespaced/', include(namespaced_router.urls, namespace='example')), + url(r'^example/', include(notes_router.urls)), +] class BasicViewSet(viewsets.ViewSet): @@ -63,9 +99,26 @@ class TestSimpleRouter(TestCase): self.assertEqual(route.mapping[method], endpoint) -class RouterTestModel(models.Model): - uuid = models.CharField(max_length=20) - text = models.CharField(max_length=200) +class TestRootView(TestCase): + urls = 'tests.test_routers' + + def test_retrieve_namespaced_root(self): + response = self.client.get('/namespaced/') + self.assertEqual( + response.data, + { + "example": "http://testserver/namespaced/example/", + } + ) + + def test_retrieve_non_namespaced_root(self): + response = self.client.get('/non-namespaced/') + self.assertEqual( + response.data, + { + "example": "http://testserver/non-namespaced/example/", + } + ) class TestCustomLookupFields(TestCase): @@ -75,51 +128,29 @@ class TestCustomLookupFields(TestCase): urls = 'tests.test_routers' def setUp(self): - class NoteSerializer(serializers.HyperlinkedModelSerializer): - url = serializers.HyperlinkedIdentityField(view_name='routertestmodel-detail', lookup_field='uuid') - - class Meta: - model = RouterTestModel - fields = ('url', 'uuid', 'text') - - class NoteViewSet(viewsets.ModelViewSet): - queryset = RouterTestModel.objects.all() - serializer_class = NoteSerializer - lookup_field = 'uuid' - - self.router = SimpleRouter() - self.router.register(r'notes', NoteViewSet) - - from tests import test_routers - urls = getattr(test_routers, 'urlpatterns') - urls += patterns( - '', - url(r'^', include(self.router.urls)), - ) - RouterTestModel.objects.create(uuid='123', text='foo bar') def test_custom_lookup_field_route(self): - detail_route = self.router.urls[-1] + detail_route = notes_router.urls[-1] detail_url_pattern = detail_route.regex.pattern self.assertIn('<uuid>', detail_url_pattern) def test_retrieve_lookup_field_list_view(self): - response = self.client.get('/notes/') + response = self.client.get('/example/notes/') self.assertEqual( response.data, [{ - "url": "http://testserver/notes/123/", + "url": "http://testserver/example/notes/123/", "uuid": "123", "text": "foo bar" }] ) def test_retrieve_lookup_field_detail_view(self): - response = self.client.get('/notes/123/') + response = self.client.get('/example/notes/123/') self.assertEqual( response.data, { - "url": "http://testserver/notes/123/", + "url": "http://testserver/example/notes/123/", "uuid": "123", "text": "foo bar" } ) @@ -149,7 +180,7 @@ class TestLookupValueRegex(TestCase): class TestTrailingSlashIncluded(TestCase): def setUp(self): class NoteViewSet(viewsets.ModelViewSet): - model = RouterTestModel + queryset = RouterTestModel.objects.all() self.router = SimpleRouter() self.router.register(r'notes', NoteViewSet) @@ -164,7 +195,7 @@ class TestTrailingSlashIncluded(TestCase): class TestTrailingSlashRemoved(TestCase): def setUp(self): class NoteViewSet(viewsets.ModelViewSet): - model = RouterTestModel + queryset = RouterTestModel.objects.all() self.router = SimpleRouter(trailing_slash=False) self.router.register(r'notes', NoteViewSet) @@ -179,7 +210,8 @@ class TestTrailingSlashRemoved(TestCase): class TestNameableRoot(TestCase): def setUp(self): class NoteViewSet(viewsets.ModelViewSet): - model = RouterTestModel + queryset = RouterTestModel.objects.all() + self.router = DefaultRouter() self.router.root_view_name = 'nameable-root' self.router.register(r'notes', NoteViewSet) @@ -261,6 +293,14 @@ class DynamicListAndDetailViewSet(viewsets.ViewSet): def detail_route_get(self, request, *args, **kwargs): return Response({'method': 'link2'}) + @list_route(url_path="list_custom-route") + def list_custom_route_get(self, request, *args, **kwargs): + return Response({'method': 'link1'}) + + @detail_route(url_path="detail_custom-route") + def detail_custom_route_get(self, request, *args, **kwargs): + return Response({'method': 'link2'}) + class TestDynamicListAndDetailRouter(TestCase): def setUp(self): @@ -269,35 +309,30 @@ class TestDynamicListAndDetailRouter(TestCase): def test_list_and_detail_route_decorators(self): routes = self.router.get_routes(DynamicListAndDetailViewSet) decorator_routes = [r for r in routes if not (r.name.endswith('-list') or r.name.endswith('-detail'))] + + MethodNamesMap = namedtuple('MethodNamesMap', 'method_name url_path') # Make sure all these endpoints exist and none have been clobbered - for i, endpoint in enumerate(['list_route_get', 'list_route_post', 'detail_route_get', 'detail_route_post']): + for i, endpoint in enumerate([MethodNamesMap('list_custom_route_get', 'list_custom-route'), + MethodNamesMap('list_route_get', 'list_route_get'), + MethodNamesMap('list_route_post', 'list_route_post'), + MethodNamesMap('detail_custom_route_get', 'detail_custom-route'), + MethodNamesMap('detail_route_get', 'detail_route_get'), + MethodNamesMap('detail_route_post', 'detail_route_post') + ]): route = decorator_routes[i] # check url listing - if endpoint.startswith('list_'): + method_name = endpoint.method_name + url_path = endpoint.url_path + + if method_name.startswith('list_'): self.assertEqual(route.url, - '^{{prefix}}/{0}{{trailing_slash}}$'.format(endpoint)) + '^{{prefix}}/{0}{{trailing_slash}}$'.format(url_path)) else: self.assertEqual(route.url, - '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(endpoint)) + '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(url_path)) # check method to function mapping - if endpoint.endswith('_post'): + if method_name.endswith('_post'): method_map = 'post' else: method_map = 'get' - self.assertEqual(route.mapping[method_map], endpoint) - - -class TestRootWithAListlessViewset(TestCase): - def setUp(self): - class NoteViewSet(mixins.RetrieveModelMixin, - viewsets.GenericViewSet): - model = RouterTestModel - - self.router = DefaultRouter() - self.router.register(r'notes', NoteViewSet) - self.view = self.router.urls[0].callback - - def test_api_root(self): - request = factory.get('/') - response = self.view(request) - self.assertEqual(response.data, {}) + self.assertEqual(route.mapping[method_map], method_name) diff --git a/tests/test_serializer.py b/tests/test_serializer.py index c17b6d8c..b7a0484b 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -1,7 +1,9 @@ # coding: utf-8 from __future__ import unicode_literals +from .utils import MockObject from rest_framework import serializers from rest_framework.compat import unicode_repr +import pickle import pytest @@ -216,3 +218,80 @@ class TestUnicodeRepr: instance = ExampleObject() serializer = ExampleSerializer(instance) repr(serializer) # Should not error. + + +class TestNotRequiredOutput: + def test_not_required_output_for_dict(self): + """ + 'required=False' should allow a dictionary key to be missing in output. + """ + class ExampleSerializer(serializers.Serializer): + omitted = serializers.CharField(required=False) + included = serializers.CharField() + + serializer = ExampleSerializer(data={'included': 'abc'}) + serializer.is_valid() + assert serializer.data == {'included': 'abc'} + + def test_not_required_output_for_object(self): + """ + 'required=False' should allow an object attribute to be missing in output. + """ + class ExampleSerializer(serializers.Serializer): + omitted = serializers.CharField(required=False) + included = serializers.CharField() + + def create(self, validated_data): + return MockObject(**validated_data) + + serializer = ExampleSerializer(data={'included': 'abc'}) + serializer.is_valid() + serializer.save() + assert serializer.data == {'included': 'abc'} + + def test_default_required_output_for_dict(self): + """ + 'default="something"' should require dictionary key. + + We need to handle this as the field will have an implicit + 'required=False', but it should still have a value. + """ + class ExampleSerializer(serializers.Serializer): + omitted = serializers.CharField(default='abc') + included = serializers.CharField() + + serializer = ExampleSerializer({'included': 'abc'}) + with pytest.raises(KeyError): + serializer.data + + def test_default_required_output_for_object(self): + """ + 'default="something"' should require object attribute. + + We need to handle this as the field will have an implicit + 'required=False', but it should still have a value. + """ + class ExampleSerializer(serializers.Serializer): + omitted = serializers.CharField(default='abc') + included = serializers.CharField() + + instance = MockObject(included='abc') + serializer = ExampleSerializer(instance) + with pytest.raises(AttributeError): + serializer.data + + +class TestCacheSerializerData: + def test_cache_serializer_data(self): + """ + Caching serializer data with pickle will drop the serializer info, + but does preserve the data itself. + """ + class ExampleSerializer(serializers.Serializer): + field1 = serializers.CharField() + field2 = serializers.CharField() + + serializer = ExampleSerializer({'field1': 'a', 'field2': 'b'}) + pickled = pickle.dumps(serializer.data) + data = pickle.loads(pickled) + assert data == {'field1': 'a', 'field2': 'b'} diff --git a/tests/test_serializer_bulk_update.py b/tests/test_serializer_bulk_update.py index fb881a75..bc955b2e 100644 --- a/tests/test_serializer_bulk_update.py +++ b/tests/test_serializer_bulk_update.py @@ -101,7 +101,7 @@ class BulkCreateSerializerTests(TestCase): serializer = self.BookSerializer(data=data, many=True) self.assertEqual(serializer.is_valid(), False) - expected_errors = {'non_field_errors': ['Expected a list of items but got type `int`.']} + expected_errors = {'non_field_errors': ['Expected a list of items but got type "int".']} self.assertEqual(serializer.errors, expected_errors) @@ -118,6 +118,6 @@ class BulkCreateSerializerTests(TestCase): serializer = self.BookSerializer(data=data, many=True) self.assertEqual(serializer.is_valid(), False) - expected_errors = {'non_field_errors': ['Expected a list of items but got type `dict`.']} + expected_errors = {'non_field_errors': ['Expected a list of items but got type "dict".']} self.assertEqual(serializer.errors, expected_errors) diff --git a/tests/test_versioning.py b/tests/test_versioning.py index c44f727d..553463d1 100644 --- a/tests/test_versioning.py +++ b/tests/test_versioning.py @@ -1,9 +1,13 @@ +from .utils import UsingURLPatterns from django.conf.urls import include, url +from rest_framework import serializers from rest_framework import status, versioning from rest_framework.decorators import APIView from rest_framework.response import Response from rest_framework.reverse import reverse from rest_framework.test import APIRequestFactory, APITestCase +from rest_framework.versioning import NamespaceVersioning +import pytest class RequestVersionView(APIView): @@ -28,17 +32,8 @@ class RequestInvalidVersionView(APIView): factory = APIRequestFactory() -mock_view = lambda request: None - -included_patterns = [ - url(r'^namespaced/$', mock_view, name='another'), -] - -urlpatterns = [ - url(r'^v1/', include(included_patterns, namespace='v1')), - url(r'^another/$', mock_view, name='another'), - url(r'^(?P<version>[^/]+)/another/$', mock_view, name='another') -] +dummy_view = lambda request: None +dummy_pk_view = lambda request, pk: None class TestRequestVersion: @@ -114,8 +109,17 @@ class TestRequestVersion: assert response.data == {'version': None} -class TestURLReversing(APITestCase): - urls = 'tests.test_versioning' +class TestURLReversing(UsingURLPatterns, APITestCase): + included = [ + url(r'^namespaced/$', dummy_view, name='another'), + url(r'^example/(?P<pk>\d+)/$', dummy_pk_view, name='example-detail') + ] + + urlpatterns = [ + url(r'^v1/', include(included, namespace='v1')), + url(r'^another/$', dummy_view, name='another'), + url(r'^(?P<version>[^/]+)/another/$', dummy_view, name='another'), + ] def test_reverse_unversioned(self): view = ReverseView.as_view() @@ -221,3 +225,35 @@ class TestInvalidVersion: request.resolver_match = FakeResolverMatch response = view(request, version='v3') assert response.status_code == status.HTTP_404_NOT_FOUND + + +class TestHyperlinkedRelatedField(UsingURLPatterns, APITestCase): + included = [ + url(r'^namespaced/(?P<pk>\d+)/$', dummy_view, name='namespaced'), + ] + + urlpatterns = [ + url(r'^v1/', include(included, namespace='v1')), + url(r'^v2/', include(included, namespace='v2')) + ] + + def setUp(self): + super(TestHyperlinkedRelatedField, self).setUp() + + class MockQueryset(object): + def get(self, pk): + return 'object %s' % pk + + self.field = serializers.HyperlinkedRelatedField( + view_name='namespaced', + queryset=MockQueryset() + ) + request = factory.get('/') + request.versioning_scheme = NamespaceVersioning() + request.version = 'v1' + self.field._context = {'request': request} + + def test_bug_2489(self): + assert self.field.to_internal_value('/v1/namespaced/3/') == 'object 3' + with pytest.raises(serializers.ValidationError): + self.field.to_internal_value('/v2/namespaced/3/') diff --git a/tests/utils.py b/tests/utils.py index 5e902ba9..b9034996 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,30 +1,29 @@ -from contextlib import contextmanager from django.core.exceptions import ObjectDoesNotExist from django.core.urlresolvers import NoReverseMatch -from django.utils import six -from rest_framework.settings import api_settings -@contextmanager -def temporary_setting(setting, value, module=None): +class UsingURLPatterns(object): """ - Temporarily change value of setting for test. + Isolates URL patterns used during testing on the test class itself. + For example: - Optionally reload given module, useful when module uses value of setting on - import. - """ - original_value = getattr(api_settings, setting) - setattr(api_settings, setting, value) - - if module is not None: - six.moves.reload_module(module) + class MyTestCase(UsingURLPatterns, TestCase): + urlpatterns = [ + ... + ] - yield + def test_something(self): + ... + """ + urls = __name__ - setattr(api_settings, setting, original_value) + def setUp(self): + global urlpatterns + urlpatterns = self.urlpatterns - if module is not None: - six.moves.reload_module(module) + def tearDown(self): + global urlpatterns + urlpatterns = [] class MockObject(object): |
