diff options
| -rw-r--r-- | rest_framework/relations.py | 42 | ||||
| -rw-r--r-- | tests/test_relations_pk.py | 12 | 
2 files changed, 47 insertions, 7 deletions
| diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 268b95cf..1665dd35 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -1,5 +1,5 @@  from rest_framework.compat import smart_text, urlparse -from rest_framework.fields import empty, Field +from rest_framework.fields import get_attribute, empty, Field  from rest_framework.reverse import reverse  from rest_framework.utils import html  from django.core.exceptions import ObjectDoesNotExist, ImproperlyConfigured @@ -9,6 +9,11 @@ from django.utils import six  from django.utils.translation import ugettext_lazy as _ +class PKOnlyObject(object): +    def __init__(self, pk): +        self.pk = pk + +  class RelatedField(Field):      def __init__(self, **kwargs):          self.queryset = kwargs.pop('queryset', None) @@ -45,6 +50,10 @@ class RelatedField(Field):              queryset = queryset.all()          return queryset +    def get_iterable(self, instance, source): +        relationship = get_attribute(instance, [source]) +        return relationship.all() if (hasattr(relationship, 'all')) else relationship +      @property      def choices(self):          return dict([ @@ -85,6 +94,31 @@ class PrimaryKeyRelatedField(RelatedField):          except (TypeError, ValueError):              self.fail('incorrect_type', data_type=type(data).__name__) +    def get_attribute(self, instance): +        # We customize `get_attribute` here for performance reasons. +        # For relationships the instance will already have the pk of +        # the related object. We return this directly instead of returning the +        # object itself, which would require a database lookup. +        try: +            return PKOnlyObject(pk=instance.serializable_value(self.source)) +        except AttributeError: +            return get_attribute(instance, [self.source]) + +    def get_iterable(self, instance, source): +        # For consistency with `get_attribute` we're using `serializable_value()` +        # here. Typically there won't be any difference, but some custom field +        # types might return a non-primative value for the pk otherwise. +        # +        # We could try to get smart with `values_list('pk', flat=True)`, which +        # would be better in some case, but would actually end up with *more* +        # queries if the developer is using `prefetch_related` across the +        # relationship. +        relationship = super(PrimaryKeyRelatedField, self).get_iterable(instance, source) +        return [ +            PKOnlyObject(pk=item.serializable_value('pk')) +            for item in relationship +        ] +      def to_representation(self, value):          return value.pk @@ -277,8 +311,10 @@ class ManyRelation(Field):              for item in data          ] -    def to_representation(self, obj): -        iterable = obj.all() if (hasattr(obj, 'all')) else obj +    def get_attribute(self, instance): +        return self.child_relation.get_iterable(instance, self.source) + +    def to_representation(self, iterable):          return [              self.child_relation.to_representation(value)              for value in iterable diff --git a/tests/test_relations_pk.py b/tests/test_relations_pk.py index da3c5786..ba5f6c17 100644 --- a/tests/test_relations_pk.py +++ b/tests/test_relations_pk.py @@ -68,7 +68,8 @@ class PKManyToManyTests(TestCase):              {'id': 2, 'name': 'source-2', 'targets': [1, 2]},              {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}          ] -        self.assertEqual(serializer.data, expected) +        with self.assertNumQueries(4): +            self.assertEqual(serializer.data, expected)      def test_reverse_many_to_many_retrieve(self):          queryset = ManyToManyTarget.objects.all() @@ -78,7 +79,8 @@ class PKManyToManyTests(TestCase):              {'id': 2, 'name': 'target-2', 'sources': [2, 3]},              {'id': 3, 'name': 'target-3', 'sources': [3]}          ] -        self.assertEqual(serializer.data, expected) +        with self.assertNumQueries(4): +            self.assertEqual(serializer.data, expected)      def test_many_to_many_update(self):          data = {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]} @@ -173,7 +175,8 @@ class PKForeignKeyTests(TestCase):              {'id': 2, 'name': 'source-2', 'target': 1},              {'id': 3, 'name': 'source-3', 'target': 1}          ] -        self.assertEqual(serializer.data, expected) +        with self.assertNumQueries(1): +            self.assertEqual(serializer.data, expected)      def test_reverse_foreign_key_retrieve(self):          queryset = ForeignKeyTarget.objects.all() @@ -182,7 +185,8 @@ class PKForeignKeyTests(TestCase):              {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},              {'id': 2, 'name': 'target-2', 'sources': []},          ] -        self.assertEqual(serializer.data, expected) +        with self.assertNumQueries(3): +            self.assertEqual(serializer.data, expected)      def test_foreign_key_update(self):          data = {'id': 1, 'name': 'source-1', 'target': 2} | 
