aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
authorTom Christie2013-04-08 14:48:45 -0700
committerTom Christie2013-04-08 14:48:45 -0700
commitce8ffd390a61bbeab0cbff1b52070a1076c94988 (patch)
tree0c9db3ae55236e282c45e8f1cddcbb6c58ca30a2 /rest_framework
parentd97e72cdb2f4fcc5aa2c19527a2b2ff11cf784bb (diff)
parent73efa96de983fc644328d2fc498651aa917a2272 (diff)
downloaddjango-rest-framework-ce8ffd390a61bbeab0cbff1b52070a1076c94988.tar.bz2
Merge pull request #753 from maspwr/writable-nested-modelserializer
one-many writable nested modelserializer
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/serializers.py56
-rw-r--r--rest_framework/tests/relations_nested.py98
-rw-r--r--rest_framework/tests/serializer_bulk_update.py6
3 files changed, 139 insertions, 21 deletions
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 668bcc49..73cad00f 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -130,14 +130,14 @@ class BaseSerializer(WritableField):
def __init__(self, instance=None, data=None, files=None,
context=None, partial=False, many=None,
- allow_delete=False, **kwargs):
+ allow_add_remove=False, **kwargs):
super(BaseSerializer, self).__init__(**kwargs)
self.opts = self._options_class(self.Meta)
self.parent = None
self.root = None
self.partial = partial
self.many = many
- self.allow_delete = allow_delete
+ self.allow_add_remove = allow_add_remove
self.context = context or {}
@@ -154,8 +154,8 @@ class BaseSerializer(WritableField):
if many and instance is not None and not hasattr(instance, '__iter__'):
raise ValueError('instance should be a queryset or other iterable with many=True')
- if allow_delete and not many:
- raise ValueError('allow_delete should only be used for bulk updates, but you have not set many=True')
+ if allow_add_remove and not many:
+ raise ValueError('allow_add_remove should only be used for bulk updates, but you have not set many=True')
#####
# Methods to determine which fields to use when (de)serializing objects.
@@ -288,8 +288,15 @@ class BaseSerializer(WritableField):
You should override this method to control how deserialized objects
are instantiated.
"""
+ removed_relations = []
+
+ # Deleted related objects
+ if self._deleted:
+ removed_relations = list(self._deleted)
+
if instance is not None:
instance.update(attrs)
+ instance._removed_relations = removed_relations
return instance
return attrs
@@ -377,6 +384,7 @@ class BaseSerializer(WritableField):
# Set the serializer object if it exists
obj = getattr(self.parent.object, field_name) if self.parent.object else None
+ obj = obj.all() if is_simple_callable(getattr(obj, 'all', None)) else obj
if value in (None, ''):
into[(self.source or field_name)] = None
@@ -386,7 +394,8 @@ class BaseSerializer(WritableField):
'data': value,
'context': self.context,
'partial': self.partial,
- 'many': self.many
+ 'many': self.many,
+ 'allow_add_remove': self.allow_add_remove
}
serializer = self.__class__(**kwargs)
@@ -496,6 +505,9 @@ class BaseSerializer(WritableField):
def save_object(self, obj, **kwargs):
obj.save(**kwargs)
+ if self.allow_add_remove and hasattr(obj, '_removed_relations'):
+ [self.delete_object(item) for item in obj._removed_relations]
+
def delete_object(self, obj):
obj.delete()
@@ -508,7 +520,7 @@ class BaseSerializer(WritableField):
else:
self.save_object(self.object, **kwargs)
- if self.allow_delete and self._deleted:
+ if self.allow_add_remove and self._deleted:
[self.delete_object(item) for item in self._deleted]
return self.object
@@ -699,6 +711,7 @@ class ModelSerializer(Serializer):
m2m_data = {}
related_data = {}
nested_forward_relations = {}
+ removed_relations = []
meta = self.opts.model._meta
# Reverse fk or one-to-one relations
@@ -724,6 +737,10 @@ class ModelSerializer(Serializer):
if isinstance(self.fields.get(field_name, None), Serializer):
nested_forward_relations[field_name] = attrs[field_name]
+ # Deleted related objects
+ if self._deleted:
+ removed_relations = list(self._deleted)
+
# Update an existing instance...
if instance is not None:
for key, val in attrs.items():
@@ -740,6 +757,7 @@ class ModelSerializer(Serializer):
instance._related_data = related_data
instance._m2m_data = m2m_data
instance._nested_forward_relations = nested_forward_relations
+ instance._removed_relations = removed_relations
return instance
@@ -764,6 +782,9 @@ class ModelSerializer(Serializer):
obj.save(**kwargs)
+ if self.allow_add_remove and hasattr(obj, '_removed_relations'):
+ [self.delete_object(item) for item in obj._removed_relations]
+
if getattr(obj, '_m2m_data', None):
for accessor_name, object_list in obj._m2m_data.items():
setattr(obj, accessor_name, object_list)
@@ -773,18 +794,17 @@ class ModelSerializer(Serializer):
for accessor_name, related in obj._related_data.items():
field = self.fields.get(accessor_name, None)
if isinstance(field, Serializer):
- # TODO: Following will be needed for reverse FK
- # if field.many:
- # # Nested reverse fk relationship
- # for related_item in related:
- # fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name
- # setattr(related_item, fk_field, obj)
- # self.save_object(related_item)
- # else:
- # Nested reverse one-one relationship
- fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name
- setattr(related, fk_field, obj)
- self.save_object(related)
+ if field.many:
+ # Nested reverse fk relationship
+ for related_item in related:
+ fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name
+ setattr(related_item, fk_field, obj)
+ self.save_object(related_item)
+ else:
+ # Nested reverse one-one relationship
+ fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name
+ setattr(related, fk_field, obj)
+ self.save_object(related)
else:
# Reverse FK or reverse one-one
setattr(obj, accessor_name, related)
diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/relations_nested.py
index e7af6565..20683d4a 100644
--- a/rest_framework/tests/relations_nested.py
+++ b/rest_framework/tests/relations_nested.py
@@ -13,6 +13,15 @@ class OneToOneSource(models.Model):
target = models.OneToOneField(OneToOneTarget, related_name='source')
+class OneToManyTarget(models.Model):
+ name = models.CharField(max_length=100)
+
+
+class OneToManySource(models.Model):
+ name = models.CharField(max_length=100)
+ target = models.ForeignKey(OneToManyTarget, related_name='sources')
+
+
class ReverseNestedOneToOneTests(TestCase):
def setUp(self):
class OneToOneSourceSerializer(serializers.ModelSerializer):
@@ -189,3 +198,92 @@ class ForwardNestedOneToOneTests(TestCase):
# {'id': 3, 'name': 'target-3', 'source': None}
# ]
# self.assertEqual(serializer.data, expected)
+
+
+class ReverseNestedOneToManyTests(TestCase):
+ def setUp(self):
+ class OneToManySourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = OneToManySource
+ fields = ('id', 'name')
+
+ class OneToManyTargetSerializer(serializers.ModelSerializer):
+ sources = OneToManySourceSerializer(many=True, allow_add_remove=True)
+
+ class Meta:
+ model = OneToManyTarget
+ fields = ('id', 'name', 'sources')
+
+ self.Serializer = OneToManyTargetSerializer
+
+ target = OneToManyTarget(name='target-1')
+ target.save()
+ for idx in range(1, 4):
+ source = OneToManySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_one_to_many_retrieve(self):
+ queryset = OneToManyTarget.objects.all()
+ serializer = self.Serializer(queryset)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'}]},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_one_to_many_create(self):
+ data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'},
+ {'id': 4, 'name': 'source-4'}]}
+ instance = OneToManyTarget.objects.get(pk=1)
+ serializer = self.Serializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-1')
+
+ # Ensure source 4 is added, and everything else is as
+ # expected.
+ queryset = OneToManyTarget.objects.all()
+ serializer = self.Serializer(queryset)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'},
+ {'id': 4, 'name': 'source-4'}]}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_one_to_many_create_with_invalid_data(self):
+ data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'},
+ {'id': 4}]}
+ serializer = self.Serializer(data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'sources': [{}, {}, {}, {'name': ['This field is required.']}]})
+
+ def test_one_to_many_update(self):
+ data = {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'}]}
+ instance = OneToManyTarget.objects.get(pk=1)
+ serializer = self.Serializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-1-updated')
+
+ # Ensure (target 1, source 1) are updated,
+ # and everything else is as expected.
+ queryset = OneToManyTarget.objects.all()
+ serializer = self.Serializer(queryset)
+ expected = [
+ {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'},
+ {'id': 2, 'name': 'source-2'},
+ {'id': 3, 'name': 'source-3'}]}
+
+ ]
+ self.assertEqual(serializer.data, expected)
diff --git a/rest_framework/tests/serializer_bulk_update.py b/rest_framework/tests/serializer_bulk_update.py
index afc1a1a9..5328e733 100644
--- a/rest_framework/tests/serializer_bulk_update.py
+++ b/rest_framework/tests/serializer_bulk_update.py
@@ -201,7 +201,7 @@ class BulkUpdateSerializerTests(TestCase):
'author': 'Haruki Murakami'
}
]
- serializer = self.BookSerializer(self.books(), data=data, many=True, allow_delete=True)
+ serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True)
self.assertEqual(serializer.is_valid(), True)
self.assertEqual(serializer.data, data)
serializer.save()
@@ -223,7 +223,7 @@ class BulkUpdateSerializerTests(TestCase):
'author': 'Haruki Murakami'
}
]
- serializer = self.BookSerializer(self.books(), data=data, many=True, allow_delete=True)
+ serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True)
self.assertEqual(serializer.is_valid(), True)
self.assertEqual(serializer.data, data)
serializer.save()
@@ -249,6 +249,6 @@ class BulkUpdateSerializerTests(TestCase):
{},
{'id': ['Enter a whole number.']}
]
- serializer = self.BookSerializer(self.books(), data=data, many=True, allow_delete=True)
+ serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True)
self.assertEqual(serializer.is_valid(), False)
self.assertEqual(serializer.errors, expected_errors)