diff options
| author | Tom Christie | 2015-01-28 09:26:49 +0000 | 
|---|---|---|
| committer | Tom Christie | 2015-01-28 09:26:49 +0000 | 
| commit | 6d89430dd268e01812214d1819337e1498d6068a (patch) | |
| tree | 9a19c7a1af371d6a3ccfde9ac2e45926366add6d | |
| parent | 81c2562ec4a1871a6f5f471ad37799ede3dbc166 (diff) | |
| parent | e7da266a866adddd5c37453fab33812ee412752b (diff) | |
| download | django-rest-framework-6d89430dd268e01812214d1819337e1498d6068a.tar.bz2 | |
Merge pull request #2475 from sdreher/master
 ManyRelatedField.get_value clearing field on partial update
| -rw-r--r-- | rest_framework/relations.py | 5 | ||||
| -rw-r--r-- | tests/test_relations.py | 33 | 
2 files changed, 38 insertions, 0 deletions
| diff --git a/rest_framework/relations.py b/rest_framework/relations.py index aa0c2def..13793f37 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -338,7 +338,12 @@ class ManyRelatedField(Field):          # We override the default field access in order to support          # lists in HTML forms.          if html.is_html_input(dictionary): +            # Don't return [] if the update is partial +            if self.field_name not in dictionary: +                if getattr(self.root, 'partial', False): +                    return empty              return dictionary.getlist(self.field_name) +          return dictionary.get(self.field_name, empty)      def to_internal_value(self, data): diff --git a/tests/test_relations.py b/tests/test_relations.py index 62353dc2..d478d855 100644 --- a/tests/test_relations.py +++ b/tests/test_relations.py @@ -1,6 +1,8 @@  from .utils import mock_reverse, fail_reverse, BadType, MockObject, MockQueryset  from django.core.exceptions import ImproperlyConfigured +from django.utils.datastructures import MultiValueDict  from rest_framework import serializers +from rest_framework.fields import empty  from rest_framework.test import APISimpleTestCase  import pytest @@ -134,3 +136,34 @@ class TestSlugRelatedField(APISimpleTestCase):      def test_representation(self):          representation = self.field.to_representation(self.instance)          assert representation == self.instance.name + + +class TestManyRelatedField(APISimpleTestCase): +    def setUp(self): +        self.instance = MockObject(pk=1, name='foo') +        self.field = serializers.StringRelatedField(many=True) +        self.field.field_name = 'foo' + +    def test_get_value_regular_dictionary_full(self): +        assert 'bar' == self.field.get_value({'foo': 'bar'}) +        assert empty == self.field.get_value({'baz': 'bar'}) + +    def test_get_value_regular_dictionary_partial(self): +        setattr(self.field.root, 'partial', True) +        assert 'bar' == self.field.get_value({'foo': 'bar'}) +        assert empty == self.field.get_value({'baz': 'bar'}) + +    def test_get_value_multi_dictionary_full(self): +        mvd = MultiValueDict({'foo': ['bar1', 'bar2']}) +        assert ['bar1', 'bar2'] == self.field.get_value(mvd) + +        mvd = MultiValueDict({'baz': ['bar1', 'bar2']}) +        assert [] == self.field.get_value(mvd) + +    def test_get_value_multi_dictionary_partial(self): +        setattr(self.field.root, 'partial', True) +        mvd = MultiValueDict({'foo': ['bar1', 'bar2']}) +        assert ['bar1', 'bar2'] == self.field.get_value(mvd) + +        mvd = MultiValueDict({'baz': ['bar1', 'bar2']}) +        assert empty == self.field.get_value(mvd) | 
