aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
authorTom Christie2014-10-08 12:17:30 +0100
committerTom Christie2014-10-08 12:17:30 +0100
commit0cbb57b40fdb073c7ca09c9d1078926260c646db (patch)
treeecb869965a0fb31428382d0f2ab60855e8f97268 /rest_framework
parentaf0f01c5b6597fe2f146268f7632f7e3954d17c2 (diff)
downloaddjango-rest-framework-0cbb57b40fdb073c7ca09c9d1078926260c646db.tar.bz2
Tweak pre/post save hooks. Return instance in .update().
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/mixins.py13
-rw-r--r--rest_framework/serializers.py24
2 files changed, 22 insertions, 15 deletions
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py
index 03ebb034..4c62debb 100644
--- a/rest_framework/mixins.py
+++ b/rest_framework/mixins.py
@@ -20,11 +20,11 @@ class CreateModelMixin(object):
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
- self.create_valid(serializer)
+ self.perform_create(serializer)
headers = self.get_success_headers(serializer.data)
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
- def create_valid(self, serializer):
+ def perform_create(self, serializer):
serializer.save()
def get_success_headers(self, data):
@@ -67,10 +67,10 @@ class UpdateModelMixin(object):
instance = self.get_object()
serializer = self.get_serializer(instance, data=request.data, partial=partial)
serializer.is_valid(raise_exception=True)
- self.update_valid(serializer)
+ self.preform_update(serializer)
return Response(serializer.data)
- def update_valid(self, serializer):
+ def preform_update(self, serializer):
serializer.save()
def partial_update(self, request, *args, **kwargs):
@@ -84,9 +84,12 @@ class DestroyModelMixin(object):
"""
def destroy(self, request, *args, **kwargs):
instance = self.get_object()
- instance.delete()
+ self.perform_destroy(instance)
return Response(status=status.HTTP_204_NO_CONTENT)
+ def perform_destroy(self, instance):
+ instance.delete()
+
# The AllowPUTAsCreateMixin was previously the default behaviour
# for PUT requests. This has now been removed and must be *explicitly*
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 3d868a9e..e7cd50d6 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -83,7 +83,10 @@ class BaseSerializer(Field):
)
if self.instance is not None:
- self.update(self.instance, validated_data)
+ self.instance = self.update(self.instance, validated_data)
+ assert self.instance is not None, (
+ '`update()` did not return an object instance.'
+ )
else:
self.instance = self.create(validated_data)
assert self.instance is not None, (
@@ -444,19 +447,19 @@ class ModelSerializer(Serializer):
self.validators.extend(validators)
self._kwargs['validators'] = validators
- def create(self, attrs):
+ def create(self, validated_attrs):
ModelClass = self.Meta.model
- # Remove many-to-many relationships from attrs.
+ # Remove many-to-many relationships from validated_attrs.
# They are not valid arguments to the default `.create()` method,
# as they require that the instance has already been saved.
info = model_meta.get_field_info(ModelClass)
many_to_many = {}
for field_name, relation_info in info.relations.items():
- if relation_info.to_many and (field_name in attrs):
- many_to_many[field_name] = attrs.pop(field_name)
+ if relation_info.to_many and (field_name in validated_attrs):
+ many_to_many[field_name] = validated_attrs.pop(field_name)
- instance = ModelClass.objects.create(**attrs)
+ instance = ModelClass.objects.create(**validated_attrs)
# Save many-to-many relationships after the instance is created.
if many_to_many:
@@ -465,10 +468,11 @@ class ModelSerializer(Serializer):
return instance
- def update(self, obj, attrs):
- for attr, value in attrs.items():
- setattr(obj, attr, value)
- obj.save()
+ def update(self, instance, validated_attrs):
+ for attr, value in validated_attrs.items():
+ setattr(instance, attr, value)
+ instance.save()
+ return instance
def get_unique_together_validators(self):
field_names = set([