aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
authorTom Christie2013-03-13 20:40:39 +0000
committerTom Christie2013-03-13 20:40:39 +0000
commitacc8c1faa4f85dda00723d755e56bb3c980dbc75 (patch)
tree963418ba768dadd1ff2b5912fe819f05c2288ddb /rest_framework
parenta53596ce28359e24313a5fb9bd8f3564eb12678e (diff)
downloaddjango-rest-framework-acc8c1faa4f85dda00723d755e56bb3c980dbc75.tar.bz2
force_insert, force_update arguments. Closes #484.
Confirmed by `assertNumQueries(…)` in tests.
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/mixins.py6
-rw-r--r--rest_framework/serializers.py14
-rw-r--r--rest_framework/tests/generics.py10
3 files changed, 16 insertions, 14 deletions
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py
index 8e401204..7d9a6e65 100644
--- a/rest_framework/mixins.py
+++ b/rest_framework/mixins.py
@@ -44,7 +44,7 @@ class CreateModelMixin(object):
if serializer.is_valid():
self.pre_save(serializer.object)
- self.object = serializer.save()
+ self.object = serializer.save(force_insert=True)
self.post_save(self.object, created=True)
headers = self.get_success_headers(serializer.data)
return Response(serializer.data, status=status.HTTP_201_CREATED,
@@ -119,9 +119,11 @@ class UpdateModelMixin(object):
# we have relevant permissions, as if this was a POST request.
self.check_permissions(clone_request(request, 'POST'))
created = True
+ save_kwargs = {'force_insert': True}
success_status_code = status.HTTP_201_CREATED
else:
created = False
+ save_kwargs = {'force_update': True}
success_status_code = status.HTTP_200_OK
serializer = self.get_serializer(self.object, data=request.DATA,
@@ -129,7 +131,7 @@ class UpdateModelMixin(object):
if serializer.is_valid():
self.pre_save(serializer.object)
- self.object = serializer.save()
+ self.object = serializer.save(**save_kwargs)
self.post_save(self.object, created=created)
return Response(serializer.data, status=success_status_code)
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index cd2bb8f1..4fe857a6 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -391,17 +391,17 @@ class BaseSerializer(Field):
return self._data
- def save_object(self, obj):
- obj.save()
+ def save_object(self, obj, **kwargs):
+ obj.save(**kwargs)
- def save(self):
+ def save(self, **kwargs):
"""
Save the deserialized object and return it.
"""
if isinstance(self.object, list):
- [self.save_object(item) for item in self.object]
+ [self.save_object(item, **kwargs) for item in self.object]
else:
- self.save_object(self.object)
+ self.save_object(self.object, **kwargs)
return self.object
@@ -621,11 +621,11 @@ class ModelSerializer(Serializer):
if instance:
return self.full_clean(instance)
- def save_object(self, obj):
+ def save_object(self, obj, **kwargs):
"""
Save the deserialized object and return it.
"""
- obj.save()
+ obj.save(**kwargs)
if getattr(self, 'm2m_data', None):
for accessor_name, object_list in self.m2m_data.items():
diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py
index 1837898b..f564890c 100644
--- a/rest_framework/tests/generics.py
+++ b/rest_framework/tests/generics.py
@@ -184,7 +184,7 @@ class TestInstanceView(TestCase):
content = {'text': 'foobar'}
request = factory.put('/1', json.dumps(content),
content_type='application/json')
- with self.assertNumQueries(3):
+ with self.assertNumQueries(2):
response = self.view(request, pk='1').render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
@@ -199,7 +199,7 @@ class TestInstanceView(TestCase):
request = factory.patch('/1', json.dumps(content),
content_type='application/json')
- with self.assertNumQueries(3):
+ with self.assertNumQueries(2):
response = self.view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
@@ -248,7 +248,7 @@ class TestInstanceView(TestCase):
content = {'id': 999, 'text': 'foobar'}
request = factory.put('/1', json.dumps(content),
content_type='application/json')
- with self.assertNumQueries(3):
+ with self.assertNumQueries(2):
response = self.view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
@@ -264,7 +264,7 @@ class TestInstanceView(TestCase):
content = {'text': 'foobar'}
request = factory.put('/1', json.dumps(content),
content_type='application/json')
- with self.assertNumQueries(4):
+ with self.assertNumQueries(3):
response = self.view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
@@ -280,7 +280,7 @@ class TestInstanceView(TestCase):
# pk fields can not be created on demand, only the database can set the pk for a new object
request = factory.put('/5', json.dumps(content),
content_type='application/json')
- with self.assertNumQueries(4):
+ with self.assertNumQueries(3):
response = self.view(request, pk=5).render()
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
new_obj = self.objects.get(pk=5)