diff options
| author | Tom Christie | 2014-03-04 15:26:34 +0000 | 
|---|---|---|
| committer | Tom Christie | 2014-03-04 15:26:34 +0000 | 
| commit | 4edd39b2e4a6490e7eac17dd989418f745d8efc3 (patch) | |
| tree | ed2fb9171bd9283d3314364caa0d005b08f6f4e2 /rest_framework | |
| parent | 24a688223240eb1e71db3c0f00cd621e80cb9fb2 (diff) | |
| parent | dea2766abac5ef55fa226f413711cfd49af2a745 (diff) | |
| download | django-rest-framework-4edd39b2e4a6490e7eac17dd989418f745d8efc3.tar.bz2 | |
Merge pull request #1442 from Anton-Shutik/master
RelatedField default value handling fixed
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/fields.py | 10 | ||||
| -rw-r--r-- | rest_framework/relations.py | 10 | ||||
| -rw-r--r-- | rest_framework/tests/test_serializer.py | 52 | 
3 files changed, 67 insertions, 5 deletions
diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 05daaab7..68b95682 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -301,6 +301,11 @@ class WritableField(Field):          result.validators = self.validators[:]          return result +    def get_default_value(self): +        if is_simple_callable(self.default): +            return self.default() +        return self.default +      def validate(self, value):          if value in validators.EMPTY_VALUES and self.required:              raise ValidationError(self.error_messages['required']) @@ -349,10 +354,7 @@ class WritableField(Field):          except KeyError:              if self.default is not None and not self.partial:                  # Note: partial updates shouldn't set defaults -                if is_simple_callable(self.default): -                    native = self.default() -                else: -                    native = self.default +                native = self.get_default_value()              else:                  if self.required:                      raise ValidationError(self.error_messages['required']) diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 163a8984..308545ce 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -119,6 +119,14 @@ class RelatedField(WritableField):      choices = property(_get_choices, _set_choices) +    ### Default value handling + +    def get_default_value(self): +        default = super(RelatedField, self).get_default_value() +        if self.many and default is None: +            return [] +        return default +      ### Regular serializer stuff...      def field_to_native(self, obj, field_name): @@ -167,7 +175,7 @@ class RelatedField(WritableField):          except KeyError:              if self.partial:                  return -            value = [] if self.many else None +            value = self.get_default_value()          if value in self.null_values:              if self.required: diff --git a/rest_framework/tests/test_serializer.py b/rest_framework/tests/test_serializer.py index 47082190..198c269f 100644 --- a/rest_framework/tests/test_serializer.py +++ b/rest_framework/tests/test_serializer.py @@ -900,6 +900,58 @@ class DefaultValueTests(TestCase):          self.assertEqual(instance.text, 'overridden') +class WritableFieldDefaultValueTests(TestCase): + +    def setUp(self): +        self.expected = {'default': 'value'} +        self.create_field = fields.WritableField + +    def test_get_default_value_with_noncallable(self): +        field = self.create_field(default=self.expected) +        got = field.get_default_value() +        self.assertEqual(got, self.expected) + +    def test_get_default_value_with_callable(self): +        field = self.create_field(default=lambda : self.expected) +        got = field.get_default_value() +        self.assertEqual(got, self.expected) + +    def test_get_default_value_when_not_required(self): +        field = self.create_field(default=self.expected, required=False) +        got = field.get_default_value() +        self.assertEqual(got, self.expected) + +    def test_get_default_value_returns_None(self): +        field = self.create_field() +        got = field.get_default_value() +        self.assertIsNone(got) + +    def test_get_default_value_returns_non_True_values(self): +        values = [None, '', False, 0, [], (), {}] # values that assumed as 'False' in the 'if' clause +        for expected in values: +            field = self.create_field(default=expected) +            got = field.get_default_value() +            self.assertEqual(got, expected) + + +class RelatedFieldDefaultValueTests(WritableFieldDefaultValueTests): + +    def setUp(self): +        self.expected = {'foo': 'bar'} +        self.create_field = relations.RelatedField + +    def test_get_default_value_returns_empty_list(self): +        field = self.create_field(many=True) +        got = field.get_default_value() +        self.assertListEqual(got, []) + +    def test_get_default_value_returns_expected(self): +        expected = [1, 2, 3] +        field = self.create_field(many=True, default=expected) +        got = field.get_default_value() +        self.assertListEqual(got, expected) + +  class CallableDefaultValueTests(TestCase):      def setUp(self):          class CallableDefaultValueSerializer(serializers.ModelSerializer):  | 
