aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework/serializers.py
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework/serializers.py')
-rw-r--r--rest_framework/serializers.py77
1 files changed, 63 insertions, 14 deletions
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 9e3881a2..9c27717f 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -6,8 +6,8 @@ form encoded input.
Serialization in REST framework is a two-phase process:
1. Serializers marshal between complex types like model instances, and
-python primatives.
-2. The process of marshalling between python primatives and request and
+python primitives.
+2. The process of marshalling between python primitives and request and
response content is handled by parsers and renderers.
"""
from __future__ import unicode_literals
@@ -31,9 +31,17 @@ from rest_framework.relations import *
from rest_framework.fields import *
+def pretty_name(name):
+ """Converts 'first_name' to 'First name'"""
+ if not name:
+ return ''
+ return name.replace('_', ' ').capitalize()
+
+
class RelationsList(list):
_deleted = []
+
class NestedValidationError(ValidationError):
"""
The default ValidationError behavior is to stringify each item in the list
@@ -48,9 +56,13 @@ class NestedValidationError(ValidationError):
def __init__(self, message):
if isinstance(message, dict):
- self.messages = [message]
+ self._messages = [message]
else:
- self.messages = message
+ self._messages = message
+
+ @property
+ def messages(self):
+ return self._messages
class DictWithMetadata(dict):
@@ -254,10 +266,13 @@ class BaseSerializer(WritableField):
for field_name, field in self.fields.items():
if field_name in self._errors:
continue
+
+ source = field.source or field_name
+ if self.partial and source not in attrs:
+ continue
try:
validate_method = getattr(self, 'validate_%s' % field_name, None)
if validate_method:
- source = field.source or field_name
attrs = validate_method(attrs, source)
except ValidationError as err:
self._errors[field_name] = self._errors.get(field_name, []) + list(err.messages)
@@ -300,14 +315,19 @@ class BaseSerializer(WritableField):
"""
ret = self._dict_class()
ret.fields = self._dict_class()
- ret.empty = obj is None
for field_name, field in self.fields.items():
+ if field.read_only and obj is None:
+ continue
field.initialize(parent=self, field_name=field_name)
key = self.get_field_key(field_name)
value = field.field_to_native(obj, field_name)
+ method = getattr(self, 'transform_%s' % field_name, None)
+ if callable(method):
+ value = method(obj, value)
ret[key] = value
- ret.fields[key] = field
+ ret.fields[key] = self.augment_field(field, field_name, key, value)
+
return ret
def from_native(self, data, files):
@@ -315,6 +335,7 @@ class BaseSerializer(WritableField):
Deserialize primitives -> objects.
"""
self._errors = {}
+
if data is not None or files is not None:
attrs = self.restore_fields(data, files)
if attrs is not None:
@@ -325,6 +346,15 @@ class BaseSerializer(WritableField):
if not self._errors:
return self.restore_object(attrs, instance=getattr(self, 'object', None))
+ def augment_field(self, field, field_name, key, value):
+ # This horrible stuff is to manage serializers rendering to HTML
+ field._errors = self._errors.get(key) if self._errors else None
+ field._name = field_name
+ field._value = self.init_data.get(key) if self._errors and self.init_data else value
+ if not field.label:
+ field.label = pretty_name(key)
+ return field
+
def field_to_native(self, obj, field_name):
"""
Override default so that the serializer can be used as a nested field
@@ -375,8 +405,14 @@ class BaseSerializer(WritableField):
return
# Set the serializer object if it exists
- obj = getattr(self.parent.object, field_name) if self.parent.object else None
- obj = obj.all() if is_simple_callable(getattr(obj, 'all', None)) else obj
+ obj = get_component(self.parent.object, self.source or field_name) if self.parent.object else None
+
+ # If we have a model manager or similar object then we need
+ # to iterate through each instance.
+ if (self.many and
+ not hasattr(obj, '__iter__') and
+ is_simple_callable(getattr(obj, 'all', None))):
+ obj = obj.all()
if self.source == '*':
if value:
@@ -503,6 +539,9 @@ class BaseSerializer(WritableField):
"""
Save the deserialized object and return it.
"""
+ # Clear cached _data, which may be invalidated by `save()`
+ self._data = None
+
if isinstance(self.object, list):
[self.save_object(item, **kwargs) for item in self.object]
@@ -751,6 +790,8 @@ class ModelSerializer(Serializer):
# TODO: TypedChoiceField?
if model_field.flatchoices: # This ModelField contains choices
kwargs['choices'] = model_field.flatchoices
+ if model_field.null:
+ kwargs['empty'] = None
return ChoiceField(**kwargs)
# put this below the ChoiceField because min_value isn't a valid initializer
@@ -822,13 +863,13 @@ class ModelSerializer(Serializer):
# Reverse fk or one-to-one relations
for (obj, model) in meta.get_all_related_objects_with_model():
- field_name = obj.field.related_query_name()
+ field_name = obj.get_accessor_name()
if field_name in attrs:
related_data[field_name] = attrs.pop(field_name)
# Reverse m2m relations
for (obj, model) in meta.get_all_related_m2m_objects_with_model():
- field_name = obj.field.related_query_name()
+ field_name = obj.get_accessor_name()
if field_name in attrs:
m2m_data[field_name] = attrs.pop(field_name)
@@ -846,7 +887,10 @@ class ModelSerializer(Serializer):
# Update an existing instance...
if instance is not None:
for key, val in attrs.items():
- setattr(instance, key, val)
+ try:
+ setattr(instance, key, val)
+ except ValueError:
+ self._errors[key] = self.error_messages['required']
# ...or create a new instance
else:
@@ -872,7 +916,7 @@ class ModelSerializer(Serializer):
def save_object(self, obj, **kwargs):
"""
- Save the deserialized object and return it.
+ Save the deserialized object.
"""
if getattr(obj, '_nested_forward_relations', None):
# Nested relationships need to be saved before we can save the
@@ -890,11 +934,16 @@ class ModelSerializer(Serializer):
del(obj._m2m_data)
if getattr(obj, '_related_data', None):
+ related_fields = dict([
+ (field.get_accessor_name(), field)
+ for field, model
+ in obj._meta.get_all_related_objects_with_model()
+ ])
for accessor_name, related in obj._related_data.items():
if isinstance(related, RelationsList):
# Nested reverse fk relationship
for related_item in related:
- fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name
+ fk_field = related_fields[accessor_name].field.name
setattr(related_item, fk_field, obj)
self.save_object(related_item)