diff options
Diffstat (limited to 'rest_framework/serializers.py')
| -rw-r--r-- | rest_framework/serializers.py | 77 |
1 files changed, 63 insertions, 14 deletions
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 9e3881a2..9c27717f 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -6,8 +6,8 @@ form encoded input. Serialization in REST framework is a two-phase process: 1. Serializers marshal between complex types like model instances, and -python primatives. -2. The process of marshalling between python primatives and request and +python primitives. +2. The process of marshalling between python primitives and request and response content is handled by parsers and renderers. """ from __future__ import unicode_literals @@ -31,9 +31,17 @@ from rest_framework.relations import * from rest_framework.fields import * +def pretty_name(name): + """Converts 'first_name' to 'First name'""" + if not name: + return '' + return name.replace('_', ' ').capitalize() + + class RelationsList(list): _deleted = [] + class NestedValidationError(ValidationError): """ The default ValidationError behavior is to stringify each item in the list @@ -48,9 +56,13 @@ class NestedValidationError(ValidationError): def __init__(self, message): if isinstance(message, dict): - self.messages = [message] + self._messages = [message] else: - self.messages = message + self._messages = message + + @property + def messages(self): + return self._messages class DictWithMetadata(dict): @@ -254,10 +266,13 @@ class BaseSerializer(WritableField): for field_name, field in self.fields.items(): if field_name in self._errors: continue + + source = field.source or field_name + if self.partial and source not in attrs: + continue try: validate_method = getattr(self, 'validate_%s' % field_name, None) if validate_method: - source = field.source or field_name attrs = validate_method(attrs, source) except ValidationError as err: self._errors[field_name] = self._errors.get(field_name, []) + list(err.messages) @@ -300,14 +315,19 @@ class BaseSerializer(WritableField): """ ret = self._dict_class() ret.fields = self._dict_class() - ret.empty = obj is None for field_name, field in self.fields.items(): + if field.read_only and obj is None: + continue field.initialize(parent=self, field_name=field_name) key = self.get_field_key(field_name) value = field.field_to_native(obj, field_name) + method = getattr(self, 'transform_%s' % field_name, None) + if callable(method): + value = method(obj, value) ret[key] = value - ret.fields[key] = field + ret.fields[key] = self.augment_field(field, field_name, key, value) + return ret def from_native(self, data, files): @@ -315,6 +335,7 @@ class BaseSerializer(WritableField): Deserialize primitives -> objects. """ self._errors = {} + if data is not None or files is not None: attrs = self.restore_fields(data, files) if attrs is not None: @@ -325,6 +346,15 @@ class BaseSerializer(WritableField): if not self._errors: return self.restore_object(attrs, instance=getattr(self, 'object', None)) + def augment_field(self, field, field_name, key, value): + # This horrible stuff is to manage serializers rendering to HTML + field._errors = self._errors.get(key) if self._errors else None + field._name = field_name + field._value = self.init_data.get(key) if self._errors and self.init_data else value + if not field.label: + field.label = pretty_name(key) + return field + def field_to_native(self, obj, field_name): """ Override default so that the serializer can be used as a nested field @@ -375,8 +405,14 @@ class BaseSerializer(WritableField): return # Set the serializer object if it exists - obj = getattr(self.parent.object, field_name) if self.parent.object else None - obj = obj.all() if is_simple_callable(getattr(obj, 'all', None)) else obj + obj = get_component(self.parent.object, self.source or field_name) if self.parent.object else None + + # If we have a model manager or similar object then we need + # to iterate through each instance. + if (self.many and + not hasattr(obj, '__iter__') and + is_simple_callable(getattr(obj, 'all', None))): + obj = obj.all() if self.source == '*': if value: @@ -503,6 +539,9 @@ class BaseSerializer(WritableField): """ Save the deserialized object and return it. """ + # Clear cached _data, which may be invalidated by `save()` + self._data = None + if isinstance(self.object, list): [self.save_object(item, **kwargs) for item in self.object] @@ -751,6 +790,8 @@ class ModelSerializer(Serializer): # TODO: TypedChoiceField? if model_field.flatchoices: # This ModelField contains choices kwargs['choices'] = model_field.flatchoices + if model_field.null: + kwargs['empty'] = None return ChoiceField(**kwargs) # put this below the ChoiceField because min_value isn't a valid initializer @@ -822,13 +863,13 @@ class ModelSerializer(Serializer): # Reverse fk or one-to-one relations for (obj, model) in meta.get_all_related_objects_with_model(): - field_name = obj.field.related_query_name() + field_name = obj.get_accessor_name() if field_name in attrs: related_data[field_name] = attrs.pop(field_name) # Reverse m2m relations for (obj, model) in meta.get_all_related_m2m_objects_with_model(): - field_name = obj.field.related_query_name() + field_name = obj.get_accessor_name() if field_name in attrs: m2m_data[field_name] = attrs.pop(field_name) @@ -846,7 +887,10 @@ class ModelSerializer(Serializer): # Update an existing instance... if instance is not None: for key, val in attrs.items(): - setattr(instance, key, val) + try: + setattr(instance, key, val) + except ValueError: + self._errors[key] = self.error_messages['required'] # ...or create a new instance else: @@ -872,7 +916,7 @@ class ModelSerializer(Serializer): def save_object(self, obj, **kwargs): """ - Save the deserialized object and return it. + Save the deserialized object. """ if getattr(obj, '_nested_forward_relations', None): # Nested relationships need to be saved before we can save the @@ -890,11 +934,16 @@ class ModelSerializer(Serializer): del(obj._m2m_data) if getattr(obj, '_related_data', None): + related_fields = dict([ + (field.get_accessor_name(), field) + for field, model + in obj._meta.get_all_related_objects_with_model() + ]) for accessor_name, related in obj._related_data.items(): if isinstance(related, RelationsList): # Nested reverse fk relationship for related_item in related: - fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name + fk_field = related_fields[accessor_name].field.name setattr(related_item, fk_field, obj) self.save_object(related_item) |
