diff options
| author | Tom Christie | 2014-10-08 12:17:30 +0100 |
|---|---|---|
| committer | Tom Christie | 2014-10-08 12:17:30 +0100 |
| commit | 0cbb57b40fdb073c7ca09c9d1078926260c646db (patch) | |
| tree | ecb869965a0fb31428382d0f2ab60855e8f97268 /rest_framework | |
| parent | af0f01c5b6597fe2f146268f7632f7e3954d17c2 (diff) | |
| download | django-rest-framework-0cbb57b40fdb073c7ca09c9d1078926260c646db.tar.bz2 | |
Tweak pre/post save hooks. Return instance in .update().
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/mixins.py | 13 | ||||
| -rw-r--r-- | rest_framework/serializers.py | 24 |
2 files changed, 22 insertions, 15 deletions
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 03ebb034..4c62debb 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -20,11 +20,11 @@ class CreateModelMixin(object): def create(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) - self.create_valid(serializer) + self.perform_create(serializer) headers = self.get_success_headers(serializer.data) return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) - def create_valid(self, serializer): + def perform_create(self, serializer): serializer.save() def get_success_headers(self, data): @@ -67,10 +67,10 @@ class UpdateModelMixin(object): instance = self.get_object() serializer = self.get_serializer(instance, data=request.data, partial=partial) serializer.is_valid(raise_exception=True) - self.update_valid(serializer) + self.preform_update(serializer) return Response(serializer.data) - def update_valid(self, serializer): + def preform_update(self, serializer): serializer.save() def partial_update(self, request, *args, **kwargs): @@ -84,9 +84,12 @@ class DestroyModelMixin(object): """ def destroy(self, request, *args, **kwargs): instance = self.get_object() - instance.delete() + self.perform_destroy(instance) return Response(status=status.HTTP_204_NO_CONTENT) + def perform_destroy(self, instance): + instance.delete() + # The AllowPUTAsCreateMixin was previously the default behaviour # for PUT requests. This has now been removed and must be *explicitly* diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 3d868a9e..e7cd50d6 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -83,7 +83,10 @@ class BaseSerializer(Field): ) if self.instance is not None: - self.update(self.instance, validated_data) + self.instance = self.update(self.instance, validated_data) + assert self.instance is not None, ( + '`update()` did not return an object instance.' + ) else: self.instance = self.create(validated_data) assert self.instance is not None, ( @@ -444,19 +447,19 @@ class ModelSerializer(Serializer): self.validators.extend(validators) self._kwargs['validators'] = validators - def create(self, attrs): + def create(self, validated_attrs): ModelClass = self.Meta.model - # Remove many-to-many relationships from attrs. + # Remove many-to-many relationships from validated_attrs. # They are not valid arguments to the default `.create()` method, # as they require that the instance has already been saved. info = model_meta.get_field_info(ModelClass) many_to_many = {} for field_name, relation_info in info.relations.items(): - if relation_info.to_many and (field_name in attrs): - many_to_many[field_name] = attrs.pop(field_name) + if relation_info.to_many and (field_name in validated_attrs): + many_to_many[field_name] = validated_attrs.pop(field_name) - instance = ModelClass.objects.create(**attrs) + instance = ModelClass.objects.create(**validated_attrs) # Save many-to-many relationships after the instance is created. if many_to_many: @@ -465,10 +468,11 @@ class ModelSerializer(Serializer): return instance - def update(self, obj, attrs): - for attr, value in attrs.items(): - setattr(obj, attr, value) - obj.save() + def update(self, instance, validated_attrs): + for attr, value in validated_attrs.items(): + setattr(instance, attr, value) + instance.save() + return instance def get_unique_together_validators(self): field_names = set([ |
