diff options
| author | Tom Christie | 2012-11-02 20:53:33 +0000 | 
|---|---|---|
| committer | Tom Christie | 2012-11-02 20:53:33 +0000 | 
| commit | 6eaec7a0eccabb3e1b010d07633632e8a3ecd86f (patch) | |
| tree | 57ea5143abd2d0ce5cc5ab12f00ab7accbfc980a | |
| parent | e84ce60a0da77c0e07e0d6e5f627694ae3f4422f (diff) | |
| download | django-rest-framework-6eaec7a0eccabb3e1b010d07633632e8a3ecd86f.tar.bz2 | |
foreign key tests
| -rw-r--r-- | rest_framework/fields.py | 13 | ||||
| -rw-r--r-- | rest_framework/tests/pk_relations.py | 139 | 
2 files changed, 132 insertions, 20 deletions
diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 375d7a46..965e22c4 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -383,7 +383,8 @@ class PrimaryKeyRelatedField(RelatedField):          try:              return self.queryset.get(pk=data)          except ObjectDoesNotExist: -            raise ValidationError('Invalid hyperlink - object does not exist.') +            msg = "Invalid pk '%s' - object does not exist." % smart_unicode(data) +            raise ValidationError(msg)      def field_to_native(self, obj, field_name):          try: @@ -430,6 +431,16 @@ class ManyPrimaryKeyRelatedField(ManyRelatedField):          # Forward relationship          return [self.to_native(item.pk) for item in queryset.all()] +    def from_native(self, data): +        if self.queryset is None: +            raise Exception('Writable related fields must include a `queryset` argument') + +        try: +            return self.queryset.get(pk=data) +        except ObjectDoesNotExist: +            msg = "Invalid pk '%s' - object does not exist." % smart_unicode(data) +            raise ValidationError(msg) +  ### Slug relationships diff --git a/rest_framework/tests/pk_relations.py b/rest_framework/tests/pk_relations.py index eca534c2..a0934344 100644 --- a/rest_framework/tests/pk_relations.py +++ b/rest_framework/tests/pk_relations.py @@ -3,26 +3,50 @@ from django.test import TestCase  from rest_framework import serializers -class Target(models.Model): +# ManyToMany + +class ManyToManyTarget(models.Model):      name = models.CharField(max_length=100) -class Source(models.Model): +class ManyToManySource(models.Model):      name = models.CharField(max_length=100) -    targets = models.ManyToManyField(Target, related_name='sources') +    targets = models.ManyToManyField(ManyToManyTarget, related_name='sources') -class TargetSerializer(serializers.ModelSerializer): -    sources = serializers.ManyPrimaryKeyRelatedField() +class ManyToManyTargetSerializer(serializers.ModelSerializer): +    sources = serializers.ManyPrimaryKeyRelatedField(queryset=ManyToManySource.objects.all())      class Meta: -        fields = ('id', 'name', 'sources') -        model = Target +        model = ManyToManyTarget -class SourceSerializer(serializers.ModelSerializer): +class ManyToManySourceSerializer(serializers.ModelSerializer):      class Meta: -        model = Source +        model = ManyToManySource + + +# ForeignKey + +class ForeignKeyTarget(models.Model): +    name = models.CharField(max_length=100) + + +class ForeignKeySource(models.Model): +    name = models.CharField(max_length=100) +    target = models.ForeignKey(ForeignKeyTarget, related_name='sources') + + +class ForeignKeyTargetSerializer(serializers.ModelSerializer): +    sources = serializers.ManyPrimaryKeyRelatedField(queryset=ForeignKeySource.objects.all()) + +    class Meta: +        model = ForeignKeyTarget + + +class ForeignKeySourceSerializer(serializers.ModelSerializer): +    class Meta: +        model = ForeignKeySource  # TODO: Add test that .data cannot be accessed prior to .is_valid @@ -30,15 +54,16 @@ class SourceSerializer(serializers.ModelSerializer):  class PrimaryKeyManyToManyTests(TestCase):      def setUp(self):          for idx in range(1, 4): -            target = Target(name='target-%d' % idx) +            target = ManyToManyTarget(name='target-%d' % idx)              target.save() -            source = Source(name='source-%d' % idx) +            source = ManyToManySource(name='source-%d' % idx)              source.save() -            for target in Target.objects.all(): +            for target in ManyToManyTarget.objects.all():                  source.targets.add(target)      def test_many_to_many_retrieve(self): -        serializer = SourceSerializer(instance=Source.objects.all()) +        queryset = ManyToManySource.objects.all() +        serializer = ManyToManySourceSerializer(instance=queryset)          expected = [                  {'id': 1, 'name': u'source-1', 'targets': [1]},                  {'id': 2, 'name': u'source-2', 'targets': [1, 2]}, @@ -47,7 +72,8 @@ class PrimaryKeyManyToManyTests(TestCase):          self.assertEquals(serializer.data, expected)      def test_reverse_many_to_many_retrieve(self): -        serializer = TargetSerializer(instance=Target.objects.all()) +        queryset = ManyToManyTarget.objects.all() +        serializer = ManyToManyTargetSerializer(instance=queryset)          expected = [              {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]},              {'id': 2, 'name': u'target-2', 'sources': [2, 3]}, @@ -57,12 +83,15 @@ class PrimaryKeyManyToManyTests(TestCase):      def test_many_to_many_update(self):          data = {'id': 1, 'name': u'source-1', 'targets': [1, 2, 3]} -        serializer = SourceSerializer(data, instance=Source.objects.get(pk=1)) +        instance = ManyToManySource.objects.get(pk=1) +        serializer = ManyToManySourceSerializer(data, instance=instance)          self.assertTrue(serializer.is_valid())          self.assertEquals(serializer.data, data) +        serializer.save()          # Ensure source 1 is updated, and everything else is as expected -        serializer = SourceSerializer(instance=Source.objects.all()) +        queryset = ManyToManySource.objects.all() +        serializer = ManyToManySourceSerializer(instance=queryset)          expected = [                  {'id': 1, 'name': u'source-1', 'targets': [1, 2, 3]},                  {'id': 2, 'name': u'source-2', 'targets': [1, 2]}, @@ -71,16 +100,88 @@ class PrimaryKeyManyToManyTests(TestCase):          self.assertEquals(serializer.data, expected)      def test_reverse_many_to_many_update(self): -        data = {'id': 1, 'name': u'target-0', 'sources': [1]} -        serializer = TargetSerializer(data, instance=Target.objects.get(pk=1)) +        data = {'id': 1, 'name': u'target-1', 'sources': [1]} +        instance = ManyToManyTarget.objects.get(pk=1) +        serializer = ManyToManyTargetSerializer(data, instance=instance)          self.assertTrue(serializer.is_valid())          self.assertEquals(serializer.data, data) +        serializer.save()          # Ensure target 1 is updated, and everything else is as expected -        serializer = TargetSerializer(instance=Target.objects.all()) +        queryset = ManyToManyTarget.objects.all() +        serializer = ManyToManyTargetSerializer(instance=queryset)          expected = [              {'id': 1, 'name': u'target-1', 'sources': [1]},              {'id': 2, 'name': u'target-2', 'sources': [2, 3]},              {'id': 3, 'name': u'target-3', 'sources': [3]}          ]          self.assertEquals(serializer.data, expected) + + +class PrimaryKeyForeignKeyTests(TestCase): +    def setUp(self): +        target = ForeignKeyTarget(name='target-1') +        target.save() +        new_target = ForeignKeyTarget(name='target-2') +        new_target.save() +        for idx in range(1, 4): +            source = ForeignKeySource(name='source-%d' % idx, target=target) +            source.save() + +    def test_foreign_key_retrieve(self): +        queryset = ForeignKeySource.objects.all() +        serializer = ForeignKeySourceSerializer(instance=queryset) +        expected = [ +            {'id': 1, 'name': u'source-1', 'target': 1}, +            {'id': 2, 'name': u'source-2', 'target': 1}, +            {'id': 3, 'name': u'source-3', 'target': 1} +        ] +        self.assertEquals(serializer.data, expected) + +    def test_reverse_foreign_key_retrieve(self): +        queryset = ForeignKeyTarget.objects.all() +        serializer = ForeignKeyTargetSerializer(instance=queryset) +        expected = [ +            {'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]}, +            {'id': 2, 'name': u'target-2', 'sources': []}, +        ] +        self.assertEquals(serializer.data, expected) + +    def test_foreign_key_update(self): +        data = {'id': 1, 'name': u'source-1', 'target': 2} +        instance = ForeignKeySource.objects.get(pk=1) +        serializer = ForeignKeySourceSerializer(data, instance=instance) +        self.assertTrue(serializer.is_valid()) +        self.assertEquals(serializer.data, data) +        serializer.save() + +        # # Ensure source 1 is updated, and everything else is as expected +        queryset = ForeignKeySource.objects.all() +        serializer = ForeignKeySourceSerializer(instance=queryset) +        expected = [ +            {'id': 1, 'name': u'source-1', 'target': 2}, +            {'id': 2, 'name': u'source-2', 'target': 1}, +            {'id': 3, 'name': u'source-3', 'target': 1} +        ] +        self.assertEquals(serializer.data, expected) + +    # reverse foreign keys MUST be read_only +    # In the general case they do not provide .remove() or .clear() +    # and cannot be arbitrarily set. + +    # def test_reverse_foreign_key_update(self): +    #     data = {'id': 1, 'name': u'target-1', 'sources': [1]} +    #     instance = ForeignKeyTarget.objects.get(pk=1) +    #     serializer = ForeignKeyTargetSerializer(data, instance=instance) +    #     self.assertTrue(serializer.is_valid()) +    #     self.assertEquals(serializer.data, data) +    #     serializer.save() + +    #     # Ensure target 1 is updated, and everything else is as expected +    #     queryset = ForeignKeyTarget.objects.all() +    #     serializer = ForeignKeyTargetSerializer(instance=queryset) +    #     expected = [ +    #         {'id': 1, 'name': u'target-1', 'sources': [1]}, +    #         {'id': 2, 'name': u'target-2', 'sources': []}, +    #     ] +    #     self.assertEquals(serializer.data, expected)  | 
