From 72c04d570d167209f3f34d6d78492426f206b245 Mon Sep 17 00:00:00 2001 From: Mark Aaron Shirley Date: Fri, 4 Jan 2013 12:50:01 +0100 Subject: Add nested create for 1to1 reverse relationships --- rest_framework/serializers.py | 46 +++++++++++++++++++++++++++++++++++-------- 1 file changed, 38 insertions(+), 8 deletions(-) (limited to 'rest_framework/serializers.py') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 27458f96..a43a81d7 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -93,7 +93,7 @@ class SerializerOptions(object): self.exclude = getattr(meta, 'exclude', ()) -class BaseSerializer(Field): +class BaseSerializer(WritableField): class Meta(object): pass @@ -218,7 +218,10 @@ class BaseSerializer(Field): try: field.field_from_native(data, files, field_name, reverted_data) except ValidationError as err: - self._errors[field_name] = list(err.messages) + if hasattr(err, 'message_dict'): + self._errors[field_name] = [err.message_dict] + else: + self._errors[field_name] = list(err.messages) return reverted_data @@ -369,6 +372,25 @@ class ModelSerializer(Serializer): """ _options_class = ModelSerializerOptions + def field_from_native(self, data, files, field_name, into): + if self.read_only: + return + + try: + native = data[field_name] + except KeyError: + if self.required: + raise ValidationError(self.error_messages['required']) + return + + obj = self.from_native(native, files) + if not self._errors: + self.object = obj + into[self.source or field_name] = self + else: + # Propagate errors up to our parent + raise ValidationError(self._errors) + def get_default_fields(self): """ Return all the fields that should be serialized for the model. @@ -542,10 +564,9 @@ class ModelSerializer(Serializer): return instance - def save(self): - """ - Save the deserialized object and return it. - """ + def _save(self, parent=None, fk_field=None): + if parent and fk_field: + setattr(self.object, fk_field, parent) self.object.save() if getattr(self, 'm2m_data', None): @@ -555,9 +576,18 @@ class ModelSerializer(Serializer): if getattr(self, 'related_data', None): for accessor_name, object_list in self.related_data.items(): - setattr(self.object, accessor_name, object_list) + if isinstance(object_list, ModelSerializer): + fk_field = self.object._meta.get_field_by_name(accessor_name)[0].field.name + object_list._save(parent=self.object, fk_field=fk_field) + else: + setattr(self.object, accessor_name, object_list) self.related_data = {} - + + def save(self): + """ + Save the deserialized object and return it. + """ + self._save() return self.object -- cgit v1.2.3