diff options
| author | Tom Christie | 2012-12-07 21:32:39 +0000 | 
|---|---|---|
| committer | Tom Christie | 2012-12-07 21:32:45 +0000 | 
| commit | 303bc7cf95033d2560668bf6f4d97f05f1268967 (patch) | |
| tree | 7bfd4b12df5f8e1dd9109bbfffbc0ce3c604b041 | |
| parent | a5178e9a363d00f3eef8d86da2d0ec687518f288 (diff) | |
| download | django-rest-framework-303bc7cf95033d2560668bf6f4d97f05f1268967.tar.bz2 | |
Support nullable FKs, with blank=True
| -rw-r--r-- | rest_framework/fields.py | 8 | ||||
| -rw-r--r-- | rest_framework/serializers.py | 14 | ||||
| -rw-r--r-- | rest_framework/tests/hyperlinkedserializers.py | 25 | ||||
| -rw-r--r-- | rest_framework/tests/pk_relations.py | 53 | 
4 files changed, 85 insertions, 15 deletions
| diff --git a/rest_framework/fields.py b/rest_framework/fields.py index c28a9695..bffc0fb0 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -350,7 +350,13 @@ class RelatedField(WritableField):              return          value = data.get(field_name) -        into[(self.source or field_name)] = self.from_native(value) + +        if value is None and not self.blank: +            raise ValidationError('Value may not be null') +        elif value is None and self.blank: +            into[(self.source or field_name)] = None +        else: +            into[(self.source or field_name)] = self.from_native(value)  class ManyRelatedMixin(object): diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 5edd46f5..13c41a4b 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -431,10 +431,14 @@ class ModelSerializer(Serializer):          """          # TODO: filter queryset using:          # .using(db).complex_filter(self.rel.limit_choices_to) -        queryset = model_field.rel.to._default_manager +        kwargs = { +            'blank': model_field.blank, +            'queryset': model_field.rel.to._default_manager +        } +          if to_many: -            return ManyPrimaryKeyRelatedField(queryset=queryset) -        return PrimaryKeyRelatedField(queryset=queryset) +            return ManyPrimaryKeyRelatedField(**kwargs) +        return PrimaryKeyRelatedField(**kwargs)      def get_field(self, model_field):          """ @@ -572,9 +576,9 @@ class HyperlinkedModelSerializer(ModelSerializer):          # TODO: filter queryset using:          # .using(db).complex_filter(self.rel.limit_choices_to)          rel = model_field.rel.to -        queryset = rel._default_manager          kwargs = { -            'queryset': queryset, +            'blank': model_field.blank, +            'queryset': rel._default_manager,              'view_name': self._get_default_view_name(rel)          }          if to_many: diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/hyperlinkedserializers.py index d7effce7..24bf61bf 100644 --- a/rest_framework/tests/hyperlinkedserializers.py +++ b/rest_framework/tests/hyperlinkedserializers.py @@ -1,6 +1,7 @@  from django.conf.urls.defaults import patterns, url  from django.test import TestCase  from django.test.client import RequestFactory +from django.utils import simplejson as json  from rest_framework import generics, status, serializers  from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, Album, Photo, OptionalRelationModel @@ -54,10 +55,12 @@ class BlogPostCommentListCreate(generics.ListCreateAPIView):      model = BlogPostComment      serializer_class = BlogPostCommentSerializer +  class BlogPostCommentDetail(generics.RetrieveAPIView):      model = BlogPostComment      serializer_class = BlogPostCommentSerializer +  class BlogPostDetail(generics.RetrieveAPIView):      model = BlogPost @@ -71,7 +74,7 @@ class AlbumDetail(generics.RetrieveAPIView):      model = Album -class OptionalRelationDetail(generics.RetrieveAPIView): +class OptionalRelationDetail(generics.RetrieveUpdateDestroyAPIView):      model = OptionalRelationModel      model_serializer_class = serializers.HyperlinkedModelSerializer @@ -162,7 +165,7 @@ class TestManyToManyHyperlinkedView(TestCase):          GET requests to ListCreateAPIView should return list of objects.          """          request = factory.get('/manytomany/') -        response = self.list_view(request).render() +        response = self.list_view(request)          self.assertEquals(response.status_code, status.HTTP_200_OK)          self.assertEquals(response.data, self.data) @@ -171,7 +174,7 @@ class TestManyToManyHyperlinkedView(TestCase):          GET requests to ListCreateAPIView should return list of objects.          """          request = factory.get('/manytomany/1/') -        response = self.detail_view(request, pk=1).render() +        response = self.detail_view(request, pk=1)          self.assertEquals(response.status_code, status.HTTP_200_OK)          self.assertEquals(response.data, self.data[0]) @@ -194,7 +197,7 @@ class TestCreateWithForeignKeys(TestCase):          }          request = factory.post('/comments/', data=data) -        response = self.create_view(request).render() +        response = self.create_view(request)          self.assertEqual(response.status_code, status.HTTP_201_CREATED)          self.assertEqual(response['Location'], 'http://testserver/comments/1/')          self.assertEqual(self.post.blogpostcomment_set.count(), 1) @@ -219,7 +222,7 @@ class TestCreateWithForeignKeysAndCustomSlug(TestCase):          }          request = factory.post('/photos/', data=data) -        response = self.list_create_view(request).render() +        response = self.list_create_view(request)          self.assertEqual(response.status_code, status.HTTP_201_CREATED)          self.assertNotIn('Location', response, msg='Location should only be included if there is a "url" field on the serializer')          self.assertEqual(self.post.photo_set.count(), 1) @@ -244,6 +247,16 @@ class TestOptionalRelationHyperlinkedView(TestCase):          for non existing relations.          """          request = factory.get('/optionalrelationmodel-detail/1') -        response = self.detail_view(request, pk=1).render() +        response = self.detail_view(request, pk=1)          self.assertEquals(response.status_code, status.HTTP_200_OK)          self.assertEquals(response.data, self.data) + +    def test_put_detail_view(self): +        """ +        PUT requests to RetrieveUpdateDestroyAPIView with optional relations +        should accept None for non existing relations. +        """ +        response = self.client.put('/optionalrelation/1/', +                                   data=json.dumps(self.data), +                                   content_type='application/json') +        self.assertEqual(response.status_code, status.HTTP_200_OK) diff --git a/rest_framework/tests/pk_relations.py b/rest_framework/tests/pk_relations.py index 3dcc76f9..53245d94 100644 --- a/rest_framework/tests/pk_relations.py +++ b/rest_framework/tests/pk_relations.py @@ -49,9 +49,22 @@ class ForeignKeySourceSerializer(serializers.ModelSerializer):          model = ForeignKeySource +# Nullable ForeignKey + +class NullableForeignKeySource(models.Model): +    name = models.CharField(max_length=100) +    target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True, +                               related_name='nullable_sources') + + +class NullableForeignKeySourceSerializer(serializers.ModelSerializer): +    class Meta: +        model = NullableForeignKeySource + +  # TODO: Add test that .data cannot be accessed prior to .is_valid -class PrimaryKeyManyToManyTests(TestCase): +class PKManyToManyTests(TestCase):      def setUp(self):          for idx in range(1, 4):              target = ManyToManyTarget(name='target-%d' % idx) @@ -137,7 +150,7 @@ class PrimaryKeyManyToManyTests(TestCase):          self.assertEquals(serializer.data, expected) -class PrimaryKeyForeignKeyTests(TestCase): +class PKForeignKeyTests(TestCase):      def setUp(self):          target = ForeignKeyTarget(name='target-1')          target.save() @@ -174,7 +187,7 @@ class PrimaryKeyForeignKeyTests(TestCase):          self.assertEquals(serializer.data, data)          serializer.save() -        # # Ensure source 1 is updated, and everything else is as expected +        # Ensure source 1 is updated, and everything else is as expected          queryset = ForeignKeySource.objects.all()          serializer = ForeignKeySourceSerializer(queryset)          expected = [ @@ -184,6 +197,40 @@ class PrimaryKeyForeignKeyTests(TestCase):          ]          self.assertEquals(serializer.data, expected) +    def test_foreign_key_update_with_invalid_null(self): +        data = {'id': 1, 'name': u'source-1', 'target': None} +        instance = ForeignKeySource.objects.get(pk=1) +        serializer = ForeignKeySourceSerializer(instance, data=data) +        self.assertFalse(serializer.is_valid()) +        self.assertEquals(serializer.errors, {'target': [u'Value may not be null']}) + + +class PKNullableForeignKeyTests(TestCase): +    def setUp(self): +        target = ForeignKeyTarget(name='target-1') +        target.save() +        for idx in range(1, 4): +            source = NullableForeignKeySource(name='source-%d' % idx, target=target) +            source.save() + +    def test_foreign_key_update_with_valid_null(self): +        data = {'id': 1, 'name': u'source-1', 'target': None} +        instance = NullableForeignKeySource.objects.get(pk=1) +        serializer = NullableForeignKeySourceSerializer(instance, data=data) +        self.assertTrue(serializer.is_valid()) +        self.assertEquals(serializer.data, data) +        serializer.save() + +        # Ensure source 1 is updated, and everything else is as expected +        queryset = NullableForeignKeySource.objects.all() +        serializer = NullableForeignKeySourceSerializer(queryset) +        expected = [ +            {'id': 1, 'name': u'source-1', 'target': None}, +            {'id': 2, 'name': u'source-2', 'target': 1}, +            {'id': 3, 'name': u'source-3', 'target': 1} +        ] +        self.assertEquals(serializer.data, expected) +      # reverse foreign keys MUST be read_only      # In the general case they do not provide .remove() or .clear()      # and cannot be arbitrarily set. | 
