aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework/serializers.py
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework/serializers.py')
-rw-r--r--rest_framework/serializers.py65
1 files changed, 53 insertions, 12 deletions
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 31cfa344..fde06d83 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -32,6 +32,9 @@ from rest_framework.relations import *
from rest_framework.fields import *
+class RelationsList(list):
+ _deleted = []
+
class NestedValidationError(ValidationError):
"""
The default ValidationError behavior is to stringify each item in the list
@@ -161,7 +164,6 @@ class BaseSerializer(WritableField):
self._data = None
self._files = None
self._errors = None
- self._deleted = None
if many and instance is not None and not hasattr(instance, '__iter__'):
raise ValueError('instance should be a queryset or other iterable with many=True')
@@ -336,9 +338,9 @@ class BaseSerializer(WritableField):
value = obj
for component in source.split('.'):
- value = get_component(value, component)
if value is None:
- break
+ return self.to_native(None)
+ value = get_component(value, component)
except ObjectDoesNotExist:
return None
@@ -378,6 +380,7 @@ class BaseSerializer(WritableField):
# Set the serializer object if it exists
obj = getattr(self.parent.object, field_name) if self.parent.object else None
+ obj = obj.all() if is_simple_callable(getattr(obj, 'all', None)) else obj
if self.source == '*':
if value:
@@ -391,7 +394,8 @@ class BaseSerializer(WritableField):
'data': value,
'context': self.context,
'partial': self.partial,
- 'many': self.many
+ 'many': self.many,
+ 'allow_add_remove': self.allow_add_remove
}
serializer = self.__class__(**kwargs)
@@ -434,7 +438,7 @@ class BaseSerializer(WritableField):
DeprecationWarning, stacklevel=3)
if many:
- ret = []
+ ret = RelationsList()
errors = []
update = self.object is not None
@@ -461,8 +465,8 @@ class BaseSerializer(WritableField):
ret.append(self.from_native(item, None))
errors.append(self._errors)
- if update:
- self._deleted = identity_to_objects.values()
+ if update and self.allow_add_remove:
+ ret._deleted = identity_to_objects.values()
self._errors = any(errors) and errors or []
else:
@@ -514,12 +518,12 @@ class BaseSerializer(WritableField):
"""
if isinstance(self.object, list):
[self.save_object(item, **kwargs) for item in self.object]
+
+ if self.object._deleted:
+ [self.delete_object(item) for item in self.object._deleted]
else:
self.save_object(self.object, **kwargs)
- if self.allow_add_remove and self._deleted:
- [self.delete_object(item) for item in self._deleted]
-
return self.object
def metadata(self):
@@ -795,9 +799,12 @@ class ModelSerializer(Serializer):
cls = self.opts.model
opts = get_concrete_model(cls)._meta
exclusions = [field.name for field in opts.fields + opts.many_to_many]
+
for field_name, field in self.fields.items():
field_name = field.source or field_name
- if field_name in exclusions and not field.read_only:
+ if field_name in exclusions \
+ and not field.read_only \
+ and not isinstance(field, Serializer):
exclusions.remove(field_name)
return exclusions
@@ -823,6 +830,7 @@ class ModelSerializer(Serializer):
"""
m2m_data = {}
related_data = {}
+ nested_forward_relations = {}
meta = self.opts.model._meta
# Reverse fk or one-to-one relations
@@ -842,6 +850,12 @@ class ModelSerializer(Serializer):
if field.name in attrs:
m2m_data[field.name] = attrs.pop(field.name)
+ # Nested forward relations - These need to be marked so we can save
+ # them before saving the parent model instance.
+ for field_name in attrs.keys():
+ if isinstance(self.fields.get(field_name, None), Serializer):
+ nested_forward_relations[field_name] = attrs[field_name]
+
# Update an existing instance...
if instance is not None:
for key, val in attrs.items():
@@ -857,6 +871,7 @@ class ModelSerializer(Serializer):
# at the point of save.
instance._related_data = related_data
instance._m2m_data = m2m_data
+ instance._nested_forward_relations = nested_forward_relations
return instance
@@ -872,6 +887,14 @@ class ModelSerializer(Serializer):
"""
Save the deserialized object and return it.
"""
+ if getattr(obj, '_nested_forward_relations', None):
+ # Nested relationships need to be saved before we can save the
+ # parent instance.
+ for field_name, sub_object in obj._nested_forward_relations.items():
+ if sub_object:
+ self.save_object(sub_object)
+ setattr(obj, field_name, sub_object)
+
obj.save(**kwargs)
if getattr(obj, '_m2m_data', None):
@@ -881,7 +904,25 @@ class ModelSerializer(Serializer):
if getattr(obj, '_related_data', None):
for accessor_name, related in obj._related_data.items():
- setattr(obj, accessor_name, related)
+ if isinstance(related, RelationsList):
+ # Nested reverse fk relationship
+ for related_item in related:
+ fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name
+ setattr(related_item, fk_field, obj)
+ self.save_object(related_item)
+
+ # Delete any removed objects
+ if related._deleted:
+ [self.delete_object(item) for item in related._deleted]
+
+ elif isinstance(related, models.Model):
+ # Nested reverse one-one relationship
+ fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name
+ setattr(related, fk_field, obj)
+ self.save_object(related)
+ else:
+ # Reverse FK or reverse one-one
+ setattr(obj, accessor_name, related)
del(obj._related_data)