diff options
| author | Tom Christie | 2013-01-22 09:11:38 +0000 | 
|---|---|---|
| committer | Tom Christie | 2013-01-22 09:11:38 +0000 | 
| commit | b7ab2aee46c718f683b19eefba1b48f233da40e4 (patch) | |
| tree | 1af09c7dbcc939c749d30adf25b14d232200f44f /rest_framework | |
| parent | 65b62d64ec54b528b62a1500b8f6ffe216d45c09 (diff) | |
| parent | e29ba356f054222893655901923811bd9675d4cc (diff) | |
| download | django-rest-framework-b7ab2aee46c718f683b19eefba1b48f233da40e4.tar.bz2 | |
Merge branch 'master' into unauthenticated_response
Conflicts:
	docs/api-guide/authentication.md
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/__init__.py | 2 | ||||
| -rw-r--r-- | rest_framework/authtoken/views.py | 3 | ||||
| -rw-r--r-- | rest_framework/decorators.py | 9 | ||||
| -rw-r--r-- | rest_framework/relations.py | 41 | ||||
| -rw-r--r-- | rest_framework/serializers.py | 25 | ||||
| -rw-r--r-- | rest_framework/templates/rest_framework/base.html | 2 | ||||
| -rw-r--r-- | rest_framework/tests/decorators.py | 23 | ||||
| -rw-r--r-- | rest_framework/tests/fields.py | 49 | ||||
| -rw-r--r-- | rest_framework/tests/models.py | 11 | ||||
| -rw-r--r-- | rest_framework/tests/pagination.py | 4 | ||||
| -rw-r--r-- | rest_framework/tests/relations_hyperlink.py | 50 | ||||
| -rw-r--r-- | rest_framework/tests/relations_nested.py | 34 | ||||
| -rw-r--r-- | rest_framework/tests/relations_pk.py | 39 | ||||
| -rw-r--r-- | rest_framework/tests/relations_slug.py | 257 | ||||
| -rw-r--r-- | rest_framework/tests/urlpatterns.py | 78 | ||||
| -rw-r--r-- | rest_framework/urlpatterns.py | 45 | ||||
| -rw-r--r-- | rest_framework/utils/encoders.py | 4 | 
17 files changed, 614 insertions, 62 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index 1d25ee7f..bc267fad 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -1,3 +1,3 @@ -__version__ = '2.1.15' +__version__ = '2.1.16'  VERSION = __version__  # synonym diff --git a/rest_framework/authtoken/views.py b/rest_framework/authtoken/views.py index d318c723..7c03cb76 100644 --- a/rest_framework/authtoken/views.py +++ b/rest_framework/authtoken/views.py @@ -12,10 +12,11 @@ class ObtainAuthToken(APIView):      permission_classes = ()      parser_classes = (parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser,)      renderer_classes = (renderers.JSONRenderer,) +    serializer_class = AuthTokenSerializer      model = Token      def post(self, request): -        serializer = AuthTokenSerializer(data=request.DATA) +        serializer = self.serializer_class(data=request.DATA)          if serializer.is_valid():              token, created = Token.objects.get_or_create(user=serializer.object['user'])              return Response({'token': token.key}) diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index 1b710a03..7a4103e1 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -1,4 +1,5 @@  from rest_framework.views import APIView +import types  def api_view(http_method_names): @@ -23,6 +24,14 @@ def api_view(http_method_names):          #         pass          #     WrappedAPIView.__doc__ = func.doc    <--- Not possible to do this +        # api_view applied without (method_names) +        assert not(isinstance(http_method_names, types.FunctionType)), \ +            '@api_view missing list of allowed HTTP methods' + +        # api_view applied with eg. string instead of list of strings +        assert isinstance(http_method_names, (list, tuple)), \ +            '@api_view expected a list of strings, recieved %s' % type(http_method_names).__name__ +          allowed_methods = set(http_method_names) | set(('options',))          WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods] diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 0d93f448..af63ceaa 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -101,7 +101,13 @@ class RelatedField(WritableField):      ### Regular serializer stuff...      def field_to_native(self, obj, field_name): -        value = getattr(obj, self.source or field_name) +        try: +            value = getattr(obj, self.source or field_name) +        except ObjectDoesNotExist: +            return None + +        if value is None: +            return None          return self.to_native(value)      def field_from_native(self, data, files, field_name, into): @@ -171,7 +177,7 @@ class PrimaryKeyRelatedField(RelatedField):      default_error_messages = {          'does_not_exist': _("Invalid pk '%s' - object does not exist."), -        'invalid': _('Invalid value.'), +        'incorrect_type': _('Incorrect type.  Expected pk value, received %s.'),      }      # TODO: Remove these field hacks... @@ -202,7 +208,8 @@ class PrimaryKeyRelatedField(RelatedField):              msg = self.error_messages['does_not_exist'] % smart_unicode(data)              raise ValidationError(msg)          except (TypeError, ValueError): -            msg = self.error_messages['invalid'] +            received = type(data).__name__ +            msg = self.error_messages['incorrect_type'] % received              raise ValidationError(msg)      def field_to_native(self, obj, field_name): @@ -211,7 +218,10 @@ class PrimaryKeyRelatedField(RelatedField):              pk = obj.serializable_value(self.source or field_name)          except AttributeError:              # RelatedObject (reverse relationship) -            obj = getattr(obj, self.source or field_name) +            try: +                obj = getattr(obj, self.source or field_name) +            except ObjectDoesNotExist: +                return None              return self.to_native(obj.pk)          # Forward relationship          return self.to_native(pk) @@ -226,7 +236,7 @@ class ManyPrimaryKeyRelatedField(ManyRelatedField):      default_error_messages = {          'does_not_exist': _("Invalid pk '%s' - object does not exist."), -        'invalid': _('Invalid value.'), +        'incorrect_type': _('Incorrect type.  Expected pk value, received %s.'),      }      def prepare_value(self, obj): @@ -266,7 +276,8 @@ class ManyPrimaryKeyRelatedField(ManyRelatedField):              msg = self.error_messages['does_not_exist'] % smart_unicode(data)              raise ValidationError(msg)          except (TypeError, ValueError): -            msg = self.error_messages['invalid'] +            received = type(data).__name__ +            msg = self.error_messages['incorrect_type'] % received              raise ValidationError(msg)  ### Slug relationships @@ -324,7 +335,7 @@ class HyperlinkedRelatedField(RelatedField):          'incorrect_match': _('Invalid hyperlink - Incorrect URL match'),          'configuration_error': _('Invalid hyperlink due to configuration error'),          'does_not_exist': _("Invalid hyperlink - object does not exist."), -        'invalid': _('Invalid value.'), +        'incorrect_type': _('Incorrect type.  Expected url string, received %s.'),      }      def __init__(self, *args, **kwargs): @@ -367,13 +378,13 @@ class HyperlinkedRelatedField(RelatedField):          kwargs = {self.slug_url_kwarg: slug}          try: -            return reverse(self.view_name, kwargs=kwargs, request=request, format=format) +            return reverse(view_name, kwargs=kwargs, request=request, format=format)          except:              pass          kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug}          try: -            return reverse(self.view_name, kwargs=kwargs, request=request, format=format) +            return reverse(view_name, kwargs=kwargs, request=request, format=format)          except:              pass @@ -388,8 +399,8 @@ class HyperlinkedRelatedField(RelatedField):          try:              http_prefix = value.startswith('http:') or value.startswith('https:')          except AttributeError: -            msg = self.error_messages['invalid'] -            raise ValidationError(msg) +            msg = self.error_messages['incorrect_type'] +            raise ValidationError(msg % type(value).__name__)          if http_prefix:              # If needed convert absolute URLs to relative path @@ -425,8 +436,8 @@ class HyperlinkedRelatedField(RelatedField):          except ObjectDoesNotExist:              raise ValidationError(self.error_messages['does_not_exist'])          except (TypeError, ValueError): -            msg = self.error_messages['invalid'] -            raise ValidationError(msg) +            msg = self.error_messages['incorrect_type'] +            raise ValidationError(msg % type(value).__name__)          return obj @@ -490,13 +501,13 @@ class HyperlinkedIdentityField(Field):          kwargs = {self.slug_url_kwarg: slug}          try: -            return reverse(self.view_name, kwargs=kwargs, request=request, format=format) +            return reverse(view_name, kwargs=kwargs, request=request, format=format)          except:              pass          kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug}          try: -            return reverse(self.view_name, kwargs=kwargs, request=request, format=format) +            return reverse(view_name, kwargs=kwargs, request=request, format=format)          except:              pass diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index fa92838b..27458f96 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -298,15 +298,18 @@ class BaseSerializer(Field):          Override default so that we can apply ModelSerializer as a nested          field to relationships.          """ -        if self.source: -            for component in self.source.split('.'): -                obj = getattr(obj, component) +        try: +            if self.source: +                for component in self.source.split('.'): +                    obj = getattr(obj, component) +                    if is_simple_callable(obj): +                        obj = obj() +            else: +                obj = getattr(obj, field_name)                  if is_simple_callable(obj):                      obj = obj() -        else: -            obj = getattr(obj, field_name) -            if is_simple_callable(obj): -                obj = value() +        except ObjectDoesNotExist: +            return None          # If the object has an "all" method, assume it's a relationship          if is_simple_callable(getattr(obj, 'all', None)): @@ -412,7 +415,7 @@ class ModelSerializer(Serializer):          """          Returns a default instance of the pk field.          """ -        return Field() +        return self.get_field(model_field)      def get_nested_field(self, model_field):          """ @@ -430,7 +433,7 @@ class ModelSerializer(Serializer):          # TODO: filter queryset using:          # .using(db).complex_filter(self.rel.limit_choices_to)          kwargs = { -            'null': model_field.null, +            'null': model_field.null or model_field.blank,              'queryset': model_field.rel.to._default_manager          } @@ -449,6 +452,9 @@ class ModelSerializer(Serializer):          if model_field.null or model_field.blank:              kwargs['required'] = False +        if isinstance(model_field, models.AutoField) or not model_field.editable: +            kwargs['read_only'] = True +          if model_field.has_default():              kwargs['required'] = False              kwargs['default'] = model_field.get_default() @@ -462,6 +468,7 @@ class ModelSerializer(Serializer):              return ChoiceField(**kwargs)          field_mapping = { +            models.AutoField: IntegerField,              models.FloatField: FloatField,              models.IntegerField: IntegerField,              models.PositiveIntegerField: IntegerField, diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html index 42e49cb9..092bf2e4 100644 --- a/rest_framework/templates/rest_framework/base.html +++ b/rest_framework/templates/rest_framework/base.html @@ -112,7 +112,7 @@              <div class="request-info">                  <pre class="prettyprint"><b>{{ request.method }}</b> {{ request.get_full_path }}</pre> -            <div> +            </div>              <div class="response-info">                  <pre class="prettyprint"><div class="meta nocode"><b>HTTP {{ response.status_code }} {{ response.status_text }}</b>{% autoescape off %}  {% for key, val in response.items %}<b>{{ key }}:</b> <span class="lit">{{ val|urlize_quoted_links }}</span> diff --git a/rest_framework/tests/decorators.py b/rest_framework/tests/decorators.py index bc44a45b..82f912e9 100644 --- a/rest_framework/tests/decorators.py +++ b/rest_framework/tests/decorators.py @@ -1,5 +1,4 @@  from django.test import TestCase -from django.test.client import RequestFactory  from rest_framework import status  from rest_framework.response import Response  from rest_framework.renderers import JSONRenderer @@ -29,13 +28,27 @@ class DecoratorTestCase(TestCase):          response.request = request          return APIView.finalize_response(self, request, response, *args, **kwargs) -    def test_wrap_view(self): +    def test_api_view_incorrect(self): +        """ +        If @api_view is not applied correct, we should raise an assertion. +        """ -        @api_view(['GET']) +        @api_view          def view(request): -            return Response({}) +            return Response() + +        request = self.factory.get('/') +        self.assertRaises(AssertionError, view, request) + +    def test_api_view_incorrect_arguments(self): +        """ +        If @api_view is missing arguments, we should raise an assertion. +        """ -        self.assertTrue(isinstance(view.cls_instance, APIView)) +        with self.assertRaises(AssertionError): +            @api_view('GET') +            def view(request): +                return Response()      def test_calling_method(self): diff --git a/rest_framework/tests/fields.py b/rest_framework/tests/fields.py new file mode 100644 index 00000000..8068272d --- /dev/null +++ b/rest_framework/tests/fields.py @@ -0,0 +1,49 @@ +""" +General serializer field tests. +""" + +from django.db import models +from django.test import TestCase +from rest_framework import serializers + + +class TimestampedModel(models.Model): +    added = models.DateTimeField(auto_now_add=True) +    updated = models.DateTimeField(auto_now=True) + + +class CharPrimaryKeyModel(models.Model): +    id = models.CharField(max_length=20, primary_key=True) + + +class TimestampedModelSerializer(serializers.ModelSerializer): +    class Meta: +        model = TimestampedModel + + +class CharPrimaryKeyModelSerializer(serializers.ModelSerializer): +    class Meta: +        model = CharPrimaryKeyModel + + +class ReadOnlyFieldTests(TestCase): +    def test_auto_now_fields_read_only(self): +        """ +        auto_now and auto_now_add fields should be read_only by default. +        """ +        serializer = TimestampedModelSerializer() +        self.assertEquals(serializer.fields['added'].read_only, True) + +    def test_auto_pk_fields_read_only(self): +        """ +        AutoField fields should be read_only by default. +        """ +        serializer = TimestampedModelSerializer() +        self.assertEquals(serializer.fields['id'].read_only, True) + +    def test_non_auto_pk_fields_not_read_only(self): +        """ +        PK fields other than AutoField fields should not be read_only by default. +        """ +        serializer = CharPrimaryKeyModelSerializer() +        self.assertEquals(serializer.fields['id'].read_only, False) diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py index 59c35074..93f09761 100644 --- a/rest_framework/tests/models.py +++ b/rest_framework/tests/models.py @@ -205,3 +205,14 @@ class NullableForeignKeySource(RESTFrameworkModel):      name = models.CharField(max_length=100)      target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True,                                 related_name='nullable_sources') + + +# OneToOne +class OneToOneTarget(RESTFrameworkModel): +    name = models.CharField(max_length=100) + + +class NullableOneToOneSource(RESTFrameworkModel): +    name = models.CharField(max_length=100) +    target = models.OneToOneField(OneToOneTarget, null=True, blank=True, +                                  related_name='nullable_source') diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py index 81d297a1..3b550877 100644 --- a/rest_framework/tests/pagination.py +++ b/rest_framework/tests/pagination.py @@ -181,10 +181,10 @@ class UnitTestPagination(TestCase):          """          Ensure context gets passed through to the object serializer.          """ -        serializer = PassOnContextPaginationSerializer(self.first_page) +        serializer = PassOnContextPaginationSerializer(self.first_page, context={'foo': 'bar'})          serializer.data          results = serializer.fields[serializer.results_field] -        self.assertTrue(serializer.context is results.context) +        self.assertEquals(serializer.context, results.context)  class TestUnpaginated(TestCase): diff --git a/rest_framework/tests/relations_hyperlink.py b/rest_framework/tests/relations_hyperlink.py index a7f8a035..6d137f68 100644 --- a/rest_framework/tests/relations_hyperlink.py +++ b/rest_framework/tests/relations_hyperlink.py @@ -1,8 +1,8 @@ -from django.db import models  from django.test import TestCase  from rest_framework import serializers  from rest_framework.compat import patterns, url -from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource +from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource +  def dummy_view(request, pk):      pass @@ -13,8 +13,11 @@ urlpatterns = patterns('',      url(r'^foreignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeysource-detail'),      url(r'^foreignkeytarget/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeytarget-detail'),      url(r'^nullableforeignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableforeignkeysource-detail'), +    url(r'^onetoonetarget/(?P<pk>[0-9]+)/$', dummy_view, name='onetoonetarget-detail'), +    url(r'^nullableonetoonesource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableonetoonesource-detail'),  ) +  class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer):      sources = serializers.ManyHyperlinkedRelatedField(view_name='manytomanysource-detail') @@ -40,16 +43,17 @@ class ForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):  # Nullable ForeignKey +class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer): +    class Meta: +        model = NullableForeignKeySource -class NullableForeignKeySource(models.Model): -    name = models.CharField(max_length=100) -    target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True, -                               related_name='nullable_sources') +# OneToOne +class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer): +    nullable_source = serializers.HyperlinkedRelatedField(view_name='nullableonetoonesource-detail') -class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):      class Meta: -        model = NullableForeignKeySource +        model = OneToOneTarget  # TODO: Add test that .data cannot be accessed prior to .is_valid @@ -211,6 +215,13 @@ class HyperlinkedForeignKeyTests(TestCase):          ]          self.assertEquals(serializer.data, expected) +    def test_foreign_key_update_incorrect_type(self): +        data = {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': 2} +        instance = ForeignKeySource.objects.get(pk=1) +        serializer = ForeignKeySourceSerializer(instance, data=data) +        self.assertFalse(serializer.is_valid()) +        self.assertEquals(serializer.errors, {'target': [u'Incorrect type.  Expected url string, received int.']}) +      def test_reverse_foreign_key_update(self):          data = {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']}          instance = ForeignKeyTarget.objects.get(pk=2) @@ -223,7 +234,7 @@ class HyperlinkedForeignKeyTests(TestCase):          expected = [              {'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/1/', '/foreignkeysource/2/', '/foreignkeysource/3/']},              {'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': []}, -        ]         +        ]          self.assertEquals(new_serializer.data, expected)          serializer.save() @@ -409,3 +420,24 @@ class HyperlinkedNullableForeignKeyTests(TestCase):      #         {'id': 2, 'name': u'target-2', 'sources': []},      #     ]      #     self.assertEquals(serializer.data, expected) + + +class HyperlinkedNullableOneToOneTests(TestCase): +    urls = 'rest_framework.tests.relations_hyperlink' + +    def setUp(self): +        target = OneToOneTarget(name='target-1') +        target.save() +        new_target = OneToOneTarget(name='target-2') +        new_target.save() +        source = NullableOneToOneSource(name='source-1', target=target) +        source.save() + +    def test_reverse_foreign_key_retrieve_with_null(self): +        queryset = OneToOneTarget.objects.all() +        serializer = NullableOneToOneTargetSerializer(queryset) +        expected = [ +            {'url': '/onetoonetarget/1/', 'name': u'target-1', 'nullable_source': '/nullableonetoonesource/1/'}, +            {'url': '/onetoonetarget/2/', 'name': u'target-2', 'nullable_source': None}, +        ] +        self.assertEquals(serializer.data, expected) diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/relations_nested.py index 5710c1ef..0e129fae 100644 --- a/rest_framework/tests/relations_nested.py +++ b/rest_framework/tests/relations_nested.py @@ -1,7 +1,6 @@ -from django.db import models  from django.test import TestCase  from rest_framework import serializers -from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource +from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource  class ForeignKeySourceSerializer(serializers.ModelSerializer): @@ -28,6 +27,18 @@ class NullableForeignKeySourceSerializer(serializers.ModelSerializer):          model = NullableForeignKeySource +class NullableOneToOneSourceSerializer(serializers.ModelSerializer): +    class Meta: +        model = NullableOneToOneSource + + +class NullableOneToOneTargetSerializer(serializers.ModelSerializer): +    nullable_source = NullableOneToOneSourceSerializer() + +    class Meta: +        model = OneToOneTarget + +  class ReverseForeignKeyTests(TestCase):      def setUp(self):          target = ForeignKeyTarget(name='target-1') @@ -82,3 +93,22 @@ class NestedNullableForeignKeyTests(TestCase):              {'id': 3, 'name': u'source-3', 'target': None},          ]          self.assertEquals(serializer.data, expected) + + +class NestedNullableOneToOneTests(TestCase): +    def setUp(self): +        target = OneToOneTarget(name='target-1') +        target.save() +        new_target = OneToOneTarget(name='target-2') +        new_target.save() +        source = NullableOneToOneSource(name='source-1', target=target) +        source.save() + +    def test_reverse_foreign_key_retrieve_with_null(self): +        queryset = OneToOneTarget.objects.all() +        serializer = NullableOneToOneTargetSerializer(queryset) +        expected = [ +            {'id': 1, 'name': u'target-1', 'nullable_source': {'id': 1, 'name': u'source-1', 'target': 1}}, +            {'id': 2, 'name': u'target-2', 'nullable_source': None}, +        ] +        self.assertEquals(serializer.data, expected) diff --git a/rest_framework/tests/relations_pk.py b/rest_framework/tests/relations_pk.py index af6da2c0..3391e60a 100644 --- a/rest_framework/tests/relations_pk.py +++ b/rest_framework/tests/relations_pk.py @@ -1,7 +1,6 @@ -from django.db import models  from django.test import TestCase  from rest_framework import serializers -from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource +from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource  class ManyToManyTargetSerializer(serializers.ModelSerializer): @@ -33,6 +32,14 @@ class NullableForeignKeySourceSerializer(serializers.ModelSerializer):          model = NullableForeignKeySource +# OneToOne +class NullableOneToOneTargetSerializer(serializers.ModelSerializer): +    nullable_source = serializers.PrimaryKeyRelatedField() + +    class Meta: +        model = OneToOneTarget + +  # TODO: Add test that .data cannot be accessed prior to .is_valid  class PKManyToManyTests(TestCase): @@ -187,6 +194,13 @@ class PKForeignKeyTests(TestCase):          ]          self.assertEquals(serializer.data, expected) +    def test_foreign_key_update_incorrect_type(self): +        data = {'id': 1, 'name': u'source-1', 'target': 'foo'} +        instance = ForeignKeySource.objects.get(pk=1) +        serializer = ForeignKeySourceSerializer(instance, data=data) +        self.assertFalse(serializer.is_valid()) +        self.assertEquals(serializer.errors, {'target': [u'Incorrect type.  Expected pk value, received str.']}) +      def test_reverse_foreign_key_update(self):          data = {'id': 2, 'name': u'target-2', 'sources': [1, 3]}          instance = ForeignKeyTarget.objects.get(pk=2) @@ -199,7 +213,7 @@ class PKForeignKeyTests(TestCase):          expected = [              {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]},              {'id': 2, 'name': u'target-2', 'sources': []}, -        ]         +        ]          self.assertEquals(new_serializer.data, expected)          serializer.save() @@ -383,3 +397,22 @@ class PKNullableForeignKeyTests(TestCase):      #         {'id': 2, 'name': u'target-2', 'sources': []},      #     ]      #     self.assertEquals(serializer.data, expected) + + +class PKNullableOneToOneTests(TestCase): +    def setUp(self): +        target = OneToOneTarget(name='target-1') +        target.save() +        new_target = OneToOneTarget(name='target-2') +        new_target.save() +        source = NullableOneToOneSource(name='source-1', target=target) +        source.save() + +    def test_reverse_foreign_key_retrieve_with_null(self): +        queryset = OneToOneTarget.objects.all() +        serializer = NullableOneToOneTargetSerializer(queryset) +        expected = [ +            {'id': 1, 'name': u'target-1', 'nullable_source': 1}, +            {'id': 2, 'name': u'target-2', 'nullable_source': None}, +        ] +        self.assertEquals(serializer.data, expected) diff --git a/rest_framework/tests/relations_slug.py b/rest_framework/tests/relations_slug.py new file mode 100644 index 00000000..37ccc75e --- /dev/null +++ b/rest_framework/tests/relations_slug.py @@ -0,0 +1,257 @@ +from django.test import TestCase +from rest_framework import serializers +from rest_framework.tests.models import NullableForeignKeySource, ForeignKeySource, ForeignKeyTarget + + +class ForeignKeyTargetSerializer(serializers.ModelSerializer): +    sources = serializers.ManySlugRelatedField(slug_field='name') + +    class Meta: +        model = ForeignKeyTarget + + +class ForeignKeySourceSerializer(serializers.ModelSerializer): +    target = serializers.SlugRelatedField(slug_field='name') + +    class Meta: +        model = ForeignKeySource + + +class NullableForeignKeySourceSerializer(serializers.ModelSerializer): +    target = serializers.SlugRelatedField(slug_field='name', null=True) + +    class Meta: +        model = NullableForeignKeySource + + +# TODO: M2M Tests, FKTests (Non-nulable), One2One +class PKForeignKeyTests(TestCase): +    def setUp(self): +        target = ForeignKeyTarget(name='target-1') +        target.save() +        new_target = ForeignKeyTarget(name='target-2') +        new_target.save() +        for idx in range(1, 4): +            source = ForeignKeySource(name='source-%d' % idx, target=target) +            source.save() + +    def test_foreign_key_retrieve(self): +        queryset = ForeignKeySource.objects.all() +        serializer = ForeignKeySourceSerializer(queryset) +        expected = [ +            {'id': 1, 'name': u'source-1', 'target': 'target-1'}, +            {'id': 2, 'name': u'source-2', 'target': 'target-1'}, +            {'id': 3, 'name': u'source-3', 'target': 'target-1'} +        ] +        self.assertEquals(serializer.data, expected) + +    def test_reverse_foreign_key_retrieve(self): +        queryset = ForeignKeyTarget.objects.all() +        serializer = ForeignKeyTargetSerializer(queryset) +        expected = [ +            {'id': 1, 'name': u'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, +            {'id': 2, 'name': u'target-2', 'sources': []}, +        ] +        self.assertEquals(serializer.data, expected) + +    def test_foreign_key_update(self): +        data = {'id': 1, 'name': u'source-1', 'target': 'target-2'} +        instance = ForeignKeySource.objects.get(pk=1) +        serializer = ForeignKeySourceSerializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        self.assertEquals(serializer.data, data) +        serializer.save() + +        # Ensure source 1 is updated, and everything else is as expected +        queryset = ForeignKeySource.objects.all() +        serializer = ForeignKeySourceSerializer(queryset) +        expected = [ +            {'id': 1, 'name': u'source-1', 'target': 'target-2'}, +            {'id': 2, 'name': u'source-2', 'target': 'target-1'}, +            {'id': 3, 'name': u'source-3', 'target': 'target-1'} +        ] +        self.assertEquals(serializer.data, expected) + +    def test_foreign_key_update_incorrect_type(self): +        data = {'id': 1, 'name': u'source-1', 'target': 123} +        instance = ForeignKeySource.objects.get(pk=1) +        serializer = ForeignKeySourceSerializer(instance, data=data) +        self.assertFalse(serializer.is_valid()) +        self.assertEquals(serializer.errors, {'target': [u'Object with name=123 does not exist.']}) + +    def test_reverse_foreign_key_update(self): +        data = {'id': 2, 'name': u'target-2', 'sources': ['source-1', 'source-3']} +        instance = ForeignKeyTarget.objects.get(pk=2) +        serializer = ForeignKeyTargetSerializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        # We shouldn't have saved anything to the db yet since save +        # hasn't been called. +        queryset = ForeignKeyTarget.objects.all() +        new_serializer = ForeignKeyTargetSerializer(queryset) +        expected = [ +            {'id': 1, 'name': u'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, +            {'id': 2, 'name': u'target-2', 'sources': []}, +        ] +        self.assertEquals(new_serializer.data, expected) + +        serializer.save() +        self.assertEquals(serializer.data, data) + +        # Ensure target 2 is update, and everything else is as expected +        queryset = ForeignKeyTarget.objects.all() +        serializer = ForeignKeyTargetSerializer(queryset) +        expected = [ +            {'id': 1, 'name': u'target-1', 'sources': ['source-2']}, +            {'id': 2, 'name': u'target-2', 'sources': ['source-1', 'source-3']}, +        ] +        self.assertEquals(serializer.data, expected) + +    def test_foreign_key_create(self): +        data = {'id': 4, 'name': u'source-4', 'target': 'target-2'} +        serializer = ForeignKeySourceSerializer(data=data) +        serializer.is_valid() +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEquals(serializer.data, data) +        self.assertEqual(obj.name, u'source-4') + +        # Ensure source 4 is added, and everything else is as expected +        queryset = ForeignKeySource.objects.all() +        serializer = ForeignKeySourceSerializer(queryset) +        expected = [ +            {'id': 1, 'name': u'source-1', 'target': 'target-1'}, +            {'id': 2, 'name': u'source-2', 'target': 'target-1'}, +            {'id': 3, 'name': u'source-3', 'target': 'target-1'}, +            {'id': 4, 'name': u'source-4', 'target': 'target-2'}, +        ] +        self.assertEquals(serializer.data, expected) + +    def test_reverse_foreign_key_create(self): +        data = {'id': 3, 'name': u'target-3', 'sources': ['source-1', 'source-3']} +        serializer = ForeignKeyTargetSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEquals(serializer.data, data) +        self.assertEqual(obj.name, u'target-3') + +        # Ensure target 3 is added, and everything else is as expected +        queryset = ForeignKeyTarget.objects.all() +        serializer = ForeignKeyTargetSerializer(queryset) +        expected = [ +            {'id': 1, 'name': u'target-1', 'sources': ['source-2']}, +            {'id': 2, 'name': u'target-2', 'sources': []}, +            {'id': 3, 'name': u'target-3', 'sources': ['source-1', 'source-3']}, +        ] +        self.assertEquals(serializer.data, expected) + +    def test_foreign_key_update_with_invalid_null(self): +        data = {'id': 1, 'name': u'source-1', 'target': None} +        instance = ForeignKeySource.objects.get(pk=1) +        serializer = ForeignKeySourceSerializer(instance, data=data) +        self.assertFalse(serializer.is_valid()) +        self.assertEquals(serializer.errors, {'target': [u'Value may not be null']}) + + +class SlugNullableForeignKeyTests(TestCase): +    def setUp(self): +        target = ForeignKeyTarget(name='target-1') +        target.save() +        for idx in range(1, 4): +            if idx == 3: +                target = None +            source = NullableForeignKeySource(name='source-%d' % idx, target=target) +            source.save() + +    def test_foreign_key_retrieve_with_null(self): +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset) +        expected = [ +            {'id': 1, 'name': u'source-1', 'target': 'target-1'}, +            {'id': 2, 'name': u'source-2', 'target': 'target-1'}, +            {'id': 3, 'name': u'source-3', 'target': None}, +        ] +        self.assertEquals(serializer.data, expected) + +    def test_foreign_key_create_with_valid_null(self): +        data = {'id': 4, 'name': u'source-4', 'target': None} +        serializer = NullableForeignKeySourceSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEquals(serializer.data, data) +        self.assertEqual(obj.name, u'source-4') + +        # Ensure source 4 is created, and everything else is as expected +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset) +        expected = [ +            {'id': 1, 'name': u'source-1', 'target': 'target-1'}, +            {'id': 2, 'name': u'source-2', 'target': 'target-1'}, +            {'id': 3, 'name': u'source-3', 'target': None}, +            {'id': 4, 'name': u'source-4', 'target': None} +        ] +        self.assertEquals(serializer.data, expected) + +    def test_foreign_key_create_with_valid_emptystring(self): +        """ +        The emptystring should be interpreted as null in the context +        of relationships. +        """ +        data = {'id': 4, 'name': u'source-4', 'target': ''} +        expected_data = {'id': 4, 'name': u'source-4', 'target': None} +        serializer = NullableForeignKeySourceSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEquals(serializer.data, expected_data) +        self.assertEqual(obj.name, u'source-4') + +        # Ensure source 4 is created, and everything else is as expected +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset) +        expected = [ +            {'id': 1, 'name': u'source-1', 'target': 'target-1'}, +            {'id': 2, 'name': u'source-2', 'target': 'target-1'}, +            {'id': 3, 'name': u'source-3', 'target': None}, +            {'id': 4, 'name': u'source-4', 'target': None} +        ] +        self.assertEquals(serializer.data, expected) + +    def test_foreign_key_update_with_valid_null(self): +        data = {'id': 1, 'name': u'source-1', 'target': None} +        instance = NullableForeignKeySource.objects.get(pk=1) +        serializer = NullableForeignKeySourceSerializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        self.assertEquals(serializer.data, data) +        serializer.save() + +        # Ensure source 1 is updated, and everything else is as expected +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset) +        expected = [ +            {'id': 1, 'name': u'source-1', 'target': None}, +            {'id': 2, 'name': u'source-2', 'target': 'target-1'}, +            {'id': 3, 'name': u'source-3', 'target': None} +        ] +        self.assertEquals(serializer.data, expected) + +    def test_foreign_key_update_with_valid_emptystring(self): +        """ +        The emptystring should be interpreted as null in the context +        of relationships. +        """ +        data = {'id': 1, 'name': u'source-1', 'target': ''} +        expected_data = {'id': 1, 'name': u'source-1', 'target': None} +        instance = NullableForeignKeySource.objects.get(pk=1) +        serializer = NullableForeignKeySourceSerializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        self.assertEquals(serializer.data, expected_data) +        serializer.save() + +        # Ensure source 1 is updated, and everything else is as expected +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset) +        expected = [ +            {'id': 1, 'name': u'source-1', 'target': None}, +            {'id': 2, 'name': u'source-2', 'target': 'target-1'}, +            {'id': 3, 'name': u'source-3', 'target': None} +        ] +        self.assertEquals(serializer.data, expected) diff --git a/rest_framework/tests/urlpatterns.py b/rest_framework/tests/urlpatterns.py new file mode 100644 index 00000000..43e8ef69 --- /dev/null +++ b/rest_framework/tests/urlpatterns.py @@ -0,0 +1,78 @@ +from collections import namedtuple + +from django.core import urlresolvers + +from django.test import TestCase +from django.test.client import RequestFactory + +from rest_framework.compat import patterns, url, include +from rest_framework.urlpatterns import format_suffix_patterns + + +# A container class for test paths for the test case +URLTestPath = namedtuple('URLTestPath', ['path', 'args', 'kwargs']) + + +def dummy_view(request, *args, **kwargs): +    pass + + +class FormatSuffixTests(TestCase): +    """ +    Tests `format_suffix_patterns` against different URLPatterns to ensure the URLs still resolve properly, including any captured parameters. +    """ +    def _resolve_urlpatterns(self, urlpatterns, test_paths): +        factory = RequestFactory() +        try: +            urlpatterns = format_suffix_patterns(urlpatterns) +        except: +            self.fail("Failed to apply `format_suffix_patterns` on  the supplied urlpatterns") +        resolver = urlresolvers.RegexURLResolver(r'^/', urlpatterns) +        for test_path in test_paths: +            request = factory.get(test_path.path) +            try: +                callback, callback_args, callback_kwargs = resolver.resolve(request.path_info) +            except: +                self.fail("Failed to resolve URL: %s" % request.path_info) +            self.assertEquals(callback_args, test_path.args) +            self.assertEquals(callback_kwargs, test_path.kwargs) + +    def test_format_suffix(self): +        urlpatterns = patterns( +            '', +            url(r'^test$', dummy_view), +        ) +        test_paths = [ +            URLTestPath('/test', (), {}), +            URLTestPath('/test.api', (), {'format': 'api'}), +            URLTestPath('/test.asdf', (), {'format': 'asdf'}), +        ] +        self._resolve_urlpatterns(urlpatterns, test_paths) + +    def test_default_args(self): +        urlpatterns = patterns( +            '', +            url(r'^test$', dummy_view, {'foo': 'bar'}), +        ) +        test_paths = [ +            URLTestPath('/test', (), {'foo': 'bar', }), +            URLTestPath('/test.api', (), {'foo': 'bar', 'format': 'api'}), +            URLTestPath('/test.asdf', (), {'foo': 'bar', 'format': 'asdf'}), +        ] +        self._resolve_urlpatterns(urlpatterns, test_paths) + +    def test_included_urls(self): +        nested_patterns = patterns( +            '', +            url(r'^path$', dummy_view) +        ) +        urlpatterns = patterns( +            '', +            url(r'^test/', include(nested_patterns), {'foo': 'bar'}), +        ) +        test_paths = [ +            URLTestPath('/test/path', (), {'foo': 'bar', }), +            URLTestPath('/test/path.api', (), {'foo': 'bar', 'format': 'api'}), +            URLTestPath('/test/path.asdf', (), {'foo': 'bar', 'format': 'asdf'}), +        ] +        self._resolve_urlpatterns(urlpatterns, test_paths) diff --git a/rest_framework/urlpatterns.py b/rest_framework/urlpatterns.py index 143928c9..47789026 100644 --- a/rest_framework/urlpatterns.py +++ b/rest_framework/urlpatterns.py @@ -1,5 +1,35 @@ -from rest_framework.compat import url +from rest_framework.compat import url, include  from rest_framework.settings import api_settings +from django.core.urlresolvers import RegexURLResolver + + +def apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required): +    ret = [] +    for urlpattern in urlpatterns: +        if isinstance(urlpattern, RegexURLResolver): +            # Set of included URL patterns +            regex = urlpattern.regex.pattern +            namespace = urlpattern.namespace +            app_name = urlpattern.app_name +            kwargs = urlpattern.default_kwargs +            # Add in the included patterns, after applying the suffixes +            patterns = apply_suffix_patterns(urlpattern.url_patterns, +                                             suffix_pattern, +                                             suffix_required) +            ret.append(url(regex, include(patterns, namespace, app_name), kwargs)) + +        else: +            # Regular URL pattern +            regex = urlpattern.regex.pattern.rstrip('$') + suffix_pattern +            view = urlpattern._callback or urlpattern._callback_str +            kwargs = urlpattern.default_args +            name = urlpattern.name +            # Add in both the existing and the new urlpattern +            if not suffix_required: +                ret.append(urlpattern) +            ret.append(url(regex, view, kwargs, name)) + +    return ret  def format_suffix_patterns(urlpatterns, suffix_required=False, allowed=None): @@ -28,15 +58,4 @@ def format_suffix_patterns(urlpatterns, suffix_required=False, allowed=None):      else:          suffix_pattern = r'\.(?P<%s>[a-z]+)$' % suffix_kwarg -    ret = [] -    for urlpattern in urlpatterns: -        # Form our complementing '.format' urlpattern -        regex = urlpattern.regex.pattern.rstrip('$') + suffix_pattern -        view = urlpattern._callback or urlpattern._callback_str -        kwargs = urlpattern.default_args -        name = urlpattern.name -        # Add in both the existing and the new urlpattern -        if not suffix_required: -            ret.append(urlpattern) -        ret.append(url(regex, view, kwargs, name)) -    return ret +    return apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required) diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py index c70b24dd..7afe100a 100644 --- a/rest_framework/utils/encoders.py +++ b/rest_framework/utils/encoders.py @@ -12,7 +12,7 @@ from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata  class JSONEncoder(json.JSONEncoder):      """ -    JSONEncoder subclass that knows how to encode date/time, +    JSONEncoder subclass that knows how to encode date/time/timedelta,      decimal types, and generators.      """      def default(self, o): @@ -34,6 +34,8 @@ class JSONEncoder(json.JSONEncoder):              if o.microsecond:                  r = r[:12]              return r +        elif isinstance(o, datetime.timedelta): +            return str(o.total_seconds())          elif isinstance(o, decimal.Decimal):              return str(o)          elif hasattr(o, '__iter__'):  | 
