aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTom Christie2014-03-04 15:26:34 +0000
committerTom Christie2014-03-04 15:26:34 +0000
commit4edd39b2e4a6490e7eac17dd989418f745d8efc3 (patch)
treeed2fb9171bd9283d3314364caa0d005b08f6f4e2
parent24a688223240eb1e71db3c0f00cd621e80cb9fb2 (diff)
parentdea2766abac5ef55fa226f413711cfd49af2a745 (diff)
downloaddjango-rest-framework-4edd39b2e4a6490e7eac17dd989418f745d8efc3.tar.bz2
Merge pull request #1442 from Anton-Shutik/master
RelatedField default value handling fixed
-rw-r--r--rest_framework/fields.py10
-rw-r--r--rest_framework/relations.py10
-rw-r--r--rest_framework/tests/test_serializer.py52
3 files changed, 67 insertions, 5 deletions
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 05daaab7..68b95682 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -301,6 +301,11 @@ class WritableField(Field):
result.validators = self.validators[:]
return result
+ def get_default_value(self):
+ if is_simple_callable(self.default):
+ return self.default()
+ return self.default
+
def validate(self, value):
if value in validators.EMPTY_VALUES and self.required:
raise ValidationError(self.error_messages['required'])
@@ -349,10 +354,7 @@ class WritableField(Field):
except KeyError:
if self.default is not None and not self.partial:
# Note: partial updates shouldn't set defaults
- if is_simple_callable(self.default):
- native = self.default()
- else:
- native = self.default
+ native = self.get_default_value()
else:
if self.required:
raise ValidationError(self.error_messages['required'])
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index 163a8984..308545ce 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -119,6 +119,14 @@ class RelatedField(WritableField):
choices = property(_get_choices, _set_choices)
+ ### Default value handling
+
+ def get_default_value(self):
+ default = super(RelatedField, self).get_default_value()
+ if self.many and default is None:
+ return []
+ return default
+
### Regular serializer stuff...
def field_to_native(self, obj, field_name):
@@ -167,7 +175,7 @@ class RelatedField(WritableField):
except KeyError:
if self.partial:
return
- value = [] if self.many else None
+ value = self.get_default_value()
if value in self.null_values:
if self.required:
diff --git a/rest_framework/tests/test_serializer.py b/rest_framework/tests/test_serializer.py
index 47082190..198c269f 100644
--- a/rest_framework/tests/test_serializer.py
+++ b/rest_framework/tests/test_serializer.py
@@ -900,6 +900,58 @@ class DefaultValueTests(TestCase):
self.assertEqual(instance.text, 'overridden')
+class WritableFieldDefaultValueTests(TestCase):
+
+ def setUp(self):
+ self.expected = {'default': 'value'}
+ self.create_field = fields.WritableField
+
+ def test_get_default_value_with_noncallable(self):
+ field = self.create_field(default=self.expected)
+ got = field.get_default_value()
+ self.assertEqual(got, self.expected)
+
+ def test_get_default_value_with_callable(self):
+ field = self.create_field(default=lambda : self.expected)
+ got = field.get_default_value()
+ self.assertEqual(got, self.expected)
+
+ def test_get_default_value_when_not_required(self):
+ field = self.create_field(default=self.expected, required=False)
+ got = field.get_default_value()
+ self.assertEqual(got, self.expected)
+
+ def test_get_default_value_returns_None(self):
+ field = self.create_field()
+ got = field.get_default_value()
+ self.assertIsNone(got)
+
+ def test_get_default_value_returns_non_True_values(self):
+ values = [None, '', False, 0, [], (), {}] # values that assumed as 'False' in the 'if' clause
+ for expected in values:
+ field = self.create_field(default=expected)
+ got = field.get_default_value()
+ self.assertEqual(got, expected)
+
+
+class RelatedFieldDefaultValueTests(WritableFieldDefaultValueTests):
+
+ def setUp(self):
+ self.expected = {'foo': 'bar'}
+ self.create_field = relations.RelatedField
+
+ def test_get_default_value_returns_empty_list(self):
+ field = self.create_field(many=True)
+ got = field.get_default_value()
+ self.assertListEqual(got, [])
+
+ def test_get_default_value_returns_expected(self):
+ expected = [1, 2, 3]
+ field = self.create_field(many=True, default=expected)
+ got = field.get_default_value()
+ self.assertListEqual(got, expected)
+
+
class CallableDefaultValueTests(TestCase):
def setUp(self):
class CallableDefaultValueSerializer(serializers.ModelSerializer):