aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTom Christie2013-04-10 22:38:02 +0100
committerTom Christie2013-04-10 22:38:02 +0100
commit76e039d70e8fc7f1d5c65180cb544abab81e600e (patch)
treed5ffa15006f7935ca7a960353ef78cbf444eecac
parent3f91379e4eaf07418a99fda1932af91511c55e7b (diff)
downloaddjango-rest-framework-76e039d70e8fc7f1d5c65180cb544abab81e600e.tar.bz2
First pass on automatically including reverse relationship
-rw-r--r--rest_framework/serializers.py43
-rw-r--r--rest_framework/tests/serializer.py37
2 files changed, 74 insertions, 6 deletions
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index e28bbe81..eac909c7 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -598,6 +598,24 @@ class ModelSerializer(Serializer):
if field:
ret[model_field.name] = field
+ # Reverse relationships are only included if they are explicitly
+ # present in `Meta.fields`.
+ if self.opts.fields:
+ reverse = opts.get_all_related_objects()
+ reverse += opts.get_all_related_many_to_many_objects()
+ for rel in reverse:
+ name = rel.get_accessor_name()
+ if name not in self.opts.fields:
+ continue
+
+ if nested:
+ field = self.get_nested_field(None, rel)
+ else:
+ field = self.get_related_field(None, rel, to_many=True)
+
+ if field:
+ ret[name] = field
+
for field_name in self.opts.read_only_fields:
assert field_name in ret, \
"read_only_fields on '%s' included invalid item '%s'" % \
@@ -612,24 +630,36 @@ class ModelSerializer(Serializer):
"""
return self.get_field(model_field)
- def get_nested_field(self, model_field):
+ def get_nested_field(self, model_field, rel=None):
"""
Creates a default instance of a nested relational field.
"""
+ if rel:
+ model_class = rel.model
+ else:
+ model_class = model_field.rel.to
+
class NestedModelSerializer(ModelSerializer):
class Meta:
- model = model_field.rel.to
+ model = model_class
return NestedModelSerializer()
- def get_related_field(self, model_field, to_many=False):
+ def get_related_field(self, model_field, rel=None, to_many=False):
"""
Creates a default instance of a flat relational field.
"""
# TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to)
+ if rel:
+ model_class = rel.model
+ required = True
+ else:
+ model_class = model_field.rel.to
+ required = not(model_field.null or model_field.blank)
+
kwargs = {
- 'required': not(model_field.null or model_field.blank),
- 'queryset': model_field.rel.to._default_manager,
+ 'required': required,
+ 'queryset': model_class._default_manager,
'many': to_many
}
@@ -797,7 +827,8 @@ class HyperlinkedModelSerializer(ModelSerializer):
return self._default_view_name % format_kwargs
def get_pk_field(self, model_field):
- return None
+ if self.opts.fields and model_field.name in self.opts.fields:
+ return self.get_field(model_field)
def get_related_field(self, model_field, to_many):
"""
diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py
index 05217f35..3a94fad5 100644
--- a/rest_framework/tests/serializer.py
+++ b/rest_framework/tests/serializer.py
@@ -738,6 +738,43 @@ class ManyRelatedTests(TestCase):
self.assertEqual(serializer.data, expected)
+ def test_include_reverse_relations(self):
+ post = BlogPost.objects.create(title="Test blog post")
+ post.blogpostcomment_set.create(text="I hate this blog post")
+ post.blogpostcomment_set.create(text="I love this blog post")
+
+ class BlogPostSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BlogPost
+ fields = ('id', 'title', 'blogpostcomment_set')
+
+ serializer = BlogPostSerializer(instance=post)
+ expected = {
+ 'id': 1, 'title': 'Test blog post', 'blogpostcomment_set': [1, 2]
+ }
+ self.assertEqual(serializer.data, expected)
+
+ def test_depth_include_reverse_relations(self):
+ post = BlogPost.objects.create(title="Test blog post")
+ post.blogpostcomment_set.create(text="I hate this blog post")
+ post.blogpostcomment_set.create(text="I love this blog post")
+
+ class BlogPostSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BlogPost
+ fields = ('id', 'title', 'blogpostcomment_set')
+ depth = 1
+
+ serializer = BlogPostSerializer(instance=post)
+ expected = {
+ 'id': 1, 'title': 'Test blog post',
+ 'blogpostcomment_set': [
+ {'id': 1, 'text': 'I hate this blog post', 'blog_post': 1},
+ {'id': 2, 'text': 'I love this blog post', 'blog_post': 1}
+ ]
+ }
+ self.assertEqual(serializer.data, expected)
+
def test_callable_source(self):
post = BlogPost.objects.create(title="Test blog post")
post.blogpostcomment_set.create(text="I love this blog post")