aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework/serializers.py
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework/serializers.py')
-rw-r--r--rest_framework/serializers.py217
1 files changed, 175 insertions, 42 deletions
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 4fe857a6..1b2b0821 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -20,6 +20,25 @@ from rest_framework.relations import *
from rest_framework.fields import *
+class NestedValidationError(ValidationError):
+ """
+ The default ValidationError behavior is to stringify each item in the list
+ if the messages are a list of error messages.
+
+ In the case of nested serializers, where the parent has many children,
+ then the child's `serializer.errors` will be a list of dicts. In the case
+ of a single child, the `serializer.errors` will be a dict.
+
+ We need to override the default behavior to get properly nested error dicts.
+ """
+
+ def __init__(self, message):
+ if isinstance(message, dict):
+ self.messages = [message]
+ else:
+ self.messages = message
+
+
class DictWithMetadata(dict):
"""
A dict-like object, that can have additional properties attached.
@@ -98,7 +117,7 @@ class SerializerOptions(object):
self.exclude = getattr(meta, 'exclude', ())
-class BaseSerializer(Field):
+class BaseSerializer(WritableField):
"""
This is the Serializer implementation.
We need to implement it as `BaseSerializer` due to metaclass magicks.
@@ -110,13 +129,15 @@ class BaseSerializer(Field):
_dict_class = SortedDictWithMetadata
def __init__(self, instance=None, data=None, files=None,
- context=None, partial=False, many=None, source=None):
- super(BaseSerializer, self).__init__(source=source)
+ context=None, partial=False, many=None,
+ allow_add_remove=False, **kwargs):
+ super(BaseSerializer, self).__init__(**kwargs)
self.opts = self._options_class(self.Meta)
self.parent = None
self.root = None
self.partial = partial
self.many = many
+ self.allow_add_remove = allow_add_remove
self.context = context or {}
@@ -128,6 +149,13 @@ class BaseSerializer(Field):
self._data = None
self._files = None
self._errors = None
+ self._deleted = None
+
+ if many and instance is not None and not hasattr(instance, '__iter__'):
+ raise ValueError('instance should be a queryset or other iterable with many=True')
+
+ if allow_add_remove and not many:
+ raise ValueError('allow_add_remove should only be used for bulk updates, but you have not set many=True')
#####
# Methods to determine which fields to use when (de)serializing objects.
@@ -296,40 +324,91 @@ class BaseSerializer(Field):
def field_to_native(self, obj, field_name):
"""
- Override default so that we can apply ModelSerializer as a nested
- field to relationships.
+ Override default so that the serializer can be used as a nested field
+ across relationships.
"""
if self.source == '*':
return self.to_native(obj)
try:
- if self.source:
- for component in self.source.split('.'):
- obj = getattr(obj, component)
- if is_simple_callable(obj):
- obj = obj()
- else:
- obj = getattr(obj, field_name)
- if is_simple_callable(obj):
- obj = obj()
+ source = self.source or field_name
+ value = obj
+
+ for component in source.split('.'):
+ value = get_component(value, component)
+ if value is None:
+ break
except ObjectDoesNotExist:
return None
- # If the object has an "all" method, assume it's a relationship
- if is_simple_callable(getattr(obj, 'all', None)):
- return [self.to_native(item) for item in obj.all()]
+ if is_simple_callable(getattr(value, 'all', None)):
+ return [self.to_native(item) for item in value.all()]
- if obj is None:
+ if value is None:
return None
if self.many is not None:
many = self.many
else:
- many = hasattr(obj, '__iter__') and not isinstance(obj, (Page, dict, six.text_type))
+ many = hasattr(value, '__iter__') and not isinstance(value, (Page, dict, six.text_type))
if many:
- return [self.to_native(item) for item in obj]
- return self.to_native(obj)
+ return [self.to_native(item) for item in value]
+ return self.to_native(value)
+
+ def field_from_native(self, data, files, field_name, into):
+ """
+ Override default so that the serializer can be used as a writable
+ nested field across relationships.
+ """
+ if self.read_only:
+ return
+
+ try:
+ value = data[field_name]
+ except KeyError:
+ if self.default is not None and not self.partial:
+ # Note: partial updates shouldn't set defaults
+ value = copy.deepcopy(self.default)
+ else:
+ if self.required:
+ raise ValidationError(self.error_messages['required'])
+ return
+
+ # Set the serializer object if it exists
+ obj = getattr(self.parent.object, field_name) if self.parent.object else None
+
+ if value in (None, ''):
+ into[(self.source or field_name)] = None
+ else:
+ kwargs = {
+ 'instance': obj,
+ 'data': value,
+ 'context': self.context,
+ 'partial': self.partial,
+ 'many': self.many
+ }
+ serializer = self.__class__(**kwargs)
+
+ if serializer.is_valid():
+ into[self.source or field_name] = serializer.object
+ else:
+ # Propagate errors up to our parent
+ raise NestedValidationError(serializer.errors)
+
+ def get_identity(self, data):
+ """
+ This hook is required for bulk update.
+ It is used to determine the canonical identity of a given object.
+
+ Note that the data has not been validated at this point, so we need
+ to make sure that we catch any cases of incorrect datatypes being
+ passed to this method.
+ """
+ try:
+ return data.get('id', None)
+ except AttributeError:
+ return None
@property
def errors(self):
@@ -352,10 +431,37 @@ class BaseSerializer(Field):
if many:
ret = []
errors = []
- for item in data:
- ret.append(self.from_native(item, None))
- errors.append(self._errors)
- self._errors = any(errors) and errors or []
+ update = self.object is not None
+
+ if update:
+ # If this is a bulk update we need to map all the objects
+ # to a canonical identity so we can determine which
+ # individual object is being updated for each item in the
+ # incoming data
+ objects = self.object
+ identities = [self.get_identity(self.to_native(obj)) for obj in objects]
+ identity_to_objects = dict(zip(identities, objects))
+
+ if hasattr(data, '__iter__') and not isinstance(data, (dict, six.text_type)):
+ for item in data:
+ if update:
+ # Determine which object we're updating
+ identity = self.get_identity(item)
+ self.object = identity_to_objects.pop(identity, None)
+ if self.object is None and not self.allow_add_remove:
+ ret.append(None)
+ errors.append({'non_field_errors': ['Cannot create a new item, only existing items may be updated.']})
+ continue
+
+ ret.append(self.from_native(item, None))
+ errors.append(self._errors)
+
+ if update:
+ self._deleted = identity_to_objects.values()
+
+ self._errors = any(errors) and errors or []
+ else:
+ self._errors = {'non_field_errors': ['Expected a list of items.']}
else:
ret = self.from_native(data, files)
@@ -394,6 +500,9 @@ class BaseSerializer(Field):
def save_object(self, obj, **kwargs):
obj.save(**kwargs)
+ def delete_object(self, obj):
+ obj.delete()
+
def save(self, **kwargs):
"""
Save the deserialized object and return it.
@@ -402,6 +511,10 @@ class BaseSerializer(Field):
[self.save_object(item, **kwargs) for item in self.object]
else:
self.save_object(self.object, **kwargs)
+
+ if self.allow_add_remove and self._deleted:
+ [self.delete_object(item) for item in self._deleted]
+
return self.object
@@ -584,33 +697,43 @@ class ModelSerializer(Serializer):
"""
Restore the model instance.
"""
- self.m2m_data = {}
- self.related_data = {}
+ m2m_data = {}
+ related_data = {}
+ meta = self.opts.model._meta
- # Reverse fk relations
- for (obj, model) in self.opts.model._meta.get_all_related_objects_with_model():
+ # Reverse fk or one-to-one relations
+ for (obj, model) in meta.get_all_related_objects_with_model():
field_name = obj.field.related_query_name()
if field_name in attrs:
- self.related_data[field_name] = attrs.pop(field_name)
+ related_data[field_name] = attrs.pop(field_name)
# Reverse m2m relations
- for (obj, model) in self.opts.model._meta.get_all_related_m2m_objects_with_model():
+ for (obj, model) in meta.get_all_related_m2m_objects_with_model():
field_name = obj.field.related_query_name()
if field_name in attrs:
- self.m2m_data[field_name] = attrs.pop(field_name)
+ m2m_data[field_name] = attrs.pop(field_name)
# Forward m2m relations
- for field in self.opts.model._meta.many_to_many:
+ for field in meta.many_to_many:
if field.name in attrs:
- self.m2m_data[field.name] = attrs.pop(field.name)
+ m2m_data[field.name] = attrs.pop(field.name)
+ # Update an existing instance...
if instance is not None:
for key, val in attrs.items():
setattr(instance, key, val)
+ # ...or create a new instance
else:
instance = self.opts.model(**attrs)
+ # Any relations that cannot be set until we've
+ # saved the model get hidden away on these
+ # private attributes, so we can deal with them
+ # at the point of save.
+ instance._related_data = related_data
+ instance._m2m_data = m2m_data
+
return instance
def from_native(self, data, files):
@@ -627,15 +750,15 @@ class ModelSerializer(Serializer):
"""
obj.save(**kwargs)
- if getattr(self, 'm2m_data', None):
- for accessor_name, object_list in self.m2m_data.items():
- setattr(self.object, accessor_name, object_list)
- self.m2m_data = {}
+ if getattr(obj, '_m2m_data', None):
+ for accessor_name, object_list in obj._m2m_data.items():
+ setattr(obj, accessor_name, object_list)
+ del(obj._m2m_data)
- if getattr(self, 'related_data', None):
- for accessor_name, object_list in self.related_data.items():
- setattr(self.object, accessor_name, object_list)
- self.related_data = {}
+ if getattr(obj, '_related_data', None):
+ for accessor_name, related in obj._related_data.items():
+ setattr(obj, accessor_name, related)
+ del(obj._related_data)
class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
@@ -690,3 +813,13 @@ class HyperlinkedModelSerializer(ModelSerializer):
'many': to_many
}
return HyperlinkedRelatedField(**kwargs)
+
+ def get_identity(self, data):
+ """
+ This hook is required for bulk update.
+ We need to override the default, to use the url as the identity.
+ """
+ try:
+ return data.get('url', None)
+ except AttributeError:
+ return None