diff options
| author | Tom Christie | 2013-04-17 09:26:34 +0100 | 
|---|---|---|
| committer | Tom Christie | 2013-04-17 09:26:34 +0100 | 
| commit | bcf4cb2b4e2fdf10b0df01ece1aa8ce9dc97285a (patch) | |
| tree | 2a27a61091e6be47d9596bdced5976cfbd130576 /rest_framework | |
| parent | ea55143a2308b396c8df6f59a0f6d663c1067163 (diff) | |
| parent | 76e039d70e8fc7f1d5c65180cb544abab81e600e (diff) | |
| download | django-rest-framework-bcf4cb2b4e2fdf10b0df01ece1aa8ce9dc97285a.tar.bz2 | |
Merge branch 'include_reverse_relations' of https://github.com/tomchristie/django-rest-framework into include_reverse_relations
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/serializers.py | 43 | ||||
| -rw-r--r-- | rest_framework/tests/serializer.py | 37 | 
2 files changed, 74 insertions, 6 deletions
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index e28bbe81..eac909c7 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -598,6 +598,24 @@ class ModelSerializer(Serializer):              if field:                  ret[model_field.name] = field +        # Reverse relationships are only included if they are explicitly +        # present in `Meta.fields`. +        if self.opts.fields: +            reverse = opts.get_all_related_objects() +            reverse += opts.get_all_related_many_to_many_objects() +            for rel in reverse: +                name = rel.get_accessor_name() +                if name not in self.opts.fields: +                    continue + +                if nested: +                    field = self.get_nested_field(None, rel) +                else: +                    field = self.get_related_field(None, rel, to_many=True) + +                if field: +                    ret[name] = field +          for field_name in self.opts.read_only_fields:              assert field_name in ret, \                  "read_only_fields on '%s' included invalid item '%s'" % \ @@ -612,24 +630,36 @@ class ModelSerializer(Serializer):          """          return self.get_field(model_field) -    def get_nested_field(self, model_field): +    def get_nested_field(self, model_field, rel=None):          """          Creates a default instance of a nested relational field.          """ +        if rel: +            model_class = rel.model +        else: +            model_class = model_field.rel.to +          class NestedModelSerializer(ModelSerializer):              class Meta: -                model = model_field.rel.to +                model = model_class          return NestedModelSerializer() -    def get_related_field(self, model_field, to_many=False): +    def get_related_field(self, model_field, rel=None, to_many=False):          """          Creates a default instance of a flat relational field.          """          # TODO: filter queryset using:          # .using(db).complex_filter(self.rel.limit_choices_to) +        if rel: +            model_class = rel.model +            required = True +        else: +            model_class = model_field.rel.to +            required = not(model_field.null or model_field.blank) +          kwargs = { -            'required': not(model_field.null or model_field.blank), -            'queryset': model_field.rel.to._default_manager, +            'required': required, +            'queryset': model_class._default_manager,              'many': to_many          } @@ -797,7 +827,8 @@ class HyperlinkedModelSerializer(ModelSerializer):          return self._default_view_name % format_kwargs      def get_pk_field(self, model_field): -        return None +        if self.opts.fields and model_field.name in self.opts.fields: +            return self.get_field(model_field)      def get_related_field(self, model_field, to_many):          """ diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py index 05217f35..3a94fad5 100644 --- a/rest_framework/tests/serializer.py +++ b/rest_framework/tests/serializer.py @@ -738,6 +738,43 @@ class ManyRelatedTests(TestCase):          self.assertEqual(serializer.data, expected) +    def test_include_reverse_relations(self): +        post = BlogPost.objects.create(title="Test blog post") +        post.blogpostcomment_set.create(text="I hate this blog post") +        post.blogpostcomment_set.create(text="I love this blog post") + +        class BlogPostSerializer(serializers.ModelSerializer): +            class Meta: +                model = BlogPost +                fields = ('id', 'title', 'blogpostcomment_set') + +        serializer = BlogPostSerializer(instance=post) +        expected = { +            'id': 1, 'title': 'Test blog post', 'blogpostcomment_set': [1, 2] +        } +        self.assertEqual(serializer.data, expected) + +    def test_depth_include_reverse_relations(self): +        post = BlogPost.objects.create(title="Test blog post") +        post.blogpostcomment_set.create(text="I hate this blog post") +        post.blogpostcomment_set.create(text="I love this blog post") + +        class BlogPostSerializer(serializers.ModelSerializer): +            class Meta: +                model = BlogPost +                fields = ('id', 'title', 'blogpostcomment_set') +                depth = 1 + +        serializer = BlogPostSerializer(instance=post) +        expected = { +            'id': 1, 'title': 'Test blog post', +            'blogpostcomment_set': [ +                {'id': 1, 'text': 'I hate this blog post', 'blog_post': 1}, +                {'id': 2, 'text': 'I love this blog post', 'blog_post': 1} +            ] +        } +        self.assertEqual(serializer.data, expected) +      def test_callable_source(self):          post = BlogPost.objects.create(title="Test blog post")          post.blogpostcomment_set.create(text="I love this blog post")  | 
