diff options
| author | Mark Aaron Shirley | 2013-03-12 20:59:25 -0700 | 
|---|---|---|
| committer | Mark Aaron Shirley | 2013-03-14 15:17:13 -0700 | 
| commit | 3006e3825f29e920f881b816fd71566bf0e8d341 (patch) | |
| tree | 729a93c40c06769470ca57c770fcc89f8ed8b05e /rest_framework | |
| parent | b6b686d285e376dbf4f2d2f15bd0e3ef0f1c3a37 (diff) | |
| download | django-rest-framework-3006e3825f29e920f881b816fd71566bf0e8d341.tar.bz2 | |
One-to-one writable, nested serializer support
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/serializers.py | 44 | ||||
| -rw-r--r-- | rest_framework/tests/nesting.py | 125 | ||||
| -rw-r--r-- | rest_framework/tests/serializer_nested.py | 4 | 
3 files changed, 160 insertions, 13 deletions
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index f83451d3..893db2ec 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -26,13 +26,17 @@ class NestedValidationError(ValidationError):      if the messages are a list of error messages.      In the case of nested serializers, where the parent has many children, -    then the child's `serializer.errors` will be a list of dicts. +    then the child's `serializer.errors` will be a list of dicts.  In the case +    of a single child, the `serializer.errors` will be a dict.      We need to override the default behavior to get properly nested error dicts.      """      def __init__(self, message): -        self.messages = message +        if isinstance(message, dict): +            self.messages = [message] +        else: +            self.messages = message  class DictWithMetadata(dict): @@ -143,6 +147,7 @@ class BaseSerializer(WritableField):          self._data = None          self._files = None          self._errors = None +        self._delete = False      #####      # Methods to determine which fields to use when (de)serializing objects. @@ -354,15 +359,19 @@ class BaseSerializer(WritableField):                  raise ValidationError(self.error_messages['required'])              return -        if self.parent.object: -            # Set the serializer object if it exists -            obj = getattr(self.parent.object, field_name) -            self.object = obj +        # Set the serializer object if it exists +        obj = getattr(self.parent.object, field_name) if self.parent.object else None          if value in (None, ''): -            into[(self.source or field_name)] = None +            if isinstance(self, ModelSerializer): +                self._delete = True +                self.object = obj +                into[(self.source or field_name)] = self +            else: +                into[(self.source or field_name)] = None          else:              kwargs = { +                'instance': obj,                  'data': value,                  'context': self.context,                  'partial': self.partial, @@ -371,8 +380,10 @@ class BaseSerializer(WritableField):              serializer = self.__class__(**kwargs)              if serializer.is_valid(): -                self.object = serializer.object -                into[self.source or field_name] = serializer.object +                if isinstance(serializer, ModelSerializer): +                    into[self.source or field_name] = serializer +                else: +                    into[self.source or field_name] = serializer.object              else:                  # Propagate errors up to our parent                  raise NestedValidationError(serializer.errors) @@ -664,10 +675,17 @@ class ModelSerializer(Serializer):          if instance:              return self.full_clean(instance) -    def save_object(self, obj): +    def save_object(self, obj, parent=None, fk_field=None):          """          Save the deserialized object and return it.          """ +        if self._delete: +            obj.delete() +            return + +        if parent and fk_field: +            setattr(self.object, fk_field, parent) +          obj.save()          if getattr(self, 'm2m_data', None): @@ -677,7 +695,11 @@ class ModelSerializer(Serializer):          if getattr(self, 'related_data', None):              for accessor_name, object_list in self.related_data.items(): -                setattr(self.object, accessor_name, object_list) +                if isinstance(object_list, ModelSerializer): +                    fk_field = self.object._meta.get_field_by_name(accessor_name)[0].field.name +                    object_list.save_object(object_list.object, parent=self.object, fk_field=fk_field) +                else: +                    setattr(self.object, accessor_name, object_list)              self.related_data = {} diff --git a/rest_framework/tests/nesting.py b/rest_framework/tests/nesting.py new file mode 100644 index 00000000..35b7a365 --- /dev/null +++ b/rest_framework/tests/nesting.py @@ -0,0 +1,125 @@ +from __future__ import unicode_literals +from django.db import models +from django.test import TestCase +from rest_framework import serializers + + +class OneToOneTarget(models.Model): +    name = models.CharField(max_length=100) + + +class OneToOneTargetSource(models.Model): +    name = models.CharField(max_length=100) +    target = models.OneToOneField(OneToOneTarget, null=True, blank=True, +                                  related_name='target_source') + + +class OneToOneSource(models.Model): +    name = models.CharField(max_length=100) +    target_source = models.OneToOneField(OneToOneTargetSource, related_name='source') + + +class OneToOneSourceSerializer(serializers.ModelSerializer): +    class Meta: +        model = OneToOneSource +        exclude = ('target_source', ) + + +class OneToOneTargetSourceSerializer(serializers.ModelSerializer): +    source = OneToOneSourceSerializer() + +    class Meta: +        model = OneToOneTargetSource +        exclude = ('target', ) + +class OneToOneTargetSerializer(serializers.ModelSerializer): +    target_source = OneToOneTargetSourceSerializer() + +    class Meta: +        model = OneToOneTarget + + +class NestedOneToOneTests(TestCase): +    def setUp(self): +        for idx in range(1, 4): +            target = OneToOneTarget(name='target-%d' % idx) +            target.save() +            target_source = OneToOneTargetSource(name='target-source-%d' % idx, target=target) +            target_source.save() +            source = OneToOneSource(name='source-%d' % idx, target_source=target_source) +            source.save() + +    def test_one_to_one_retrieve(self): +        queryset = OneToOneTarget.objects.all() +        serializer = OneToOneTargetSerializer(queryset) +        expected = [ +            {'id': 1, 'name': 'target-1', 'target_source': {'id': 1, 'name': 'target-source-1', 'source': {'id': 1, 'name': 'source-1'}}}, +            {'id': 2, 'name': 'target-2', 'target_source': {'id': 2, 'name': 'target-source-2', 'source': {'id': 2, 'name': 'source-2'}}}, +            {'id': 3, 'name': 'target-3', 'target_source': {'id': 3, 'name': 'target-source-3', 'source': {'id': 3, 'name': 'source-3'}}}             +        ] +        self.assertEqual(serializer.data, expected) +         + +    def test_one_to_one_create(self): +        data = {'id': 4, 'name': 'target-4', 'target_source': {'id': 4, 'name': 'target-source-4', 'source': {'id': 4, 'name': 'source-4'}}} +        serializer = OneToOneTargetSerializer(data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'target-4') + +        # Ensure (target 4, target_source 4, source 4) are added, and +        # everything else is as expected. +        queryset = OneToOneTarget.objects.all() +        serializer = OneToOneTargetSerializer(queryset) +        expected = [ +            {'id': 1, 'name': 'target-1', 'target_source': {'id': 1, 'name': 'target-source-1', 'source': {'id': 1, 'name': 'source-1'}}}, +            {'id': 2, 'name': 'target-2', 'target_source': {'id': 2, 'name': 'target-source-2', 'source': {'id': 2, 'name': 'source-2'}}}, +            {'id': 3, 'name': 'target-3', 'target_source': {'id': 3, 'name': 'target-source-3', 'source': {'id': 3, 'name': 'source-3'}}}, +            {'id': 4, 'name': 'target-4', 'target_source': {'id': 4, 'name': 'target-source-4', 'source': {'id': 4, 'name': 'source-4'}}} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_one_to_one_create_with_invalid_data(self): +        data = {'id': 4, 'name': 'target-4', 'target_source': {'id': 4, 'name': 'target-source-4', 'source': {'id': 4}}} +        serializer = OneToOneTargetSerializer(data=data) +        self.assertFalse(serializer.is_valid()) +        self.assertEqual(serializer.errors, {'target_source': [{'source': [{'name': ['This field is required.']}]}]}) + +    def test_one_to_one_update(self): +        data = {'id': 3, 'name': 'target-3-updated', 'target_source': {'id': 3, 'name': 'target-source-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}}} +        instance = OneToOneTarget.objects.get(pk=3) +        serializer = OneToOneTargetSerializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() +        self.assertEqual(serializer.data, data) +        self.assertEqual(obj.name, 'target-3-updated') + +        # Ensure (target 3, target_source 3, source 3) are updated, +        # and everything else is as expected. +        queryset = OneToOneTarget.objects.all() +        serializer = OneToOneTargetSerializer(queryset) +        expected = [ +            {'id': 1, 'name': 'target-1', 'target_source': {'id': 1, 'name': 'target-source-1', 'source': {'id': 1, 'name': 'source-1'}}}, +            {'id': 2, 'name': 'target-2', 'target_source': {'id': 2, 'name': 'target-source-2', 'source': {'id': 2, 'name': 'source-2'}}}, +            {'id': 3, 'name': 'target-3-updated', 'target_source': {'id': 3, 'name': 'target-source-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}}} +        ] +        self.assertEqual(serializer.data, expected) + +    def test_one_to_one_delete(self): +        data = {'id': 3, 'name': 'target-3', 'target_source': None} +        instance = OneToOneTarget.objects.get(pk=3) +        serializer = OneToOneTargetSerializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        obj = serializer.save() + +        # Ensure (target_source 3, source 3) are deleted, +        # and everything else is as expected. +        queryset = OneToOneTarget.objects.all() +        serializer = OneToOneTargetSerializer(queryset) +        expected = [ +            {'id': 1, 'name': 'target-1', 'target_source': {'id': 1, 'name': 'target-source-1', 'source': {'id': 1, 'name': 'source-1'}}}, +            {'id': 2, 'name': 'target-2', 'target_source': {'id': 2, 'name': 'target-source-2', 'source': {'id': 2, 'name': 'source-2'}}}, +            {'id': 3, 'name': 'target-3', 'target_source': None} +        ] +        self.assertEqual(serializer.data, expected) diff --git a/rest_framework/tests/serializer_nested.py b/rest_framework/tests/serializer_nested.py index fcf644c7..299c3bc5 100644 --- a/rest_framework/tests/serializer_nested.py +++ b/rest_framework/tests/serializer_nested.py @@ -124,7 +124,7 @@ class WritableNestedSerializerObjectTests(TestCase):              def __init__(self, order, title, duration):                  self.order, self.title, self.duration = order, title, duration -            def __cmp__(self, other): +            def __eq__(self, other):                  return (                      self.order == other.order and                      self.title == other.title and @@ -135,7 +135,7 @@ class WritableNestedSerializerObjectTests(TestCase):              def __init__(self, album_name, artist, tracks):                  self.album_name, self.artist, self.tracks = album_name, artist, tracks -            def __cmp__(self, other): +            def __eq__(self, other):                  return (                      self.album_name == other.album_name and                      self.artist == other.artist and  | 
