diff options
| author | Tom Christie | 2014-09-18 14:58:08 +0100 | 
|---|---|---|
| committer | Tom Christie | 2014-09-18 14:58:08 +0100 | 
| commit | 106362b437f45e04faaea759df57a66a8a2d7cfd (patch) | |
| tree | 22410f7d4d0e8ccec5c4891505a1b4b6974d68af | |
| parent | 9fdb2280d11db126771686d626aa8a0247b8a46c (diff) | |
| download | django-rest-framework-106362b437f45e04faaea759df57a66a8a2d7cfd.tar.bz2 | |
ModelSerializer.create() to handle many to many by default
| -rw-r--r-- | rest_framework/relations.py | 5 | ||||
| -rw-r--r-- | rest_framework/serializers.py | 20 | ||||
| -rw-r--r-- | tests/test_model_serializer.py | 45 | 
3 files changed, 67 insertions, 3 deletions
| diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 474d3e75..5aa1f8bd 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -24,7 +24,10 @@ class RelatedField(Field):          # We override this method in order to automagically create          # `ManyRelation` classes instead when `many=True` is set.          if kwargs.pop('many', False): -            return ManyRelation(child_relation=cls(*args, **kwargs)) +            return ManyRelation( +                child_relation=cls(*args, **kwargs), +                read_only=kwargs.get('read_only', False) +            )          return super(RelatedField, cls).__new__(cls, *args, **kwargs)      def get_queryset(self): diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 9f3e53fd..03e20df8 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -344,7 +344,25 @@ class ModelSerializer(Serializer):      def create(self, attrs):          ModelClass = self.Meta.model -        return ModelClass.objects.create(**attrs) + +        # Remove many-to-many relationships from attrs. +        # They are not valid arguments to the default `.create()` method, +        # as they require that the instance has already been saved. +        info = model_meta.get_field_info(ModelClass) +        many_to_many = {} +        for key, relation_info in info.relations.items(): +            if relation_info.to_many and (key in attrs): +                many_to_many[key] = attrs.pop(key) + +        instance = ModelClass.objects.create(**attrs) + +        # Save many to many relationships after the instance is created. +        if many_to_many: +            for key, value in many_to_many.items(): +                setattr(instance, key, value) +            instance.save() + +        return instance      def update(self, obj, attrs):          for attr, value in attrs.items(): diff --git a/tests/test_model_serializer.py b/tests/test_model_serializer.py index b3dae713..6f207e02 100644 --- a/tests/test_model_serializer.py +++ b/tests/test_model_serializer.py @@ -360,7 +360,7 @@ class TestIntegration(TestCase):          self.serializer_cls = TestSerializer -    def test_pk_relationship_representations(self): +    def test_pk_retrival(self):          serializer = self.serializer_cls(self.instance)          expected = {              'id': self.instance.pk, @@ -370,3 +370,46 @@ class TestIntegration(TestCase):              'through': []          }          self.assertEqual(serializer.data, expected) + +    def test_pk_create(self): +        new_foreign_key = ForeignKeyTargetModel.objects.create( +            name='foreign_key' +        ) +        new_one_to_one = OneToOneTargetModel.objects.create( +            name='one_to_one' +        ) +        new_many_to_many = [ +            ManyToManyTargetModel.objects.create( +                name='new many_to_many (%d)' % idx +            ) for idx in range(3) +        ] +        data = { +            'foreign_key': new_foreign_key.pk, +            'one_to_one': new_one_to_one.pk, +            'many_to_many': [item.pk for item in new_many_to_many], +        } + +        # Serializer should validate okay. +        serializer = self.serializer_cls(data=data) +        assert serializer.is_valid() + +        # Creating the instance, relationship attributes should be set. +        instance = serializer.save() +        assert instance.foreign_key.pk == new_foreign_key.pk +        assert instance.one_to_one.pk == new_one_to_one.pk +        assert [ +            item.pk for item in instance.many_to_many.all() +        ] == [ +            item.pk for item in new_many_to_many +        ] +        assert list(instance.through.all()) == [] + +        # Representation should be correct. +        expected = { +            'id': instance.pk, +            'foreign_key': new_foreign_key.pk, +            'one_to_one': new_one_to_one.pk, +            'many_to_many': [item.pk for item in new_many_to_many], +            'through': [] +        } +        self.assertEqual(serializer.data, expected) | 
