aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTom Christie2012-12-07 21:32:39 +0000
committerTom Christie2012-12-07 21:32:45 +0000
commit303bc7cf95033d2560668bf6f4d97f05f1268967 (patch)
tree7bfd4b12df5f8e1dd9109bbfffbc0ce3c604b041
parenta5178e9a363d00f3eef8d86da2d0ec687518f288 (diff)
downloaddjango-rest-framework-303bc7cf95033d2560668bf6f4d97f05f1268967.tar.bz2
Support nullable FKs, with blank=True
-rw-r--r--rest_framework/fields.py8
-rw-r--r--rest_framework/serializers.py14
-rw-r--r--rest_framework/tests/hyperlinkedserializers.py25
-rw-r--r--rest_framework/tests/pk_relations.py53
4 files changed, 85 insertions, 15 deletions
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index c28a9695..bffc0fb0 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -350,7 +350,13 @@ class RelatedField(WritableField):
return
value = data.get(field_name)
- into[(self.source or field_name)] = self.from_native(value)
+
+ if value is None and not self.blank:
+ raise ValidationError('Value may not be null')
+ elif value is None and self.blank:
+ into[(self.source or field_name)] = None
+ else:
+ into[(self.source or field_name)] = self.from_native(value)
class ManyRelatedMixin(object):
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 5edd46f5..13c41a4b 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -431,10 +431,14 @@ class ModelSerializer(Serializer):
"""
# TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to)
- queryset = model_field.rel.to._default_manager
+ kwargs = {
+ 'blank': model_field.blank,
+ 'queryset': model_field.rel.to._default_manager
+ }
+
if to_many:
- return ManyPrimaryKeyRelatedField(queryset=queryset)
- return PrimaryKeyRelatedField(queryset=queryset)
+ return ManyPrimaryKeyRelatedField(**kwargs)
+ return PrimaryKeyRelatedField(**kwargs)
def get_field(self, model_field):
"""
@@ -572,9 +576,9 @@ class HyperlinkedModelSerializer(ModelSerializer):
# TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to)
rel = model_field.rel.to
- queryset = rel._default_manager
kwargs = {
- 'queryset': queryset,
+ 'blank': model_field.blank,
+ 'queryset': rel._default_manager,
'view_name': self._get_default_view_name(rel)
}
if to_many:
diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/hyperlinkedserializers.py
index d7effce7..24bf61bf 100644
--- a/rest_framework/tests/hyperlinkedserializers.py
+++ b/rest_framework/tests/hyperlinkedserializers.py
@@ -1,6 +1,7 @@
from django.conf.urls.defaults import patterns, url
from django.test import TestCase
from django.test.client import RequestFactory
+from django.utils import simplejson as json
from rest_framework import generics, status, serializers
from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, Album, Photo, OptionalRelationModel
@@ -54,10 +55,12 @@ class BlogPostCommentListCreate(generics.ListCreateAPIView):
model = BlogPostComment
serializer_class = BlogPostCommentSerializer
+
class BlogPostCommentDetail(generics.RetrieveAPIView):
model = BlogPostComment
serializer_class = BlogPostCommentSerializer
+
class BlogPostDetail(generics.RetrieveAPIView):
model = BlogPost
@@ -71,7 +74,7 @@ class AlbumDetail(generics.RetrieveAPIView):
model = Album
-class OptionalRelationDetail(generics.RetrieveAPIView):
+class OptionalRelationDetail(generics.RetrieveUpdateDestroyAPIView):
model = OptionalRelationModel
model_serializer_class = serializers.HyperlinkedModelSerializer
@@ -162,7 +165,7 @@ class TestManyToManyHyperlinkedView(TestCase):
GET requests to ListCreateAPIView should return list of objects.
"""
request = factory.get('/manytomany/')
- response = self.list_view(request).render()
+ response = self.list_view(request)
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, self.data)
@@ -171,7 +174,7 @@ class TestManyToManyHyperlinkedView(TestCase):
GET requests to ListCreateAPIView should return list of objects.
"""
request = factory.get('/manytomany/1/')
- response = self.detail_view(request, pk=1).render()
+ response = self.detail_view(request, pk=1)
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, self.data[0])
@@ -194,7 +197,7 @@ class TestCreateWithForeignKeys(TestCase):
}
request = factory.post('/comments/', data=data)
- response = self.create_view(request).render()
+ response = self.create_view(request)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(response['Location'], 'http://testserver/comments/1/')
self.assertEqual(self.post.blogpostcomment_set.count(), 1)
@@ -219,7 +222,7 @@ class TestCreateWithForeignKeysAndCustomSlug(TestCase):
}
request = factory.post('/photos/', data=data)
- response = self.list_create_view(request).render()
+ response = self.list_create_view(request)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertNotIn('Location', response, msg='Location should only be included if there is a "url" field on the serializer')
self.assertEqual(self.post.photo_set.count(), 1)
@@ -244,6 +247,16 @@ class TestOptionalRelationHyperlinkedView(TestCase):
for non existing relations.
"""
request = factory.get('/optionalrelationmodel-detail/1')
- response = self.detail_view(request, pk=1).render()
+ response = self.detail_view(request, pk=1)
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, self.data)
+
+ def test_put_detail_view(self):
+ """
+ PUT requests to RetrieveUpdateDestroyAPIView with optional relations
+ should accept None for non existing relations.
+ """
+ response = self.client.put('/optionalrelation/1/',
+ data=json.dumps(self.data),
+ content_type='application/json')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
diff --git a/rest_framework/tests/pk_relations.py b/rest_framework/tests/pk_relations.py
index 3dcc76f9..53245d94 100644
--- a/rest_framework/tests/pk_relations.py
+++ b/rest_framework/tests/pk_relations.py
@@ -49,9 +49,22 @@ class ForeignKeySourceSerializer(serializers.ModelSerializer):
model = ForeignKeySource
+# Nullable ForeignKey
+
+class NullableForeignKeySource(models.Model):
+ name = models.CharField(max_length=100)
+ target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True,
+ related_name='nullable_sources')
+
+
+class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = NullableForeignKeySource
+
+
# TODO: Add test that .data cannot be accessed prior to .is_valid
-class PrimaryKeyManyToManyTests(TestCase):
+class PKManyToManyTests(TestCase):
def setUp(self):
for idx in range(1, 4):
target = ManyToManyTarget(name='target-%d' % idx)
@@ -137,7 +150,7 @@ class PrimaryKeyManyToManyTests(TestCase):
self.assertEquals(serializer.data, expected)
-class PrimaryKeyForeignKeyTests(TestCase):
+class PKForeignKeyTests(TestCase):
def setUp(self):
target = ForeignKeyTarget(name='target-1')
target.save()
@@ -174,7 +187,7 @@ class PrimaryKeyForeignKeyTests(TestCase):
self.assertEquals(serializer.data, data)
serializer.save()
- # # Ensure source 1 is updated, and everything else is as expected
+ # Ensure source 1 is updated, and everything else is as expected
queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset)
expected = [
@@ -184,6 +197,40 @@ class PrimaryKeyForeignKeyTests(TestCase):
]
self.assertEquals(serializer.data, expected)
+ def test_foreign_key_update_with_invalid_null(self):
+ data = {'id': 1, 'name': u'source-1', 'target': None}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEquals(serializer.errors, {'target': [u'Value may not be null']})
+
+
+class PKNullableForeignKeyTests(TestCase):
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ for idx in range(1, 4):
+ source = NullableForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_update_with_valid_null(self):
+ data = {'id': 1, 'name': u'source-1', 'target': None}
+ instance = NullableForeignKeySource.objects.get(pk=1)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEquals(serializer.data, data)
+ serializer.save()
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset)
+ expected = [
+ {'id': 1, 'name': u'source-1', 'target': None},
+ {'id': 2, 'name': u'source-2', 'target': 1},
+ {'id': 3, 'name': u'source-3', 'target': 1}
+ ]
+ self.assertEquals(serializer.data, expected)
+
# reverse foreign keys MUST be read_only
# In the general case they do not provide .remove() or .clear()
# and cannot be arbitrarily set.