diff options
| author | Tom Christie | 2012-10-08 12:52:56 +0100 | 
|---|---|---|
| committer | Tom Christie | 2012-10-08 12:52:56 +0100 | 
| commit | 52ba2e333375c6829fb89b6b43e4d19b2f2a86a4 (patch) | |
| tree | d45e35751961ae6cac9813a2af073098b32f7e7e | |
| parent | 4fd8ab17a3e935d72bb4ec25ed8f16a21ec2c0ef (diff) | |
| download | django-rest-framework-52ba2e333375c6829fb89b6b43e4d19b2f2a86a4.tar.bz2 | |
Fix #285
| -rw-r--r-- | rest_framework/mixins.py | 11 | ||||
| -rw-r--r-- | rest_framework/serializers.py | 23 | ||||
| -rw-r--r-- | rest_framework/tests/generics.py | 37 | ||||
| -rw-r--r-- | rest_framework/tests/models.py | 8 | 
4 files changed, 61 insertions, 18 deletions
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 46821f64..7cfbe030 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -77,16 +77,15 @@ class UpdateModelMixin(object):              self.object = None          serializer = self.get_serializer(data=request.DATA, instance=self.object) +          if serializer.is_valid():              if self.object is None: -                obj = serializer.object -                # TODO: Make ModelSerializers return regular instances, -                # not DeserializedObject -                if hasattr(obj, 'object'): -                    obj = obj.object -                self.update_urlconf_attributes(serializer.object.object) +                # If PUT occurs to a non existant object, we need to set any +                # attributes on the object that are implicit in the URL. +                self.update_urlconf_attributes(serializer.object)              self.object = serializer.save()              return Response(serializer.data) +          return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)      def update_urlconf_attributes(self, obj): diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index ba8bf8ad..1770c4ce 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -2,7 +2,6 @@ import copy  import datetime  import types  from decimal import Decimal -from django.core.serializers.base import DeserializedObject  from django.db import models  from django.utils.datastructures import SortedDict  from rest_framework.compat import get_concrete_model @@ -224,9 +223,6 @@ class BaseSerializer(Field):          """          Serialize objects -> primatives.          """ -        if isinstance(obj, DeserializedObject): -            obj = obj.object -          if isinstance(obj, dict):              return dict([(key, self.to_native(val))                           for (key, val) in obj.items()]) @@ -383,23 +379,30 @@ class ModelSerializer(Serializer):          """          Restore the model instance.          """ +        self.m2m_data = {} +          if instance:              for key, val in attrs.items():                  setattr(instance, key, val) -            return DeserializedObject(instance) +            return instance -        m2m_data = {}          for field in self.opts.model._meta.many_to_many:              if field.name in attrs: -                m2m_data[field.name] = attrs.pop(field.name) -        return DeserializedObject(self.opts.model(**attrs), m2m_data) +                self.m2m_data[field.name] = attrs.pop(field.name) +        return self.opts.model(**attrs) -    def save(self): +    def save(self, save_m2m=True):          """          Save the deserialized object and return it.          """          self.object.save() -        return self.object.object + +        if self.m2m_data and save_m2m: +            for accessor_name, object_list in self.m2m_data.items(): +                setattr(self.object, accessor_name, object_list) +            self.m2m_data = {} + +        return self.object  class HyperlinkedModelSerializerOptions(ModelSerializerOptions): diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py index 2a6a0744..f4263478 100644 --- a/rest_framework/tests/generics.py +++ b/rest_framework/tests/generics.py @@ -1,8 +1,8 @@  from django.test import TestCase  from django.test.client import RequestFactory  from django.utils import simplejson as json -from rest_framework import generics, status -from rest_framework.tests.models import BasicModel +from rest_framework import generics, serializers, status +from rest_framework.tests.models import BasicModel, Comment  factory = RequestFactory() @@ -223,3 +223,36 @@ class TestInstanceView(TestCase):          self.assertEquals(response.data, {'id': 1, 'text': 'foobar'})          updated = self.objects.get(id=1)          self.assertEquals(updated.text, 'foobar') + + +# Regression test for #285 + +class CommentSerializer(serializers.ModelSerializer): +    class Meta: +        model = Comment +        exclude = ('created',) + + +class CommentView(generics.ListCreateAPIView): +    serializer_class = CommentSerializer +    model = Comment + + +class TestCreateModelWithAutoNowAddField(TestCase): +    def setUp(self): +        self.objects = Comment.objects +        self.view = CommentView.as_view() + +    def test_create_model_with_auto_now_add_field(self): +        """ +        Regression test for #285 + +        https://github.com/tomchristie/django-rest-framework/issues/285 +        """ +        content = {'email': 'foobar@example.com', 'content': 'foobar'} +        request = factory.post('/', json.dumps(content), +                               content_type='application/json') +        response = self.view(request).render() +        self.assertEquals(response.status_code, status.HTTP_201_CREATED) +        created = self.objects.get(id=1) +        self.assertEquals(created.content, 'foobar') diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py index 7c7f485b..6a758f0c 100644 --- a/rest_framework/tests/models.py +++ b/rest_framework/tests/models.py @@ -83,3 +83,11 @@ class TaggedItem(RESTFrameworkModel):  class Bookmark(RESTFrameworkModel):      url = models.URLField()      tags = GenericRelation(TaggedItem) + + +# Model for regression test for #285 + +class Comment(RESTFrameworkModel): +    email = models.EmailField() +    content = models.CharField(max_length=200) +    created = models.DateTimeField(auto_now_add=True)  | 
