diff options
| author | Mark Aaron Shirley | 2013-01-04 12:50:01 +0100 | 
|---|---|---|
| committer | Mark Aaron Shirley | 2013-01-16 16:04:19 -0800 | 
| commit | 72c04d570d167209f3f34d6d78492426f206b245 (patch) | |
| tree | 20802be8afbc8846940cd91a647d7ae5fb5b39fc /rest_framework/serializers.py | |
| parent | 0f0a07b732a4bd90957c08b01d51e70c7e739d5d (diff) | |
| download | django-rest-framework-72c04d570d167209f3f34d6d78492426f206b245.tar.bz2 | |
Add nested create for 1to1 reverse relationships
Diffstat (limited to 'rest_framework/serializers.py')
| -rw-r--r-- | rest_framework/serializers.py | 46 | 
1 files changed, 38 insertions, 8 deletions
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 27458f96..a43a81d7 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -93,7 +93,7 @@ class SerializerOptions(object):          self.exclude = getattr(meta, 'exclude', ()) -class BaseSerializer(Field): +class BaseSerializer(WritableField):      class Meta(object):          pass @@ -218,7 +218,10 @@ class BaseSerializer(Field):              try:                  field.field_from_native(data, files, field_name, reverted_data)              except ValidationError as err: -                self._errors[field_name] = list(err.messages) +                if hasattr(err, 'message_dict'): +                    self._errors[field_name] = [err.message_dict] +                else: +                    self._errors[field_name] = list(err.messages)          return reverted_data @@ -369,6 +372,25 @@ class ModelSerializer(Serializer):      """      _options_class = ModelSerializerOptions +    def field_from_native(self, data, files, field_name, into): +        if self.read_only: +            return + +        try: +            native = data[field_name] +        except KeyError: +            if self.required: +                raise ValidationError(self.error_messages['required']) +            return + +        obj = self.from_native(native, files) +        if not self._errors: +            self.object = obj +            into[self.source or field_name] = self +        else: +            # Propagate errors up to our parent +            raise ValidationError(self._errors) +      def get_default_fields(self):          """          Return all the fields that should be serialized for the model. @@ -542,10 +564,9 @@ class ModelSerializer(Serializer):          return instance -    def save(self): -        """ -        Save the deserialized object and return it. -        """ +    def _save(self, parent=None, fk_field=None): +        if parent and fk_field: +            setattr(self.object, fk_field, parent)          self.object.save()          if getattr(self, 'm2m_data', None): @@ -555,9 +576,18 @@ class ModelSerializer(Serializer):          if getattr(self, 'related_data', None):              for accessor_name, object_list in self.related_data.items(): -                setattr(self.object, accessor_name, object_list) +                if isinstance(object_list, ModelSerializer): +                    fk_field = self.object._meta.get_field_by_name(accessor_name)[0].field.name +                    object_list._save(parent=self.object, fk_field=fk_field) +                else: +                    setattr(self.object, accessor_name, object_list)              self.related_data = {} - +             +    def save(self): +        """ +        Save the deserialized object and return it. +        """ +        self._save()          return self.object  | 
