diff options
| author | Tom Christie | 2014-12-10 22:19:46 +0000 | 
|---|---|---|
| committer | Tom Christie | 2014-12-10 22:19:46 +0000 | 
| commit | 313c36faca20b872f649f2fcea6c783f2d20fe72 (patch) | |
| tree | b777c6b2444c4e8154cd27e3e38eacaebd336f20 | |
| parent | 8ad0b83148fbb64c6ae91ba970ccd6e26a483a1e (diff) | |
| parent | 1e336ef30da9f7366fc68185016c4011c25e1199 (diff) | |
| download | django-rest-framework-313c36faca20b872f649f2fcea6c783f2d20fe72.tar.bz2 | |
Merge pull request #2242 from tomchristie/hyperlinked-pk-optimization
Hyperlinked PK optimization.
| -rw-r--r-- | rest_framework/relations.py | 57 | ||||
| -rw-r--r-- | tests/test_relations_hyperlink.py | 18 | ||||
| -rw-r--r-- | tests/test_relations_pk.py | 12 | ||||
| -rw-r--r-- | tests/test_relations_slug.py | 15 | 
4 files changed, 67 insertions, 35 deletions
diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 75d68204..892ce6c1 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -84,9 +84,20 @@ class RelatedField(Field):              queryset = queryset.all()          return queryset -    def get_iterable(self, instance, source_attrs): -        relationship = get_attribute(instance, source_attrs) -        return relationship.all() if (hasattr(relationship, 'all')) else relationship +    def use_pk_only_optimization(self): +        return False + +    def get_attribute(self, instance): +        if self.use_pk_only_optimization() and self.source_attrs: +            # Optimized case, return a mock object only containing the pk attribute. +            try: +                instance = get_attribute(instance, self.source_attrs[:-1]) +                return PKOnlyObject(pk=instance.serializable_value(self.source_attrs[-1])) +            except AttributeError: +                pass + +        # Standard case, return the object instance. +        return get_attribute(instance, self.source_attrs)      @property      def choices(self): @@ -120,6 +131,9 @@ class PrimaryKeyRelatedField(RelatedField):          'incorrect_type': _('Incorrect type. Expected pk value, received {data_type}.'),      } +    def use_pk_only_optimization(self): +        return True +      def to_internal_value(self, data):          try:              return self.get_queryset().get(pk=data) @@ -128,32 +142,6 @@ class PrimaryKeyRelatedField(RelatedField):          except (TypeError, ValueError):              self.fail('incorrect_type', data_type=type(data).__name__) -    def get_attribute(self, instance): -        # We customize `get_attribute` here for performance reasons. -        # For relationships the instance will already have the pk of -        # the related object. We return this directly instead of returning the -        # object itself, which would require a database lookup. -        try: -            instance = get_attribute(instance, self.source_attrs[:-1]) -            return PKOnlyObject(pk=instance.serializable_value(self.source_attrs[-1])) -        except AttributeError: -            return get_attribute(instance, self.source_attrs) - -    def get_iterable(self, instance, source_attrs): -        # For consistency with `get_attribute` we're using `serializable_value()` -        # here. Typically there won't be any difference, but some custom field -        # types might return a non-primitive value for the pk otherwise. -        # -        # We could try to get smart with `values_list('pk', flat=True)`, which -        # would be better in some case, but would actually end up with *more* -        # queries if the developer is using `prefetch_related` across the -        # relationship. -        relationship = super(PrimaryKeyRelatedField, self).get_iterable(instance, source_attrs) -        return [ -            PKOnlyObject(pk=item.serializable_value('pk')) -            for item in relationship -        ] -      def to_representation(self, value):          return value.pk @@ -184,6 +172,9 @@ class HyperlinkedRelatedField(RelatedField):          super(HyperlinkedRelatedField, self).__init__(**kwargs) +    def use_pk_only_optimization(self): +        return self.lookup_field == 'pk' +      def get_object(self, view_name, view_args, view_kwargs):          """          Return the object corresponding to a matched URL. @@ -285,6 +276,11 @@ class HyperlinkedIdentityField(HyperlinkedRelatedField):          kwargs['source'] = '*'          super(HyperlinkedIdentityField, self).__init__(view_name, **kwargs) +    def use_pk_only_optimization(self): +        # We have the complete object instance already. We don't need +        # to run the 'only get the pk for this relationship' code. +        return False +  class SlugRelatedField(RelatedField):      """ @@ -349,7 +345,8 @@ class ManyRelatedField(Field):          ]      def get_attribute(self, instance): -        return self.child_relation.get_iterable(instance, self.source_attrs) +        relationship = get_attribute(instance, self.source_attrs) +        return relationship.all() if (hasattr(relationship, 'all')) else relationship      def to_representation(self, iterable):          return [ diff --git a/tests/test_relations_hyperlink.py b/tests/test_relations_hyperlink.py index b938e385..f1b882ed 100644 --- a/tests/test_relations_hyperlink.py +++ b/tests/test_relations_hyperlink.py @@ -89,7 +89,14 @@ class HyperlinkedManyToManyTests(TestCase):              {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},              {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}          ] -        self.assertEqual(serializer.data, expected) +        with self.assertNumQueries(4): +            self.assertEqual(serializer.data, expected) + +    def test_many_to_many_retrieve_prefetch_related(self): +        queryset = ManyToManySource.objects.all().prefetch_related('targets') +        serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) +        with self.assertNumQueries(2): +            serializer.data      def test_reverse_many_to_many_retrieve(self):          queryset = ManyToManyTarget.objects.all() @@ -99,7 +106,8 @@ class HyperlinkedManyToManyTests(TestCase):              {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},              {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}          ] -        self.assertEqual(serializer.data, expected) +        with self.assertNumQueries(4): +            self.assertEqual(serializer.data, expected)      def test_many_to_many_update(self):          data = {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} @@ -197,7 +205,8 @@ class HyperlinkedForeignKeyTests(TestCase):              {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},              {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}          ] -        self.assertEqual(serializer.data, expected) +        with self.assertNumQueries(1): +            self.assertEqual(serializer.data, expected)      def test_reverse_foreign_key_retrieve(self):          queryset = ForeignKeyTarget.objects.all() @@ -206,7 +215,8 @@ class HyperlinkedForeignKeyTests(TestCase):              {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']},              {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},          ] -        self.assertEqual(serializer.data, expected) +        with self.assertNumQueries(3): +            self.assertEqual(serializer.data, expected)      def test_foreign_key_update(self):          data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'} diff --git a/tests/test_relations_pk.py b/tests/test_relations_pk.py index e95a877e..f872a8dc 100644 --- a/tests/test_relations_pk.py +++ b/tests/test_relations_pk.py @@ -71,6 +71,12 @@ class PKManyToManyTests(TestCase):          with self.assertNumQueries(4):              self.assertEqual(serializer.data, expected) +    def test_many_to_many_retrieve_prefetch_related(self): +        queryset = ManyToManySource.objects.all().prefetch_related('targets') +        serializer = ManyToManySourceSerializer(queryset, many=True) +        with self.assertNumQueries(2): +            serializer.data +      def test_reverse_many_to_many_retrieve(self):          queryset = ManyToManyTarget.objects.all()          serializer = ManyToManyTargetSerializer(queryset, many=True) @@ -188,6 +194,12 @@ class PKForeignKeyTests(TestCase):          with self.assertNumQueries(3):              self.assertEqual(serializer.data, expected) +    def test_reverse_foreign_key_retrieve_prefetch_related(self): +        queryset = ForeignKeyTarget.objects.all().prefetch_related('sources') +        serializer = ForeignKeyTargetSerializer(queryset, many=True) +        with self.assertNumQueries(2): +            serializer.data +      def test_foreign_key_update(self):          data = {'id': 1, 'name': 'source-1', 'target': 2}          instance = ForeignKeySource.objects.get(pk=1) diff --git a/tests/test_relations_slug.py b/tests/test_relations_slug.py index 7bac9046..cd2cb1ed 100644 --- a/tests/test_relations_slug.py +++ b/tests/test_relations_slug.py @@ -54,7 +54,14 @@ class SlugForeignKeyTests(TestCase):              {'id': 2, 'name': 'source-2', 'target': 'target-1'},              {'id': 3, 'name': 'source-3', 'target': 'target-1'}          ] -        self.assertEqual(serializer.data, expected) +        with self.assertNumQueries(4): +            self.assertEqual(serializer.data, expected) + +    def test_foreign_key_retrieve_select_related(self): +        queryset = ForeignKeySource.objects.all().select_related('target') +        serializer = ForeignKeySourceSerializer(queryset, many=True) +        with self.assertNumQueries(1): +            serializer.data      def test_reverse_foreign_key_retrieve(self):          queryset = ForeignKeyTarget.objects.all() @@ -65,6 +72,12 @@ class SlugForeignKeyTests(TestCase):          ]          self.assertEqual(serializer.data, expected) +    def test_reverse_foreign_key_retrieve_prefetch_related(self): +        queryset = ForeignKeyTarget.objects.all().prefetch_related('sources') +        serializer = ForeignKeyTargetSerializer(queryset, many=True) +        with self.assertNumQueries(2): +            serializer.data +      def test_foreign_key_update(self):          data = {'id': 1, 'name': 'source-1', 'target': 'target-2'}          instance = ForeignKeySource.objects.get(pk=1)  | 
