diff options
| author | Tom Christie | 2015-02-06 14:35:06 +0000 | 
|---|---|---|
| committer | Tom Christie | 2015-02-06 14:35:06 +0000 | 
| commit | 3dff9a4fe2952cf632ca7f4cd9ecf4221059ca91 (patch) | |
| tree | 0649d42b20b875e97cb551b987644b61e7860e84 /tests | |
| parent | c06a82d0531f4cb290baacee196829c770913eaa (diff) | |
| parent | 1f996128458570a909d13f15c3d739fb12111984 (diff) | |
| download | django-rest-framework-model-serializer-caching.tar.bz2 | |
Resolve merge conflictmodel-serializer-caching
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/browsable_api/auth_urls.py | 1 | ||||
| -rw-r--r-- | tests/test_fields.py | 129 | ||||
| -rw-r--r-- | tests/test_filters.py | 17 | ||||
| -rw-r--r-- | tests/test_generics.py | 8 | ||||
| -rw-r--r-- | tests/test_htmlrenderer.py | 6 | ||||
| -rw-r--r-- | tests/test_metadata.py | 60 | ||||
| -rw-r--r-- | tests/test_model_serializer.py | 54 | ||||
| -rw-r--r-- | tests/test_multitable_inheritance.py | 4 | ||||
| -rw-r--r-- | tests/test_pagination.py | 1048 | ||||
| -rw-r--r-- | tests/test_parsers.py | 4 | ||||
| -rw-r--r-- | tests/test_relations.py | 35 | ||||
| -rw-r--r-- | tests/test_relations_hyperlink.py | 7 | ||||
| -rw-r--r-- | tests/test_renderers.py | 107 | ||||
| -rw-r--r-- | tests/test_routers.py | 151 | ||||
| -rw-r--r-- | tests/test_serializer.py | 79 | ||||
| -rw-r--r-- | tests/test_serializer_bulk_update.py | 4 | ||||
| -rw-r--r-- | tests/test_versioning.py | 62 | ||||
| -rw-r--r-- | tests/utils.py | 35 | 
18 files changed, 1115 insertions, 696 deletions
| diff --git a/tests/browsable_api/auth_urls.py b/tests/browsable_api/auth_urls.py index bce7dcf9..97bc1036 100644 --- a/tests/browsable_api/auth_urls.py +++ b/tests/browsable_api/auth_urls.py @@ -3,6 +3,7 @@ from django.conf.urls import patterns, url, include  from .views import MockView +  urlpatterns = patterns(      '',      (r'^$', MockView.as_view()), diff --git a/tests/test_fields.py b/tests/test_fields.py index 04c721d3..48ada780 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -4,6 +4,7 @@ from rest_framework import serializers  import datetime  import django  import pytest +import uuid  # Tests for field keyword arguments and core functionality. @@ -223,8 +224,8 @@ class MockHTMLDict(dict):      getlist = None -class TestCharHTMLInput: -    def test_empty_html_checkbox(self): +class TestHTMLInput: +    def test_empty_html_charfield(self):          class TestSerializer(serializers.Serializer):              message = serializers.CharField(default='happy') @@ -232,23 +233,31 @@ class TestCharHTMLInput:          assert serializer.is_valid()          assert serializer.validated_data == {'message': 'happy'} -    def test_empty_html_checkbox_allow_null(self): +    def test_empty_html_charfield_allow_null(self):          class TestSerializer(serializers.Serializer):              message = serializers.CharField(allow_null=True) -        serializer = TestSerializer(data=MockHTMLDict()) +        serializer = TestSerializer(data=MockHTMLDict({'message': ''}))          assert serializer.is_valid()          assert serializer.validated_data == {'message': None} -    def test_empty_html_checkbox_allow_null_allow_blank(self): +    def test_empty_html_datefield_allow_null(self): +        class TestSerializer(serializers.Serializer): +            expiry = serializers.DateField(allow_null=True) + +        serializer = TestSerializer(data=MockHTMLDict({'expiry': ''})) +        assert serializer.is_valid() +        assert serializer.validated_data == {'expiry': None} + +    def test_empty_html_charfield_allow_null_allow_blank(self):          class TestSerializer(serializers.Serializer):              message = serializers.CharField(allow_null=True, allow_blank=True) -        serializer = TestSerializer(data=MockHTMLDict({})) +        serializer = TestSerializer(data=MockHTMLDict({'message': ''}))          assert serializer.is_valid()          assert serializer.validated_data == {'message': ''} -    def test_empty_html_required_false(self): +    def test_empty_html_charfield_required_false(self):          class TestSerializer(serializers.Serializer):              message = serializers.CharField(required=False) @@ -338,7 +347,7 @@ class TestBooleanField(FieldValues):          False: False,      }      invalid_inputs = { -        'foo': ['`foo` is not a valid boolean.'], +        'foo': ['"foo" is not a valid boolean.'],          None: ['This field may not be null.']      }      outputs = { @@ -368,7 +377,7 @@ class TestNullBooleanField(FieldValues):          None: None      }      invalid_inputs = { -        'foo': ['`foo` is not a valid boolean.'], +        'foo': ['"foo" is not a valid boolean.'],      }      outputs = {          'true': True, @@ -439,7 +448,7 @@ class TestSlugField(FieldValues):          'slug-99': 'slug-99',      }      invalid_inputs = { -        'slug 99': ["Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens."] +        'slug 99': ['Enter a valid "slug" consisting of letters, numbers, underscores or hyphens.']      }      outputs = {}      field = serializers.SlugField() @@ -459,6 +468,23 @@ class TestURLField(FieldValues):      field = serializers.URLField() +class TestUUIDField(FieldValues): +    """ +    Valid and invalid values for `UUIDField`. +    """ +    valid_inputs = { +        '825d7aeb-05a9-45b5-a5b7-05df87923cda': uuid.UUID('825d7aeb-05a9-45b5-a5b7-05df87923cda'), +        '825d7aeb05a945b5a5b705df87923cda': uuid.UUID('825d7aeb-05a9-45b5-a5b7-05df87923cda') +    } +    invalid_inputs = { +        '825d7aeb-05a9-45b5-a5b7': ['"825d7aeb-05a9-45b5-a5b7" is not a valid UUID.'] +    } +    outputs = { +        uuid.UUID('825d7aeb-05a9-45b5-a5b7-05df87923cda'): '825d7aeb-05a9-45b5-a5b7-05df87923cda' +    } +    field = serializers.UUIDField() + +  # Number types...  class TestIntegerField(FieldValues): @@ -640,8 +666,8 @@ class TestDateField(FieldValues):          datetime.date(2001, 1, 1): datetime.date(2001, 1, 1),      }      invalid_inputs = { -        'abc': ['Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]'], -        '2001-99-99': ['Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]]'], +        'abc': ['Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]].'], +        '2001-99-99': ['Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]].'],          datetime.datetime(2001, 1, 1, 12, 00): ['Expected a date but got a datetime.'],      }      outputs = { @@ -658,7 +684,7 @@ class TestCustomInputFormatDateField(FieldValues):          '1 Jan 2001': datetime.date(2001, 1, 1),      }      invalid_inputs = { -        '2001-01-01': ['Date has wrong format. Use one of these formats instead: DD [Jan-Dec] YYYY'] +        '2001-01-01': ['Date has wrong format. Use one of these formats instead: DD [Jan-Dec] YYYY.']      }      outputs = {}      field = serializers.DateField(input_formats=['%d %b %Y']) @@ -702,8 +728,8 @@ class TestDateTimeField(FieldValues):          '2001-01-01T14:00+01:00' if (django.VERSION > (1, 4)) else '2001-01-01T13:00Z': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC())      }      invalid_inputs = { -        'abc': ['Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]'], -        '2001-99-99T99:00': ['Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]'], +        'abc': ['Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z].'], +        '2001-99-99T99:00': ['Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z].'],          datetime.date(2001, 1, 1): ['Expected a datetime but got a date.'],      }      outputs = { @@ -721,7 +747,7 @@ class TestCustomInputFormatDateTimeField(FieldValues):          '1:35pm, 1 Jan 2001': datetime.datetime(2001, 1, 1, 13, 35, tzinfo=timezone.UTC()),      }      invalid_inputs = { -        '2001-01-01T20:50': ['Datetime has wrong format. Use one of these formats instead: hh:mm[AM|PM], DD [Jan-Dec] YYYY'] +        '2001-01-01T20:50': ['Datetime has wrong format. Use one of these formats instead: hh:mm[AM|PM], DD [Jan-Dec] YYYY.']      }      outputs = {}      field = serializers.DateTimeField(default_timezone=timezone.UTC(), input_formats=['%I:%M%p, %d %b %Y']) @@ -773,8 +799,8 @@ class TestTimeField(FieldValues):          datetime.time(13, 00): datetime.time(13, 00),      }      invalid_inputs = { -        'abc': ['Time has wrong format. Use one of these formats instead: hh:mm[:ss[.uuuuuu]]'], -        '99:99': ['Time has wrong format. Use one of these formats instead: hh:mm[:ss[.uuuuuu]]'], +        'abc': ['Time has wrong format. Use one of these formats instead: hh:mm[:ss[.uuuuuu]].'], +        '99:99': ['Time has wrong format. Use one of these formats instead: hh:mm[:ss[.uuuuuu]].'],      }      outputs = {          datetime.time(13, 00): '13:00:00' @@ -790,7 +816,7 @@ class TestCustomInputFormatTimeField(FieldValues):          '1:00pm': datetime.time(13, 00),      }      invalid_inputs = { -        '13:00': ['Time has wrong format. Use one of these formats instead: hh:mm[AM|PM]'], +        '13:00': ['Time has wrong format. Use one of these formats instead: hh:mm[AM|PM].'],      }      outputs = {}      field = serializers.TimeField(input_formats=['%I:%M%p']) @@ -832,7 +858,7 @@ class TestChoiceField(FieldValues):          'good': 'good',      }      invalid_inputs = { -        'amazing': ['`amazing` is not a valid choice.'] +        'amazing': ['"amazing" is not a valid choice.']      }      outputs = {          'good': 'good', @@ -872,8 +898,8 @@ class TestChoiceFieldWithType(FieldValues):          3: 3,      }      invalid_inputs = { -        5: ['`5` is not a valid choice.'], -        'abc': ['`abc` is not a valid choice.'] +        5: ['"5" is not a valid choice.'], +        'abc': ['"abc" is not a valid choice.']      }      outputs = {          '1': 1, @@ -899,7 +925,7 @@ class TestChoiceFieldWithListChoices(FieldValues):          'good': 'good',      }      invalid_inputs = { -        'awful': ['`awful` is not a valid choice.'] +        'awful': ['"awful" is not a valid choice.']      }      outputs = {          'good': 'good' @@ -917,8 +943,8 @@ class TestMultipleChoiceField(FieldValues):          ('aircon', 'manual'): set(['aircon', 'manual']),      }      invalid_inputs = { -        'abc': ['Expected a list of items but got type `str`.'], -        ('aircon', 'incorrect'): ['`incorrect` is not a valid choice.'] +        'abc': ['Expected a list of items but got type "str".'], +        ('aircon', 'incorrect'): ['"incorrect" is not a valid choice.']      }      outputs = [          (['aircon', 'manual'], set(['aircon', 'manual'])) @@ -1021,14 +1047,14 @@ class TestValidImageField(FieldValues):  class TestListField(FieldValues):      """ -    Values for `ListField`. +    Values for `ListField` with IntegerField as child.      """      valid_inputs = [          ([1, 2, 3], [1, 2, 3]),          (['1', '2', '3'], [1, 2, 3])      ]      invalid_inputs = [ -        ('not a list', ['Expected a list of items but got type `str`']), +        ('not a list', ['Expected a list of items but got type "str".']),          ([1, 2, 'error'], ['A valid integer is required.'])      ]      outputs = [ @@ -1038,6 +1064,55 @@ class TestListField(FieldValues):      field = serializers.ListField(child=serializers.IntegerField()) +class TestUnvalidatedListField(FieldValues): +    """ +    Values for `ListField` with no `child` argument. +    """ +    valid_inputs = [ +        ([1, '2', True, [4, 5, 6]], [1, '2', True, [4, 5, 6]]), +    ] +    invalid_inputs = [ +        ('not a list', ['Expected a list of items but got type "str".']), +    ] +    outputs = [ +        ([1, '2', True, [4, 5, 6]], [1, '2', True, [4, 5, 6]]), +    ] +    field = serializers.ListField() + + +class TestDictField(FieldValues): +    """ +    Values for `ListField` with CharField as child. +    """ +    valid_inputs = [ +        ({'a': 1, 'b': '2', 3: 3}, {'a': '1', 'b': '2', '3': '3'}), +    ] +    invalid_inputs = [ +        ({'a': 1, 'b': None}, ['This field may not be null.']), +        ('not a dict', ['Expected a dictionary of items but got type "str".']), +    ] +    outputs = [ +        ({'a': 1, 'b': '2', 3: 3}, {'a': '1', 'b': '2', '3': '3'}), +    ] +    field = serializers.DictField(child=serializers.CharField()) + + +class TestUnvalidatedDictField(FieldValues): +    """ +    Values for `ListField` with no `child` argument. +    """ +    valid_inputs = [ +        ({'a': 1, 'b': [4, 5, 6], 1: 123}, {'a': 1, 'b': [4, 5, 6], '1': 123}), +    ] +    invalid_inputs = [ +        ('not a dict', ['Expected a dictionary of items but got type "str".']), +    ] +    outputs = [ +        ({'a': 1, 'b': [4, 5, 6]}, {'a': 1, 'b': [4, 5, 6]}), +    ] +    field = serializers.DictField() + +  # Tests for FieldField.  # --------------------- diff --git a/tests/test_filters.py b/tests/test_filters.py index dc84dcbd..355f02ce 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -5,13 +5,15 @@ from django.db import models  from django.conf.urls import patterns, url  from django.core.urlresolvers import reverse  from django.test import TestCase +from django.test.utils import override_settings  from django.utils import unittest  from django.utils.dateparse import parse_date +from django.utils.six.moves import reload_module  from rest_framework import generics, serializers, status, filters  from rest_framework.compat import django_filters  from rest_framework.test import APIRequestFactory  from .models import BaseFilterableItem, FilterableItem, BasicModel -from .utils import temporary_setting +  factory = APIRequestFactory() @@ -404,7 +406,9 @@ class SearchFilterTests(TestCase):          )      def test_search_with_nonstandard_search_param(self): -        with temporary_setting('SEARCH_PARAM', 'query', module=filters): +        with override_settings(REST_FRAMEWORK={'SEARCH_PARAM': 'query'}): +            reload_module(filters) +              class SearchListView(generics.ListAPIView):                  queryset = SearchFilterModel.objects.all()                  serializer_class = SearchFilterSerializer @@ -422,6 +426,8 @@ class SearchFilterTests(TestCase):                  ]              ) +        reload_module(filters) +  class OrderingFilterModel(models.Model):      title = models.CharField(max_length=20) @@ -467,6 +473,7 @@ class DjangoFilterOrderingTests(TestCase):          for d in data:              DjangoFilterOrderingModel.objects.create(**d) +    @unittest.skipUnless(django_filters, 'django-filter not installed')      def test_default_ordering(self):          class DjangoFilterOrderingView(generics.ListAPIView):              serializer_class = DjangoFilterOrderingSerializer @@ -641,7 +648,9 @@ class OrderingFilterTests(TestCase):          )      def test_ordering_with_nonstandard_ordering_param(self): -        with temporary_setting('ORDERING_PARAM', 'order', filters): +        with override_settings(REST_FRAMEWORK={'ORDERING_PARAM': 'order'}): +            reload_module(filters) +              class OrderingListView(generics.ListAPIView):                  queryset = OrderingFilterModel.objects.all()                  serializer_class = OrderingFilterSerializer @@ -661,6 +670,8 @@ class OrderingFilterTests(TestCase):                  ]              ) +        reload_module(filters) +  class SensitiveOrderingFilterModel(models.Model):      username = models.CharField(max_length=20) diff --git a/tests/test_generics.py b/tests/test_generics.py index 94023c30..88e792ce 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -117,7 +117,7 @@ class TestRootView(TestCase):          with self.assertNumQueries(0):              response = self.view(request).render()          self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) -        self.assertEqual(response.data, {"detail": "Method 'PUT' not allowed."}) +        self.assertEqual(response.data, {"detail": 'Method "PUT" not allowed.'})      def test_delete_root_view(self):          """ @@ -127,7 +127,7 @@ class TestRootView(TestCase):          with self.assertNumQueries(0):              response = self.view(request).render()          self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) -        self.assertEqual(response.data, {"detail": "Method 'DELETE' not allowed."}) +        self.assertEqual(response.data, {"detail": 'Method "DELETE" not allowed.'})      def test_post_cannot_set_id(self):          """ @@ -181,7 +181,7 @@ class TestInstanceView(TestCase):          with self.assertNumQueries(0):              response = self.view(request).render()          self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) -        self.assertEqual(response.data, {"detail": "Method 'POST' not allowed."}) +        self.assertEqual(response.data, {"detail": 'Method "POST" not allowed.'})      def test_put_instance_view(self):          """ @@ -483,7 +483,7 @@ class TestFilterBackendAppliedToViews(TestCase):          request = factory.get('/1')          response = instance_view(request, pk=1).render()          self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) -        self.assertEqual(response.data, {'detail': 'Not found'}) +        self.assertEqual(response.data, {'detail': 'Not found.'})      def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(self):          """ diff --git a/tests/test_htmlrenderer.py b/tests/test_htmlrenderer.py index 2edc6b4b..a33b832f 100644 --- a/tests/test_htmlrenderer.py +++ b/tests/test_htmlrenderer.py @@ -56,7 +56,13 @@ class TemplateHTMLRendererTests(TestCase):                  return Template("example: {{ object }}")              raise TemplateDoesNotExist(template_name) +        def select_template(template_name_list, dirs=None, using=None): +            if template_name_list == ['example.html']: +                return Template("example: {{ object }}") +            raise TemplateDoesNotExist(template_name_list[0]) +          django.template.loader.get_template = get_template +        django.template.loader.select_template = select_template      def tearDown(self):          """ diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 5ff59c72..5031c0f3 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -1,9 +1,8 @@  from __future__ import unicode_literals - -from rest_framework import exceptions, serializers, views +from rest_framework import exceptions, serializers, status, views, versioning  from rest_framework.request import Request +from rest_framework.renderers import BrowsableAPIRenderer  from rest_framework.test import APIRequestFactory -import pytest  request = Request(APIRequestFactory().options('/')) @@ -17,7 +16,8 @@ class TestMetadata:              """Example view."""              pass -        response = ExampleView().options(request=request) +        view = ExampleView.as_view() +        response = view(request=request)          expected = {              'name': 'Example',              'description': 'Example view.', @@ -31,7 +31,7 @@ class TestMetadata:                  'multipart/form-data'              ]          } -        assert response.status_code == 200 +        assert response.status_code == status.HTTP_200_OK          assert response.data == expected      def test_none_metadata(self): @@ -42,8 +42,10 @@ class TestMetadata:          class ExampleView(views.APIView):              metadata_class = None -        with pytest.raises(exceptions.MethodNotAllowed): -            ExampleView().options(request=request) +        view = ExampleView.as_view() +        response = view(request=request) +        assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED +        assert response.data == {'detail': 'Method "OPTIONS" not allowed.'}      def test_actions(self):          """ @@ -63,7 +65,8 @@ class TestMetadata:              def get_serializer(self):                  return ExampleSerializer() -        response = ExampleView().options(request=request) +        view = ExampleView.as_view() +        response = view(request=request)          expected = {              'name': 'Example',              'description': 'Example view.', @@ -104,7 +107,7 @@ class TestMetadata:                  }              }          } -        assert response.status_code == 200 +        assert response.status_code == status.HTTP_200_OK          assert response.data == expected      def test_global_permissions(self): @@ -132,8 +135,9 @@ class TestMetadata:                  if request.method == 'POST':                      raise exceptions.PermissionDenied() -        response = ExampleView().options(request=request) -        assert response.status_code == 200 +        view = ExampleView.as_view() +        response = view(request=request) +        assert response.status_code == status.HTTP_200_OK          assert list(response.data['actions'].keys()) == ['PUT']      def test_object_permissions(self): @@ -161,6 +165,36 @@ class TestMetadata:                  if self.request.method == 'PUT':                      raise exceptions.PermissionDenied() -        response = ExampleView().options(request=request) -        assert response.status_code == 200 +        view = ExampleView.as_view() +        response = view(request=request) +        assert response.status_code == status.HTTP_200_OK          assert list(response.data['actions'].keys()) == ['POST'] + +    def test_bug_2455_clone_request(self): +        class ExampleView(views.APIView): +            renderer_classes = (BrowsableAPIRenderer,) + +            def post(self, request): +                pass + +            def get_serializer(self): +                assert hasattr(self.request, 'version') +                return serializers.Serializer() + +        view = ExampleView.as_view() +        view(request=request) + +    def test_bug_2477_clone_request(self): +        class ExampleView(views.APIView): +            renderer_classes = (BrowsableAPIRenderer,) + +            def post(self, request): +                pass + +            def get_serializer(self): +                assert hasattr(self.request, 'versioning_scheme') +                return serializers.Serializer() + +        scheme = versioning.QueryParameterVersioning +        view = ExampleView.as_view(versioning_class=scheme) +        view(request=request) diff --git a/tests/test_model_serializer.py b/tests/test_model_serializer.py index 5c56c8db..bce2008a 100644 --- a/tests/test_model_serializer.py +++ b/tests/test_model_serializer.py @@ -5,11 +5,14 @@ shortcuts for automatically creating serializers based on a given model class.  These tests deal with ensuring that we correctly map the model fields onto  an appropriate set of serializer fields for each case.  """ +from __future__ import unicode_literals  from django.core.exceptions import ImproperlyConfigured  from django.core.validators import MaxValueValidator, MinValueValidator, MinLengthValidator  from django.db import models  from django.test import TestCase +from django.utils import six  from rest_framework import serializers +from rest_framework.compat import unicode_repr  def dedent(blocktext): @@ -119,12 +122,12 @@ class TestRegularFieldMappings(TestCase):                  positive_small_integer_field = IntegerField()                  slug_field = SlugField(max_length=100)                  small_integer_field = IntegerField() -                text_field = CharField(style={'type': 'textarea'}) +                text_field = CharField(style={'base_template': 'textarea.html'})                  time_field = TimeField()                  url_field = URLField(max_length=100)                  custom_field = ModelField(model_field=<tests.test_model_serializer.CustomField: custom_field>)          """) -        self.assertEqual(repr(TestSerializer()), expected) +        self.assertEqual(unicode_repr(TestSerializer()), expected)      def test_field_options(self):          class TestSerializer(serializers.ModelSerializer): @@ -142,7 +145,14 @@ class TestRegularFieldMappings(TestCase):                  descriptive_field = IntegerField(help_text='Some help text', label='A label')                  choices_field = ChoiceField(choices=[('red', 'Red'), ('blue', 'Blue'), ('green', 'Green')])          """) -        self.assertEqual(repr(TestSerializer()), expected) +        if six.PY2: +            # This particular case is too awkward to resolve fully across +            # both py2 and py3. +            expected = expected.replace( +                "('red', 'Red'), ('blue', 'Blue'), ('green', 'Green')", +                "(u'red', u'Red'), (u'blue', u'Blue'), (u'green', u'Green')" +            ) +        self.assertEqual(unicode_repr(TestSerializer()), expected)      def test_method_field(self):          """ @@ -206,7 +216,7 @@ class TestRegularFieldMappings(TestCase):          with self.assertRaises(ImproperlyConfigured) as excinfo:              TestSerializer().fields -        expected = 'Field name `invalid` is not valid for model `ModelBase`.' +        expected = 'Field name `invalid` is not valid for model `RegularFieldsModel`.'          assert str(excinfo.exception) == expected      def test_missing_field(self): @@ -229,6 +239,26 @@ class TestRegularFieldMappings(TestCase):          )          assert str(excinfo.exception) == expected +    def test_missing_superclass_field(self): +        """ +        Fields that have been declared on a parent of the serializer class may +        be excluded from the `Meta.fields` option. +        """ +        class TestSerializer(serializers.ModelSerializer): +            missing = serializers.ReadOnlyField() + +            class Meta: +                model = RegularFieldsModel + +        class ChildSerializer(TestSerializer): +            missing = serializers.ReadOnlyField() + +            class Meta: +                model = RegularFieldsModel +                fields = ('auto_field',) + +        ChildSerializer().fields +  # Tests for relational field mappings.  # ------------------------------------ @@ -276,7 +306,7 @@ class TestRelationalFieldMappings(TestCase):                  many_to_many = PrimaryKeyRelatedField(many=True, queryset=ManyToManyTargetModel.objects.all())                  through = PrimaryKeyRelatedField(many=True, read_only=True)          """) -        self.assertEqual(repr(TestSerializer()), expected) +        self.assertEqual(unicode_repr(TestSerializer()), expected)      def test_nested_relations(self):          class TestSerializer(serializers.ModelSerializer): @@ -300,7 +330,7 @@ class TestRelationalFieldMappings(TestCase):                      id = IntegerField(label='ID', read_only=True)                      name = CharField(max_length=100)          """) -        self.assertEqual(repr(TestSerializer()), expected) +        self.assertEqual(unicode_repr(TestSerializer()), expected)      def test_hyperlinked_relations(self):          class TestSerializer(serializers.HyperlinkedModelSerializer): @@ -315,7 +345,7 @@ class TestRelationalFieldMappings(TestCase):                  many_to_many = HyperlinkedRelatedField(many=True, queryset=ManyToManyTargetModel.objects.all(), view_name='manytomanytargetmodel-detail')                  through = HyperlinkedRelatedField(many=True, read_only=True, view_name='throughtargetmodel-detail')          """) -        self.assertEqual(repr(TestSerializer()), expected) +        self.assertEqual(unicode_repr(TestSerializer()), expected)      def test_nested_hyperlinked_relations(self):          class TestSerializer(serializers.HyperlinkedModelSerializer): @@ -339,7 +369,7 @@ class TestRelationalFieldMappings(TestCase):                      url = HyperlinkedIdentityField(view_name='throughtargetmodel-detail')                      name = CharField(max_length=100)          """) -        self.assertEqual(repr(TestSerializer()), expected) +        self.assertEqual(unicode_repr(TestSerializer()), expected)      def test_pk_reverse_foreign_key(self):          class TestSerializer(serializers.ModelSerializer): @@ -353,7 +383,7 @@ class TestRelationalFieldMappings(TestCase):                  name = CharField(max_length=100)                  reverse_foreign_key = PrimaryKeyRelatedField(many=True, queryset=RelationalModel.objects.all())          """) -        self.assertEqual(repr(TestSerializer()), expected) +        self.assertEqual(unicode_repr(TestSerializer()), expected)      def test_pk_reverse_one_to_one(self):          class TestSerializer(serializers.ModelSerializer): @@ -367,7 +397,7 @@ class TestRelationalFieldMappings(TestCase):                  name = CharField(max_length=100)                  reverse_one_to_one = PrimaryKeyRelatedField(queryset=RelationalModel.objects.all())          """) -        self.assertEqual(repr(TestSerializer()), expected) +        self.assertEqual(unicode_repr(TestSerializer()), expected)      def test_pk_reverse_many_to_many(self):          class TestSerializer(serializers.ModelSerializer): @@ -381,7 +411,7 @@ class TestRelationalFieldMappings(TestCase):                  name = CharField(max_length=100)                  reverse_many_to_many = PrimaryKeyRelatedField(many=True, queryset=RelationalModel.objects.all())          """) -        self.assertEqual(repr(TestSerializer()), expected) +        self.assertEqual(unicode_repr(TestSerializer()), expected)      def test_pk_reverse_through(self):          class TestSerializer(serializers.ModelSerializer): @@ -395,7 +425,7 @@ class TestRelationalFieldMappings(TestCase):                  name = CharField(max_length=100)                  reverse_through = PrimaryKeyRelatedField(many=True, read_only=True)          """) -        self.assertEqual(repr(TestSerializer()), expected) +        self.assertEqual(unicode_repr(TestSerializer()), expected)  class TestIntegration(TestCase): diff --git a/tests/test_multitable_inheritance.py b/tests/test_multitable_inheritance.py index e1b40cc7..15627e1d 100644 --- a/tests/test_multitable_inheritance.py +++ b/tests/test_multitable_inheritance.py @@ -48,8 +48,8 @@ class InheritedModelSerializationTests(TestCase):          Assert that a model with a onetoone field that is the primary key is          not treated like a derived model          """ -        parent = ParentModel(name1='parent name') -        associate = AssociatedModel(name='hello', ref=parent) +        parent = ParentModel.objects.create(name1='parent name') +        associate = AssociatedModel.objects.create(name='hello', ref=parent)          serializer = AssociatedModelSerializer(associate)          self.assertEqual(set(serializer.data.keys()),                           set(['name', 'ref'])) diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 1fd9cf9c..13bfb627 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -1,553 +1,671 @@ +# coding: utf-8  from __future__ import unicode_literals -import datetime -from decimal import Decimal -from django.core.paginator import Paginator -from django.test import TestCase -from django.utils import unittest -from rest_framework import generics, serializers, status, pagination, filters -from rest_framework.compat import django_filters +from rest_framework import exceptions, generics, pagination, serializers, status, filters +from rest_framework.request import Request +from rest_framework.pagination import PageLink, PAGE_BREAK  from rest_framework.test import APIRequestFactory -from .models import BasicModel, FilterableItem +import pytest  factory = APIRequestFactory() -# Helper function to split arguments out of an url -def split_arguments_from_url(url): -    if '?' not in url: -        return url +class TestPaginationIntegration: +    """ +    Integration tests. +    """ -    path, args = url.split('?') -    args = dict(r.split('=') for r in args.split('&')) -    return path, args +    def setup(self): +        class PassThroughSerializer(serializers.BaseSerializer): +            def to_representation(self, item): +                return item +        class EvenItemsOnly(filters.BaseFilterBackend): +            def filter_queryset(self, request, queryset, view): +                return [item for item in queryset if item % 2 == 0] + +        class BasicPagination(pagination.PageNumberPagination): +            paginate_by = 5 +            paginate_by_param = 'page_size' +            max_paginate_by = 20 + +        self.view = generics.ListAPIView.as_view( +            serializer_class=PassThroughSerializer, +            queryset=range(1, 101), +            filter_backends=[EvenItemsOnly], +            pagination_class=BasicPagination +        ) -class BasicSerializer(serializers.ModelSerializer): -    class Meta: -        model = BasicModel +    def test_filtered_items_are_paginated(self): +        request = factory.get('/', {'page': 2}) +        response = self.view(request) +        assert response.status_code == status.HTTP_200_OK +        assert response.data == { +            'results': [12, 14, 16, 18, 20], +            'previous': 'http://testserver/', +            'next': 'http://testserver/?page=3', +            'count': 50 +        } +    def test_setting_page_size(self): +        """ +        When 'paginate_by_param' is set, the client may choose a page size. +        """ +        request = factory.get('/', {'page_size': 10}) +        response = self.view(request) +        assert response.status_code == status.HTTP_200_OK +        assert response.data == { +            'results': [2, 4, 6, 8, 10, 12, 14, 16, 18, 20], +            'previous': None, +            'next': 'http://testserver/?page=2&page_size=10', +            'count': 50 +        } -class FilterableItemSerializer(serializers.ModelSerializer): -    class Meta: -        model = FilterableItem +    def test_setting_page_size_over_maximum(self): +        """ +        When page_size parameter exceeds maxiumum allowable, +        then it should be capped to the maxiumum. +        """ +        request = factory.get('/', {'page_size': 1000}) +        response = self.view(request) +        assert response.status_code == status.HTTP_200_OK +        assert response.data == { +            'results': [ +                2, 4, 6, 8, 10, 12, 14, 16, 18, 20, +                22, 24, 26, 28, 30, 32, 34, 36, 38, 40 +            ], +            'previous': None, +            'next': 'http://testserver/?page=2&page_size=1000', +            'count': 50 +        } +    def test_setting_page_size_to_zero(self): +        """ +        When page_size parameter is invalid it should return to the default. +        """ +        request = factory.get('/', {'page_size': 0}) +        response = self.view(request) +        assert response.status_code == status.HTTP_200_OK +        assert response.data == { +            'results': [2, 4, 6, 8, 10], +            'previous': None, +            'next': 'http://testserver/?page=2&page_size=0', +            'count': 50 +        } -class RootView(generics.ListCreateAPIView): -    """ -    Example description for OPTIONS. -    """ -    queryset = BasicModel.objects.all() -    serializer_class = BasicSerializer -    paginate_by = 10 +    def test_additional_query_params_are_preserved(self): +        request = factory.get('/', {'page': 2, 'filter': 'even'}) +        response = self.view(request) +        assert response.status_code == status.HTTP_200_OK +        assert response.data == { +            'results': [12, 14, 16, 18, 20], +            'previous': 'http://testserver/?filter=even', +            'next': 'http://testserver/?filter=even&page=3', +            'count': 50 +        } +    def test_404_not_found_for_zero_page(self): +        request = factory.get('/', {'page': '0'}) +        response = self.view(request) +        assert response.status_code == status.HTTP_404_NOT_FOUND +        assert response.data == { +            'detail': 'Invalid page "0": That page number is less than 1.' +        } -class DefaultPageSizeKwargView(generics.ListAPIView): -    """ -    View for testing default paginate_by_param usage -    """ -    queryset = BasicModel.objects.all() -    serializer_class = BasicSerializer +    def test_404_not_found_for_invalid_page(self): +        request = factory.get('/', {'page': 'invalid'}) +        response = self.view(request) +        assert response.status_code == status.HTTP_404_NOT_FOUND +        assert response.data == { +            'detail': 'Invalid page "invalid": That page number is not an integer.' +        } -class PaginateByParamView(generics.ListAPIView): +class TestPaginationDisabledIntegration:      """ -    View for testing custom paginate_by_param usage +    Integration tests for disabled pagination.      """ -    queryset = BasicModel.objects.all() -    serializer_class = BasicSerializer -    paginate_by_param = 'page_size' +    def setup(self): +        class PassThroughSerializer(serializers.BaseSerializer): +            def to_representation(self, item): +                return item -class MaxPaginateByView(generics.ListAPIView): -    """ -    View for testing custom max_paginate_by usage -    """ -    queryset = BasicModel.objects.all() -    serializer_class = BasicSerializer -    paginate_by = 3 -    max_paginate_by = 5 -    paginate_by_param = 'page_size' +        self.view = generics.ListAPIView.as_view( +            serializer_class=PassThroughSerializer, +            queryset=range(1, 101), +            pagination_class=None +        ) + +    def test_unpaginated_list(self): +        request = factory.get('/', {'page': 2}) +        response = self.view(request) +        assert response.status_code == status.HTTP_200_OK +        assert response.data == list(range(1, 101)) -class IntegrationTestPagination(TestCase): +class TestDeprecatedStylePagination:      """ -    Integration tests for paginated list views. +    Integration tests for deprecated style of setting pagination +    attributes on the view.      """ -    def setUp(self): -        """ -        Create 26 BasicModel instances. -        """ -        for char in 'abcdefghijklmnopqrstuvwxyz': -            BasicModel(text=char * 3).save() -        self.objects = BasicModel.objects -        self.data = [ -            {'id': obj.id, 'text': obj.text} -            for obj in self.objects.all() -        ] -        self.view = RootView.as_view() - -    def test_get_paginated_root_view(self): -        """ -        GET requests to paginated ListCreateAPIView should return paginated results. -        """ -        request = factory.get('/') -        # Note: Database queries are a `SELECT COUNT`, and `SELECT <fields>` -        with self.assertNumQueries(2): -            response = self.view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data['count'], 26) -        self.assertEqual(response.data['results'], self.data[:10]) -        self.assertNotEqual(response.data['next'], None) -        self.assertEqual(response.data['previous'], None) - -        request = factory.get(*split_arguments_from_url(response.data['next'])) -        with self.assertNumQueries(2): -            response = self.view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data['count'], 26) -        self.assertEqual(response.data['results'], self.data[10:20]) -        self.assertNotEqual(response.data['next'], None) -        self.assertNotEqual(response.data['previous'], None) - -        request = factory.get(*split_arguments_from_url(response.data['next'])) -        with self.assertNumQueries(2): -            response = self.view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data['count'], 26) -        self.assertEqual(response.data['results'], self.data[20:]) -        self.assertEqual(response.data['next'], None) -        self.assertNotEqual(response.data['previous'], None) - - -class IntegrationTestPaginationAndFiltering(TestCase): - -    def setUp(self): -        """ -        Create 50 FilterableItem instances. -        """ -        base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8)) -        for i in range(26): -            text = chr(i + ord(base_data[0])) * 3  # Produces string 'aaa', 'bbb', etc. -            decimal = base_data[1] + i -            date = base_data[2] - datetime.timedelta(days=i * 2) -            FilterableItem(text=text, decimal=decimal, date=date).save() - -        self.objects = FilterableItem.objects -        self.data = [ -            {'id': obj.id, 'text': obj.text, 'decimal': str(obj.decimal), 'date': obj.date.isoformat()} -            for obj in self.objects.all() -        ] - -    @unittest.skipUnless(django_filters, 'django-filter not installed') -    def test_get_django_filter_paginated_filtered_root_view(self): -        """ -        GET requests to paginated filtered ListCreateAPIView should return -        paginated results. The next and previous links should preserve the -        filtered parameters. -        """ -        class DecimalFilter(django_filters.FilterSet): -            decimal = django_filters.NumberFilter(lookup_type='lt') - -            class Meta: -                model = FilterableItem -                fields = ['text', 'decimal', 'date'] - -        class FilterFieldsRootView(generics.ListCreateAPIView): -            queryset = FilterableItem.objects.all() -            serializer_class = FilterableItemSerializer -            paginate_by = 10 -            filter_class = DecimalFilter -            filter_backends = (filters.DjangoFilterBackend,) - -        view = FilterFieldsRootView.as_view() - -        EXPECTED_NUM_QUERIES = 2 - -        request = factory.get('/', {'decimal': '15.20'}) -        with self.assertNumQueries(EXPECTED_NUM_QUERIES): -            response = view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data['count'], 15) -        self.assertEqual(response.data['results'], self.data[:10]) -        self.assertNotEqual(response.data['next'], None) -        self.assertEqual(response.data['previous'], None) - -        request = factory.get(*split_arguments_from_url(response.data['next'])) -        with self.assertNumQueries(EXPECTED_NUM_QUERIES): -            response = view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data['count'], 15) -        self.assertEqual(response.data['results'], self.data[10:15]) -        self.assertEqual(response.data['next'], None) -        self.assertNotEqual(response.data['previous'], None) - -        request = factory.get(*split_arguments_from_url(response.data['previous'])) -        with self.assertNumQueries(EXPECTED_NUM_QUERIES): -            response = view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data['count'], 15) -        self.assertEqual(response.data['results'], self.data[:10]) -        self.assertNotEqual(response.data['next'], None) -        self.assertEqual(response.data['previous'], None) - -    def test_get_basic_paginated_filtered_root_view(self): -        """ -        Same as `test_get_django_filter_paginated_filtered_root_view`, -        except using a custom filter backend instead of the django-filter -        backend, -        """ +    def setup(self): +        class PassThroughSerializer(serializers.BaseSerializer): +            def to_representation(self, item): +                return item -        class DecimalFilterBackend(filters.BaseFilterBackend): -            def filter_queryset(self, request, queryset, view): -                return queryset.filter(decimal__lt=Decimal(request.GET['decimal'])) - -        class BasicFilterFieldsRootView(generics.ListCreateAPIView): -            queryset = FilterableItem.objects.all() -            serializer_class = FilterableItemSerializer -            paginate_by = 10 -            filter_backends = (DecimalFilterBackend,) - -        view = BasicFilterFieldsRootView.as_view() - -        request = factory.get('/', {'decimal': '15.20'}) -        with self.assertNumQueries(2): -            response = view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data['count'], 15) -        self.assertEqual(response.data['results'], self.data[:10]) -        self.assertNotEqual(response.data['next'], None) -        self.assertEqual(response.data['previous'], None) - -        request = factory.get(*split_arguments_from_url(response.data['next'])) -        with self.assertNumQueries(2): -            response = view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data['count'], 15) -        self.assertEqual(response.data['results'], self.data[10:15]) -        self.assertEqual(response.data['next'], None) -        self.assertNotEqual(response.data['previous'], None) - -        request = factory.get(*split_arguments_from_url(response.data['previous'])) -        with self.assertNumQueries(2): -            response = view(request).render() -        self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(response.data['count'], 15) -        self.assertEqual(response.data['results'], self.data[:10]) -        self.assertNotEqual(response.data['next'], None) -        self.assertEqual(response.data['previous'], None) - - -class PassOnContextPaginationSerializer(pagination.PaginationSerializer): -    class Meta: -        object_serializer_class = serializers.Serializer - - -class UnitTestPagination(TestCase): -    """ -    Unit tests for pagination of primitive objects. -    """ +        class ExampleView(generics.ListAPIView): +            serializer_class = PassThroughSerializer +            queryset = range(1, 101) +            pagination_class = pagination.PageNumberPagination +            paginate_by = 20 +            page_query_param = 'page_number' -    def setUp(self): -        self.objects = [char * 3 for char in 'abcdefghijklmnopqrstuvwxyz'] -        paginator = Paginator(self.objects, 10) -        self.first_page = paginator.page(1) -        self.last_page = paginator.page(3) - -    def test_native_pagination(self): -        serializer = pagination.PaginationSerializer(self.first_page) -        self.assertEqual(serializer.data['count'], 26) -        self.assertEqual(serializer.data['next'], '?page=2') -        self.assertEqual(serializer.data['previous'], None) -        self.assertEqual(serializer.data['results'], self.objects[:10]) - -        serializer = pagination.PaginationSerializer(self.last_page) -        self.assertEqual(serializer.data['count'], 26) -        self.assertEqual(serializer.data['next'], None) -        self.assertEqual(serializer.data['previous'], '?page=2') -        self.assertEqual(serializer.data['results'], self.objects[20:]) - -    def test_context_available_in_result(self): -        """ -        Ensure context gets passed through to the object serializer. -        """ -        serializer = PassOnContextPaginationSerializer(self.first_page, context={'foo': 'bar'}) -        serializer.data -        results = serializer.fields[serializer.results_field] -        self.assertEqual(serializer.context, results.context) +        self.view = ExampleView.as_view() + +    def test_paginate_by_attribute_on_view(self): +        request = factory.get('/?page_number=2') +        response = self.view(request) +        assert response.status_code == status.HTTP_200_OK +        assert response.data == { +            'results': [ +                21, 22, 23, 24, 25, 26, 27, 28, 29, 30, +                31, 32, 33, 34, 35, 36, 37, 38, 39, 40 +            ], +            'previous': 'http://testserver/', +            'next': 'http://testserver/?page_number=3', +            'count': 100 +        } -class TestUnpaginated(TestCase): +class TestPageNumberPagination:      """ -    Tests for list views without pagination. +    Unit tests for `pagination.PageNumberPagination`.      """ -    def setUp(self): -        """ -        Create 13 BasicModel instances. -        """ -        for i in range(13): -            BasicModel(text=i).save() -        self.objects = BasicModel.objects -        self.data = [ -            {'id': obj.id, 'text': obj.text} -            for obj in self.objects.all() -        ] -        self.view = DefaultPageSizeKwargView.as_view() - -    def test_unpaginated(self): -        """ -        Tests the default page size for this view. -        no page size --> no limit --> no meta data -        """ -        request = factory.get('/') -        response = self.view(request) -        self.assertEqual(response.data, self.data) +    def setup(self): +        class ExamplePagination(pagination.PageNumberPagination): +            paginate_by = 5 +        self.pagination = ExamplePagination() +        self.queryset = range(1, 101) + +    def paginate_queryset(self, request): +        return list(self.pagination.paginate_queryset(self.queryset, request)) + +    def get_paginated_content(self, queryset): +        response = self.pagination.get_paginated_response(queryset) +        return response.data + +    def get_html_context(self): +        return self.pagination.get_html_context() + +    def test_no_page_number(self): +        request = Request(factory.get('/')) +        queryset = self.paginate_queryset(request) +        content = self.get_paginated_content(queryset) +        context = self.get_html_context() +        assert queryset == [1, 2, 3, 4, 5] +        assert content == { +            'results': [1, 2, 3, 4, 5], +            'previous': None, +            'next': 'http://testserver/?page=2', +            'count': 100 +        } +        assert context == { +            'previous_url': None, +            'next_url': 'http://testserver/?page=2', +            'page_links': [ +                PageLink('http://testserver/', 1, True, False), +                PageLink('http://testserver/?page=2', 2, False, False), +                PageLink('http://testserver/?page=3', 3, False, False), +                PAGE_BREAK, +                PageLink('http://testserver/?page=20', 20, False, False), +            ] +        } +        assert self.pagination.display_page_controls +        assert isinstance(self.pagination.to_html(), type('')) + +    def test_second_page(self): +        request = Request(factory.get('/', {'page': 2})) +        queryset = self.paginate_queryset(request) +        content = self.get_paginated_content(queryset) +        context = self.get_html_context() +        assert queryset == [6, 7, 8, 9, 10] +        assert content == { +            'results': [6, 7, 8, 9, 10], +            'previous': 'http://testserver/', +            'next': 'http://testserver/?page=3', +            'count': 100 +        } +        assert context == { +            'previous_url': 'http://testserver/', +            'next_url': 'http://testserver/?page=3', +            'page_links': [ +                PageLink('http://testserver/', 1, False, False), +                PageLink('http://testserver/?page=2', 2, True, False), +                PageLink('http://testserver/?page=3', 3, False, False), +                PAGE_BREAK, +                PageLink('http://testserver/?page=20', 20, False, False), +            ] +        } + +    def test_last_page(self): +        request = Request(factory.get('/', {'page': 'last'})) +        queryset = self.paginate_queryset(request) +        content = self.get_paginated_content(queryset) +        context = self.get_html_context() +        assert queryset == [96, 97, 98, 99, 100] +        assert content == { +            'results': [96, 97, 98, 99, 100], +            'previous': 'http://testserver/?page=19', +            'next': None, +            'count': 100 +        } +        assert context == { +            'previous_url': 'http://testserver/?page=19', +            'next_url': None, +            'page_links': [ +                PageLink('http://testserver/', 1, False, False), +                PAGE_BREAK, +                PageLink('http://testserver/?page=18', 18, False, False), +                PageLink('http://testserver/?page=19', 19, False, False), +                PageLink('http://testserver/?page=20', 20, True, False), +            ] +        } + +    def test_invalid_page(self): +        request = Request(factory.get('/', {'page': 'invalid'})) +        with pytest.raises(exceptions.NotFound): +            self.paginate_queryset(request) -class TestCustomPaginateByParam(TestCase): +class TestLimitOffset:      """ -    Tests for list views with default page size kwarg +    Unit tests for `pagination.LimitOffsetPagination`.      """ -    def setUp(self): +    def setup(self): +        class ExamplePagination(pagination.LimitOffsetPagination): +            default_limit = 10 +        self.pagination = ExamplePagination() +        self.queryset = range(1, 101) + +    def paginate_queryset(self, request): +        return list(self.pagination.paginate_queryset(self.queryset, request)) + +    def get_paginated_content(self, queryset): +        response = self.pagination.get_paginated_response(queryset) +        return response.data + +    def get_html_context(self): +        return self.pagination.get_html_context() + +    def test_no_offset(self): +        request = Request(factory.get('/', {'limit': 5})) +        queryset = self.paginate_queryset(request) +        content = self.get_paginated_content(queryset) +        context = self.get_html_context() +        assert queryset == [1, 2, 3, 4, 5] +        assert content == { +            'results': [1, 2, 3, 4, 5], +            'previous': None, +            'next': 'http://testserver/?limit=5&offset=5', +            'count': 100 +        } +        assert context == { +            'previous_url': None, +            'next_url': 'http://testserver/?limit=5&offset=5', +            'page_links': [ +                PageLink('http://testserver/?limit=5', 1, True, False), +                PageLink('http://testserver/?limit=5&offset=5', 2, False, False), +                PageLink('http://testserver/?limit=5&offset=10', 3, False, False), +                PAGE_BREAK, +                PageLink('http://testserver/?limit=5&offset=95', 20, False, False), +            ] +        } +        assert self.pagination.display_page_controls +        assert isinstance(self.pagination.to_html(), type('')) + +    def test_single_offset(self):          """ -        Create 13 BasicModel instances. +        When the offset is not a multiple of the limit we get some edge cases: +        * The first page should still be offset zero. +        * We may end up displaying an extra page in the pagination control.          """ -        for i in range(13): -            BasicModel(text=i).save() -        self.objects = BasicModel.objects -        self.data = [ -            {'id': obj.id, 'text': obj.text} -            for obj in self.objects.all() -        ] -        self.view = PaginateByParamView.as_view() - -    def test_default_page_size(self): +        request = Request(factory.get('/', {'limit': 5, 'offset': 1})) +        queryset = self.paginate_queryset(request) +        content = self.get_paginated_content(queryset) +        context = self.get_html_context() +        assert queryset == [2, 3, 4, 5, 6] +        assert content == { +            'results': [2, 3, 4, 5, 6], +            'previous': 'http://testserver/?limit=5', +            'next': 'http://testserver/?limit=5&offset=6', +            'count': 100 +        } +        assert context == { +            'previous_url': 'http://testserver/?limit=5', +            'next_url': 'http://testserver/?limit=5&offset=6', +            'page_links': [ +                PageLink('http://testserver/?limit=5', 1, False, False), +                PageLink('http://testserver/?limit=5&offset=1', 2, True, False), +                PageLink('http://testserver/?limit=5&offset=6', 3, False, False), +                PAGE_BREAK, +                PageLink('http://testserver/?limit=5&offset=96', 21, False, False), +            ] +        } + +    def test_first_offset(self): +        request = Request(factory.get('/', {'limit': 5, 'offset': 5})) +        queryset = self.paginate_queryset(request) +        content = self.get_paginated_content(queryset) +        context = self.get_html_context() +        assert queryset == [6, 7, 8, 9, 10] +        assert content == { +            'results': [6, 7, 8, 9, 10], +            'previous': 'http://testserver/?limit=5', +            'next': 'http://testserver/?limit=5&offset=10', +            'count': 100 +        } +        assert context == { +            'previous_url': 'http://testserver/?limit=5', +            'next_url': 'http://testserver/?limit=5&offset=10', +            'page_links': [ +                PageLink('http://testserver/?limit=5', 1, False, False), +                PageLink('http://testserver/?limit=5&offset=5', 2, True, False), +                PageLink('http://testserver/?limit=5&offset=10', 3, False, False), +                PAGE_BREAK, +                PageLink('http://testserver/?limit=5&offset=95', 20, False, False), +            ] +        } + +    def test_middle_offset(self): +        request = Request(factory.get('/', {'limit': 5, 'offset': 10})) +        queryset = self.paginate_queryset(request) +        content = self.get_paginated_content(queryset) +        context = self.get_html_context() +        assert queryset == [11, 12, 13, 14, 15] +        assert content == { +            'results': [11, 12, 13, 14, 15], +            'previous': 'http://testserver/?limit=5&offset=5', +            'next': 'http://testserver/?limit=5&offset=15', +            'count': 100 +        } +        assert context == { +            'previous_url': 'http://testserver/?limit=5&offset=5', +            'next_url': 'http://testserver/?limit=5&offset=15', +            'page_links': [ +                PageLink('http://testserver/?limit=5', 1, False, False), +                PageLink('http://testserver/?limit=5&offset=5', 2, False, False), +                PageLink('http://testserver/?limit=5&offset=10', 3, True, False), +                PageLink('http://testserver/?limit=5&offset=15', 4, False, False), +                PAGE_BREAK, +                PageLink('http://testserver/?limit=5&offset=95', 20, False, False), +            ] +        } + +    def test_ending_offset(self): +        request = Request(factory.get('/', {'limit': 5, 'offset': 95})) +        queryset = self.paginate_queryset(request) +        content = self.get_paginated_content(queryset) +        context = self.get_html_context() +        assert queryset == [96, 97, 98, 99, 100] +        assert content == { +            'results': [96, 97, 98, 99, 100], +            'previous': 'http://testserver/?limit=5&offset=90', +            'next': None, +            'count': 100 +        } +        assert context == { +            'previous_url': 'http://testserver/?limit=5&offset=90', +            'next_url': None, +            'page_links': [ +                PageLink('http://testserver/?limit=5', 1, False, False), +                PAGE_BREAK, +                PageLink('http://testserver/?limit=5&offset=85', 18, False, False), +                PageLink('http://testserver/?limit=5&offset=90', 19, False, False), +                PageLink('http://testserver/?limit=5&offset=95', 20, True, False), +            ] +        } + +    def test_invalid_offset(self):          """ -        Tests the default page size for this view. -        no page size --> no limit --> no meta data +        An invalid offset query param should be treated as 0.          """ -        request = factory.get('/') -        response = self.view(request).render() -        self.assertEqual(response.data, self.data) +        request = Request(factory.get('/', {'limit': 5, 'offset': 'invalid'})) +        queryset = self.paginate_queryset(request) +        assert queryset == [1, 2, 3, 4, 5] -    def test_paginate_by_param(self): +    def test_invalid_limit(self):          """ -        If paginate_by_param is set, the new kwarg should limit per view requests. +        An invalid limit query param should be ignored in favor of the default.          """ -        request = factory.get('/', {'page_size': 5}) -        response = self.view(request).render() -        self.assertEqual(response.data['count'], 13) -        self.assertEqual(response.data['results'], self.data[:5]) +        request = Request(factory.get('/', {'limit': 'invalid', 'offset': 0})) +        queryset = self.paginate_queryset(request) +        assert queryset == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] -class TestMaxPaginateByParam(TestCase): +class TestCursorPagination:      """ -    Tests for list views with max_paginate_by kwarg +    Unit tests for `pagination.CursorPagination`.      """ -    def setUp(self): +    def setup(self): +        class MockObject(object): +            def __init__(self, idx): +                self.created = idx + +        class MockQuerySet(object): +            def __init__(self, items): +                self.items = items + +            def filter(self, created__gt=None, created__lt=None): +                if created__gt is not None: +                    return MockQuerySet([ +                        item for item in self.items +                        if item.created > int(created__gt) +                    ]) + +                assert created__lt is not None +                return MockQuerySet([ +                    item for item in self.items +                    if item.created < int(created__lt) +                ]) + +            def order_by(self, *ordering): +                if ordering[0].startswith('-'): +                    return MockQuerySet(list(reversed(self.items))) +                return self + +            def __getitem__(self, sliced): +                return self.items[sliced] + +        class ExamplePagination(pagination.CursorPagination): +            page_size = 5 +            ordering = 'created' + +        self.pagination = ExamplePagination() +        self.queryset = MockQuerySet([ +            MockObject(idx) for idx in [ +                1, 1, 1, 1, 1, +                1, 2, 3, 4, 4, +                4, 4, 5, 6, 7, +                7, 7, 7, 7, 7, +                7, 7, 7, 8, 9, +                9, 9, 9, 9, 9 +            ] +        ]) + +    def get_pages(self, url):          """ -        Create 13 BasicModel instances. -        """ -        for i in range(13): -            BasicModel(text=i).save() -        self.objects = BasicModel.objects -        self.data = [ -            {'id': obj.id, 'text': obj.text} -            for obj in self.objects.all() -        ] -        self.view = MaxPaginateByView.as_view() - -    def test_max_paginate_by(self): -        """ -        If max_paginate_by is set, it should limit page size for the view. -        """ -        request = factory.get('/', data={'page_size': 10}) -        response = self.view(request).render() -        self.assertEqual(response.data['count'], 13) -        self.assertEqual(response.data['results'], self.data[:5]) +        Given a URL return a tuple of: -    def test_max_paginate_by_without_page_size_param(self): +        (previous page, current page, next page, previous url, next url)          """ -        If max_paginate_by is set, but client does not specifiy page_size, -        standard `paginate_by` behavior should be used. -        """ -        request = factory.get('/') -        response = self.view(request).render() -        self.assertEqual(response.data['results'], self.data[:3]) - - -# Tests for context in pagination serializers +        request = Request(factory.get(url)) +        queryset = self.pagination.paginate_queryset(self.queryset, request) +        current = [item.created for item in queryset] -class CustomField(serializers.ReadOnlyField): -    def to_native(self, value): -        if 'view' not in self.context: -            raise RuntimeError("context isn't getting passed into custom field") -        return "value" +        next_url = self.pagination.get_next_link() +        previous_url = self.pagination.get_previous_link() +        if next_url is not None: +            request = Request(factory.get(next_url)) +            queryset = self.pagination.paginate_queryset(self.queryset, request) +            next = [item.created for item in queryset] +        else: +            next = None -class BasicModelSerializer(serializers.Serializer): -    text = CustomField() - -    def to_native(self, value): -        if 'view' not in self.context: -            raise RuntimeError("context isn't getting passed into serializer") -        return super(BasicSerializer, self).to_native(value) +        if previous_url is not None: +            request = Request(factory.get(previous_url)) +            queryset = self.pagination.paginate_queryset(self.queryset, request) +            previous = [item.created for item in queryset] +        else: +            previous = None +        return (previous, current, next, previous_url, next_url) -class TestContextPassedToCustomField(TestCase): -    def setUp(self): -        BasicModel.objects.create(text='ala ma kota') +    def test_invalid_cursor(self): +        request = Request(factory.get('/', {'cursor': '123'})) +        with pytest.raises(exceptions.NotFound): +            self.pagination.paginate_queryset(self.queryset, request) -    def test_with_pagination(self): -        class ListView(generics.ListCreateAPIView): -            queryset = BasicModel.objects.all() -            serializer_class = BasicModelSerializer -            paginate_by = 1 +    def test_use_with_ordering_filter(self): +        class MockView: +            filter_backends = (filters.OrderingFilter,) +            ordering_fields = ['username', 'created'] +            ordering = 'created' -        self.view = ListView.as_view() -        request = factory.get('/') -        response = self.view(request).render() +        request = Request(factory.get('/', {'ordering': 'username'})) +        ordering = self.pagination.get_ordering(request, [], MockView()) +        assert ordering == ('username',) -        self.assertEqual(response.status_code, status.HTTP_200_OK) +        request = Request(factory.get('/', {'ordering': '-username'})) +        ordering = self.pagination.get_ordering(request, [], MockView()) +        assert ordering == ('-username',) +        request = Request(factory.get('/', {'ordering': 'invalid'})) +        ordering = self.pagination.get_ordering(request, [], MockView()) +        assert ordering == ('created',) -# Tests for custom pagination serializers +    def test_cursor_pagination(self): +        (previous, current, next, previous_url, next_url) = self.get_pages('/') -class LinksSerializer(serializers.Serializer): -    next = pagination.NextPageField(source='*') -    prev = pagination.PreviousPageField(source='*') +        assert previous is None +        assert current == [1, 1, 1, 1, 1] +        assert next == [1, 2, 3, 4, 4] +        (previous, current, next, previous_url, next_url) = self.get_pages(next_url) -class CustomPaginationSerializer(pagination.BasePaginationSerializer): -    links = LinksSerializer(source='*')  # Takes the page object as the source -    total_results = serializers.ReadOnlyField(source='paginator.count') +        assert previous == [1, 1, 1, 1, 1] +        assert current == [1, 2, 3, 4, 4] +        assert next == [4, 4, 5, 6, 7] -    results_field = 'objects' +        (previous, current, next, previous_url, next_url) = self.get_pages(next_url) +        assert previous == [1, 2, 3, 4, 4] +        assert current == [4, 4, 5, 6, 7] +        assert next == [7, 7, 7, 7, 7] -class CustomFooSerializer(serializers.Serializer): -    foo = serializers.CharField() +        (previous, current, next, previous_url, next_url) = self.get_pages(next_url) +        assert previous == [4, 4, 4, 5, 6]  # Paging artifact +        assert current == [7, 7, 7, 7, 7] +        assert next == [7, 7, 7, 8, 9] -class CustomFooPaginationSerializer(pagination.PaginationSerializer): -    class Meta: -        object_serializer_class = CustomFooSerializer +        (previous, current, next, previous_url, next_url) = self.get_pages(next_url) +        assert previous == [7, 7, 7, 7, 7] +        assert current == [7, 7, 7, 8, 9] +        assert next == [9, 9, 9, 9, 9] -class TestCustomPaginationSerializer(TestCase): -    def setUp(self): -        objects = ['john', 'paul', 'george', 'ringo'] -        paginator = Paginator(objects, 2) -        self.page = paginator.page(1) +        (previous, current, next, previous_url, next_url) = self.get_pages(next_url) -    def test_custom_pagination_serializer(self): -        request = APIRequestFactory().get('/foobar') -        serializer = CustomPaginationSerializer( -            instance=self.page, -            context={'request': request} -        ) -        expected = { -            'links': { -                'next': 'http://testserver/foobar?page=2', -                'prev': None -            }, -            'total_results': 4, -            'objects': ['john', 'paul'] -        } -        self.assertEqual(serializer.data, expected) +        assert previous == [7, 7, 7, 8, 9] +        assert current == [9, 9, 9, 9, 9] +        assert next is None -    def test_custom_pagination_serializer_with_custom_object_serializer(self): -        objects = [ -            {'foo': 'bar'}, -            {'foo': 'spam'} -        ] -        paginator = Paginator(objects, 1) -        page = paginator.page(1) -        serializer = CustomFooPaginationSerializer(page) -        serializer.data +        (previous, current, next, previous_url, next_url) = self.get_pages(previous_url) +        assert previous == [7, 7, 7, 7, 7] +        assert current == [7, 7, 7, 8, 9] +        assert next == [9, 9, 9, 9, 9] -class NonIntegerPage(object): +        (previous, current, next, previous_url, next_url) = self.get_pages(previous_url) -    def __init__(self, paginator, object_list, prev_token, token, next_token): -        self.paginator = paginator -        self.object_list = object_list -        self.prev_token = prev_token -        self.token = token -        self.next_token = next_token +        assert previous == [4, 4, 5, 6, 7] +        assert current == [7, 7, 7, 7, 7] +        assert next == [8, 9, 9, 9, 9]  # Paging artifact -    def has_next(self): -        return not not self.next_token +        (previous, current, next, previous_url, next_url) = self.get_pages(previous_url) -    def next_page_number(self): -        return self.next_token +        assert previous == [1, 2, 3, 4, 4] +        assert current == [4, 4, 5, 6, 7] +        assert next == [7, 7, 7, 7, 7] -    def has_previous(self): -        return not not self.prev_token +        (previous, current, next, previous_url, next_url) = self.get_pages(previous_url) -    def previous_page_number(self): -        return self.prev_token +        assert previous == [1, 1, 1, 1, 1] +        assert current == [1, 2, 3, 4, 4] +        assert next == [4, 4, 5, 6, 7] +        (previous, current, next, previous_url, next_url) = self.get_pages(previous_url) -class NonIntegerPaginator(object): +        assert previous is None +        assert current == [1, 1, 1, 1, 1] +        assert next == [1, 2, 3, 4, 4] -    def __init__(self, object_list, per_page): -        self.object_list = object_list -        self.per_page = per_page +        assert isinstance(self.pagination.to_html(), type('')) -    def count(self): -        # pretend like we don't know how many pages we have -        return None -    def page(self, token=None): -        if token: -            try: -                first = self.object_list.index(token) -            except ValueError: -                first = 0 -        else: -            first = 0 -        n = len(self.object_list) -        last = min(first + self.per_page, n) -        prev_token = self.object_list[last - (2 * self.per_page)] if first else None -        next_token = self.object_list[last] if last < n else None -        return NonIntegerPage(self, self.object_list[first:last], prev_token, token, next_token) - - -class TestNonIntegerPagination(TestCase): -    def test_custom_pagination_serializer(self): -        objects = ['john', 'paul', 'george', 'ringo'] -        paginator = NonIntegerPaginator(objects, 2) - -        request = APIRequestFactory().get('/foobar') -        serializer = CustomPaginationSerializer( -            instance=paginator.page(), -            context={'request': request} -        ) -        expected = { -            'links': { -                'next': 'http://testserver/foobar?page={0}'.format(objects[2]), -                'prev': None -            }, -            'total_results': None, -            'objects': objects[:2] -        } -        self.assertEqual(serializer.data, expected) +def test_get_displayed_page_numbers(): +    """ +    Test our contextual page display function. -        request = APIRequestFactory().get('/foobar') -        serializer = CustomPaginationSerializer( -            instance=paginator.page('george'), -            context={'request': request} -        ) -        expected = { -            'links': { -                'next': None, -                'prev': 'http://testserver/foobar?page={0}'.format(objects[0]), -            }, -            'total_results': None, -            'objects': objects[2:] -        } -        self.assertEqual(serializer.data, expected) +    This determines which pages to display in a pagination control, +    given the current page and the last page. +    """ +    displayed_page_numbers = pagination._get_displayed_page_numbers + +    # At five pages or less, all pages are displayed, always. +    assert displayed_page_numbers(1, 5) == [1, 2, 3, 4, 5] +    assert displayed_page_numbers(2, 5) == [1, 2, 3, 4, 5] +    assert displayed_page_numbers(3, 5) == [1, 2, 3, 4, 5] +    assert displayed_page_numbers(4, 5) == [1, 2, 3, 4, 5] +    assert displayed_page_numbers(5, 5) == [1, 2, 3, 4, 5] + +    # Between six and either pages we may have a single page break. +    assert displayed_page_numbers(1, 6) == [1, 2, 3, None, 6] +    assert displayed_page_numbers(2, 6) == [1, 2, 3, None, 6] +    assert displayed_page_numbers(3, 6) == [1, 2, 3, 4, 5, 6] +    assert displayed_page_numbers(4, 6) == [1, 2, 3, 4, 5, 6] +    assert displayed_page_numbers(5, 6) == [1, None, 4, 5, 6] +    assert displayed_page_numbers(6, 6) == [1, None, 4, 5, 6] + +    assert displayed_page_numbers(1, 7) == [1, 2, 3, None, 7] +    assert displayed_page_numbers(2, 7) == [1, 2, 3, None, 7] +    assert displayed_page_numbers(3, 7) == [1, 2, 3, 4, None, 7] +    assert displayed_page_numbers(4, 7) == [1, 2, 3, 4, 5, 6, 7] +    assert displayed_page_numbers(5, 7) == [1, None, 4, 5, 6, 7] +    assert displayed_page_numbers(6, 7) == [1, None, 5, 6, 7] +    assert displayed_page_numbers(7, 7) == [1, None, 5, 6, 7] + +    assert displayed_page_numbers(1, 8) == [1, 2, 3, None, 8] +    assert displayed_page_numbers(2, 8) == [1, 2, 3, None, 8] +    assert displayed_page_numbers(3, 8) == [1, 2, 3, 4, None, 8] +    assert displayed_page_numbers(4, 8) == [1, 2, 3, 4, 5, None, 8] +    assert displayed_page_numbers(5, 8) == [1, None, 4, 5, 6, 7, 8] +    assert displayed_page_numbers(6, 8) == [1, None, 5, 6, 7, 8] +    assert displayed_page_numbers(7, 8) == [1, None, 6, 7, 8] +    assert displayed_page_numbers(8, 8) == [1, None, 6, 7, 8] + +    # At nine or more pages we may have two page breaks, one on each side. +    assert displayed_page_numbers(1, 9) == [1, 2, 3, None, 9] +    assert displayed_page_numbers(2, 9) == [1, 2, 3, None, 9] +    assert displayed_page_numbers(3, 9) == [1, 2, 3, 4, None, 9] +    assert displayed_page_numbers(4, 9) == [1, 2, 3, 4, 5, None, 9] +    assert displayed_page_numbers(5, 9) == [1, None, 4, 5, 6, None, 9] +    assert displayed_page_numbers(6, 9) == [1, None, 5, 6, 7, 8, 9] +    assert displayed_page_numbers(7, 9) == [1, None, 6, 7, 8, 9] +    assert displayed_page_numbers(8, 9) == [1, None, 7, 8, 9] +    assert displayed_page_numbers(9, 9) == [1, None, 7, 8, 9] diff --git a/tests/test_parsers.py b/tests/test_parsers.py index 54455cf6..8816065a 100644 --- a/tests/test_parsers.py +++ b/tests/test_parsers.py @@ -101,7 +101,9 @@ class TestFileUploadParser(TestCase):          self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8--ÀĥƦ.txt')          filename = parser.get_filename(self.stream, None, self.parser_context) -        self.assertEqual(filename, 'fallback.txt') +        # Malformed. Either None or 'fallback.txt' will be acceptable. +        # See also https://code.djangoproject.com/ticket/24209 +        self.assertIn(filename, ('fallback.txt', None))      def __replace_content_disposition(self, disposition):          self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = disposition diff --git a/tests/test_relations.py b/tests/test_relations.py index 62353dc2..fbe176e2 100644 --- a/tests/test_relations.py +++ b/tests/test_relations.py @@ -1,6 +1,8 @@  from .utils import mock_reverse, fail_reverse, BadType, MockObject, MockQueryset  from django.core.exceptions import ImproperlyConfigured +from django.utils.datastructures import MultiValueDict  from rest_framework import serializers +from rest_framework.fields import empty  from rest_framework.test import APISimpleTestCase  import pytest @@ -33,7 +35,7 @@ class TestPrimaryKeyRelatedField(APISimpleTestCase):          with pytest.raises(serializers.ValidationError) as excinfo:              self.field.to_internal_value(4)          msg = excinfo.value.detail[0] -        assert msg == "Invalid pk '4' - object does not exist." +        assert msg == 'Invalid pk "4" - object does not exist.'      def test_pk_related_lookup_invalid_type(self):          with pytest.raises(serializers.ValidationError) as excinfo: @@ -134,3 +136,34 @@ class TestSlugRelatedField(APISimpleTestCase):      def test_representation(self):          representation = self.field.to_representation(self.instance)          assert representation == self.instance.name + + +class TestManyRelatedField(APISimpleTestCase): +    def setUp(self): +        self.instance = MockObject(pk=1, name='foo') +        self.field = serializers.StringRelatedField(many=True) +        self.field.field_name = 'foo' + +    def test_get_value_regular_dictionary_full(self): +        assert 'bar' == self.field.get_value({'foo': 'bar'}) +        assert empty == self.field.get_value({'baz': 'bar'}) + +    def test_get_value_regular_dictionary_partial(self): +        setattr(self.field.root, 'partial', True) +        assert 'bar' == self.field.get_value({'foo': 'bar'}) +        assert empty == self.field.get_value({'baz': 'bar'}) + +    def test_get_value_multi_dictionary_full(self): +        mvd = MultiValueDict({'foo': ['bar1', 'bar2']}) +        assert ['bar1', 'bar2'] == self.field.get_value(mvd) + +        mvd = MultiValueDict({'baz': ['bar1', 'bar2']}) +        assert [] == self.field.get_value(mvd) + +    def test_get_value_multi_dictionary_partial(self): +        setattr(self.field.root, 'partial', True) +        mvd = MultiValueDict({'foo': ['bar1', 'bar2']}) +        assert ['bar1', 'bar2'] == self.field.get_value(mvd) + +        mvd = MultiValueDict({'baz': ['bar1', 'bar2']}) +        assert empty == self.field.get_value(mvd) diff --git a/tests/test_relations_hyperlink.py b/tests/test_relations_hyperlink.py index f1b882ed..aede61d2 100644 --- a/tests/test_relations_hyperlink.py +++ b/tests/test_relations_hyperlink.py @@ -1,5 +1,5 @@  from __future__ import unicode_literals -from django.conf.urls import patterns, url +from django.conf.urls import url  from django.test import TestCase  from rest_framework import serializers  from rest_framework.test import APIRequestFactory @@ -14,8 +14,7 @@ request = factory.get('/')  # Just to ensure we have a request in the serializer  dummy_view = lambda request, pk: None -urlpatterns = patterns( -    '', +urlpatterns = [      url(r'^dummyurl/(?P<pk>[0-9]+)/$', dummy_view, name='dummy-url'),      url(r'^manytomanysource/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanysource-detail'),      url(r'^manytomanytarget/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanytarget-detail'), @@ -24,7 +23,7 @@ urlpatterns = patterns(      url(r'^nullableforeignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableforeignkeysource-detail'),      url(r'^onetoonetarget/(?P<pk>[0-9]+)/$', dummy_view, name='onetoonetarget-detail'),      url(r'^nullableonetoonesource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableonetoonesource-detail'), -) +]  # ManyToMany diff --git a/tests/test_renderers.py b/tests/test_renderers.py index 7b78f7ba..f68405f0 100644 --- a/tests/test_renderers.py +++ b/tests/test_renderers.py @@ -1,6 +1,5 @@  # -*- coding: utf-8 -*-  from __future__ import unicode_literals -  from django.conf.urls import patterns, url, include  from django.core.cache import cache  from django.db import models @@ -8,6 +7,7 @@ from django.test import TestCase  from django.utils import six  from django.utils.translation import ugettext_lazy as _  from rest_framework import status, permissions +from rest_framework.compat import OrderedDict  from rest_framework.response import Response  from rest_framework.views import APIView  from rest_framework.renderers import BaseRenderer, JSONRenderer, BrowsableAPIRenderer @@ -15,7 +15,6 @@ from rest_framework.settings import api_settings  from rest_framework.test import APIRequestFactory  from collections import MutableMapping  import json -import pickle  import re @@ -408,84 +407,46 @@ class CacheRenderTest(TestCase):      urls = 'tests.test_renderers' -    cache_key = 'just_a_cache_key' - -    @classmethod -    def _get_pickling_errors(cls, obj, seen=None): -        """ Return any errors that would be raised if `obj' is pickled -        Courtesy of koffie @ http://stackoverflow.com/a/7218986/109897 -        """ -        if seen is None: -            seen = [] -        try: -            state = obj.__getstate__() -        except AttributeError: -            return -        if state is None: -            return -        if isinstance(state, tuple): -            if not isinstance(state[0], dict): -                state = state[1] -            else: -                state = state[0].update(state[1]) -        result = {} -        for i in state: -            try: -                pickle.dumps(state[i], protocol=2) -            except pickle.PicklingError: -                if not state[i] in seen: -                    seen.append(state[i]) -                    result[i] = cls._get_pickling_errors(state[i], seen) -        return result - -    def http_resp(self, http_method, url): -        """ -        Simple wrapper for Client http requests -        Removes the `client' and `request' attributes from as they are -        added by django.test.client.Client and not part of caching -        responses outside of tests. -        """ -        method = getattr(self.client, http_method) -        resp = method(url) -        resp._closable_objects = [] -        del resp.client, resp.request -        try: -            del resp.wsgi_request -        except AttributeError: -            pass -        return resp - -    def test_obj_pickling(self): -        """ -        Test that responses are properly pickled -        """ -        resp = self.http_resp('get', '/cache') - -        # Make sure that no pickling errors occurred -        self.assertEqual(self._get_pickling_errors(resp), {}) - -        # Unfortunately LocMem backend doesn't raise PickleErrors but returns -        # None instead. -        cache.set(self.cache_key, resp) -        self.assertTrue(cache.get(self.cache_key) is not None) -      def test_head_caching(self):          """          Test caching of HEAD requests          """ -        resp = self.http_resp('head', '/cache') -        cache.set(self.cache_key, resp) - -        cached_resp = cache.get(self.cache_key) -        self.assertIsInstance(cached_resp, Response) +        response = self.client.head('/cache') +        cache.set('key', response) +        cached_response = cache.get('key') +        assert isinstance(cached_response, Response) +        assert cached_response.content == response.content +        assert cached_response.status_code == response.status_code      def test_get_caching(self):          """          Test caching of GET requests          """ -        resp = self.http_resp('get', '/cache') -        cache.set(self.cache_key, resp) +        response = self.client.get('/cache') +        cache.set('key', response) +        cached_response = cache.get('key') +        assert isinstance(cached_response, Response) +        assert cached_response.content == response.content +        assert cached_response.status_code == response.status_code + + +class TestJSONIndentationStyles: +    def test_indented(self): +        renderer = JSONRenderer() +        data = OrderedDict([('a', 1), ('b', 2)]) +        assert renderer.render(data) == b'{"a":1,"b":2}' -        cached_resp = cache.get(self.cache_key) -        self.assertIsInstance(cached_resp, Response) -        self.assertEqual(cached_resp.content, resp.content) +    def test_compact(self): +        renderer = JSONRenderer() +        data = OrderedDict([('a', 1), ('b', 2)]) +        context = {'indent': 4} +        assert ( +            renderer.render(data, renderer_context=context) == +            b'{\n    "a": 1,\n    "b": 2\n}' +        ) + +    def test_long_form(self): +        renderer = JSONRenderer() +        renderer.compact = False +        data = OrderedDict([('a', 1), ('b', 2)]) +        assert renderer.render(data) == b'{"a": 1, "b": 2}' diff --git a/tests/test_routers.py b/tests/test_routers.py index 34306146..948c69bb 100644 --- a/tests/test_routers.py +++ b/tests/test_routers.py @@ -1,17 +1,53 @@  from __future__ import unicode_literals -from django.conf.urls import patterns, url, include +from django.conf.urls import url, include  from django.db import models  from django.test import TestCase  from django.core.exceptions import ImproperlyConfigured -from rest_framework import serializers, viewsets, mixins, permissions +from rest_framework import serializers, viewsets, permissions  from rest_framework.decorators import detail_route, list_route  from rest_framework.response import Response  from rest_framework.routers import SimpleRouter, DefaultRouter  from rest_framework.test import APIRequestFactory +from collections import namedtuple  factory = APIRequestFactory() -urlpatterns = patterns('',) + +class RouterTestModel(models.Model): +    uuid = models.CharField(max_length=20) +    text = models.CharField(max_length=200) + + +class NoteSerializer(serializers.HyperlinkedModelSerializer): +    url = serializers.HyperlinkedIdentityField(view_name='routertestmodel-detail', lookup_field='uuid') + +    class Meta: +        model = RouterTestModel +        fields = ('url', 'uuid', 'text') + + +class NoteViewSet(viewsets.ModelViewSet): +    queryset = RouterTestModel.objects.all() +    serializer_class = NoteSerializer +    lookup_field = 'uuid' + + +class MockViewSet(viewsets.ModelViewSet): +    queryset = None +    serializer_class = None + + +notes_router = SimpleRouter() +notes_router.register(r'notes', NoteViewSet) + +namespaced_router = DefaultRouter() +namespaced_router.register(r'example', MockViewSet, base_name='example') + +urlpatterns = [ +    url(r'^non-namespaced/', include(namespaced_router.urls)), +    url(r'^namespaced/', include(namespaced_router.urls, namespace='example')), +    url(r'^example/', include(notes_router.urls)), +]  class BasicViewSet(viewsets.ViewSet): @@ -63,9 +99,26 @@ class TestSimpleRouter(TestCase):                  self.assertEqual(route.mapping[method], endpoint) -class RouterTestModel(models.Model): -    uuid = models.CharField(max_length=20) -    text = models.CharField(max_length=200) +class TestRootView(TestCase): +    urls = 'tests.test_routers' + +    def test_retrieve_namespaced_root(self): +        response = self.client.get('/namespaced/') +        self.assertEqual( +            response.data, +            { +                "example": "http://testserver/namespaced/example/", +            } +        ) + +    def test_retrieve_non_namespaced_root(self): +        response = self.client.get('/non-namespaced/') +        self.assertEqual( +            response.data, +            { +                "example": "http://testserver/non-namespaced/example/", +            } +        )  class TestCustomLookupFields(TestCase): @@ -75,51 +128,29 @@ class TestCustomLookupFields(TestCase):      urls = 'tests.test_routers'      def setUp(self): -        class NoteSerializer(serializers.HyperlinkedModelSerializer): -            url = serializers.HyperlinkedIdentityField(view_name='routertestmodel-detail', lookup_field='uuid') - -            class Meta: -                model = RouterTestModel -                fields = ('url', 'uuid', 'text') - -        class NoteViewSet(viewsets.ModelViewSet): -            queryset = RouterTestModel.objects.all() -            serializer_class = NoteSerializer -            lookup_field = 'uuid' - -        self.router = SimpleRouter() -        self.router.register(r'notes', NoteViewSet) - -        from tests import test_routers -        urls = getattr(test_routers, 'urlpatterns') -        urls += patterns( -            '', -            url(r'^', include(self.router.urls)), -        ) -          RouterTestModel.objects.create(uuid='123', text='foo bar')      def test_custom_lookup_field_route(self): -        detail_route = self.router.urls[-1] +        detail_route = notes_router.urls[-1]          detail_url_pattern = detail_route.regex.pattern          self.assertIn('<uuid>', detail_url_pattern)      def test_retrieve_lookup_field_list_view(self): -        response = self.client.get('/notes/') +        response = self.client.get('/example/notes/')          self.assertEqual(              response.data,              [{ -                "url": "http://testserver/notes/123/", +                "url": "http://testserver/example/notes/123/",                  "uuid": "123", "text": "foo bar"              }]          )      def test_retrieve_lookup_field_detail_view(self): -        response = self.client.get('/notes/123/') +        response = self.client.get('/example/notes/123/')          self.assertEqual(              response.data,              { -                "url": "http://testserver/notes/123/", +                "url": "http://testserver/example/notes/123/",                  "uuid": "123", "text": "foo bar"              }          ) @@ -149,7 +180,7 @@ class TestLookupValueRegex(TestCase):  class TestTrailingSlashIncluded(TestCase):      def setUp(self):          class NoteViewSet(viewsets.ModelViewSet): -            model = RouterTestModel +            queryset = RouterTestModel.objects.all()          self.router = SimpleRouter()          self.router.register(r'notes', NoteViewSet) @@ -164,7 +195,7 @@ class TestTrailingSlashIncluded(TestCase):  class TestTrailingSlashRemoved(TestCase):      def setUp(self):          class NoteViewSet(viewsets.ModelViewSet): -            model = RouterTestModel +            queryset = RouterTestModel.objects.all()          self.router = SimpleRouter(trailing_slash=False)          self.router.register(r'notes', NoteViewSet) @@ -179,7 +210,8 @@ class TestTrailingSlashRemoved(TestCase):  class TestNameableRoot(TestCase):      def setUp(self):          class NoteViewSet(viewsets.ModelViewSet): -            model = RouterTestModel +            queryset = RouterTestModel.objects.all() +          self.router = DefaultRouter()          self.router.root_view_name = 'nameable-root'          self.router.register(r'notes', NoteViewSet) @@ -261,6 +293,14 @@ class DynamicListAndDetailViewSet(viewsets.ViewSet):      def detail_route_get(self, request, *args, **kwargs):          return Response({'method': 'link2'}) +    @list_route(url_path="list_custom-route") +    def list_custom_route_get(self, request, *args, **kwargs): +        return Response({'method': 'link1'}) + +    @detail_route(url_path="detail_custom-route") +    def detail_custom_route_get(self, request, *args, **kwargs): +        return Response({'method': 'link2'}) +  class TestDynamicListAndDetailRouter(TestCase):      def setUp(self): @@ -269,35 +309,30 @@ class TestDynamicListAndDetailRouter(TestCase):      def test_list_and_detail_route_decorators(self):          routes = self.router.get_routes(DynamicListAndDetailViewSet)          decorator_routes = [r for r in routes if not (r.name.endswith('-list') or r.name.endswith('-detail'))] + +        MethodNamesMap = namedtuple('MethodNamesMap', 'method_name url_path')          # Make sure all these endpoints exist and none have been clobbered -        for i, endpoint in enumerate(['list_route_get', 'list_route_post', 'detail_route_get', 'detail_route_post']): +        for i, endpoint in enumerate([MethodNamesMap('list_custom_route_get', 'list_custom-route'), +                                      MethodNamesMap('list_route_get', 'list_route_get'), +                                      MethodNamesMap('list_route_post', 'list_route_post'), +                                      MethodNamesMap('detail_custom_route_get', 'detail_custom-route'), +                                      MethodNamesMap('detail_route_get', 'detail_route_get'), +                                      MethodNamesMap('detail_route_post', 'detail_route_post') +                                      ]):              route = decorator_routes[i]              # check url listing -            if endpoint.startswith('list_'): +            method_name = endpoint.method_name +            url_path = endpoint.url_path + +            if method_name.startswith('list_'):                  self.assertEqual(route.url, -                                 '^{{prefix}}/{0}{{trailing_slash}}$'.format(endpoint)) +                                 '^{{prefix}}/{0}{{trailing_slash}}$'.format(url_path))              else:                  self.assertEqual(route.url, -                                 '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(endpoint)) +                                 '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(url_path))              # check method to function mapping -            if endpoint.endswith('_post'): +            if method_name.endswith('_post'):                  method_map = 'post'              else:                  method_map = 'get' -            self.assertEqual(route.mapping[method_map], endpoint) - - -class TestRootWithAListlessViewset(TestCase): -    def setUp(self): -        class NoteViewSet(mixins.RetrieveModelMixin, -                          viewsets.GenericViewSet): -            model = RouterTestModel - -        self.router = DefaultRouter() -        self.router.register(r'notes', NoteViewSet) -        self.view = self.router.urls[0].callback - -    def test_api_root(self): -        request = factory.get('/') -        response = self.view(request) -        self.assertEqual(response.data, {}) +            self.assertEqual(route.mapping[method_map], method_name) diff --git a/tests/test_serializer.py b/tests/test_serializer.py index c17b6d8c..b7a0484b 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -1,7 +1,9 @@  # coding: utf-8  from __future__ import unicode_literals +from .utils import MockObject  from rest_framework import serializers  from rest_framework.compat import unicode_repr +import pickle  import pytest @@ -216,3 +218,80 @@ class TestUnicodeRepr:          instance = ExampleObject()          serializer = ExampleSerializer(instance)          repr(serializer)  # Should not error. + + +class TestNotRequiredOutput: +    def test_not_required_output_for_dict(self): +        """ +        'required=False' should allow a dictionary key to be missing in output. +        """ +        class ExampleSerializer(serializers.Serializer): +            omitted = serializers.CharField(required=False) +            included = serializers.CharField() + +        serializer = ExampleSerializer(data={'included': 'abc'}) +        serializer.is_valid() +        assert serializer.data == {'included': 'abc'} + +    def test_not_required_output_for_object(self): +        """ +        'required=False' should allow an object attribute to be missing in output. +        """ +        class ExampleSerializer(serializers.Serializer): +            omitted = serializers.CharField(required=False) +            included = serializers.CharField() + +            def create(self, validated_data): +                return MockObject(**validated_data) + +        serializer = ExampleSerializer(data={'included': 'abc'}) +        serializer.is_valid() +        serializer.save() +        assert serializer.data == {'included': 'abc'} + +    def test_default_required_output_for_dict(self): +        """ +        'default="something"' should require dictionary key. + +        We need to handle this as the field will have an implicit +        'required=False', but it should still have a value. +        """ +        class ExampleSerializer(serializers.Serializer): +            omitted = serializers.CharField(default='abc') +            included = serializers.CharField() + +        serializer = ExampleSerializer({'included': 'abc'}) +        with pytest.raises(KeyError): +            serializer.data + +    def test_default_required_output_for_object(self): +        """ +        'default="something"' should require object attribute. + +        We need to handle this as the field will have an implicit +        'required=False', but it should still have a value. +        """ +        class ExampleSerializer(serializers.Serializer): +            omitted = serializers.CharField(default='abc') +            included = serializers.CharField() + +        instance = MockObject(included='abc') +        serializer = ExampleSerializer(instance) +        with pytest.raises(AttributeError): +            serializer.data + + +class TestCacheSerializerData: +    def test_cache_serializer_data(self): +        """ +        Caching serializer data with pickle will drop the serializer info, +        but does preserve the data itself. +        """ +        class ExampleSerializer(serializers.Serializer): +            field1 = serializers.CharField() +            field2 = serializers.CharField() + +        serializer = ExampleSerializer({'field1': 'a', 'field2': 'b'}) +        pickled = pickle.dumps(serializer.data) +        data = pickle.loads(pickled) +        assert data == {'field1': 'a', 'field2': 'b'} diff --git a/tests/test_serializer_bulk_update.py b/tests/test_serializer_bulk_update.py index fb881a75..bc955b2e 100644 --- a/tests/test_serializer_bulk_update.py +++ b/tests/test_serializer_bulk_update.py @@ -101,7 +101,7 @@ class BulkCreateSerializerTests(TestCase):          serializer = self.BookSerializer(data=data, many=True)          self.assertEqual(serializer.is_valid(), False) -        expected_errors = {'non_field_errors': ['Expected a list of items but got type `int`.']} +        expected_errors = {'non_field_errors': ['Expected a list of items but got type "int".']}          self.assertEqual(serializer.errors, expected_errors) @@ -118,6 +118,6 @@ class BulkCreateSerializerTests(TestCase):          serializer = self.BookSerializer(data=data, many=True)          self.assertEqual(serializer.is_valid(), False) -        expected_errors = {'non_field_errors': ['Expected a list of items but got type `dict`.']} +        expected_errors = {'non_field_errors': ['Expected a list of items but got type "dict".']}          self.assertEqual(serializer.errors, expected_errors) diff --git a/tests/test_versioning.py b/tests/test_versioning.py index c44f727d..553463d1 100644 --- a/tests/test_versioning.py +++ b/tests/test_versioning.py @@ -1,9 +1,13 @@ +from .utils import UsingURLPatterns  from django.conf.urls import include, url +from rest_framework import serializers  from rest_framework import status, versioning  from rest_framework.decorators import APIView  from rest_framework.response import Response  from rest_framework.reverse import reverse  from rest_framework.test import APIRequestFactory, APITestCase +from rest_framework.versioning import NamespaceVersioning +import pytest  class RequestVersionView(APIView): @@ -28,17 +32,8 @@ class RequestInvalidVersionView(APIView):  factory = APIRequestFactory() -mock_view = lambda request: None - -included_patterns = [ -    url(r'^namespaced/$', mock_view, name='another'), -] - -urlpatterns = [ -    url(r'^v1/', include(included_patterns, namespace='v1')), -    url(r'^another/$', mock_view, name='another'), -    url(r'^(?P<version>[^/]+)/another/$', mock_view, name='another') -] +dummy_view = lambda request: None +dummy_pk_view = lambda request, pk: None  class TestRequestVersion: @@ -114,8 +109,17 @@ class TestRequestVersion:          assert response.data == {'version': None} -class TestURLReversing(APITestCase): -    urls = 'tests.test_versioning' +class TestURLReversing(UsingURLPatterns, APITestCase): +    included = [ +        url(r'^namespaced/$', dummy_view, name='another'), +        url(r'^example/(?P<pk>\d+)/$', dummy_pk_view, name='example-detail') +    ] + +    urlpatterns = [ +        url(r'^v1/', include(included, namespace='v1')), +        url(r'^another/$', dummy_view, name='another'), +        url(r'^(?P<version>[^/]+)/another/$', dummy_view, name='another'), +    ]      def test_reverse_unversioned(self):          view = ReverseView.as_view() @@ -221,3 +225,35 @@ class TestInvalidVersion:          request.resolver_match = FakeResolverMatch          response = view(request, version='v3')          assert response.status_code == status.HTTP_404_NOT_FOUND + + +class TestHyperlinkedRelatedField(UsingURLPatterns, APITestCase): +    included = [ +        url(r'^namespaced/(?P<pk>\d+)/$', dummy_view, name='namespaced'), +    ] + +    urlpatterns = [ +        url(r'^v1/', include(included, namespace='v1')), +        url(r'^v2/', include(included, namespace='v2')) +    ] + +    def setUp(self): +        super(TestHyperlinkedRelatedField, self).setUp() + +        class MockQueryset(object): +            def get(self, pk): +                return 'object %s' % pk + +        self.field = serializers.HyperlinkedRelatedField( +            view_name='namespaced', +            queryset=MockQueryset() +        ) +        request = factory.get('/') +        request.versioning_scheme = NamespaceVersioning() +        request.version = 'v1' +        self.field._context = {'request': request} + +    def test_bug_2489(self): +        assert self.field.to_internal_value('/v1/namespaced/3/') == 'object 3' +        with pytest.raises(serializers.ValidationError): +            self.field.to_internal_value('/v2/namespaced/3/') diff --git a/tests/utils.py b/tests/utils.py index 5e902ba9..b9034996 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,30 +1,29 @@ -from contextlib import contextmanager  from django.core.exceptions import ObjectDoesNotExist  from django.core.urlresolvers import NoReverseMatch -from django.utils import six -from rest_framework.settings import api_settings -@contextmanager -def temporary_setting(setting, value, module=None): +class UsingURLPatterns(object):      """ -    Temporarily change value of setting for test. +    Isolates URL patterns used during testing on the test class itself. +    For example: -    Optionally reload given module, useful when module uses value of setting on -    import. -    """ -    original_value = getattr(api_settings, setting) -    setattr(api_settings, setting, value) - -    if module is not None: -        six.moves.reload_module(module) +    class MyTestCase(UsingURLPatterns, TestCase): +        urlpatterns = [ +            ... +        ] -    yield +        def test_something(self): +            ... +    """ +    urls = __name__ -    setattr(api_settings, setting, original_value) +    def setUp(self): +        global urlpatterns +        urlpatterns = self.urlpatterns -    if module is not None: -        six.moves.reload_module(module) +    def tearDown(self): +        global urlpatterns +        urlpatterns = []  class MockObject(object): | 
