From 3f79a9a3d3e7692d90476f8a6907957b47aab821 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 22 Mar 2013 22:39:45 +0000 Subject: one-one writable nested modelserializers --- rest_framework/serializers.py | 11 ++- rest_framework/tests/relations_nested.py | 154 ++++++++++++++++--------------- 2 files changed, 92 insertions(+), 73 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 6aca2f57..26c34044 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -753,7 +753,16 @@ class ModelSerializer(Serializer): if getattr(obj, '_related_data', None): for accessor_name, related in obj._related_data.items(): - setattr(obj, accessor_name, related) + if related is None: + previous = getattr(obj, accessor_name, related) + if previous: + previous.delete() + elif isinstance(related, models.Model): + fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name + setattr(related, fk_field, obj) + self.save_object(related) + else: + setattr(obj, accessor_name, related) del(obj._related_data) diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/relations_nested.py index a125ba65..4592e559 100644 --- a/rest_framework/tests/relations_nested.py +++ b/rest_framework/tests/relations_nested.py @@ -1,115 +1,125 @@ from __future__ import unicode_literals +from django.db import models from django.test import TestCase from rest_framework import serializers -from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource -class ForeignKeySourceSerializer(serializers.ModelSerializer): - class Meta: - depth = 1 - model = ForeignKeySource - +class OneToOneTarget(models.Model): + name = models.CharField(max_length=100) -class FlatForeignKeySourceSerializer(serializers.ModelSerializer): - class Meta: - model = ForeignKeySource +class OneToOneTargetSource(models.Model): + name = models.CharField(max_length=100) + target = models.OneToOneField(OneToOneTarget, null=True, blank=True, + related_name='target_source') -class ForeignKeyTargetSerializer(serializers.ModelSerializer): - sources = FlatForeignKeySourceSerializer(many=True) - class Meta: - model = ForeignKeyTarget +class OneToOneSource(models.Model): + name = models.CharField(max_length=100) + target_source = models.OneToOneField(OneToOneTargetSource, related_name='source') -class NullableForeignKeySourceSerializer(serializers.ModelSerializer): +class OneToOneSourceSerializer(serializers.ModelSerializer): class Meta: - depth = 1 - model = NullableForeignKeySource + model = OneToOneSource + exclude = ('target_source', ) -class NullableOneToOneSourceSerializer(serializers.ModelSerializer): +class OneToOneTargetSourceSerializer(serializers.ModelSerializer): + source = OneToOneSourceSerializer() + class Meta: - model = NullableOneToOneSource + model = OneToOneTargetSource + exclude = ('target', ) -class NullableOneToOneTargetSerializer(serializers.ModelSerializer): - nullable_source = NullableOneToOneSourceSerializer() +class OneToOneTargetSerializer(serializers.ModelSerializer): + target_source = OneToOneTargetSourceSerializer() class Meta: model = OneToOneTarget -class ReverseForeignKeyTests(TestCase): +class NestedOneToOneTests(TestCase): def setUp(self): - target = ForeignKeyTarget(name='target-1') - target.save() - new_target = ForeignKeyTarget(name='target-2') - new_target.save() for idx in range(1, 4): - source = ForeignKeySource(name='source-%d' % idx, target=target) + target = OneToOneTarget(name='target-%d' % idx) + target.save() + target_source = OneToOneTargetSource(name='target-source-%d' % idx, target=target) + target_source.save() + source = OneToOneSource(name='source-%d' % idx, target_source=target_source) source.save() - def test_foreign_key_retrieve(self): - queryset = ForeignKeySource.objects.all() - serializer = ForeignKeySourceSerializer(queryset, many=True) + def test_one_to_one_retrieve(self): + queryset = OneToOneTarget.objects.all() + serializer = OneToOneTargetSerializer(queryset) expected = [ - {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, - {'id': 2, 'name': 'source-2', 'target': {'id': 1, 'name': 'target-1'}}, - {'id': 3, 'name': 'source-3', 'target': {'id': 1, 'name': 'target-1'}}, + {'id': 1, 'name': 'target-1', 'target_source': {'id': 1, 'name': 'target-source-1', 'source': {'id': 1, 'name': 'source-1'}}}, + {'id': 2, 'name': 'target-2', 'target_source': {'id': 2, 'name': 'target-source-2', 'source': {'id': 2, 'name': 'source-2'}}}, + {'id': 3, 'name': 'target-3', 'target_source': {'id': 3, 'name': 'target-source-3', 'source': {'id': 3, 'name': 'source-3'}}} ] self.assertEqual(serializer.data, expected) - def test_reverse_foreign_key_retrieve(self): - queryset = ForeignKeyTarget.objects.all() - serializer = ForeignKeyTargetSerializer(queryset, many=True) + def test_one_to_one_create(self): + data = {'id': 4, 'name': 'target-4', 'target_source': {'id': 4, 'name': 'target-source-4', 'source': {'id': 4, 'name': 'source-4'}}} + serializer = OneToOneTargetSerializer(data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'target-4') + + # Ensure (target 4, target_source 4, source 4) are added, and + # everything else is as expected. + queryset = OneToOneTarget.objects.all() + serializer = OneToOneTargetSerializer(queryset) expected = [ - {'id': 1, 'name': 'target-1', 'sources': [ - {'id': 1, 'name': 'source-1', 'target': 1}, - {'id': 2, 'name': 'source-2', 'target': 1}, - {'id': 3, 'name': 'source-3', 'target': 1}, - ]}, - {'id': 2, 'name': 'target-2', 'sources': [ - ]} + {'id': 1, 'name': 'target-1', 'target_source': {'id': 1, 'name': 'target-source-1', 'source': {'id': 1, 'name': 'source-1'}}}, + {'id': 2, 'name': 'target-2', 'target_source': {'id': 2, 'name': 'target-source-2', 'source': {'id': 2, 'name': 'source-2'}}}, + {'id': 3, 'name': 'target-3', 'target_source': {'id': 3, 'name': 'target-source-3', 'source': {'id': 3, 'name': 'source-3'}}}, + {'id': 4, 'name': 'target-4', 'target_source': {'id': 4, 'name': 'target-source-4', 'source': {'id': 4, 'name': 'source-4'}}} ] self.assertEqual(serializer.data, expected) - -class NestedNullableForeignKeyTests(TestCase): - def setUp(self): - target = ForeignKeyTarget(name='target-1') - target.save() - for idx in range(1, 4): - if idx == 3: - target = None - source = NullableForeignKeySource(name='source-%d' % idx, target=target) - source.save() - - def test_foreign_key_retrieve_with_null(self): - queryset = NullableForeignKeySource.objects.all() - serializer = NullableForeignKeySourceSerializer(queryset, many=True) + def test_one_to_one_create_with_invalid_data(self): + data = {'id': 4, 'name': 'target-4', 'target_source': {'id': 4, 'name': 'target-source-4', 'source': {'id': 4}}} + serializer = OneToOneTargetSerializer(data=data) + self.assertFalse(serializer.is_valid()) + self.assertEqual(serializer.errors, {'target_source': [{'source': [{'name': ['This field is required.']}]}]}) + + def test_one_to_one_update(self): + data = {'id': 3, 'name': 'target-3-updated', 'target_source': {'id': 3, 'name': 'target-source-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}}} + instance = OneToOneTarget.objects.get(pk=3) + serializer = OneToOneTargetSerializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'target-3-updated') + + # Ensure (target 3, target_source 3, source 3) are updated, + # and everything else is as expected. + queryset = OneToOneTarget.objects.all() + serializer = OneToOneTargetSerializer(queryset) expected = [ - {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, - {'id': 2, 'name': 'source-2', 'target': {'id': 1, 'name': 'target-1'}}, - {'id': 3, 'name': 'source-3', 'target': None}, + {'id': 1, 'name': 'target-1', 'target_source': {'id': 1, 'name': 'target-source-1', 'source': {'id': 1, 'name': 'source-1'}}}, + {'id': 2, 'name': 'target-2', 'target_source': {'id': 2, 'name': 'target-source-2', 'source': {'id': 2, 'name': 'source-2'}}}, + {'id': 3, 'name': 'target-3-updated', 'target_source': {'id': 3, 'name': 'target-source-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}}} ] self.assertEqual(serializer.data, expected) + def test_one_to_one_delete(self): + data = {'id': 3, 'name': 'target-3', 'target_source': None} + instance = OneToOneTarget.objects.get(pk=3) + serializer = OneToOneTargetSerializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + serializer.save() -class NestedNullableOneToOneTests(TestCase): - def setUp(self): - target = OneToOneTarget(name='target-1') - target.save() - new_target = OneToOneTarget(name='target-2') - new_target.save() - source = NullableOneToOneSource(name='source-1', target=target) - source.save() - - def test_reverse_foreign_key_retrieve_with_null(self): + # Ensure (target_source 3, source 3) are deleted, + # and everything else is as expected. queryset = OneToOneTarget.objects.all() - serializer = NullableOneToOneTargetSerializer(queryset, many=True) + serializer = OneToOneTargetSerializer(queryset) expected = [ - {'id': 1, 'name': 'target-1', 'nullable_source': {'id': 1, 'name': 'source-1', 'target': 1}}, - {'id': 2, 'name': 'target-2', 'nullable_source': None}, + {'id': 1, 'name': 'target-1', 'target_source': {'id': 1, 'name': 'target-source-1', 'source': {'id': 1, 'name': 'source-1'}}}, + {'id': 2, 'name': 'target-2', 'target_source': {'id': 2, 'name': 'target-source-2', 'source': {'id': 2, 'name': 'source-2'}}}, + {'id': 3, 'name': 'target-3', 'target_source': None} ] self.assertEqual(serializer.data, expected) -- cgit v1.2.3 From d97e72cdb2f4fcc5aa2c19527a2b2ff11cf784bb Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 25 Mar 2013 17:28:23 +0000 Subject: Cleanup one-one nested tests and implementation --- rest_framework/serializers.py | 37 +++++- rest_framework/tests/relations_nested.py | 186 +++++++++++++++++++++---------- 2 files changed, 157 insertions(+), 66 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 26c34044..668bcc49 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -667,9 +667,12 @@ class ModelSerializer(Serializer): cls = self.opts.model opts = get_concrete_model(cls)._meta exclusions = [field.name for field in opts.fields + opts.many_to_many] + for field_name, field in self.fields.items(): field_name = field.source or field_name - if field_name in exclusions and not field.read_only: + if field_name in exclusions \ + and not field.read_only \ + and not isinstance(field, Serializer): exclusions.remove(field_name) return exclusions @@ -695,6 +698,7 @@ class ModelSerializer(Serializer): """ m2m_data = {} related_data = {} + nested_forward_relations = {} meta = self.opts.model._meta # Reverse fk or one-to-one relations @@ -714,6 +718,12 @@ class ModelSerializer(Serializer): if field.name in attrs: m2m_data[field.name] = attrs.pop(field.name) + # Nested forward relations - These need to be marked so we can save + # them before saving the parent model instance. + for field_name in attrs.keys(): + if isinstance(self.fields.get(field_name, None), Serializer): + nested_forward_relations[field_name] = attrs[field_name] + # Update an existing instance... if instance is not None: for key, val in attrs.items(): @@ -729,6 +739,7 @@ class ModelSerializer(Serializer): # at the point of save. instance._related_data = related_data instance._m2m_data = m2m_data + instance._nested_forward_relations = nested_forward_relations return instance @@ -744,6 +755,13 @@ class ModelSerializer(Serializer): """ Save the deserialized object and return it. """ + if getattr(obj, '_nested_forward_relations', None): + # Nested relationships need to be saved before we can save the + # parent instance. + for field_name, sub_object in obj._nested_forward_relations.items(): + self.save_object(sub_object) + setattr(obj, field_name, sub_object) + obj.save(**kwargs) if getattr(obj, '_m2m_data', None): @@ -753,15 +771,22 @@ class ModelSerializer(Serializer): if getattr(obj, '_related_data', None): for accessor_name, related in obj._related_data.items(): - if related is None: - previous = getattr(obj, accessor_name, related) - if previous: - previous.delete() - elif isinstance(related, models.Model): + field = self.fields.get(accessor_name, None) + if isinstance(field, Serializer): + # TODO: Following will be needed for reverse FK + # if field.many: + # # Nested reverse fk relationship + # for related_item in related: + # fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name + # setattr(related_item, fk_field, obj) + # self.save_object(related_item) + # else: + # Nested reverse one-one relationship fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name setattr(related, fk_field, obj) self.save_object(related) else: + # Reverse FK or reverse one-one setattr(obj, accessor_name, related) del(obj._related_data) diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/relations_nested.py index 4592e559..e7af6565 100644 --- a/rest_framework/tests/relations_nested.py +++ b/rest_framework/tests/relations_nested.py @@ -8,61 +8,46 @@ class OneToOneTarget(models.Model): name = models.CharField(max_length=100) -class OneToOneTargetSource(models.Model): - name = models.CharField(max_length=100) - target = models.OneToOneField(OneToOneTarget, null=True, blank=True, - related_name='target_source') - - class OneToOneSource(models.Model): name = models.CharField(max_length=100) - target_source = models.OneToOneField(OneToOneTargetSource, related_name='source') - - -class OneToOneSourceSerializer(serializers.ModelSerializer): - class Meta: - model = OneToOneSource - exclude = ('target_source', ) - + target = models.OneToOneField(OneToOneTarget, related_name='source') -class OneToOneTargetSourceSerializer(serializers.ModelSerializer): - source = OneToOneSourceSerializer() - - class Meta: - model = OneToOneTargetSource - exclude = ('target', ) +class ReverseNestedOneToOneTests(TestCase): + def setUp(self): + class OneToOneSourceSerializer(serializers.ModelSerializer): + class Meta: + model = OneToOneSource + fields = ('id', 'name') -class OneToOneTargetSerializer(serializers.ModelSerializer): - target_source = OneToOneTargetSourceSerializer() + class OneToOneTargetSerializer(serializers.ModelSerializer): + source = OneToOneSourceSerializer() - class Meta: - model = OneToOneTarget + class Meta: + model = OneToOneTarget + fields = ('id', 'name', 'source') + self.Serializer = OneToOneTargetSerializer -class NestedOneToOneTests(TestCase): - def setUp(self): for idx in range(1, 4): target = OneToOneTarget(name='target-%d' % idx) target.save() - target_source = OneToOneTargetSource(name='target-source-%d' % idx, target=target) - target_source.save() - source = OneToOneSource(name='source-%d' % idx, target_source=target_source) + source = OneToOneSource(name='source-%d' % idx, target=target) source.save() def test_one_to_one_retrieve(self): queryset = OneToOneTarget.objects.all() - serializer = OneToOneTargetSerializer(queryset) + serializer = self.Serializer(queryset) expected = [ - {'id': 1, 'name': 'target-1', 'target_source': {'id': 1, 'name': 'target-source-1', 'source': {'id': 1, 'name': 'source-1'}}}, - {'id': 2, 'name': 'target-2', 'target_source': {'id': 2, 'name': 'target-source-2', 'source': {'id': 2, 'name': 'source-2'}}}, - {'id': 3, 'name': 'target-3', 'target_source': {'id': 3, 'name': 'target-source-3', 'source': {'id': 3, 'name': 'source-3'}}} + {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, + {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, + {'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}} ] self.assertEqual(serializer.data, expected) def test_one_to_one_create(self): - data = {'id': 4, 'name': 'target-4', 'target_source': {'id': 4, 'name': 'target-source-4', 'source': {'id': 4, 'name': 'source-4'}}} - serializer = OneToOneTargetSerializer(data=data) + data = {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}} + serializer = self.Serializer(data=data) self.assertTrue(serializer.is_valid()) obj = serializer.save() self.assertEqual(serializer.data, data) @@ -71,25 +56,25 @@ class NestedOneToOneTests(TestCase): # Ensure (target 4, target_source 4, source 4) are added, and # everything else is as expected. queryset = OneToOneTarget.objects.all() - serializer = OneToOneTargetSerializer(queryset) + serializer = self.Serializer(queryset) expected = [ - {'id': 1, 'name': 'target-1', 'target_source': {'id': 1, 'name': 'target-source-1', 'source': {'id': 1, 'name': 'source-1'}}}, - {'id': 2, 'name': 'target-2', 'target_source': {'id': 2, 'name': 'target-source-2', 'source': {'id': 2, 'name': 'source-2'}}}, - {'id': 3, 'name': 'target-3', 'target_source': {'id': 3, 'name': 'target-source-3', 'source': {'id': 3, 'name': 'source-3'}}}, - {'id': 4, 'name': 'target-4', 'target_source': {'id': 4, 'name': 'target-source-4', 'source': {'id': 4, 'name': 'source-4'}}} + {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, + {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, + {'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}}, + {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}} ] self.assertEqual(serializer.data, expected) def test_one_to_one_create_with_invalid_data(self): - data = {'id': 4, 'name': 'target-4', 'target_source': {'id': 4, 'name': 'target-source-4', 'source': {'id': 4}}} - serializer = OneToOneTargetSerializer(data=data) + data = {'id': 4, 'name': 'target-4', 'source': {'id': 4}} + serializer = self.Serializer(data=data) self.assertFalse(serializer.is_valid()) - self.assertEqual(serializer.errors, {'target_source': [{'source': [{'name': ['This field is required.']}]}]}) + self.assertEqual(serializer.errors, {'source': [{'name': ['This field is required.']}]}) def test_one_to_one_update(self): - data = {'id': 3, 'name': 'target-3-updated', 'target_source': {'id': 3, 'name': 'target-source-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}}} + data = {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}} instance = OneToOneTarget.objects.get(pk=3) - serializer = OneToOneTargetSerializer(instance, data=data) + serializer = self.Serializer(instance, data=data) self.assertTrue(serializer.is_valid()) obj = serializer.save() self.assertEqual(serializer.data, data) @@ -98,28 +83,109 @@ class NestedOneToOneTests(TestCase): # Ensure (target 3, target_source 3, source 3) are updated, # and everything else is as expected. queryset = OneToOneTarget.objects.all() - serializer = OneToOneTargetSerializer(queryset) + serializer = self.Serializer(queryset) expected = [ - {'id': 1, 'name': 'target-1', 'target_source': {'id': 1, 'name': 'target-source-1', 'source': {'id': 1, 'name': 'source-1'}}}, - {'id': 2, 'name': 'target-2', 'target_source': {'id': 2, 'name': 'target-source-2', 'source': {'id': 2, 'name': 'source-2'}}}, - {'id': 3, 'name': 'target-3-updated', 'target_source': {'id': 3, 'name': 'target-source-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}}} + {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, + {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, + {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}} ] self.assertEqual(serializer.data, expected) - def test_one_to_one_delete(self): - data = {'id': 3, 'name': 'target-3', 'target_source': None} - instance = OneToOneTarget.objects.get(pk=3) - serializer = OneToOneTargetSerializer(instance, data=data) + +class ForwardNestedOneToOneTests(TestCase): + def setUp(self): + class OneToOneTargetSerializer(serializers.ModelSerializer): + class Meta: + model = OneToOneTarget + fields = ('id', 'name') + + class OneToOneSourceSerializer(serializers.ModelSerializer): + target = OneToOneTargetSerializer() + + class Meta: + model = OneToOneSource + fields = ('id', 'name', 'target') + + self.Serializer = OneToOneSourceSerializer + + for idx in range(1, 4): + target = OneToOneTarget(name='target-%d' % idx) + target.save() + source = OneToOneSource(name='source-%d' % idx, target=target) + source.save() + + def test_one_to_one_retrieve(self): + queryset = OneToOneSource.objects.all() + serializer = self.Serializer(queryset) + expected = [ + {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, + {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, + {'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}} + ] + self.assertEqual(serializer.data, expected) + + def test_one_to_one_create(self): + data = {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}} + serializer = self.Serializer(data=data) self.assertTrue(serializer.is_valid()) - serializer.save() + obj = serializer.save() + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'source-4') + + # Ensure (target 4, target_source 4, source 4) are added, and + # everything else is as expected. + queryset = OneToOneSource.objects.all() + serializer = self.Serializer(queryset) + expected = [ + {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, + {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, + {'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}}, + {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}} + ] + self.assertEqual(serializer.data, expected) - # Ensure (target_source 3, source 3) are deleted, + def test_one_to_one_create_with_invalid_data(self): + data = {'id': 4, 'name': 'source-4', 'target': {'id': 4}} + serializer = self.Serializer(data=data) + self.assertFalse(serializer.is_valid()) + self.assertEqual(serializer.errors, {'target': [{'name': ['This field is required.']}]}) + + def test_one_to_one_update(self): + data = {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}} + instance = OneToOneSource.objects.get(pk=3) + serializer = self.Serializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'source-3-updated') + + # Ensure (target 3, target_source 3, source 3) are updated, # and everything else is as expected. - queryset = OneToOneTarget.objects.all() - serializer = OneToOneTargetSerializer(queryset) + queryset = OneToOneSource.objects.all() + serializer = self.Serializer(queryset) expected = [ - {'id': 1, 'name': 'target-1', 'target_source': {'id': 1, 'name': 'target-source-1', 'source': {'id': 1, 'name': 'source-1'}}}, - {'id': 2, 'name': 'target-2', 'target_source': {'id': 2, 'name': 'target-source-2', 'source': {'id': 2, 'name': 'source-2'}}}, - {'id': 3, 'name': 'target-3', 'target_source': None} + {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, + {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, + {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}} ] self.assertEqual(serializer.data, expected) + + + # TODO: Nullable 1-1 tests + # def test_one_to_one_delete(self): + # data = {'id': 3, 'name': 'target-3', 'target_source': None} + # instance = OneToOneTarget.objects.get(pk=3) + # serializer = self.Serializer(instance, data=data) + # self.assertTrue(serializer.is_valid()) + # serializer.save() + + # # Ensure (target_source 3, source 3) are deleted, + # # and everything else is as expected. + # queryset = OneToOneTarget.objects.all() + # serializer = self.Serializer(queryset) + # expected = [ + # {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, + # {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, + # {'id': 3, 'name': 'target-3', 'source': None} + # ] + # self.assertEqual(serializer.data, expected) -- cgit v1.2.3 From 73efa96de983fc644328d2fc498651aa917a2272 Mon Sep 17 00:00:00 2001 From: Mark Aaron Shirley Date: Sat, 6 Apr 2013 08:43:21 -0700 Subject: one-many writable nested modelserializer support --- rest_framework/serializers.py | 56 ++++++++++----- rest_framework/tests/relations_nested.py | 98 ++++++++++++++++++++++++++ rest_framework/tests/serializer_bulk_update.py | 6 +- 3 files changed, 139 insertions(+), 21 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 668bcc49..73cad00f 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -130,14 +130,14 @@ class BaseSerializer(WritableField): def __init__(self, instance=None, data=None, files=None, context=None, partial=False, many=None, - allow_delete=False, **kwargs): + allow_add_remove=False, **kwargs): super(BaseSerializer, self).__init__(**kwargs) self.opts = self._options_class(self.Meta) self.parent = None self.root = None self.partial = partial self.many = many - self.allow_delete = allow_delete + self.allow_add_remove = allow_add_remove self.context = context or {} @@ -154,8 +154,8 @@ class BaseSerializer(WritableField): if many and instance is not None and not hasattr(instance, '__iter__'): raise ValueError('instance should be a queryset or other iterable with many=True') - if allow_delete and not many: - raise ValueError('allow_delete should only be used for bulk updates, but you have not set many=True') + if allow_add_remove and not many: + raise ValueError('allow_add_remove should only be used for bulk updates, but you have not set many=True') ##### # Methods to determine which fields to use when (de)serializing objects. @@ -288,8 +288,15 @@ class BaseSerializer(WritableField): You should override this method to control how deserialized objects are instantiated. """ + removed_relations = [] + + # Deleted related objects + if self._deleted: + removed_relations = list(self._deleted) + if instance is not None: instance.update(attrs) + instance._removed_relations = removed_relations return instance return attrs @@ -377,6 +384,7 @@ class BaseSerializer(WritableField): # Set the serializer object if it exists obj = getattr(self.parent.object, field_name) if self.parent.object else None + obj = obj.all() if is_simple_callable(getattr(obj, 'all', None)) else obj if value in (None, ''): into[(self.source or field_name)] = None @@ -386,7 +394,8 @@ class BaseSerializer(WritableField): 'data': value, 'context': self.context, 'partial': self.partial, - 'many': self.many + 'many': self.many, + 'allow_add_remove': self.allow_add_remove } serializer = self.__class__(**kwargs) @@ -496,6 +505,9 @@ class BaseSerializer(WritableField): def save_object(self, obj, **kwargs): obj.save(**kwargs) + if self.allow_add_remove and hasattr(obj, '_removed_relations'): + [self.delete_object(item) for item in obj._removed_relations] + def delete_object(self, obj): obj.delete() @@ -508,7 +520,7 @@ class BaseSerializer(WritableField): else: self.save_object(self.object, **kwargs) - if self.allow_delete and self._deleted: + if self.allow_add_remove and self._deleted: [self.delete_object(item) for item in self._deleted] return self.object @@ -699,6 +711,7 @@ class ModelSerializer(Serializer): m2m_data = {} related_data = {} nested_forward_relations = {} + removed_relations = [] meta = self.opts.model._meta # Reverse fk or one-to-one relations @@ -724,6 +737,10 @@ class ModelSerializer(Serializer): if isinstance(self.fields.get(field_name, None), Serializer): nested_forward_relations[field_name] = attrs[field_name] + # Deleted related objects + if self._deleted: + removed_relations = list(self._deleted) + # Update an existing instance... if instance is not None: for key, val in attrs.items(): @@ -740,6 +757,7 @@ class ModelSerializer(Serializer): instance._related_data = related_data instance._m2m_data = m2m_data instance._nested_forward_relations = nested_forward_relations + instance._removed_relations = removed_relations return instance @@ -764,6 +782,9 @@ class ModelSerializer(Serializer): obj.save(**kwargs) + if self.allow_add_remove and hasattr(obj, '_removed_relations'): + [self.delete_object(item) for item in obj._removed_relations] + if getattr(obj, '_m2m_data', None): for accessor_name, object_list in obj._m2m_data.items(): setattr(obj, accessor_name, object_list) @@ -773,18 +794,17 @@ class ModelSerializer(Serializer): for accessor_name, related in obj._related_data.items(): field = self.fields.get(accessor_name, None) if isinstance(field, Serializer): - # TODO: Following will be needed for reverse FK - # if field.many: - # # Nested reverse fk relationship - # for related_item in related: - # fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name - # setattr(related_item, fk_field, obj) - # self.save_object(related_item) - # else: - # Nested reverse one-one relationship - fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name - setattr(related, fk_field, obj) - self.save_object(related) + if field.many: + # Nested reverse fk relationship + for related_item in related: + fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name + setattr(related_item, fk_field, obj) + self.save_object(related_item) + else: + # Nested reverse one-one relationship + fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name + setattr(related, fk_field, obj) + self.save_object(related) else: # Reverse FK or reverse one-one setattr(obj, accessor_name, related) diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/relations_nested.py index e7af6565..20683d4a 100644 --- a/rest_framework/tests/relations_nested.py +++ b/rest_framework/tests/relations_nested.py @@ -13,6 +13,15 @@ class OneToOneSource(models.Model): target = models.OneToOneField(OneToOneTarget, related_name='source') +class OneToManyTarget(models.Model): + name = models.CharField(max_length=100) + + +class OneToManySource(models.Model): + name = models.CharField(max_length=100) + target = models.ForeignKey(OneToManyTarget, related_name='sources') + + class ReverseNestedOneToOneTests(TestCase): def setUp(self): class OneToOneSourceSerializer(serializers.ModelSerializer): @@ -189,3 +198,92 @@ class ForwardNestedOneToOneTests(TestCase): # {'id': 3, 'name': 'target-3', 'source': None} # ] # self.assertEqual(serializer.data, expected) + + +class ReverseNestedOneToManyTests(TestCase): + def setUp(self): + class OneToManySourceSerializer(serializers.ModelSerializer): + class Meta: + model = OneToManySource + fields = ('id', 'name') + + class OneToManyTargetSerializer(serializers.ModelSerializer): + sources = OneToManySourceSerializer(many=True, allow_add_remove=True) + + class Meta: + model = OneToManyTarget + fields = ('id', 'name', 'sources') + + self.Serializer = OneToManyTargetSerializer + + target = OneToManyTarget(name='target-1') + target.save() + for idx in range(1, 4): + source = OneToManySource(name='source-%d' % idx, target=target) + source.save() + + def test_one_to_many_retrieve(self): + queryset = OneToManyTarget.objects.all() + serializer = self.Serializer(queryset) + expected = [ + {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, + {'id': 2, 'name': 'source-2'}, + {'id': 3, 'name': 'source-3'}]}, + ] + self.assertEqual(serializer.data, expected) + + def test_one_to_many_create(self): + data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, + {'id': 2, 'name': 'source-2'}, + {'id': 3, 'name': 'source-3'}, + {'id': 4, 'name': 'source-4'}]} + instance = OneToManyTarget.objects.get(pk=1) + serializer = self.Serializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'target-1') + + # Ensure source 4 is added, and everything else is as + # expected. + queryset = OneToManyTarget.objects.all() + serializer = self.Serializer(queryset) + expected = [ + {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, + {'id': 2, 'name': 'source-2'}, + {'id': 3, 'name': 'source-3'}, + {'id': 4, 'name': 'source-4'}]} + ] + self.assertEqual(serializer.data, expected) + + def test_one_to_many_create_with_invalid_data(self): + data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, + {'id': 2, 'name': 'source-2'}, + {'id': 3, 'name': 'source-3'}, + {'id': 4}]} + serializer = self.Serializer(data=data) + self.assertFalse(serializer.is_valid()) + self.assertEqual(serializer.errors, {'sources': [{}, {}, {}, {'name': ['This field is required.']}]}) + + def test_one_to_many_update(self): + data = {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'}, + {'id': 2, 'name': 'source-2'}, + {'id': 3, 'name': 'source-3'}]} + instance = OneToManyTarget.objects.get(pk=1) + serializer = self.Serializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'target-1-updated') + + # Ensure (target 1, source 1) are updated, + # and everything else is as expected. + queryset = OneToManyTarget.objects.all() + serializer = self.Serializer(queryset) + expected = [ + {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'}, + {'id': 2, 'name': 'source-2'}, + {'id': 3, 'name': 'source-3'}]} + + ] + self.assertEqual(serializer.data, expected) diff --git a/rest_framework/tests/serializer_bulk_update.py b/rest_framework/tests/serializer_bulk_update.py index afc1a1a9..5328e733 100644 --- a/rest_framework/tests/serializer_bulk_update.py +++ b/rest_framework/tests/serializer_bulk_update.py @@ -201,7 +201,7 @@ class BulkUpdateSerializerTests(TestCase): 'author': 'Haruki Murakami' } ] - serializer = self.BookSerializer(self.books(), data=data, many=True, allow_delete=True) + serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True) self.assertEqual(serializer.is_valid(), True) self.assertEqual(serializer.data, data) serializer.save() @@ -223,7 +223,7 @@ class BulkUpdateSerializerTests(TestCase): 'author': 'Haruki Murakami' } ] - serializer = self.BookSerializer(self.books(), data=data, many=True, allow_delete=True) + serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True) self.assertEqual(serializer.is_valid(), True) self.assertEqual(serializer.data, data) serializer.save() @@ -249,6 +249,6 @@ class BulkUpdateSerializerTests(TestCase): {}, {'id': ['Enter a whole number.']} ] - serializer = self.BookSerializer(self.books(), data=data, many=True, allow_delete=True) + serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True) self.assertEqual(serializer.is_valid(), False) self.assertEqual(serializer.errors, expected_errors) -- cgit v1.2.3 From fdc5cc3d81679d30cd20acf063dc7dc74ad17d7a Mon Sep 17 00:00:00 2001 From: Mark Aaron Shirley Date: Thu, 18 Apr 2013 10:28:20 -0700 Subject: Fix model serializer nestesd delete behavior --- rest_framework/serializers.py | 40 +++++++++++--------------------- rest_framework/tests/relations_nested.py | 19 +++++++++++++++ 2 files changed, 33 insertions(+), 26 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 0f0f11a4..78c45548 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -20,6 +20,9 @@ from rest_framework.relations import * from rest_framework.fields import * +class RelationsList(list): + _deleted = [] + class NestedValidationError(ValidationError): """ The default ValidationError behavior is to stringify each item in the list @@ -149,7 +152,6 @@ class BaseSerializer(WritableField): self._data = None self._files = None self._errors = None - self._deleted = None if many and instance is not None and not hasattr(instance, '__iter__'): raise ValueError('instance should be a queryset or other iterable with many=True') @@ -288,15 +290,8 @@ class BaseSerializer(WritableField): You should override this method to control how deserialized objects are instantiated. """ - removed_relations = [] - - # Deleted related objects - if self._deleted: - removed_relations = list(self._deleted) - if instance is not None: instance.update(attrs) - instance._removed_relations = removed_relations return instance return attrs @@ -438,7 +433,7 @@ class BaseSerializer(WritableField): PendingDeprecationWarning, stacklevel=3) if many: - ret = [] + ret = RelationsList() errors = [] update = self.object is not None @@ -466,7 +461,7 @@ class BaseSerializer(WritableField): errors.append(self._errors) if update: - self._deleted = identity_to_objects.values() + ret._deleted = identity_to_objects.values() self._errors = any(errors) and errors or [] else: @@ -509,9 +504,6 @@ class BaseSerializer(WritableField): def save_object(self, obj, **kwargs): obj.save(**kwargs) - if self.allow_add_remove and hasattr(obj, '_removed_relations'): - [self.delete_object(item) for item in obj._removed_relations] - def delete_object(self, obj): obj.delete() @@ -521,11 +513,11 @@ class BaseSerializer(WritableField): """ if isinstance(self.object, list): [self.save_object(item, **kwargs) for item in self.object] - else: - self.save_object(self.object, **kwargs) - if self.allow_add_remove and self._deleted: - [self.delete_object(item) for item in self._deleted] + if self.allow_add_remove and self.object._deleted: + [self.delete_object(item) for item in self.object._deleted] + else: + self.save_object(self.object, **kwargs) return self.object @@ -715,7 +707,6 @@ class ModelSerializer(Serializer): m2m_data = {} related_data = {} nested_forward_relations = {} - removed_relations = [] meta = self.opts.model._meta # Reverse fk or one-to-one relations @@ -741,10 +732,6 @@ class ModelSerializer(Serializer): if isinstance(self.fields.get(field_name, None), Serializer): nested_forward_relations[field_name] = attrs[field_name] - # Deleted related objects - if self._deleted: - removed_relations = list(self._deleted) - # Update an existing instance... if instance is not None: for key, val in attrs.items(): @@ -761,7 +748,6 @@ class ModelSerializer(Serializer): instance._related_data = related_data instance._m2m_data = m2m_data instance._nested_forward_relations = nested_forward_relations - instance._removed_relations = removed_relations return instance @@ -786,9 +772,6 @@ class ModelSerializer(Serializer): obj.save(**kwargs) - if self.allow_add_remove and hasattr(obj, '_removed_relations'): - [self.delete_object(item) for item in obj._removed_relations] - if getattr(obj, '_m2m_data', None): for accessor_name, object_list in obj._m2m_data.items(): setattr(obj, accessor_name, object_list) @@ -804,6 +787,11 @@ class ModelSerializer(Serializer): fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name setattr(related_item, fk_field, obj) self.save_object(related_item) + + # Delete any removed objects + if field.allow_add_remove and related._deleted: + [self.delete_object(item) for item in related._deleted] + else: # Nested reverse one-one relationship fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/relations_nested.py index 20683d4a..22c98e7f 100644 --- a/rest_framework/tests/relations_nested.py +++ b/rest_framework/tests/relations_nested.py @@ -287,3 +287,22 @@ class ReverseNestedOneToManyTests(TestCase): ] self.assertEqual(serializer.data, expected) + + def test_one_to_many_delete(self): + data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, + {'id': 3, 'name': 'source-3'}]} + instance = OneToManyTarget.objects.get(pk=1) + serializer = self.Serializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + serializer.save() + + # Ensure source 2 is deleted, and everything else is as + # expected. + queryset = OneToManyTarget.objects.all() + serializer = self.Serializer(queryset) + expected = [ + {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, + {'id': 3, 'name': 'source-3'}]} + + ] + self.assertEqual(serializer.data, expected) -- cgit v1.2.3 From 7e0a93f0eefead25f0e9b6615675f394af3a4ba0 Mon Sep 17 00:00:00 2001 From: Mark Aaron Shirley Date: Fri, 19 Apr 2013 10:46:57 -0700 Subject: Don't use field when saving related data --- rest_framework/serializers.py | 36 +++++++++++++++++------------------- 1 file changed, 17 insertions(+), 19 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 78c45548..b39cb810 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -460,7 +460,7 @@ class BaseSerializer(WritableField): ret.append(self.from_native(item, None)) errors.append(self._errors) - if update: + if update and self.allow_add_remove: ret._deleted = identity_to_objects.values() self._errors = any(errors) and errors or [] @@ -514,7 +514,7 @@ class BaseSerializer(WritableField): if isinstance(self.object, list): [self.save_object(item, **kwargs) for item in self.object] - if self.allow_add_remove and self.object._deleted: + if self.object._deleted: [self.delete_object(item) for item in self.object._deleted] else: self.save_object(self.object, **kwargs) @@ -779,24 +779,22 @@ class ModelSerializer(Serializer): if getattr(obj, '_related_data', None): for accessor_name, related in obj._related_data.items(): - field = self.fields.get(accessor_name, None) - if isinstance(field, Serializer): - if field.many: - # Nested reverse fk relationship - for related_item in related: - fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name - setattr(related_item, fk_field, obj) - self.save_object(related_item) - - # Delete any removed objects - if field.allow_add_remove and related._deleted: - [self.delete_object(item) for item in related._deleted] - - else: - # Nested reverse one-one relationship + if isinstance(related, RelationsList): + # Nested reverse fk relationship + for related_item in related: fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name - setattr(related, fk_field, obj) - self.save_object(related) + setattr(related_item, fk_field, obj) + self.save_object(related_item) + + # Delete any removed objects + if related._deleted: + [self.delete_object(item) for item in related._deleted] + + elif isinstance(related, models.Model): + # Nested reverse one-one relationship + fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name + setattr(related, fk_field, obj) + self.save_object(related) else: # Reverse FK or reverse one-one setattr(obj, accessor_name, related) -- cgit v1.2.3 From 14482a966168a98d43099d00c163d1c8c3b6471b Mon Sep 17 00:00:00 2001 From: Mark Aaron Shirley Date: Wed, 8 May 2013 22:44:23 -0700 Subject: Fix deprecation warnings in relations_nested tests --- rest_framework/tests/relations_nested.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/relations_nested.py index 22c98e7f..8325580f 100644 --- a/rest_framework/tests/relations_nested.py +++ b/rest_framework/tests/relations_nested.py @@ -46,7 +46,7 @@ class ReverseNestedOneToOneTests(TestCase): def test_one_to_one_retrieve(self): queryset = OneToOneTarget.objects.all() - serializer = self.Serializer(queryset) + serializer = self.Serializer(queryset, many=True) expected = [ {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, @@ -65,7 +65,7 @@ class ReverseNestedOneToOneTests(TestCase): # Ensure (target 4, target_source 4, source 4) are added, and # everything else is as expected. queryset = OneToOneTarget.objects.all() - serializer = self.Serializer(queryset) + serializer = self.Serializer(queryset, many=True) expected = [ {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, @@ -92,7 +92,7 @@ class ReverseNestedOneToOneTests(TestCase): # Ensure (target 3, target_source 3, source 3) are updated, # and everything else is as expected. queryset = OneToOneTarget.objects.all() - serializer = self.Serializer(queryset) + serializer = self.Serializer(queryset, many=True) expected = [ {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, @@ -125,7 +125,7 @@ class ForwardNestedOneToOneTests(TestCase): def test_one_to_one_retrieve(self): queryset = OneToOneSource.objects.all() - serializer = self.Serializer(queryset) + serializer = self.Serializer(queryset, many=True) expected = [ {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, @@ -144,7 +144,7 @@ class ForwardNestedOneToOneTests(TestCase): # Ensure (target 4, target_source 4, source 4) are added, and # everything else is as expected. queryset = OneToOneSource.objects.all() - serializer = self.Serializer(queryset) + serializer = self.Serializer(queryset, many=True) expected = [ {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, @@ -171,7 +171,7 @@ class ForwardNestedOneToOneTests(TestCase): # Ensure (target 3, target_source 3, source 3) are updated, # and everything else is as expected. queryset = OneToOneSource.objects.all() - serializer = self.Serializer(queryset) + serializer = self.Serializer(queryset, many=True) expected = [ {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, @@ -224,7 +224,7 @@ class ReverseNestedOneToManyTests(TestCase): def test_one_to_many_retrieve(self): queryset = OneToManyTarget.objects.all() - serializer = self.Serializer(queryset) + serializer = self.Serializer(queryset, many=True) expected = [ {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, {'id': 2, 'name': 'source-2'}, @@ -247,7 +247,7 @@ class ReverseNestedOneToManyTests(TestCase): # Ensure source 4 is added, and everything else is as # expected. queryset = OneToManyTarget.objects.all() - serializer = self.Serializer(queryset) + serializer = self.Serializer(queryset, many=True) expected = [ {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, {'id': 2, 'name': 'source-2'}, @@ -279,7 +279,7 @@ class ReverseNestedOneToManyTests(TestCase): # Ensure (target 1, source 1) are updated, # and everything else is as expected. queryset = OneToManyTarget.objects.all() - serializer = self.Serializer(queryset) + serializer = self.Serializer(queryset, many=True) expected = [ {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'}, {'id': 2, 'name': 'source-2'}, @@ -299,7 +299,7 @@ class ReverseNestedOneToManyTests(TestCase): # Ensure source 2 is deleted, and everything else is as # expected. queryset = OneToManyTarget.objects.all() - serializer = self.Serializer(queryset) + serializer = self.Serializer(queryset, many=True) expected = [ {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, {'id': 3, 'name': 'source-3'}]} -- cgit v1.2.3 From db9672d3048eebb3d3c3fb2b4a345e17b5aa23cc Mon Sep 17 00:00:00 2001 From: Alex Burgel Date: Wed, 24 Jul 2013 17:24:29 -0400 Subject: Add support for removing field files by sending an empty string --- rest_framework/fields.py | 5 ++++- rest_framework/tests/test_files.py | 28 ++++++++++++++++++++++++---- 2 files changed, 28 insertions(+), 5 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index f9931887..9ba5c0eb 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -307,7 +307,10 @@ class WritableField(Field): try: if self.use_files: files = files or {} - native = files[field_name] + try: + native = files[field_name] + except KeyError: + native = data[field_name] else: native = data[field_name] except KeyError: diff --git a/rest_framework/tests/test_files.py b/rest_framework/tests/test_files.py index 487046ac..495c2a7f 100644 --- a/rest_framework/tests/test_files.py +++ b/rest_framework/tests/test_files.py @@ -7,13 +7,13 @@ import datetime class UploadedFile(object): - def __init__(self, file, created=None): + def __init__(self, file=None, created=None): self.file = file self.created = created or datetime.datetime.now() class UploadedFileSerializer(serializers.Serializer): - file = serializers.FileField() + file = serializers.FileField(required=False) created = serializers.DateTimeField() def restore_object(self, attrs, instance=None): @@ -47,5 +47,25 @@ class FileSerializerTests(TestCase): now = datetime.datetime.now() serializer = UploadedFileSerializer(data={'created': now}) - self.assertFalse(serializer.is_valid()) - self.assertIn('file', serializer.errors) + self.assertTrue(serializer.is_valid()) + self.assertEqual(serializer.object.created, now) + self.assertIsNone(serializer.object.file) + + def test_remove_with_empty_string(self): + """ + Passing empty string as data should cause file to be removed + + Test for: + https://github.com/tomchristie/django-rest-framework/issues/937 + """ + now = datetime.datetime.now() + file = BytesIO(six.b('stuff')) + file.name = 'stuff.txt' + file.size = len(file.getvalue()) + + uploaded_file = UploadedFile(file=file, created=now) + + serializer = UploadedFileSerializer(instance=uploaded_file, data={'created': now, 'file': ''}) + self.assertTrue(serializer.is_valid()) + self.assertEqual(serializer.object.created, uploaded_file.created) + self.assertIsNone(serializer.object.file) -- cgit v1.2.3 From abe655e061871a568cccf473414e350f3eb61d8b Mon Sep 17 00:00:00 2001 From: Yuri Prezument Date: Mon, 12 Aug 2013 21:01:37 +0300 Subject: Make OneToOneSource.target nullable --- rest_framework/tests/test_relations_nested.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/tests/test_relations_nested.py b/rest_framework/tests/test_relations_nested.py index 8325580f..30229687 100644 --- a/rest_framework/tests/test_relations_nested.py +++ b/rest_framework/tests/test_relations_nested.py @@ -10,7 +10,8 @@ class OneToOneTarget(models.Model): class OneToOneSource(models.Model): name = models.CharField(max_length=100) - target = models.OneToOneField(OneToOneTarget, related_name='source') + target = models.OneToOneField(OneToOneTarget, related_name='source', + null=True, blank=True) class OneToManyTarget(models.Model): @@ -21,7 +22,7 @@ class OneToManySource(models.Model): name = models.CharField(max_length=100) target = models.ForeignKey(OneToManyTarget, related_name='sources') - + class ReverseNestedOneToOneTests(TestCase): def setUp(self): class OneToOneSourceSerializer(serializers.ModelSerializer): -- cgit v1.2.3 From 901d2b0eb8270befa051510e190f3d5679086c7f Mon Sep 17 00:00:00 2001 From: Yuri Prezument Date: Mon, 12 Aug 2013 21:02:59 +0300 Subject: Failing test case for nullifying nested object --- rest_framework/tests/test_relations_nested.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/tests/test_relations_nested.py b/rest_framework/tests/test_relations_nested.py index 30229687..d393b0c3 100644 --- a/rest_framework/tests/test_relations_nested.py +++ b/rest_framework/tests/test_relations_nested.py @@ -180,6 +180,25 @@ class ForwardNestedOneToOneTests(TestCase): ] self.assertEqual(serializer.data, expected) + def test_one_to_one_update_to_null(self): + data = {'id': 3, 'name': 'source-3-updated', 'target': None} + instance = OneToOneSource.objects.get(pk=3) + serializer = self.Serializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'source-3-updated') + self.assertEqual(obj.target, None) + + queryset = OneToOneSource.objects.all() + serializer = self.Serializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, + {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, + {'id': 3, 'name': 'source-3-updated', 'target': None} + ] + self.assertEqual(serializer.data, expected) # TODO: Nullable 1-1 tests # def test_one_to_one_delete(self): -- cgit v1.2.3 From ff1efcf60f0a9b66cdb736f8c0b2cfe2fc84cdf5 Mon Sep 17 00:00:00 2001 From: Yuri Prezument Date: Mon, 12 Aug 2013 18:08:23 +0300 Subject: If null or blank - don't save the nested object --- rest_framework/serializers.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index d8f9145e..2b260c25 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -522,7 +522,7 @@ class BaseSerializer(WritableField): if self.object._deleted: [self.delete_object(item) for item in self.object._deleted] else: - self.save_object(self.object, **kwargs) + self.save_object(self.object, **kwargs) return self.object @@ -891,7 +891,8 @@ class ModelSerializer(Serializer): # Nested relationships need to be saved before we can save the # parent instance. for field_name, sub_object in obj._nested_forward_relations.items(): - self.save_object(sub_object) + if sub_object: + self.save_object(sub_object) setattr(obj, field_name, sub_object) obj.save(**kwargs) -- cgit v1.2.3 From e677f3ee5c9435594ce58a3256a119c08bdc1e42 Mon Sep 17 00:00:00 2001 From: Krzysztof Jurewicz Date: Tue, 13 Aug 2013 13:26:30 +0200 Subject: PATCH requests should not be able to create objects. --- rest_framework/mixins.py | 13 ++++++++----- rest_framework/tests/test_generics.py | 11 +++++++++++ 2 files changed, 19 insertions(+), 5 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index f11def6d..59d64469 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -142,11 +142,14 @@ class UpdateModelMixin(object): try: return self.get_object() except Http404: - # If this is a PUT-as-create operation, we need to ensure that - # we have relevant permissions, as if this was a POST request. - # This will either raise a PermissionDenied exception, - # or simply return None - self.check_permissions(clone_request(self.request, 'POST')) + if self.request.method == 'PUT': + # For PUT-as-create operation, we need to ensure that we have + # relevant permissions, as if this was a POST request. This + # will either raise a PermissionDenied exception, or simply + # return None. + self.check_permissions(clone_request(self.request, 'POST')) + else: + raise def pre_save(self, obj): """ diff --git a/rest_framework/tests/test_generics.py b/rest_framework/tests/test_generics.py index 1550880b..7a87d389 100644 --- a/rest_framework/tests/test_generics.py +++ b/rest_framework/tests/test_generics.py @@ -338,6 +338,17 @@ class TestInstanceView(TestCase): new_obj = SlugBasedModel.objects.get(slug='test_slug') self.assertEqual(new_obj.text, 'foobar') + def test_patch_cannot_create_an_object(self): + """ + PATCH requests should not be able to create objects. + """ + data = {'text': 'foobar'} + request = factory.patch('/999', data, format='json') + with self.assertNumQueries(1): + response = self.view(request, pk=999).render() + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + self.assertFalse(self.objects.filter(id=999).exists()) + class TestOverriddenGetObject(TestCase): """ -- cgit v1.2.3 From 19a774f97292444a48c5b7521e1b0c0ea48b6502 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 23 Aug 2013 11:21:45 +0100 Subject: force_authenticate(None) also clears session info. Closes #1055. --- rest_framework/test.py | 2 ++ rest_framework/tests/test_testing.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/test.py b/rest_framework/test.py index a18f5a29..234d10a4 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -134,6 +134,8 @@ class APIClient(APIRequestFactory, DjangoClient): """ self.handler._force_user = user self.handler._force_token = token + if user is None: + self.logout() # Also clear any possible session info if required def request(self, **kwargs): # Ensure that any credentials set get added to every request. diff --git a/rest_framework/tests/test_testing.py b/rest_framework/tests/test_testing.py index 49d45fc2..48b8956b 100644 --- a/rest_framework/tests/test_testing.py +++ b/rest_framework/tests/test_testing.py @@ -17,8 +17,18 @@ def view(request): }) +@api_view(['GET', 'POST']) +def session_view(request): + active_session = request.session.get('active_session', False) + request.session['active_session'] = True + return Response({ + 'active_session': active_session + }) + + urlpatterns = patterns('', url(r'^view/$', view), + url(r'^session-view/$', session_view), ) @@ -46,6 +56,26 @@ class TestAPITestClient(TestCase): response = self.client.get('/view/') self.assertEqual(response.data['user'], 'example') + def test_force_authenticate_with_sessions(self): + """ + Setting `.force_authenticate()` forcibly authenticates each request. + """ + user = User.objects.create_user('example', 'example@example.com') + self.client.force_authenticate(user) + + # First request does not yet have an active session + response = self.client.get('/session-view/') + self.assertEqual(response.data['active_session'], False) + + # Subsequant requests have an active session + response = self.client.get('/session-view/') + self.assertEqual(response.data['active_session'], True) + + # Force authenticating as `None` should also logout the user session. + self.client.force_authenticate(None) + response = self.client.get('/session-view/') + self.assertEqual(response.data['active_session'], False) + def test_csrf_exempt_by_default(self): """ By default, the test client is CSRF exempt. -- cgit v1.2.3 From 95b2bf50fbb9b95facebb23812bbbb2e27a76035 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 23 Aug 2013 12:03:54 +0100 Subject: Add validation error test when passing non-file to FileField --- rest_framework/tests/test_files.py | 11 +++++++++++ 1 file changed, 11 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/tests/test_files.py b/rest_framework/tests/test_files.py index 495c2a7f..c13c38b8 100644 --- a/rest_framework/tests/test_files.py +++ b/rest_framework/tests/test_files.py @@ -69,3 +69,14 @@ class FileSerializerTests(TestCase): self.assertTrue(serializer.is_valid()) self.assertEqual(serializer.object.created, uploaded_file.created) self.assertIsNone(serializer.object.file) + + def test_validation_error_with_non_file(self): + """ + Passing non-files should raise a validation error. + """ + now = datetime.datetime.now() + errmsg = 'No file was submitted. Check the encoding type on the form.' + + serializer = UploadedFileSerializer(data={'created': now, 'file': 'abc'}) + self.assertFalse(serializer.is_valid()) + self.assertEqual(serializer.errors, {'file': [errmsg]}) -- cgit v1.2.3 From e7927e9bca5bc0d0ac3b528e68244c713c5df97f Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 23 Aug 2013 13:35:50 +0100 Subject: Extra docs on PATCH with no object. --- rest_framework/mixins.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 59d64469..426865ff 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -149,6 +149,8 @@ class UpdateModelMixin(object): # return None. self.check_permissions(clone_request(self.request, 'POST')) else: + # PATCH requests where the object does not exist should still + # return a 404 response. raise def pre_save(self, obj): -- cgit v1.2.3 From e03854ba6a74428675c40d469a7768cc5131035f Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 23 Aug 2013 14:06:14 +0100 Subject: Tweaks to display nested data in empty serializers --- rest_framework/relations.py | 9 +++++++-- rest_framework/serializers.py | 4 ++-- 2 files changed, 9 insertions(+), 4 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/relations.py b/rest_framework/relations.py index edaf76d6..7408758e 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -134,9 +134,9 @@ class RelatedField(WritableField): value = obj for component in source.split('.'): - value = get_component(value, component) if value is None: break + value = get_component(value, component) except ObjectDoesNotExist: return None @@ -567,8 +567,13 @@ class HyperlinkedIdentityField(Field): May raise a `NoReverseMatch` if the `view_name` and `lookup_field` attributes are not configured to correctly match the URL conf. """ - lookup_field = getattr(obj, self.lookup_field) + lookup_field = getattr(obj, self.lookup_field, None) kwargs = {self.lookup_field: lookup_field} + + # Handle unsaved object case + if lookup_field is None: + return None + try: return reverse(view_name, kwargs=kwargs, request=request, format=format) except NoReverseMatch: diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 2b260c25..22525964 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -338,9 +338,9 @@ class BaseSerializer(WritableField): value = obj for component in source.split('.'): - value = get_component(value, component) if value is None: - break + return self.to_native(None) + value = get_component(value, component) except ObjectDoesNotExist: return None -- cgit v1.2.3 From 0966a2680ba02e6a4586bd2777ed593fcc66a453 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 23 Aug 2013 14:38:31 +0100 Subject: First pass at HTMLFormRenderer --- rest_framework/renderers.py | 97 +++++++++++++---------- rest_framework/templates/rest_framework/base.html | 10 +-- 2 files changed, 58 insertions(+), 49 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 1006e26c..a73b2d73 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -316,6 +316,59 @@ class StaticHTMLRenderer(TemplateHTMLRenderer): return data +class HTMLFormRenderer(BaseRenderer): + template = 'rest_framework/form.html' + + def serializer_to_form_fields(self, serializer): + fields = {} + for k, v in serializer.get_fields().items(): + if getattr(v, 'read_only', True): + continue + + kwargs = {} + kwargs['required'] = v.required + + #if getattr(v, 'queryset', None): + # kwargs['queryset'] = v.queryset + + if getattr(v, 'choices', None) is not None: + kwargs['choices'] = v.choices + + if getattr(v, 'regex', None) is not None: + kwargs['regex'] = v.regex + + if getattr(v, 'widget', None): + widget = copy.deepcopy(v.widget) + kwargs['widget'] = widget + + if getattr(v, 'default', None) is not None: + kwargs['initial'] = v.default + + if getattr(v, 'label', None) is not None: + kwargs['label'] = v.label + + if getattr(v, 'help_text', None) is not None: + kwargs['help_text'] = v.help_text + + fields[k] = v.form_field_class(**kwargs) + + return fields + + def render(self, serializer, obj, request): + fields = self.serializer_to_form_fields(serializer) + + # Creating an on the fly form see: + # http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python + OnTheFlyForm = type(str("OnTheFlyForm"), (forms.Form,), fields) + data = (obj is not None) and serializer.data or None + form_instance = OnTheFlyForm(data) + + template = loader.get_template(self.template) + context = RequestContext(request, {'form': form_instance}) + + return template.render(context) + + class BrowsableAPIRenderer(BaseRenderer): """ HTML renderer used to self-document the API. @@ -371,41 +424,6 @@ class BrowsableAPIRenderer(BaseRenderer): return False # Doesn't have permissions return True - def serializer_to_form_fields(self, serializer): - fields = {} - for k, v in serializer.get_fields().items(): - if getattr(v, 'read_only', True): - continue - - kwargs = {} - kwargs['required'] = v.required - - #if getattr(v, 'queryset', None): - # kwargs['queryset'] = v.queryset - - if getattr(v, 'choices', None) is not None: - kwargs['choices'] = v.choices - - if getattr(v, 'regex', None) is not None: - kwargs['regex'] = v.regex - - if getattr(v, 'widget', None): - widget = copy.deepcopy(v.widget) - kwargs['widget'] = widget - - if getattr(v, 'default', None) is not None: - kwargs['initial'] = v.default - - if getattr(v, 'label', None) is not None: - kwargs['label'] = v.label - - if getattr(v, 'help_text', None) is not None: - kwargs['help_text'] = v.help_text - - fields[k] = v.form_field_class(**kwargs) - - return fields - def _get_form(self, view, method, request): # We need to impersonate a request with the correct method, # so that eg. any dynamic get_serializer_class methods return the @@ -447,14 +465,7 @@ class BrowsableAPIRenderer(BaseRenderer): return serializer = view.get_serializer(instance=obj) - fields = self.serializer_to_form_fields(serializer) - - # Creating an on the fly form see: - # http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python - OnTheFlyForm = type(str("OnTheFlyForm"), (forms.Form,), fields) - data = (obj is not None) and serializer.data or None - form_instance = OnTheFlyForm(data) - return form_instance + return HTMLFormRenderer().render(serializer, obj, request) def get_raw_data_form(self, view, method, request, media_types): """ diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html index 51f9c291..6ae47563 100644 --- a/rest_framework/templates/rest_framework/base.html +++ b/rest_framework/templates/rest_framework/base.html @@ -136,9 +136,9 @@ {% if post_form %}