From 3f79a9a3d3e7692d90476f8a6907957b47aab821 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 22 Mar 2013 22:39:45 +0000 Subject: one-one writable nested modelserializers --- rest_framework/serializers.py | 11 ++- rest_framework/tests/relations_nested.py | 154 ++++++++++++++++--------------- 2 files changed, 92 insertions(+), 73 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 6aca2f57..26c34044 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -753,7 +753,16 @@ class ModelSerializer(Serializer): if getattr(obj, '_related_data', None): for accessor_name, related in obj._related_data.items(): - setattr(obj, accessor_name, related) + if related is None: + previous = getattr(obj, accessor_name, related) + if previous: + previous.delete() + elif isinstance(related, models.Model): + fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name + setattr(related, fk_field, obj) + self.save_object(related) + else: + setattr(obj, accessor_name, related) del(obj._related_data) diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/relations_nested.py index a125ba65..4592e559 100644 --- a/rest_framework/tests/relations_nested.py +++ b/rest_framework/tests/relations_nested.py @@ -1,115 +1,125 @@ from __future__ import unicode_literals +from django.db import models from django.test import TestCase from rest_framework import serializers -from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource -class ForeignKeySourceSerializer(serializers.ModelSerializer): - class Meta: - depth = 1 - model = ForeignKeySource - +class OneToOneTarget(models.Model): + name = models.CharField(max_length=100) -class FlatForeignKeySourceSerializer(serializers.ModelSerializer): - class Meta: - model = ForeignKeySource +class OneToOneTargetSource(models.Model): + name = models.CharField(max_length=100) + target = models.OneToOneField(OneToOneTarget, null=True, blank=True, + related_name='target_source') -class ForeignKeyTargetSerializer(serializers.ModelSerializer): - sources = FlatForeignKeySourceSerializer(many=True) - class Meta: - model = ForeignKeyTarget +class OneToOneSource(models.Model): + name = models.CharField(max_length=100) + target_source = models.OneToOneField(OneToOneTargetSource, related_name='source') -class NullableForeignKeySourceSerializer(serializers.ModelSerializer): +class OneToOneSourceSerializer(serializers.ModelSerializer): class Meta: - depth = 1 - model = NullableForeignKeySource + model = OneToOneSource + exclude = ('target_source', ) -class NullableOneToOneSourceSerializer(serializers.ModelSerializer): +class OneToOneTargetSourceSerializer(serializers.ModelSerializer): + source = OneToOneSourceSerializer() + class Meta: - model = NullableOneToOneSource + model = OneToOneTargetSource + exclude = ('target', ) -class NullableOneToOneTargetSerializer(serializers.ModelSerializer): - nullable_source = NullableOneToOneSourceSerializer() +class OneToOneTargetSerializer(serializers.ModelSerializer): + target_source = OneToOneTargetSourceSerializer() class Meta: model = OneToOneTarget -class ReverseForeignKeyTests(TestCase): +class NestedOneToOneTests(TestCase): def setUp(self): - target = ForeignKeyTarget(name='target-1') - target.save() - new_target = ForeignKeyTarget(name='target-2') - new_target.save() for idx in range(1, 4): - source = ForeignKeySource(name='source-%d' % idx, target=target) + target = OneToOneTarget(name='target-%d' % idx) + target.save() + target_source = OneToOneTargetSource(name='target-source-%d' % idx, target=target) + target_source.save() + source = OneToOneSource(name='source-%d' % idx, target_source=target_source) source.save() - def test_foreign_key_retrieve(self): - queryset = ForeignKeySource.objects.all() - serializer = ForeignKeySourceSerializer(queryset, many=True) + def test_one_to_one_retrieve(self): + queryset = OneToOneTarget.objects.all() + serializer = OneToOneTargetSerializer(queryset) expected = [ - {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, - {'id': 2, 'name': 'source-2', 'target': {'id': 1, 'name': 'target-1'}}, - {'id': 3, 'name': 'source-3', 'target': {'id': 1, 'name': 'target-1'}}, + {'id': 1, 'name': 'target-1', 'target_source': {'id': 1, 'name': 'target-source-1', 'source': {'id': 1, 'name': 'source-1'}}}, + {'id': 2, 'name': 'target-2', 'target_source': {'id': 2, 'name': 'target-source-2', 'source': {'id': 2, 'name': 'source-2'}}}, + {'id': 3, 'name': 'target-3', 'target_source': {'id': 3, 'name': 'target-source-3', 'source': {'id': 3, 'name': 'source-3'}}} ] self.assertEqual(serializer.data, expected) - def test_reverse_foreign_key_retrieve(self): - queryset = ForeignKeyTarget.objects.all() - serializer = ForeignKeyTargetSerializer(queryset, many=True) + def test_one_to_one_create(self): + data = {'id': 4, 'name': 'target-4', 'target_source': {'id': 4, 'name': 'target-source-4', 'source': {'id': 4, 'name': 'source-4'}}} + serializer = OneToOneTargetSerializer(data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'target-4') + + # Ensure (target 4, target_source 4, source 4) are added, and + # everything else is as expected. + queryset = OneToOneTarget.objects.all() + serializer = OneToOneTargetSerializer(queryset) expected = [ - {'id': 1, 'name': 'target-1', 'sources': [ - {'id': 1, 'name': 'source-1', 'target': 1}, - {'id': 2, 'name': 'source-2', 'target': 1}, - {'id': 3, 'name': 'source-3', 'target': 1}, - ]}, - {'id': 2, 'name': 'target-2', 'sources': [ - ]} + {'id': 1, 'name': 'target-1', 'target_source': {'id': 1, 'name': 'target-source-1', 'source': {'id': 1, 'name': 'source-1'}}}, + {'id': 2, 'name': 'target-2', 'target_source': {'id': 2, 'name': 'target-source-2', 'source': {'id': 2, 'name': 'source-2'}}}, + {'id': 3, 'name': 'target-3', 'target_source': {'id': 3, 'name': 'target-source-3', 'source': {'id': 3, 'name': 'source-3'}}}, + {'id': 4, 'name': 'target-4', 'target_source': {'id': 4, 'name': 'target-source-4', 'source': {'id': 4, 'name': 'source-4'}}} ] self.assertEqual(serializer.data, expected) - -class NestedNullableForeignKeyTests(TestCase): - def setUp(self): - target = ForeignKeyTarget(name='target-1') - target.save() - for idx in range(1, 4): - if idx == 3: - target = None - source = NullableForeignKeySource(name='source-%d' % idx, target=target) - source.save() - - def test_foreign_key_retrieve_with_null(self): - queryset = NullableForeignKeySource.objects.all() - serializer = NullableForeignKeySourceSerializer(queryset, many=True) + def test_one_to_one_create_with_invalid_data(self): + data = {'id': 4, 'name': 'target-4', 'target_source': {'id': 4, 'name': 'target-source-4', 'source': {'id': 4}}} + serializer = OneToOneTargetSerializer(data=data) + self.assertFalse(serializer.is_valid()) + self.assertEqual(serializer.errors, {'target_source': [{'source': [{'name': ['This field is required.']}]}]}) + + def test_one_to_one_update(self): + data = {'id': 3, 'name': 'target-3-updated', 'target_source': {'id': 3, 'name': 'target-source-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}}} + instance = OneToOneTarget.objects.get(pk=3) + serializer = OneToOneTargetSerializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'target-3-updated') + + # Ensure (target 3, target_source 3, source 3) are updated, + # and everything else is as expected. + queryset = OneToOneTarget.objects.all() + serializer = OneToOneTargetSerializer(queryset) expected = [ - {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, - {'id': 2, 'name': 'source-2', 'target': {'id': 1, 'name': 'target-1'}}, - {'id': 3, 'name': 'source-3', 'target': None}, + {'id': 1, 'name': 'target-1', 'target_source': {'id': 1, 'name': 'target-source-1', 'source': {'id': 1, 'name': 'source-1'}}}, + {'id': 2, 'name': 'target-2', 'target_source': {'id': 2, 'name': 'target-source-2', 'source': {'id': 2, 'name': 'source-2'}}}, + {'id': 3, 'name': 'target-3-updated', 'target_source': {'id': 3, 'name': 'target-source-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}}} ] self.assertEqual(serializer.data, expected) + def test_one_to_one_delete(self): + data = {'id': 3, 'name': 'target-3', 'target_source': None} + instance = OneToOneTarget.objects.get(pk=3) + serializer = OneToOneTargetSerializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + serializer.save() -class NestedNullableOneToOneTests(TestCase): - def setUp(self): - target = OneToOneTarget(name='target-1') - target.save() - new_target = OneToOneTarget(name='target-2') - new_target.save() - source = NullableOneToOneSource(name='source-1', target=target) - source.save() - - def test_reverse_foreign_key_retrieve_with_null(self): + # Ensure (target_source 3, source 3) are deleted, + # and everything else is as expected. queryset = OneToOneTarget.objects.all() - serializer = NullableOneToOneTargetSerializer(queryset, many=True) + serializer = OneToOneTargetSerializer(queryset) expected = [ - {'id': 1, 'name': 'target-1', 'nullable_source': {'id': 1, 'name': 'source-1', 'target': 1}}, - {'id': 2, 'name': 'target-2', 'nullable_source': None}, + {'id': 1, 'name': 'target-1', 'target_source': {'id': 1, 'name': 'target-source-1', 'source': {'id': 1, 'name': 'source-1'}}}, + {'id': 2, 'name': 'target-2', 'target_source': {'id': 2, 'name': 'target-source-2', 'source': {'id': 2, 'name': 'source-2'}}}, + {'id': 3, 'name': 'target-3', 'target_source': None} ] self.assertEqual(serializer.data, expected) -- cgit v1.2.3 From d97e72cdb2f4fcc5aa2c19527a2b2ff11cf784bb Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 25 Mar 2013 17:28:23 +0000 Subject: Cleanup one-one nested tests and implementation --- rest_framework/serializers.py | 37 +++++- rest_framework/tests/relations_nested.py | 186 +++++++++++++++++++++---------- 2 files changed, 157 insertions(+), 66 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 26c34044..668bcc49 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -667,9 +667,12 @@ class ModelSerializer(Serializer): cls = self.opts.model opts = get_concrete_model(cls)._meta exclusions = [field.name for field in opts.fields + opts.many_to_many] + for field_name, field in self.fields.items(): field_name = field.source or field_name - if field_name in exclusions and not field.read_only: + if field_name in exclusions \ + and not field.read_only \ + and not isinstance(field, Serializer): exclusions.remove(field_name) return exclusions @@ -695,6 +698,7 @@ class ModelSerializer(Serializer): """ m2m_data = {} related_data = {} + nested_forward_relations = {} meta = self.opts.model._meta # Reverse fk or one-to-one relations @@ -714,6 +718,12 @@ class ModelSerializer(Serializer): if field.name in attrs: m2m_data[field.name] = attrs.pop(field.name) + # Nested forward relations - These need to be marked so we can save + # them before saving the parent model instance. + for field_name in attrs.keys(): + if isinstance(self.fields.get(field_name, None), Serializer): + nested_forward_relations[field_name] = attrs[field_name] + # Update an existing instance... if instance is not None: for key, val in attrs.items(): @@ -729,6 +739,7 @@ class ModelSerializer(Serializer): # at the point of save. instance._related_data = related_data instance._m2m_data = m2m_data + instance._nested_forward_relations = nested_forward_relations return instance @@ -744,6 +755,13 @@ class ModelSerializer(Serializer): """ Save the deserialized object and return it. """ + if getattr(obj, '_nested_forward_relations', None): + # Nested relationships need to be saved before we can save the + # parent instance. + for field_name, sub_object in obj._nested_forward_relations.items(): + self.save_object(sub_object) + setattr(obj, field_name, sub_object) + obj.save(**kwargs) if getattr(obj, '_m2m_data', None): @@ -753,15 +771,22 @@ class ModelSerializer(Serializer): if getattr(obj, '_related_data', None): for accessor_name, related in obj._related_data.items(): - if related is None: - previous = getattr(obj, accessor_name, related) - if previous: - previous.delete() - elif isinstance(related, models.Model): + field = self.fields.get(accessor_name, None) + if isinstance(field, Serializer): + # TODO: Following will be needed for reverse FK + # if field.many: + # # Nested reverse fk relationship + # for related_item in related: + # fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name + # setattr(related_item, fk_field, obj) + # self.save_object(related_item) + # else: + # Nested reverse one-one relationship fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name setattr(related, fk_field, obj) self.save_object(related) else: + # Reverse FK or reverse one-one setattr(obj, accessor_name, related) del(obj._related_data) diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/relations_nested.py index 4592e559..e7af6565 100644 --- a/rest_framework/tests/relations_nested.py +++ b/rest_framework/tests/relations_nested.py @@ -8,61 +8,46 @@ class OneToOneTarget(models.Model): name = models.CharField(max_length=100) -class OneToOneTargetSource(models.Model): - name = models.CharField(max_length=100) - target = models.OneToOneField(OneToOneTarget, null=True, blank=True, - related_name='target_source') - - class OneToOneSource(models.Model): name = models.CharField(max_length=100) - target_source = models.OneToOneField(OneToOneTargetSource, related_name='source') - - -class OneToOneSourceSerializer(serializers.ModelSerializer): - class Meta: - model = OneToOneSource - exclude = ('target_source', ) - + target = models.OneToOneField(OneToOneTarget, related_name='source') -class OneToOneTargetSourceSerializer(serializers.ModelSerializer): - source = OneToOneSourceSerializer() - - class Meta: - model = OneToOneTargetSource - exclude = ('target', ) +class ReverseNestedOneToOneTests(TestCase): + def setUp(self): + class OneToOneSourceSerializer(serializers.ModelSerializer): + class Meta: + model = OneToOneSource + fields = ('id', 'name') -class OneToOneTargetSerializer(serializers.ModelSerializer): - target_source = OneToOneTargetSourceSerializer() + class OneToOneTargetSerializer(serializers.ModelSerializer): + source = OneToOneSourceSerializer() - class Meta: - model = OneToOneTarget + class Meta: + model = OneToOneTarget + fields = ('id', 'name', 'source') + self.Serializer = OneToOneTargetSerializer -class NestedOneToOneTests(TestCase): - def setUp(self): for idx in range(1, 4): target = OneToOneTarget(name='target-%d' % idx) target.save() - target_source = OneToOneTargetSource(name='target-source-%d' % idx, target=target) - target_source.save() - source = OneToOneSource(name='source-%d' % idx, target_source=target_source) + source = OneToOneSource(name='source-%d' % idx, target=target) source.save() def test_one_to_one_retrieve(self): queryset = OneToOneTarget.objects.all() - serializer = OneToOneTargetSerializer(queryset) + serializer = self.Serializer(queryset) expected = [ - {'id': 1, 'name': 'target-1', 'target_source': {'id': 1, 'name': 'target-source-1', 'source': {'id': 1, 'name': 'source-1'}}}, - {'id': 2, 'name': 'target-2', 'target_source': {'id': 2, 'name': 'target-source-2', 'source': {'id': 2, 'name': 'source-2'}}}, - {'id': 3, 'name': 'target-3', 'target_source': {'id': 3, 'name': 'target-source-3', 'source': {'id': 3, 'name': 'source-3'}}} + {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, + {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, + {'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}} ] self.assertEqual(serializer.data, expected) def test_one_to_one_create(self): - data = {'id': 4, 'name': 'target-4', 'target_source': {'id': 4, 'name': 'target-source-4', 'source': {'id': 4, 'name': 'source-4'}}} - serializer = OneToOneTargetSerializer(data=data) + data = {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}} + serializer = self.Serializer(data=data) self.assertTrue(serializer.is_valid()) obj = serializer.save() self.assertEqual(serializer.data, data) @@ -71,25 +56,25 @@ class NestedOneToOneTests(TestCase): # Ensure (target 4, target_source 4, source 4) are added, and # everything else is as expected. queryset = OneToOneTarget.objects.all() - serializer = OneToOneTargetSerializer(queryset) + serializer = self.Serializer(queryset) expected = [ - {'id': 1, 'name': 'target-1', 'target_source': {'id': 1, 'name': 'target-source-1', 'source': {'id': 1, 'name': 'source-1'}}}, - {'id': 2, 'name': 'target-2', 'target_source': {'id': 2, 'name': 'target-source-2', 'source': {'id': 2, 'name': 'source-2'}}}, - {'id': 3, 'name': 'target-3', 'target_source': {'id': 3, 'name': 'target-source-3', 'source': {'id': 3, 'name': 'source-3'}}}, - {'id': 4, 'name': 'target-4', 'target_source': {'id': 4, 'name': 'target-source-4', 'source': {'id': 4, 'name': 'source-4'}}} + {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, + {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, + {'id': 3, 'name': 'target-3', 'source': {'id': 3, 'name': 'source-3'}}, + {'id': 4, 'name': 'target-4', 'source': {'id': 4, 'name': 'source-4'}} ] self.assertEqual(serializer.data, expected) def test_one_to_one_create_with_invalid_data(self): - data = {'id': 4, 'name': 'target-4', 'target_source': {'id': 4, 'name': 'target-source-4', 'source': {'id': 4}}} - serializer = OneToOneTargetSerializer(data=data) + data = {'id': 4, 'name': 'target-4', 'source': {'id': 4}} + serializer = self.Serializer(data=data) self.assertFalse(serializer.is_valid()) - self.assertEqual(serializer.errors, {'target_source': [{'source': [{'name': ['This field is required.']}]}]}) + self.assertEqual(serializer.errors, {'source': [{'name': ['This field is required.']}]}) def test_one_to_one_update(self): - data = {'id': 3, 'name': 'target-3-updated', 'target_source': {'id': 3, 'name': 'target-source-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}}} + data = {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}} instance = OneToOneTarget.objects.get(pk=3) - serializer = OneToOneTargetSerializer(instance, data=data) + serializer = self.Serializer(instance, data=data) self.assertTrue(serializer.is_valid()) obj = serializer.save() self.assertEqual(serializer.data, data) @@ -98,28 +83,109 @@ class NestedOneToOneTests(TestCase): # Ensure (target 3, target_source 3, source 3) are updated, # and everything else is as expected. queryset = OneToOneTarget.objects.all() - serializer = OneToOneTargetSerializer(queryset) + serializer = self.Serializer(queryset) expected = [ - {'id': 1, 'name': 'target-1', 'target_source': {'id': 1, 'name': 'target-source-1', 'source': {'id': 1, 'name': 'source-1'}}}, - {'id': 2, 'name': 'target-2', 'target_source': {'id': 2, 'name': 'target-source-2', 'source': {'id': 2, 'name': 'source-2'}}}, - {'id': 3, 'name': 'target-3-updated', 'target_source': {'id': 3, 'name': 'target-source-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}}} + {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, + {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, + {'id': 3, 'name': 'target-3-updated', 'source': {'id': 3, 'name': 'source-3-updated'}} ] self.assertEqual(serializer.data, expected) - def test_one_to_one_delete(self): - data = {'id': 3, 'name': 'target-3', 'target_source': None} - instance = OneToOneTarget.objects.get(pk=3) - serializer = OneToOneTargetSerializer(instance, data=data) + +class ForwardNestedOneToOneTests(TestCase): + def setUp(self): + class OneToOneTargetSerializer(serializers.ModelSerializer): + class Meta: + model = OneToOneTarget + fields = ('id', 'name') + + class OneToOneSourceSerializer(serializers.ModelSerializer): + target = OneToOneTargetSerializer() + + class Meta: + model = OneToOneSource + fields = ('id', 'name', 'target') + + self.Serializer = OneToOneSourceSerializer + + for idx in range(1, 4): + target = OneToOneTarget(name='target-%d' % idx) + target.save() + source = OneToOneSource(name='source-%d' % idx, target=target) + source.save() + + def test_one_to_one_retrieve(self): + queryset = OneToOneSource.objects.all() + serializer = self.Serializer(queryset) + expected = [ + {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, + {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, + {'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}} + ] + self.assertEqual(serializer.data, expected) + + def test_one_to_one_create(self): + data = {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}} + serializer = self.Serializer(data=data) self.assertTrue(serializer.is_valid()) - serializer.save() + obj = serializer.save() + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'source-4') + + # Ensure (target 4, target_source 4, source 4) are added, and + # everything else is as expected. + queryset = OneToOneSource.objects.all() + serializer = self.Serializer(queryset) + expected = [ + {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, + {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, + {'id': 3, 'name': 'source-3', 'target': {'id': 3, 'name': 'target-3'}}, + {'id': 4, 'name': 'source-4', 'target': {'id': 4, 'name': 'target-4'}} + ] + self.assertEqual(serializer.data, expected) - # Ensure (target_source 3, source 3) are deleted, + def test_one_to_one_create_with_invalid_data(self): + data = {'id': 4, 'name': 'source-4', 'target': {'id': 4}} + serializer = self.Serializer(data=data) + self.assertFalse(serializer.is_valid()) + self.assertEqual(serializer.errors, {'target': [{'name': ['This field is required.']}]}) + + def test_one_to_one_update(self): + data = {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}} + instance = OneToOneSource.objects.get(pk=3) + serializer = self.Serializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'source-3-updated') + + # Ensure (target 3, target_source 3, source 3) are updated, # and everything else is as expected. - queryset = OneToOneTarget.objects.all() - serializer = OneToOneTargetSerializer(queryset) + queryset = OneToOneSource.objects.all() + serializer = self.Serializer(queryset) expected = [ - {'id': 1, 'name': 'target-1', 'target_source': {'id': 1, 'name': 'target-source-1', 'source': {'id': 1, 'name': 'source-1'}}}, - {'id': 2, 'name': 'target-2', 'target_source': {'id': 2, 'name': 'target-source-2', 'source': {'id': 2, 'name': 'source-2'}}}, - {'id': 3, 'name': 'target-3', 'target_source': None} + {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, + {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, + {'id': 3, 'name': 'source-3-updated', 'target': {'id': 3, 'name': 'target-3-updated'}} ] self.assertEqual(serializer.data, expected) + + + # TODO: Nullable 1-1 tests + # def test_one_to_one_delete(self): + # data = {'id': 3, 'name': 'target-3', 'target_source': None} + # instance = OneToOneTarget.objects.get(pk=3) + # serializer = self.Serializer(instance, data=data) + # self.assertTrue(serializer.is_valid()) + # serializer.save() + + # # Ensure (target_source 3, source 3) are deleted, + # # and everything else is as expected. + # queryset = OneToOneTarget.objects.all() + # serializer = self.Serializer(queryset) + # expected = [ + # {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, + # {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, + # {'id': 3, 'name': 'target-3', 'source': None} + # ] + # self.assertEqual(serializer.data, expected) -- cgit v1.2.3 From 73efa96de983fc644328d2fc498651aa917a2272 Mon Sep 17 00:00:00 2001 From: Mark Aaron Shirley Date: Sat, 6 Apr 2013 08:43:21 -0700 Subject: one-many writable nested modelserializer support --- rest_framework/serializers.py | 56 ++++++++++----- rest_framework/tests/relations_nested.py | 98 ++++++++++++++++++++++++++ rest_framework/tests/serializer_bulk_update.py | 6 +- 3 files changed, 139 insertions(+), 21 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 668bcc49..73cad00f 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -130,14 +130,14 @@ class BaseSerializer(WritableField): def __init__(self, instance=None, data=None, files=None, context=None, partial=False, many=None, - allow_delete=False, **kwargs): + allow_add_remove=False, **kwargs): super(BaseSerializer, self).__init__(**kwargs) self.opts = self._options_class(self.Meta) self.parent = None self.root = None self.partial = partial self.many = many - self.allow_delete = allow_delete + self.allow_add_remove = allow_add_remove self.context = context or {} @@ -154,8 +154,8 @@ class BaseSerializer(WritableField): if many and instance is not None and not hasattr(instance, '__iter__'): raise ValueError('instance should be a queryset or other iterable with many=True') - if allow_delete and not many: - raise ValueError('allow_delete should only be used for bulk updates, but you have not set many=True') + if allow_add_remove and not many: + raise ValueError('allow_add_remove should only be used for bulk updates, but you have not set many=True') ##### # Methods to determine which fields to use when (de)serializing objects. @@ -288,8 +288,15 @@ class BaseSerializer(WritableField): You should override this method to control how deserialized objects are instantiated. """ + removed_relations = [] + + # Deleted related objects + if self._deleted: + removed_relations = list(self._deleted) + if instance is not None: instance.update(attrs) + instance._removed_relations = removed_relations return instance return attrs @@ -377,6 +384,7 @@ class BaseSerializer(WritableField): # Set the serializer object if it exists obj = getattr(self.parent.object, field_name) if self.parent.object else None + obj = obj.all() if is_simple_callable(getattr(obj, 'all', None)) else obj if value in (None, ''): into[(self.source or field_name)] = None @@ -386,7 +394,8 @@ class BaseSerializer(WritableField): 'data': value, 'context': self.context, 'partial': self.partial, - 'many': self.many + 'many': self.many, + 'allow_add_remove': self.allow_add_remove } serializer = self.__class__(**kwargs) @@ -496,6 +505,9 @@ class BaseSerializer(WritableField): def save_object(self, obj, **kwargs): obj.save(**kwargs) + if self.allow_add_remove and hasattr(obj, '_removed_relations'): + [self.delete_object(item) for item in obj._removed_relations] + def delete_object(self, obj): obj.delete() @@ -508,7 +520,7 @@ class BaseSerializer(WritableField): else: self.save_object(self.object, **kwargs) - if self.allow_delete and self._deleted: + if self.allow_add_remove and self._deleted: [self.delete_object(item) for item in self._deleted] return self.object @@ -699,6 +711,7 @@ class ModelSerializer(Serializer): m2m_data = {} related_data = {} nested_forward_relations = {} + removed_relations = [] meta = self.opts.model._meta # Reverse fk or one-to-one relations @@ -724,6 +737,10 @@ class ModelSerializer(Serializer): if isinstance(self.fields.get(field_name, None), Serializer): nested_forward_relations[field_name] = attrs[field_name] + # Deleted related objects + if self._deleted: + removed_relations = list(self._deleted) + # Update an existing instance... if instance is not None: for key, val in attrs.items(): @@ -740,6 +757,7 @@ class ModelSerializer(Serializer): instance._related_data = related_data instance._m2m_data = m2m_data instance._nested_forward_relations = nested_forward_relations + instance._removed_relations = removed_relations return instance @@ -764,6 +782,9 @@ class ModelSerializer(Serializer): obj.save(**kwargs) + if self.allow_add_remove and hasattr(obj, '_removed_relations'): + [self.delete_object(item) for item in obj._removed_relations] + if getattr(obj, '_m2m_data', None): for accessor_name, object_list in obj._m2m_data.items(): setattr(obj, accessor_name, object_list) @@ -773,18 +794,17 @@ class ModelSerializer(Serializer): for accessor_name, related in obj._related_data.items(): field = self.fields.get(accessor_name, None) if isinstance(field, Serializer): - # TODO: Following will be needed for reverse FK - # if field.many: - # # Nested reverse fk relationship - # for related_item in related: - # fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name - # setattr(related_item, fk_field, obj) - # self.save_object(related_item) - # else: - # Nested reverse one-one relationship - fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name - setattr(related, fk_field, obj) - self.save_object(related) + if field.many: + # Nested reverse fk relationship + for related_item in related: + fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name + setattr(related_item, fk_field, obj) + self.save_object(related_item) + else: + # Nested reverse one-one relationship + fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name + setattr(related, fk_field, obj) + self.save_object(related) else: # Reverse FK or reverse one-one setattr(obj, accessor_name, related) diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/relations_nested.py index e7af6565..20683d4a 100644 --- a/rest_framework/tests/relations_nested.py +++ b/rest_framework/tests/relations_nested.py @@ -13,6 +13,15 @@ class OneToOneSource(models.Model): target = models.OneToOneField(OneToOneTarget, related_name='source') +class OneToManyTarget(models.Model): + name = models.CharField(max_length=100) + + +class OneToManySource(models.Model): + name = models.CharField(max_length=100) + target = models.ForeignKey(OneToManyTarget, related_name='sources') + + class ReverseNestedOneToOneTests(TestCase): def setUp(self): class OneToOneSourceSerializer(serializers.ModelSerializer): @@ -189,3 +198,92 @@ class ForwardNestedOneToOneTests(TestCase): # {'id': 3, 'name': 'target-3', 'source': None} # ] # self.assertEqual(serializer.data, expected) + + +class ReverseNestedOneToManyTests(TestCase): + def setUp(self): + class OneToManySourceSerializer(serializers.ModelSerializer): + class Meta: + model = OneToManySource + fields = ('id', 'name') + + class OneToManyTargetSerializer(serializers.ModelSerializer): + sources = OneToManySourceSerializer(many=True, allow_add_remove=True) + + class Meta: + model = OneToManyTarget + fields = ('id', 'name', 'sources') + + self.Serializer = OneToManyTargetSerializer + + target = OneToManyTarget(name='target-1') + target.save() + for idx in range(1, 4): + source = OneToManySource(name='source-%d' % idx, target=target) + source.save() + + def test_one_to_many_retrieve(self): + queryset = OneToManyTarget.objects.all() + serializer = self.Serializer(queryset) + expected = [ + {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, + {'id': 2, 'name': 'source-2'}, + {'id': 3, 'name': 'source-3'}]}, + ] + self.assertEqual(serializer.data, expected) + + def test_one_to_many_create(self): + data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, + {'id': 2, 'name': 'source-2'}, + {'id': 3, 'name': 'source-3'}, + {'id': 4, 'name': 'source-4'}]} + instance = OneToManyTarget.objects.get(pk=1) + serializer = self.Serializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'target-1') + + # Ensure source 4 is added, and everything else is as + # expected. + queryset = OneToManyTarget.objects.all() + serializer = self.Serializer(queryset) + expected = [ + {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, + {'id': 2, 'name': 'source-2'}, + {'id': 3, 'name': 'source-3'}, + {'id': 4, 'name': 'source-4'}]} + ] + self.assertEqual(serializer.data, expected) + + def test_one_to_many_create_with_invalid_data(self): + data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, + {'id': 2, 'name': 'source-2'}, + {'id': 3, 'name': 'source-3'}, + {'id': 4}]} + serializer = self.Serializer(data=data) + self.assertFalse(serializer.is_valid()) + self.assertEqual(serializer.errors, {'sources': [{}, {}, {}, {'name': ['This field is required.']}]}) + + def test_one_to_many_update(self): + data = {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'}, + {'id': 2, 'name': 'source-2'}, + {'id': 3, 'name': 'source-3'}]} + instance = OneToManyTarget.objects.get(pk=1) + serializer = self.Serializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'target-1-updated') + + # Ensure (target 1, source 1) are updated, + # and everything else is as expected. + queryset = OneToManyTarget.objects.all() + serializer = self.Serializer(queryset) + expected = [ + {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'}, + {'id': 2, 'name': 'source-2'}, + {'id': 3, 'name': 'source-3'}]} + + ] + self.assertEqual(serializer.data, expected) diff --git a/rest_framework/tests/serializer_bulk_update.py b/rest_framework/tests/serializer_bulk_update.py index afc1a1a9..5328e733 100644 --- a/rest_framework/tests/serializer_bulk_update.py +++ b/rest_framework/tests/serializer_bulk_update.py @@ -201,7 +201,7 @@ class BulkUpdateSerializerTests(TestCase): 'author': 'Haruki Murakami' } ] - serializer = self.BookSerializer(self.books(), data=data, many=True, allow_delete=True) + serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True) self.assertEqual(serializer.is_valid(), True) self.assertEqual(serializer.data, data) serializer.save() @@ -223,7 +223,7 @@ class BulkUpdateSerializerTests(TestCase): 'author': 'Haruki Murakami' } ] - serializer = self.BookSerializer(self.books(), data=data, many=True, allow_delete=True) + serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True) self.assertEqual(serializer.is_valid(), True) self.assertEqual(serializer.data, data) serializer.save() @@ -249,6 +249,6 @@ class BulkUpdateSerializerTests(TestCase): {}, {'id': ['Enter a whole number.']} ] - serializer = self.BookSerializer(self.books(), data=data, many=True, allow_delete=True) + serializer = self.BookSerializer(self.books(), data=data, many=True, allow_add_remove=True) self.assertEqual(serializer.is_valid(), False) self.assertEqual(serializer.errors, expected_errors) -- cgit v1.2.3 From fdc5cc3d81679d30cd20acf063dc7dc74ad17d7a Mon Sep 17 00:00:00 2001 From: Mark Aaron Shirley Date: Thu, 18 Apr 2013 10:28:20 -0700 Subject: Fix model serializer nestesd delete behavior --- rest_framework/serializers.py | 40 +++++++++++--------------------- rest_framework/tests/relations_nested.py | 19 +++++++++++++++ 2 files changed, 33 insertions(+), 26 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 0f0f11a4..78c45548 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -20,6 +20,9 @@ from rest_framework.relations import * from rest_framework.fields import * +class RelationsList(list): + _deleted = [] + class NestedValidationError(ValidationError): """ The default ValidationError behavior is to stringify each item in the list @@ -149,7 +152,6 @@ class BaseSerializer(WritableField): self._data = None self._files = None self._errors = None - self._deleted = None if many and instance is not None and not hasattr(instance, '__iter__'): raise ValueError('instance should be a queryset or other iterable with many=True') @@ -288,15 +290,8 @@ class BaseSerializer(WritableField): You should override this method to control how deserialized objects are instantiated. """ - removed_relations = [] - - # Deleted related objects - if self._deleted: - removed_relations = list(self._deleted) - if instance is not None: instance.update(attrs) - instance._removed_relations = removed_relations return instance return attrs @@ -438,7 +433,7 @@ class BaseSerializer(WritableField): PendingDeprecationWarning, stacklevel=3) if many: - ret = [] + ret = RelationsList() errors = [] update = self.object is not None @@ -466,7 +461,7 @@ class BaseSerializer(WritableField): errors.append(self._errors) if update: - self._deleted = identity_to_objects.values() + ret._deleted = identity_to_objects.values() self._errors = any(errors) and errors or [] else: @@ -509,9 +504,6 @@ class BaseSerializer(WritableField): def save_object(self, obj, **kwargs): obj.save(**kwargs) - if self.allow_add_remove and hasattr(obj, '_removed_relations'): - [self.delete_object(item) for item in obj._removed_relations] - def delete_object(self, obj): obj.delete() @@ -521,11 +513,11 @@ class BaseSerializer(WritableField): """ if isinstance(self.object, list): [self.save_object(item, **kwargs) for item in self.object] - else: - self.save_object(self.object, **kwargs) - if self.allow_add_remove and self._deleted: - [self.delete_object(item) for item in self._deleted] + if self.allow_add_remove and self.object._deleted: + [self.delete_object(item) for item in self.object._deleted] + else: + self.save_object(self.object, **kwargs) return self.object @@ -715,7 +707,6 @@ class ModelSerializer(Serializer): m2m_data = {} related_data = {} nested_forward_relations = {} - removed_relations = [] meta = self.opts.model._meta # Reverse fk or one-to-one relations @@ -741,10 +732,6 @@ class ModelSerializer(Serializer): if isinstance(self.fields.get(field_name, None), Serializer): nested_forward_relations[field_name] = attrs[field_name] - # Deleted related objects - if self._deleted: - removed_relations = list(self._deleted) - # Update an existing instance... if instance is not None: for key, val in attrs.items(): @@ -761,7 +748,6 @@ class ModelSerializer(Serializer): instance._related_data = related_data instance._m2m_data = m2m_data instance._nested_forward_relations = nested_forward_relations - instance._removed_relations = removed_relations return instance @@ -786,9 +772,6 @@ class ModelSerializer(Serializer): obj.save(**kwargs) - if self.allow_add_remove and hasattr(obj, '_removed_relations'): - [self.delete_object(item) for item in obj._removed_relations] - if getattr(obj, '_m2m_data', None): for accessor_name, object_list in obj._m2m_data.items(): setattr(obj, accessor_name, object_list) @@ -804,6 +787,11 @@ class ModelSerializer(Serializer): fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name setattr(related_item, fk_field, obj) self.save_object(related_item) + + # Delete any removed objects + if field.allow_add_remove and related._deleted: + [self.delete_object(item) for item in related._deleted] + else: # Nested reverse one-one relationship fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/relations_nested.py index 20683d4a..22c98e7f 100644 --- a/rest_framework/tests/relations_nested.py +++ b/rest_framework/tests/relations_nested.py @@ -287,3 +287,22 @@ class ReverseNestedOneToManyTests(TestCase): ] self.assertEqual(serializer.data, expected) + + def test_one_to_many_delete(self): + data = {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, + {'id': 3, 'name': 'source-3'}]} + instance = OneToManyTarget.objects.get(pk=1) + serializer = self.Serializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + serializer.save() + + # Ensure source 2 is deleted, and everything else is as + # expected. + queryset = OneToManyTarget.objects.all() + serializer = self.Serializer(queryset) + expected = [ + {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, + {'id': 3, 'name': 'source-3'}]} + + ] + self.assertEqual(serializer.data, expected) -- cgit v1.2.3 From 7e0a93f0eefead25f0e9b6615675f394af3a4ba0 Mon Sep 17 00:00:00 2001 From: Mark Aaron Shirley Date: Fri, 19 Apr 2013 10:46:57 -0700 Subject: Don't use field when saving related data --- rest_framework/serializers.py | 36 +++++++++++++++++------------------- 1 file changed, 17 insertions(+), 19 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 78c45548..b39cb810 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -460,7 +460,7 @@ class BaseSerializer(WritableField): ret.append(self.from_native(item, None)) errors.append(self._errors) - if update: + if update and self.allow_add_remove: ret._deleted = identity_to_objects.values() self._errors = any(errors) and errors or [] @@ -514,7 +514,7 @@ class BaseSerializer(WritableField): if isinstance(self.object, list): [self.save_object(item, **kwargs) for item in self.object] - if self.allow_add_remove and self.object._deleted: + if self.object._deleted: [self.delete_object(item) for item in self.object._deleted] else: self.save_object(self.object, **kwargs) @@ -779,24 +779,22 @@ class ModelSerializer(Serializer): if getattr(obj, '_related_data', None): for accessor_name, related in obj._related_data.items(): - field = self.fields.get(accessor_name, None) - if isinstance(field, Serializer): - if field.many: - # Nested reverse fk relationship - for related_item in related: - fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name - setattr(related_item, fk_field, obj) - self.save_object(related_item) - - # Delete any removed objects - if field.allow_add_remove and related._deleted: - [self.delete_object(item) for item in related._deleted] - - else: - # Nested reverse one-one relationship + if isinstance(related, RelationsList): + # Nested reverse fk relationship + for related_item in related: fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name - setattr(related, fk_field, obj) - self.save_object(related) + setattr(related_item, fk_field, obj) + self.save_object(related_item) + + # Delete any removed objects + if related._deleted: + [self.delete_object(item) for item in related._deleted] + + elif isinstance(related, models.Model): + # Nested reverse one-one relationship + fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name + setattr(related, fk_field, obj) + self.save_object(related) else: # Reverse FK or reverse one-one setattr(obj, accessor_name, related) -- cgit v1.2.3 From 14482a966168a98d43099d00c163d1c8c3b6471b Mon Sep 17 00:00:00 2001 From: Mark Aaron Shirley Date: Wed, 8 May 2013 22:44:23 -0700 Subject: Fix deprecation warnings in relations_nested tests --- rest_framework/tests/relations_nested.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/relations_nested.py index 22c98e7f..8325580f 100644 --- a/rest_framework/tests/relations_nested.py +++ b/rest_framework/tests/relations_nested.py @@ -46,7 +46,7 @@ class ReverseNestedOneToOneTests(TestCase): def test_one_to_one_retrieve(self): queryset = OneToOneTarget.objects.all() - serializer = self.Serializer(queryset) + serializer = self.Serializer(queryset, many=True) expected = [ {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, @@ -65,7 +65,7 @@ class ReverseNestedOneToOneTests(TestCase): # Ensure (target 4, target_source 4, source 4) are added, and # everything else is as expected. queryset = OneToOneTarget.objects.all() - serializer = self.Serializer(queryset) + serializer = self.Serializer(queryset, many=True) expected = [ {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, @@ -92,7 +92,7 @@ class ReverseNestedOneToOneTests(TestCase): # Ensure (target 3, target_source 3, source 3) are updated, # and everything else is as expected. queryset = OneToOneTarget.objects.all() - serializer = self.Serializer(queryset) + serializer = self.Serializer(queryset, many=True) expected = [ {'id': 1, 'name': 'target-1', 'source': {'id': 1, 'name': 'source-1'}}, {'id': 2, 'name': 'target-2', 'source': {'id': 2, 'name': 'source-2'}}, @@ -125,7 +125,7 @@ class ForwardNestedOneToOneTests(TestCase): def test_one_to_one_retrieve(self): queryset = OneToOneSource.objects.all() - serializer = self.Serializer(queryset) + serializer = self.Serializer(queryset, many=True) expected = [ {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, @@ -144,7 +144,7 @@ class ForwardNestedOneToOneTests(TestCase): # Ensure (target 4, target_source 4, source 4) are added, and # everything else is as expected. queryset = OneToOneSource.objects.all() - serializer = self.Serializer(queryset) + serializer = self.Serializer(queryset, many=True) expected = [ {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, @@ -171,7 +171,7 @@ class ForwardNestedOneToOneTests(TestCase): # Ensure (target 3, target_source 3, source 3) are updated, # and everything else is as expected. queryset = OneToOneSource.objects.all() - serializer = self.Serializer(queryset) + serializer = self.Serializer(queryset, many=True) expected = [ {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, @@ -224,7 +224,7 @@ class ReverseNestedOneToManyTests(TestCase): def test_one_to_many_retrieve(self): queryset = OneToManyTarget.objects.all() - serializer = self.Serializer(queryset) + serializer = self.Serializer(queryset, many=True) expected = [ {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, {'id': 2, 'name': 'source-2'}, @@ -247,7 +247,7 @@ class ReverseNestedOneToManyTests(TestCase): # Ensure source 4 is added, and everything else is as # expected. queryset = OneToManyTarget.objects.all() - serializer = self.Serializer(queryset) + serializer = self.Serializer(queryset, many=True) expected = [ {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, {'id': 2, 'name': 'source-2'}, @@ -279,7 +279,7 @@ class ReverseNestedOneToManyTests(TestCase): # Ensure (target 1, source 1) are updated, # and everything else is as expected. queryset = OneToManyTarget.objects.all() - serializer = self.Serializer(queryset) + serializer = self.Serializer(queryset, many=True) expected = [ {'id': 1, 'name': 'target-1-updated', 'sources': [{'id': 1, 'name': 'source-1-updated'}, {'id': 2, 'name': 'source-2'}, @@ -299,7 +299,7 @@ class ReverseNestedOneToManyTests(TestCase): # Ensure source 2 is deleted, and everything else is as # expected. queryset = OneToManyTarget.objects.all() - serializer = self.Serializer(queryset) + serializer = self.Serializer(queryset, many=True) expected = [ {'id': 1, 'name': 'target-1', 'sources': [{'id': 1, 'name': 'source-1'}, {'id': 3, 'name': 'source-3'}]} -- cgit v1.2.3 From db9672d3048eebb3d3c3fb2b4a345e17b5aa23cc Mon Sep 17 00:00:00 2001 From: Alex Burgel Date: Wed, 24 Jul 2013 17:24:29 -0400 Subject: Add support for removing field files by sending an empty string --- rest_framework/fields.py | 5 ++++- rest_framework/tests/test_files.py | 28 ++++++++++++++++++++++++---- 2 files changed, 28 insertions(+), 5 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index f9931887..9ba5c0eb 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -307,7 +307,10 @@ class WritableField(Field): try: if self.use_files: files = files or {} - native = files[field_name] + try: + native = files[field_name] + except KeyError: + native = data[field_name] else: native = data[field_name] except KeyError: diff --git a/rest_framework/tests/test_files.py b/rest_framework/tests/test_files.py index 487046ac..495c2a7f 100644 --- a/rest_framework/tests/test_files.py +++ b/rest_framework/tests/test_files.py @@ -7,13 +7,13 @@ import datetime class UploadedFile(object): - def __init__(self, file, created=None): + def __init__(self, file=None, created=None): self.file = file self.created = created or datetime.datetime.now() class UploadedFileSerializer(serializers.Serializer): - file = serializers.FileField() + file = serializers.FileField(required=False) created = serializers.DateTimeField() def restore_object(self, attrs, instance=None): @@ -47,5 +47,25 @@ class FileSerializerTests(TestCase): now = datetime.datetime.now() serializer = UploadedFileSerializer(data={'created': now}) - self.assertFalse(serializer.is_valid()) - self.assertIn('file', serializer.errors) + self.assertTrue(serializer.is_valid()) + self.assertEqual(serializer.object.created, now) + self.assertIsNone(serializer.object.file) + + def test_remove_with_empty_string(self): + """ + Passing empty string as data should cause file to be removed + + Test for: + https://github.com/tomchristie/django-rest-framework/issues/937 + """ + now = datetime.datetime.now() + file = BytesIO(six.b('stuff')) + file.name = 'stuff.txt' + file.size = len(file.getvalue()) + + uploaded_file = UploadedFile(file=file, created=now) + + serializer = UploadedFileSerializer(instance=uploaded_file, data={'created': now, 'file': ''}) + self.assertTrue(serializer.is_valid()) + self.assertEqual(serializer.object.created, uploaded_file.created) + self.assertIsNone(serializer.object.file) -- cgit v1.2.3 From abe655e061871a568cccf473414e350f3eb61d8b Mon Sep 17 00:00:00 2001 From: Yuri Prezument Date: Mon, 12 Aug 2013 21:01:37 +0300 Subject: Make OneToOneSource.target nullable --- rest_framework/tests/test_relations_nested.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/tests/test_relations_nested.py b/rest_framework/tests/test_relations_nested.py index 8325580f..30229687 100644 --- a/rest_framework/tests/test_relations_nested.py +++ b/rest_framework/tests/test_relations_nested.py @@ -10,7 +10,8 @@ class OneToOneTarget(models.Model): class OneToOneSource(models.Model): name = models.CharField(max_length=100) - target = models.OneToOneField(OneToOneTarget, related_name='source') + target = models.OneToOneField(OneToOneTarget, related_name='source', + null=True, blank=True) class OneToManyTarget(models.Model): @@ -21,7 +22,7 @@ class OneToManySource(models.Model): name = models.CharField(max_length=100) target = models.ForeignKey(OneToManyTarget, related_name='sources') - + class ReverseNestedOneToOneTests(TestCase): def setUp(self): class OneToOneSourceSerializer(serializers.ModelSerializer): -- cgit v1.2.3 From 901d2b0eb8270befa051510e190f3d5679086c7f Mon Sep 17 00:00:00 2001 From: Yuri Prezument Date: Mon, 12 Aug 2013 21:02:59 +0300 Subject: Failing test case for nullifying nested object --- rest_framework/tests/test_relations_nested.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/tests/test_relations_nested.py b/rest_framework/tests/test_relations_nested.py index 30229687..d393b0c3 100644 --- a/rest_framework/tests/test_relations_nested.py +++ b/rest_framework/tests/test_relations_nested.py @@ -180,6 +180,25 @@ class ForwardNestedOneToOneTests(TestCase): ] self.assertEqual(serializer.data, expected) + def test_one_to_one_update_to_null(self): + data = {'id': 3, 'name': 'source-3-updated', 'target': None} + instance = OneToOneSource.objects.get(pk=3) + serializer = self.Serializer(instance, data=data) + self.assertTrue(serializer.is_valid()) + obj = serializer.save() + + self.assertEqual(serializer.data, data) + self.assertEqual(obj.name, 'source-3-updated') + self.assertEqual(obj.target, None) + + queryset = OneToOneSource.objects.all() + serializer = self.Serializer(queryset, many=True) + expected = [ + {'id': 1, 'name': 'source-1', 'target': {'id': 1, 'name': 'target-1'}}, + {'id': 2, 'name': 'source-2', 'target': {'id': 2, 'name': 'target-2'}}, + {'id': 3, 'name': 'source-3-updated', 'target': None} + ] + self.assertEqual(serializer.data, expected) # TODO: Nullable 1-1 tests # def test_one_to_one_delete(self): -- cgit v1.2.3 From ff1efcf60f0a9b66cdb736f8c0b2cfe2fc84cdf5 Mon Sep 17 00:00:00 2001 From: Yuri Prezument Date: Mon, 12 Aug 2013 18:08:23 +0300 Subject: If null or blank - don't save the nested object --- rest_framework/serializers.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index d8f9145e..2b260c25 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -522,7 +522,7 @@ class BaseSerializer(WritableField): if self.object._deleted: [self.delete_object(item) for item in self.object._deleted] else: - self.save_object(self.object, **kwargs) + self.save_object(self.object, **kwargs) return self.object @@ -891,7 +891,8 @@ class ModelSerializer(Serializer): # Nested relationships need to be saved before we can save the # parent instance. for field_name, sub_object in obj._nested_forward_relations.items(): - self.save_object(sub_object) + if sub_object: + self.save_object(sub_object) setattr(obj, field_name, sub_object) obj.save(**kwargs) -- cgit v1.2.3 From e677f3ee5c9435594ce58a3256a119c08bdc1e42 Mon Sep 17 00:00:00 2001 From: Krzysztof Jurewicz Date: Tue, 13 Aug 2013 13:26:30 +0200 Subject: PATCH requests should not be able to create objects. --- rest_framework/mixins.py | 13 ++++++++----- rest_framework/tests/test_generics.py | 11 +++++++++++ 2 files changed, 19 insertions(+), 5 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index f11def6d..59d64469 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -142,11 +142,14 @@ class UpdateModelMixin(object): try: return self.get_object() except Http404: - # If this is a PUT-as-create operation, we need to ensure that - # we have relevant permissions, as if this was a POST request. - # This will either raise a PermissionDenied exception, - # or simply return None - self.check_permissions(clone_request(self.request, 'POST')) + if self.request.method == 'PUT': + # For PUT-as-create operation, we need to ensure that we have + # relevant permissions, as if this was a POST request. This + # will either raise a PermissionDenied exception, or simply + # return None. + self.check_permissions(clone_request(self.request, 'POST')) + else: + raise def pre_save(self, obj): """ diff --git a/rest_framework/tests/test_generics.py b/rest_framework/tests/test_generics.py index 1550880b..7a87d389 100644 --- a/rest_framework/tests/test_generics.py +++ b/rest_framework/tests/test_generics.py @@ -338,6 +338,17 @@ class TestInstanceView(TestCase): new_obj = SlugBasedModel.objects.get(slug='test_slug') self.assertEqual(new_obj.text, 'foobar') + def test_patch_cannot_create_an_object(self): + """ + PATCH requests should not be able to create objects. + """ + data = {'text': 'foobar'} + request = factory.patch('/999', data, format='json') + with self.assertNumQueries(1): + response = self.view(request, pk=999).render() + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + self.assertFalse(self.objects.filter(id=999).exists()) + class TestOverriddenGetObject(TestCase): """ -- cgit v1.2.3 From 19a774f97292444a48c5b7521e1b0c0ea48b6502 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 23 Aug 2013 11:21:45 +0100 Subject: force_authenticate(None) also clears session info. Closes #1055. --- rest_framework/test.py | 2 ++ rest_framework/tests/test_testing.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/test.py b/rest_framework/test.py index a18f5a29..234d10a4 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -134,6 +134,8 @@ class APIClient(APIRequestFactory, DjangoClient): """ self.handler._force_user = user self.handler._force_token = token + if user is None: + self.logout() # Also clear any possible session info if required def request(self, **kwargs): # Ensure that any credentials set get added to every request. diff --git a/rest_framework/tests/test_testing.py b/rest_framework/tests/test_testing.py index 49d45fc2..48b8956b 100644 --- a/rest_framework/tests/test_testing.py +++ b/rest_framework/tests/test_testing.py @@ -17,8 +17,18 @@ def view(request): }) +@api_view(['GET', 'POST']) +def session_view(request): + active_session = request.session.get('active_session', False) + request.session['active_session'] = True + return Response({ + 'active_session': active_session + }) + + urlpatterns = patterns('', url(r'^view/$', view), + url(r'^session-view/$', session_view), ) @@ -46,6 +56,26 @@ class TestAPITestClient(TestCase): response = self.client.get('/view/') self.assertEqual(response.data['user'], 'example') + def test_force_authenticate_with_sessions(self): + """ + Setting `.force_authenticate()` forcibly authenticates each request. + """ + user = User.objects.create_user('example', 'example@example.com') + self.client.force_authenticate(user) + + # First request does not yet have an active session + response = self.client.get('/session-view/') + self.assertEqual(response.data['active_session'], False) + + # Subsequant requests have an active session + response = self.client.get('/session-view/') + self.assertEqual(response.data['active_session'], True) + + # Force authenticating as `None` should also logout the user session. + self.client.force_authenticate(None) + response = self.client.get('/session-view/') + self.assertEqual(response.data['active_session'], False) + def test_csrf_exempt_by_default(self): """ By default, the test client is CSRF exempt. -- cgit v1.2.3 From 95b2bf50fbb9b95facebb23812bbbb2e27a76035 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 23 Aug 2013 12:03:54 +0100 Subject: Add validation error test when passing non-file to FileField --- rest_framework/tests/test_files.py | 11 +++++++++++ 1 file changed, 11 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/tests/test_files.py b/rest_framework/tests/test_files.py index 495c2a7f..c13c38b8 100644 --- a/rest_framework/tests/test_files.py +++ b/rest_framework/tests/test_files.py @@ -69,3 +69,14 @@ class FileSerializerTests(TestCase): self.assertTrue(serializer.is_valid()) self.assertEqual(serializer.object.created, uploaded_file.created) self.assertIsNone(serializer.object.file) + + def test_validation_error_with_non_file(self): + """ + Passing non-files should raise a validation error. + """ + now = datetime.datetime.now() + errmsg = 'No file was submitted. Check the encoding type on the form.' + + serializer = UploadedFileSerializer(data={'created': now, 'file': 'abc'}) + self.assertFalse(serializer.is_valid()) + self.assertEqual(serializer.errors, {'file': [errmsg]}) -- cgit v1.2.3 From e7927e9bca5bc0d0ac3b528e68244c713c5df97f Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 23 Aug 2013 13:35:50 +0100 Subject: Extra docs on PATCH with no object. --- rest_framework/mixins.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 59d64469..426865ff 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -149,6 +149,8 @@ class UpdateModelMixin(object): # return None. self.check_permissions(clone_request(self.request, 'POST')) else: + # PATCH requests where the object does not exist should still + # return a 404 response. raise def pre_save(self, obj): -- cgit v1.2.3 From e03854ba6a74428675c40d469a7768cc5131035f Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 23 Aug 2013 14:06:14 +0100 Subject: Tweaks to display nested data in empty serializers --- rest_framework/relations.py | 9 +++++++-- rest_framework/serializers.py | 4 ++-- 2 files changed, 9 insertions(+), 4 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/relations.py b/rest_framework/relations.py index edaf76d6..7408758e 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -134,9 +134,9 @@ class RelatedField(WritableField): value = obj for component in source.split('.'): - value = get_component(value, component) if value is None: break + value = get_component(value, component) except ObjectDoesNotExist: return None @@ -567,8 +567,13 @@ class HyperlinkedIdentityField(Field): May raise a `NoReverseMatch` if the `view_name` and `lookup_field` attributes are not configured to correctly match the URL conf. """ - lookup_field = getattr(obj, self.lookup_field) + lookup_field = getattr(obj, self.lookup_field, None) kwargs = {self.lookup_field: lookup_field} + + # Handle unsaved object case + if lookup_field is None: + return None + try: return reverse(view_name, kwargs=kwargs, request=request, format=format) except NoReverseMatch: diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 2b260c25..22525964 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -338,9 +338,9 @@ class BaseSerializer(WritableField): value = obj for component in source.split('.'): - value = get_component(value, component) if value is None: - break + return self.to_native(None) + value = get_component(value, component) except ObjectDoesNotExist: return None -- cgit v1.2.3 From 0966a2680ba02e6a4586bd2777ed593fcc66a453 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 23 Aug 2013 14:38:31 +0100 Subject: First pass at HTMLFormRenderer --- rest_framework/renderers.py | 97 +++++++++++++---------- rest_framework/templates/rest_framework/base.html | 10 +-- 2 files changed, 58 insertions(+), 49 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 1006e26c..a73b2d73 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -316,6 +316,59 @@ class StaticHTMLRenderer(TemplateHTMLRenderer): return data +class HTMLFormRenderer(BaseRenderer): + template = 'rest_framework/form.html' + + def serializer_to_form_fields(self, serializer): + fields = {} + for k, v in serializer.get_fields().items(): + if getattr(v, 'read_only', True): + continue + + kwargs = {} + kwargs['required'] = v.required + + #if getattr(v, 'queryset', None): + # kwargs['queryset'] = v.queryset + + if getattr(v, 'choices', None) is not None: + kwargs['choices'] = v.choices + + if getattr(v, 'regex', None) is not None: + kwargs['regex'] = v.regex + + if getattr(v, 'widget', None): + widget = copy.deepcopy(v.widget) + kwargs['widget'] = widget + + if getattr(v, 'default', None) is not None: + kwargs['initial'] = v.default + + if getattr(v, 'label', None) is not None: + kwargs['label'] = v.label + + if getattr(v, 'help_text', None) is not None: + kwargs['help_text'] = v.help_text + + fields[k] = v.form_field_class(**kwargs) + + return fields + + def render(self, serializer, obj, request): + fields = self.serializer_to_form_fields(serializer) + + # Creating an on the fly form see: + # http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python + OnTheFlyForm = type(str("OnTheFlyForm"), (forms.Form,), fields) + data = (obj is not None) and serializer.data or None + form_instance = OnTheFlyForm(data) + + template = loader.get_template(self.template) + context = RequestContext(request, {'form': form_instance}) + + return template.render(context) + + class BrowsableAPIRenderer(BaseRenderer): """ HTML renderer used to self-document the API. @@ -371,41 +424,6 @@ class BrowsableAPIRenderer(BaseRenderer): return False # Doesn't have permissions return True - def serializer_to_form_fields(self, serializer): - fields = {} - for k, v in serializer.get_fields().items(): - if getattr(v, 'read_only', True): - continue - - kwargs = {} - kwargs['required'] = v.required - - #if getattr(v, 'queryset', None): - # kwargs['queryset'] = v.queryset - - if getattr(v, 'choices', None) is not None: - kwargs['choices'] = v.choices - - if getattr(v, 'regex', None) is not None: - kwargs['regex'] = v.regex - - if getattr(v, 'widget', None): - widget = copy.deepcopy(v.widget) - kwargs['widget'] = widget - - if getattr(v, 'default', None) is not None: - kwargs['initial'] = v.default - - if getattr(v, 'label', None) is not None: - kwargs['label'] = v.label - - if getattr(v, 'help_text', None) is not None: - kwargs['help_text'] = v.help_text - - fields[k] = v.form_field_class(**kwargs) - - return fields - def _get_form(self, view, method, request): # We need to impersonate a request with the correct method, # so that eg. any dynamic get_serializer_class methods return the @@ -447,14 +465,7 @@ class BrowsableAPIRenderer(BaseRenderer): return serializer = view.get_serializer(instance=obj) - fields = self.serializer_to_form_fields(serializer) - - # Creating an on the fly form see: - # http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python - OnTheFlyForm = type(str("OnTheFlyForm"), (forms.Form,), fields) - data = (obj is not None) and serializer.data or None - form_instance = OnTheFlyForm(data) - return form_instance + return HTMLFormRenderer().render(serializer, obj, request) def get_raw_data_form(self, view, method, request, media_types): """ diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html index 51f9c291..6ae47563 100644 --- a/rest_framework/templates/rest_framework/base.html +++ b/rest_framework/templates/rest_framework/base.html @@ -136,9 +136,9 @@ {% if post_form %}
{% with form=post_form %} -
+
- {% include "rest_framework/form.html" %} + {{ post_form }}
@@ -174,16 +174,14 @@
{% if put_form %}
- {% with form=put_form %} - +
- {% include "rest_framework/form.html" %} + {{ put_form }}
- {% endwith %}
{% endif %}
-- cgit v1.2.3 From 005f475c6af023cc7c75cf38d3a89e22638e5d84 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 23 Aug 2013 14:58:06 +0100 Subject: Don't consume .json style suffixes with routers. When trailing slash is false, the lookup regex should not consume '.' characters. Fixes #1057. --- rest_framework/routers.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 930011d3..3fee1e49 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -189,7 +189,11 @@ class SimpleRouter(BaseRouter): Given a viewset, return the portion of URL regex that is used to match against a single instance. """ - base_regex = '(?P<{lookup_field}>[^/]+)' + if self.trailing_slash: + base_regex = '(?P<{lookup_field}>[^/]+)' + else: + # Don't consume `.json` style suffixes + base_regex = '(?P<{lookup_field}>[^/.]+)' lookup_field = getattr(viewset, 'lookup_field', 'pk') return base_regex.format(lookup_field=lookup_field) -- cgit v1.2.3 From 1c935cd3d271efd06f1621c9dddb9e1cd0333e20 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 23 Aug 2013 15:18:47 +0100 Subject: Fix failing test for router with no trailing slash --- rest_framework/tests/test_routers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/tests/test_routers.py b/rest_framework/tests/test_routers.py index 5fcccb74..e723f7d4 100644 --- a/rest_framework/tests/test_routers.py +++ b/rest_framework/tests/test_routers.py @@ -146,7 +146,7 @@ class TestTrailingSlashRemoved(TestCase): self.urls = self.router.urls def test_urls_can_have_trailing_slash_removed(self): - expected = ['^notes$', '^notes/(?P[^/]+)$'] + expected = ['^notes$', '^notes/(?P[^/.]+)$'] for idx in range(len(expected)): self.assertEqual(expected[idx], self.urls[idx].regex.pattern) -- cgit v1.2.3 From 10d386ec6a4822402b5ffea46bdd9e7d72db519b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 23 Aug 2013 16:10:20 +0100 Subject: Cleanup and dealing with empty form data. --- rest_framework/relations.py | 2 + rest_framework/renderers.py | 103 ++++++++++++++++++++++-------------------- rest_framework/serializers.py | 3 +- 3 files changed, 58 insertions(+), 50 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 7408758e..3ad16ee5 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -244,6 +244,8 @@ class PrimaryKeyRelatedField(RelatedField): source = self.source or field_name queryset = obj for component in source.split('.'): + if queryset is None: + return [] queryset = get_component(queryset, component) # Forward relationship diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index a73b2d73..a8670546 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -319,52 +319,53 @@ class StaticHTMLRenderer(TemplateHTMLRenderer): class HTMLFormRenderer(BaseRenderer): template = 'rest_framework/form.html' - def serializer_to_form_fields(self, serializer): + def data_to_form_fields(self, data): fields = {} - for k, v in serializer.get_fields().items(): - if getattr(v, 'read_only', True): + for key, val in data.fields.items(): + if getattr(val, 'read_only', True): continue kwargs = {} - kwargs['required'] = v.required + kwargs['required'] = val.required #if getattr(v, 'queryset', None): # kwargs['queryset'] = v.queryset - if getattr(v, 'choices', None) is not None: - kwargs['choices'] = v.choices + if getattr(val, 'choices', None) is not None: + kwargs['choices'] = val.choices - if getattr(v, 'regex', None) is not None: - kwargs['regex'] = v.regex + if getattr(val, 'regex', None) is not None: + kwargs['regex'] = val.regex - if getattr(v, 'widget', None): - widget = copy.deepcopy(v.widget) + if getattr(val, 'widget', None): + widget = copy.deepcopy(val.widget) kwargs['widget'] = widget - if getattr(v, 'default', None) is not None: - kwargs['initial'] = v.default + if getattr(val, 'default', None) is not None: + kwargs['initial'] = val.default - if getattr(v, 'label', None) is not None: - kwargs['label'] = v.label + if getattr(val, 'label', None) is not None: + kwargs['label'] = val.label - if getattr(v, 'help_text', None) is not None: - kwargs['help_text'] = v.help_text + if getattr(val, 'help_text', None) is not None: + kwargs['help_text'] = val.help_text - fields[k] = v.form_field_class(**kwargs) + fields[key] = val.form_field_class(**kwargs) return fields - def render(self, serializer, obj, request): - fields = self.serializer_to_form_fields(serializer) + def render(self, data, accepted_media_type=None, renderer_context=None): + self.renderer_context = renderer_context or {} + request = renderer_context['request'] # Creating an on the fly form see: # http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python - OnTheFlyForm = type(str("OnTheFlyForm"), (forms.Form,), fields) - data = (obj is not None) and serializer.data or None - form_instance = OnTheFlyForm(data) + fields = self.data_to_form_fields(data) + DynamicForm = type(str('DynamicForm'), (forms.Form,), fields) + data = None if data.empty else data template = loader.get_template(self.template) - context = RequestContext(request, {'form': form_instance}) + context = RequestContext(request, {'form': DynamicForm(data)}) return template.render(context) @@ -377,6 +378,7 @@ class BrowsableAPIRenderer(BaseRenderer): format = 'api' template = 'rest_framework/api.html' charset = 'utf-8' + form_renderer_class = HTMLFormRenderer def get_default_renderer(self, view): """ @@ -424,19 +426,7 @@ class BrowsableAPIRenderer(BaseRenderer): return False # Doesn't have permissions return True - def _get_form(self, view, method, request): - # We need to impersonate a request with the correct method, - # so that eg. any dynamic get_serializer_class methods return the - # correct form for each method. - restore = view.request - request = clone_request(request, method) - view.request = request - try: - return self.get_form(view, method, request) - finally: - view.request = restore - - def _get_raw_data_form(self, view, method, request, media_types): + def _get_rendered_html_form(self, view, method, request): # We need to impersonate a request with the correct method, # so that eg. any dynamic get_serializer_class methods return the # correct form for each method. @@ -444,15 +434,16 @@ class BrowsableAPIRenderer(BaseRenderer): request = clone_request(request, method) view.request = request try: - return self.get_raw_data_form(view, method, request, media_types) + return self.get_rendered_html_form(view, method, request) finally: view.request = restore - def get_form(self, view, method, request): + def get_rendered_html_form(self, view, method, request): """ - Get a form, possibly bound to either the input or output data. - In the absence on of the Resource having an associated form then - provide a form that can be used to submit arbitrary content. + Return a string representing a rendered HTML form, possibly bound to + either the input or output data. + + In the absence of the View having an associated form then return None. """ obj = getattr(view, 'object', None) if not self.show_form_for_method(view, method, request, obj): @@ -465,7 +456,21 @@ class BrowsableAPIRenderer(BaseRenderer): return serializer = view.get_serializer(instance=obj) - return HTMLFormRenderer().render(serializer, obj, request) + data = serializer.data + form_renderer = self.form_renderer_class() + return form_renderer.render(data, self.accepted_media_type, self.renderer_context) + + def _get_raw_data_form(self, view, method, request, media_types): + # We need to impersonate a request with the correct method, + # so that eg. any dynamic get_serializer_class methods return the + # correct form for each method. + restore = view.request + request = clone_request(request, method) + view.request = request + try: + return self.get_raw_data_form(view, method, request, media_types) + finally: + view.request = restore def get_raw_data_form(self, view, method, request, media_types): """ @@ -520,8 +525,8 @@ class BrowsableAPIRenderer(BaseRenderer): """ Render the HTML for the browsable API representation. """ - accepted_media_type = accepted_media_type or '' - renderer_context = renderer_context or {} + self.accepted_media_type = accepted_media_type or '' + self.renderer_context = renderer_context or {} view = renderer_context['view'] request = renderer_context['request'] @@ -531,11 +536,11 @@ class BrowsableAPIRenderer(BaseRenderer): renderer = self.get_default_renderer(view) content = self.get_content(renderer, data, accepted_media_type, renderer_context) - put_form = self._get_form(view, 'PUT', request) - post_form = self._get_form(view, 'POST', request) - patch_form = self._get_form(view, 'PATCH', request) - delete_form = self._get_form(view, 'DELETE', request) - options_form = self._get_form(view, 'OPTIONS', request) + put_form = self._get_rendered_html_form(view, 'PUT', request) + post_form = self._get_rendered_html_form(view, 'POST', request) + patch_form = self._get_rendered_html_form(view, 'PATCH', request) + delete_form = self._get_rendered_html_form(view, 'DELETE', request) + options_form = self._get_rendered_html_form(view, 'OPTIONS', request) raw_data_put_form = self._get_raw_data_form(view, 'PUT', request, media_types) raw_data_post_form = self._get_raw_data_form(view, 'POST', request, media_types) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index fde06d83..97e0a005 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -300,7 +300,8 @@ class BaseSerializer(WritableField): Serialize objects -> primitives. """ ret = self._dict_class() - ret.fields = {} + ret.fields = self._dict_class() + ret.empty = obj is None for field_name, field in self.fields.items(): field.initialize(parent=self, field_name=field_name) -- cgit v1.2.3 From e23d5888522f98c30418452c0f833cf11589e0c1 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 23 Aug 2013 16:16:41 +0100 Subject: Adding standard renderer attributes and documenting --- rest_framework/renderers.py | 11 +++++++++++ 1 file changed, 11 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index a8670546..9885c8dd 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -317,7 +317,18 @@ class StaticHTMLRenderer(TemplateHTMLRenderer): class HTMLFormRenderer(BaseRenderer): + """ + Renderers serializer data into an HTML form. + + If the serializer was instantiated without an object then this will + return an HTML form not bound to any object, + otherwise it will return an HTML form with the appropriate initial data + populated from the object. + """ + media_type = 'text/html' + format = 'form' template = 'rest_framework/form.html' + charset = 'utf-8' def data_to_form_fields(self, data): fields = {} -- cgit v1.2.3 From 436e66a42db21b52fd5e1582011d2f0f7f81f9c7 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 23 Aug 2013 16:45:55 +0100 Subject: JSON responses should not include a charset --- rest_framework/renderers.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 1006e26c..c87014e2 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -36,6 +36,7 @@ class BaseRenderer(object): media_type = None format = None charset = 'utf-8' + render_style = 'text' def render(self, data, accepted_media_type=None, renderer_context=None): raise NotImplemented('Renderer class requires .render() to be implemented') @@ -51,16 +52,17 @@ class JSONRenderer(BaseRenderer): format = 'json' encoder_class = encoders.JSONEncoder ensure_ascii = True - charset = 'utf-8' - # Note that JSON encodings must be utf-8, utf-16 or utf-32. + charset = None + # JSON is a binary encoding, that can be encoded as utf-8, utf-16 or utf-32. # See: http://www.ietf.org/rfc/rfc4627.txt + # Also: http://lucumr.pocoo.org/2013/7/19/application-mimetypes-and-encodings/ def render(self, data, accepted_media_type=None, renderer_context=None): """ Render `data` into JSON. """ if data is None: - return '' + return bytes() # If 'indent' is provided in the context, then pretty print the result. # E.g. If we're being called by the BrowsableAPIRenderer. @@ -85,13 +87,12 @@ class JSONRenderer(BaseRenderer): # and may (or may not) be unicode. # On python 3.x json.dumps() returns unicode strings. if isinstance(ret, six.text_type): - return bytes(ret.encode(self.charset)) + return bytes(ret.encode('utf-8')) return ret class UnicodeJSONRenderer(JSONRenderer): ensure_ascii = False - charset = 'utf-8' """ Renderer which serializes to JSON. Does *not* apply JSON's character escaping for non-ascii characters. @@ -108,6 +109,7 @@ class JSONPRenderer(JSONRenderer): format = 'jsonp' callback_parameter = 'callback' default_callback = 'callback' + charset = 'utf-8' def get_callback(self, renderer_context): """ @@ -348,7 +350,10 @@ class BrowsableAPIRenderer(BaseRenderer): renderer_context['indent'] = 4 content = renderer.render(data, accepted_media_type, renderer_context) - if renderer.charset is None: + render_style = getattr(renderer, 'render_style', 'text') + assert render_style in ['text', 'binary'], 'Expected .render_style ' \ + '"text" or "binary", but got "%s"' % render_style + if render_style == 'binary': return '[%d bytes of binary content]' % len(content) return content -- cgit v1.2.3 From be0f5850c398b7f7397d66eaed26d6b78163b259 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 23 Aug 2013 16:51:34 +0100 Subject: Extra docs --- rest_framework/renderers.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index c07b1652..b30f2ea9 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -326,6 +326,8 @@ class HTMLFormRenderer(BaseRenderer): return an HTML form not bound to any object, otherwise it will return an HTML form with the appropriate initial data populated from the object. + + Note that rendering of field and form errors is not currently supported. """ media_type = 'text/html' format = 'form' @@ -368,6 +370,18 @@ class HTMLFormRenderer(BaseRenderer): return fields def render(self, data, accepted_media_type=None, renderer_context=None): + """ + Render serializer data and return an HTML form, as a string. + """ + # The HTMLFormRenderer currently uses something of a hack to render + # the content, by translating each of the serializer fields into + # an html form field, creating a dynamic form using those fields, + # and then rendering that form. + + # This isn't strictly neccessary, as we could render the serilizer + # fields to HTML directly. The implementation is historical and will + # likely change at some point. + self.renderer_context = renderer_context or {} request = renderer_context['request'] -- cgit v1.2.3 From 9d3fae27fd9c3236dfd9c26ae9b830deb6fa4e9b Mon Sep 17 00:00:00 2001 From: Eric Buehl Date: Fri, 23 Aug 2013 16:48:32 +0000 Subject: parameterize identity field class to allow for easier subclassing --- rest_framework/serializers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 31cfa344..abb96941 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -903,6 +903,7 @@ class HyperlinkedModelSerializer(ModelSerializer): _options_class = HyperlinkedModelSerializerOptions _default_view_name = '%(model_name)s-detail' _hyperlink_field_class = HyperlinkedRelatedField + _hyperlink_identify_field_class = HyperlinkedIdentityField def get_default_fields(self): fields = super(HyperlinkedModelSerializer, self).get_default_fields() @@ -911,7 +912,7 @@ class HyperlinkedModelSerializer(ModelSerializer): self.opts.view_name = self._get_default_view_name(self.opts.model) if 'url' not in fields: - url_field = HyperlinkedIdentityField( + url_field = self._hyperlink_identify_field_class( view_name=self.opts.view_name, lookup_field=self.opts.lookup_field ) -- cgit v1.2.3 From 316de3a8a314162e3d6ec081344eabca3a4d91b9 Mon Sep 17 00:00:00 2001 From: Alexander Akhmetov Date: Mon, 26 Aug 2013 20:05:36 +0400 Subject: Added max_paginate_by parameter --- rest_framework/generics.py | 10 +++++-- rest_framework/settings.py | 1 + rest_framework/tests/test_pagination.py | 46 +++++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 5ecf6310..33affee8 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -56,6 +56,7 @@ class GenericAPIView(views.APIView): # Pagination settings paginate_by = api_settings.PAGINATE_BY paginate_by_param = api_settings.PAGINATE_BY_PARAM + max_paginate_by = api_settings.MAX_PAGINATE_BY pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS page_kwarg = 'page' @@ -207,11 +208,16 @@ class GenericAPIView(views.APIView): if self.paginate_by_param: query_params = self.request.QUERY_PARAMS try: - return int(query_params[self.paginate_by_param]) + paginate_by_param = int(query_params[self.paginate_by_param]) except (KeyError, ValueError): pass + else: + if self.max_paginate_by: + return min(self.max_paginate_by, paginate_by_param) + else: + return paginate_by_param - return self.paginate_by + return min(self.max_paginate_by, self.paginate_by) or self.paginate_by def get_serializer_class(self): """ diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 7d25e513..b8e40bfa 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -68,6 +68,7 @@ DEFAULTS = { # Pagination 'PAGINATE_BY': None, 'PAGINATE_BY_PARAM': None, + 'MAX_PAGINATE_BY': None, # View configuration 'VIEW_NAME_FUNCTION': 'rest_framework.views.get_view_name', diff --git a/rest_framework/tests/test_pagination.py b/rest_framework/tests/test_pagination.py index 85d4640e..cbed1604 100644 --- a/rest_framework/tests/test_pagination.py +++ b/rest_framework/tests/test_pagination.py @@ -42,6 +42,16 @@ class PaginateByParamView(generics.ListAPIView): paginate_by_param = 'page_size' +class MaxPaginateByView(generics.ListAPIView): + """ + View for testing custom max_paginate_by usage + """ + model = BasicModel + paginate_by = 5 + max_paginate_by = 3 + paginate_by_param = 'page_size' + + class IntegrationTestPagination(TestCase): """ Integration tests for paginated list views. @@ -313,6 +323,42 @@ class TestCustomPaginateByParam(TestCase): self.assertEqual(response.data['results'], self.data[:5]) +class TestMaxPaginateByParam(TestCase): + """ + Tests for list views with max_paginate_by kwarg + """ + + def setUp(self): + """ + Create 13 BasicModel instances. + """ + for i in range(13): + BasicModel(text=i).save() + self.objects = BasicModel.objects + self.data = [ + {'id': obj.id, 'text': obj.text} + for obj in self.objects.all() + ] + self.view = MaxPaginateByView.as_view() + + def test_max_paginate_by(self): + """ + If max_paginate_by is set and it less than paginate_by, new kwarg should limit requests for review. + """ + request = factory.get('/?page_size=10') + response = self.view(request).render() + self.assertEqual(response.data['count'], 13) + self.assertEqual(response.data['results'], self.data[:3]) + + def test_max_paginate_by_without_page_size_param(self): + """ + If max_paginate_by is set, new kwarg should limit requests for review. + """ + request = factory.get('/') + response = self.view(request).render() + self.assertEqual(response.data['results'], self.data[:3]) + + ### Tests for context in pagination serializers class CustomField(serializers.Field): -- cgit v1.2.3 From 8d590ebfded0968e458f8e3a87efabec8384586e Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 27 Aug 2013 11:22:19 +0100 Subject: First hacky pass at displaying raw data --- rest_framework/parsers.py | 9 +++++++-- rest_framework/renderers.py | 27 +++++++++++++++++++++++++-- rest_framework/serializers.py | 2 ++ rest_framework/tests/test_serializer.py | 2 +- 4 files changed, 35 insertions(+), 5 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py index 96bfac84..c635505a 100644 --- a/rest_framework/parsers.py +++ b/rest_framework/parsers.py @@ -10,9 +10,9 @@ from django.core.files.uploadhandler import StopFutureHandlers from django.http import QueryDict from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser from django.http.multipartparser import MultiPartParserError, parse_header, ChunkIter -from rest_framework.compat import yaml, etree +from rest_framework.compat import etree, six, yaml from rest_framework.exceptions import ParseError -from rest_framework.compat import six +from rest_framework.renderers import UnicodeJSONRenderer import json import datetime import decimal @@ -32,6 +32,8 @@ class BaseParser(object): media_type = None + supports_html_forms = False + def parse(self, stream, media_type=None, parser_context=None): """ Given a stream to read from, return the parsed representation. @@ -47,6 +49,7 @@ class JSONParser(BaseParser): """ media_type = 'application/json' + renderer_class = UnicodeJSONRenderer def parse(self, stream, media_type=None, parser_context=None): """ @@ -91,6 +94,7 @@ class FormParser(BaseParser): """ media_type = 'application/x-www-form-urlencoded' + supports_html_forms = True def parse(self, stream, media_type=None, parser_context=None): """ @@ -109,6 +113,7 @@ class MultiPartParser(BaseParser): """ media_type = 'multipart/form-data' + supports_html_forms = True def parse(self, stream, media_type=None, parser_context=None): """ diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index b30f2ea9..cc8de959 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -24,7 +24,7 @@ from rest_framework.settings import api_settings from rest_framework.request import clone_request from rest_framework.utils import encoders from rest_framework.utils.breadcrumbs import get_breadcrumbs -from rest_framework import exceptions, parsers, status, VERSION +from rest_framework import exceptions, status, VERSION class BaseRenderer(object): @@ -482,7 +482,7 @@ class BrowsableAPIRenderer(BaseRenderer): if method in ('DELETE', 'OPTIONS'): return True # Don't actually need to return a form - if not getattr(view, 'get_serializer', None) or not parsers.FormParser in view.parser_classes: + if not getattr(view, 'get_serializer', None) or not any(parser.supports_html_forms for parser in view.parser_classes): return serializer = view.get_serializer(instance=obj) @@ -561,6 +561,29 @@ class BrowsableAPIRenderer(BaseRenderer): view = renderer_context['view'] request = renderer_context['request'] response = renderer_context['response'] + + obj = getattr(view, 'object', None) + if getattr(view, 'get_serializer', None): + serializer = view.get_serializer(instance=obj) + else: + serializer = None + + parsers = [] + for parser_class in view.parser_classes: + content = None + renderer_class = getattr(parser_class, 'renderer_class', None) + if renderer_class and serializer: + renderer = renderer_class() + context = renderer_context.copy() + context['indent'] = 4 + content = renderer.render(serializer.data, accepted_media_type, context) + print content + parsers.append({ + 'media_type': parser_class.media_type, + 'content': content + }) + + media_types = [parser.media_type for parser in view.parser_classes] renderer = self.get_default_renderer(view) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index abff6898..202d3a09 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -304,6 +304,8 @@ class BaseSerializer(WritableField): ret.empty = obj is None for field_name, field in self.fields.items(): + if obj is None and field.read_only: + continue field.initialize(parent=self, field_name=field_name) key = self.get_field_key(field_name) value = field.field_to_native(obj, field_name) diff --git a/rest_framework/tests/test_serializer.py b/rest_framework/tests/test_serializer.py index c2497660..7c2a276e 100644 --- a/rest_framework/tests/test_serializer.py +++ b/rest_framework/tests/test_serializer.py @@ -158,7 +158,7 @@ class BasicTests(TestCase): 'email': '', 'content': '', 'created': None, - 'sub_comment': '' + #'sub_comment': '' } self.assertEqual(serializer.data, expected) -- cgit v1.2.3 From dce47a11d3d65a697ea8aa322455d626190bc1e5 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 27 Aug 2013 12:32:13 +0100 Subject: Move settings into more sensible ordering --- rest_framework/settings.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 7d25e513..2ee15ac7 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -48,7 +48,6 @@ DEFAULTS = { ), 'DEFAULT_THROTTLE_CLASSES': ( ), - 'DEFAULT_CONTENT_NEGOTIATION_CLASS': 'rest_framework.negotiation.DefaultContentNegotiation', @@ -69,14 +68,14 @@ DEFAULTS = { 'PAGINATE_BY': None, 'PAGINATE_BY_PARAM': None, - # View configuration - 'VIEW_NAME_FUNCTION': 'rest_framework.views.get_view_name', - 'VIEW_DESCRIPTION_FUNCTION': 'rest_framework.views.get_view_description', - # Authentication 'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser', 'UNAUTHENTICATED_TOKEN': None, + # View configuration + 'VIEW_NAME_FUNCTION': 'rest_framework.views.get_view_name', + 'VIEW_DESCRIPTION_FUNCTION': 'rest_framework.views.get_view_description', + # Testing 'TEST_REQUEST_RENDERER_CLASSES': ( 'rest_framework.renderers.MultiPartRenderer', -- cgit v1.2.3 From b430503fa657330b606a9c632ea0decc4254163e Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 27 Aug 2013 12:32:33 +0100 Subject: Move exception handler out of main view --- rest_framework/views.py | 79 +++++++++++++++++++++++++++++++++++-------------- 1 file changed, 57 insertions(+), 22 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/views.py b/rest_framework/views.py index 727a9f95..7cb71ccf 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -15,8 +15,14 @@ from rest_framework.settings import api_settings from rest_framework.utils import formatting -def get_view_name(cls, suffix=None): - name = cls.__name__ +def get_view_name(view_cls, suffix=None): + """ + Given a view class, return a textual name to represent the view. + This name is used in the browsable API, and in OPTIONS responses. + + This function is the default for the `VIEW_NAME_FUNCTION` setting. + """ + name = view_cls.__name__ name = formatting.remove_trailing_string(name, 'View') name = formatting.remove_trailing_string(name, 'ViewSet') name = formatting.camelcase_to_spaces(name) @@ -25,14 +31,53 @@ def get_view_name(cls, suffix=None): return name -def get_view_description(cls, html=False): - description = cls.__doc__ or '' +def get_view_description(view_cls, html=False): + """ + Given a view class, return a textual description to represent the view. + This name is used in the browsable API, and in OPTIONS responses. + + This function is the default for the `VIEW_DESCRIPTION_FUNCTION` setting. + """ + description = view_cls.__doc__ or '' description = formatting.dedent(smart_text(description)) if html: return formatting.markup_description(description) return description +def exception_handler(exc): + """ + Returns the response that should be used for any given exception. + + By default we handle the REST framework `APIException`, and also + Django's builtin `Http404` and `PermissionDenied` exceptions. + + Any unhandled exceptions may return `None`, which will cause a 500 error + to be raised. + """ + if isinstance(exc, exceptions.APIException): + headers = {} + if getattr(exc, 'auth_header', None): + headers['WWW-Authenticate'] = exc.auth_header + if getattr(exc, 'wait', None): + headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait + + return Response({'detail': exc.detail}, + status=exc.status_code, + headers=headers) + + elif isinstance(exc, Http404): + return Response({'detail': 'Not found'}, + status=status.HTTP_404_NOT_FOUND) + + elif isinstance(exc, PermissionDenied): + return Response({'detail': 'Permission denied'}, + status=status.HTTP_403_FORBIDDEN) + + # Note: Unhandled exceptions will raise a 500 error. + return None + + class APIView(View): settings = api_settings @@ -303,33 +348,23 @@ class APIView(View): Handle any exception that occurs, by returning an appropriate response, or re-raising the error. """ - if isinstance(exc, exceptions.Throttled) and exc.wait is not None: - # Throttle wait header - self.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait - if isinstance(exc, (exceptions.NotAuthenticated, exceptions.AuthenticationFailed)): # WWW-Authenticate header for 401 responses, else coerce to 403 auth_header = self.get_authenticate_header(self.request) if auth_header: - self.headers['WWW-Authenticate'] = auth_header + exc.auth_header = auth_header else: exc.status_code = status.HTTP_403_FORBIDDEN - if isinstance(exc, exceptions.APIException): - return Response({'detail': exc.detail}, - status=exc.status_code, - exception=True) - elif isinstance(exc, Http404): - return Response({'detail': 'Not found'}, - status=status.HTTP_404_NOT_FOUND, - exception=True) - elif isinstance(exc, PermissionDenied): - return Response({'detail': 'Permission denied'}, - status=status.HTTP_403_FORBIDDEN, - exception=True) - raise + response = exception_handler(exc) + + if response is None: + raise + + response.exception = True + return response # Note: session based authentication is explicitly CSRF validated, # all other authentication is CSRF exempt. -- cgit v1.2.3 From b54cbd292c5680f4de0e028ff1cb2a9ab1cd34ff Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 27 Aug 2013 12:36:06 +0100 Subject: Use view.settings for API settings, to make testing easier. --- rest_framework/views.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/views.py b/rest_framework/views.py index 7cb71ccf..4cff0422 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -79,8 +79,8 @@ def exception_handler(exc): class APIView(View): - settings = api_settings + # The following policies may be set at either globally, or per-view. renderer_classes = api_settings.DEFAULT_RENDERER_CLASSES parser_classes = api_settings.DEFAULT_PARSER_CLASSES authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES @@ -88,6 +88,9 @@ class APIView(View): permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS + # Allow dependancy injection of other settings to make testing easier. + settings = api_settings + @classmethod def as_view(cls, **initkwargs): """ @@ -178,7 +181,7 @@ class APIView(View): Return the view name, as used in OPTIONS responses and in the browsable API. """ - func = api_settings.VIEW_NAME_FUNCTION + func = self.settings.VIEW_NAME_FUNCTION return func(self.__class__, getattr(self, 'suffix', None)) def get_view_description(self, html=False): @@ -186,7 +189,7 @@ class APIView(View): Return some descriptive text for the view, as used in OPTIONS responses and in the browsable API. """ - func = api_settings.VIEW_DESCRIPTION_FUNCTION + func = self.settings.VIEW_DESCRIPTION_FUNCTION return func(self.__class__, html) # API policy instantiation methods -- cgit v1.2.3 From 7fb3f078f0973acc1d108d8c617b26b6845599f7 Mon Sep 17 00:00:00 2001 From: Alexander Akhmetov Date: Tue, 27 Aug 2013 17:38:41 +0400 Subject: fix for python3 --- rest_framework/generics.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 33affee8..ce6c462a 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -212,12 +212,15 @@ class GenericAPIView(views.APIView): except (KeyError, ValueError): pass else: - if self.max_paginate_by: + if self.max_paginate_by is not None: return min(self.max_paginate_by, paginate_by_param) else: return paginate_by_param - return min(self.max_paginate_by, self.paginate_by) or self.paginate_by + if self.max_paginate_by: + return min(self.max_paginate_by, self.paginate_by) + else: + return self.paginate_by def get_serializer_class(self): """ -- cgit v1.2.3 From 4c53fb883fe719c3ca6244aeb8c405a24eb89a40 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 28 Aug 2013 12:52:38 +0100 Subject: Tweak MAX_PAGINATE_BY behavior in edge case. Always respect `paginate_by` settings if client does not specify page size. (Even if the developer has misconfigured, so that `paginate_by > max`.) --- rest_framework/generics.py | 20 ++++++++------------ rest_framework/tests/test_pagination.py | 11 ++++++----- 2 files changed, 14 insertions(+), 17 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/generics.py b/rest_framework/generics.py index ce6c462a..14feed20 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -14,13 +14,15 @@ from rest_framework.settings import api_settings import warnings -def strict_positive_int(integer_string): +def strict_positive_int(integer_string, cutoff=None): """ Cast a string to a strictly positive integer. """ ret = int(integer_string) if ret <= 0: raise ValueError() + if cutoff: + ret = min(ret, cutoff) return ret def get_object_or_404(queryset, **filter_kwargs): @@ -206,21 +208,15 @@ class GenericAPIView(views.APIView): PendingDeprecationWarning, stacklevel=2) if self.paginate_by_param: - query_params = self.request.QUERY_PARAMS try: - paginate_by_param = int(query_params[self.paginate_by_param]) + return strict_positive_int( + self.request.QUERY_PARAMS[self.paginate_by_param], + cutoff=self.max_paginate_by + ) except (KeyError, ValueError): pass - else: - if self.max_paginate_by is not None: - return min(self.max_paginate_by, paginate_by_param) - else: - return paginate_by_param - if self.max_paginate_by: - return min(self.max_paginate_by, self.paginate_by) - else: - return self.paginate_by + return self.paginate_by def get_serializer_class(self): """ diff --git a/rest_framework/tests/test_pagination.py b/rest_framework/tests/test_pagination.py index cbed1604..4170d4b6 100644 --- a/rest_framework/tests/test_pagination.py +++ b/rest_framework/tests/test_pagination.py @@ -47,8 +47,8 @@ class MaxPaginateByView(generics.ListAPIView): View for testing custom max_paginate_by usage """ model = BasicModel - paginate_by = 5 - max_paginate_by = 3 + paginate_by = 3 + max_paginate_by = 5 paginate_by_param = 'page_size' @@ -343,16 +343,17 @@ class TestMaxPaginateByParam(TestCase): def test_max_paginate_by(self): """ - If max_paginate_by is set and it less than paginate_by, new kwarg should limit requests for review. + If max_paginate_by is set, it should limit page size for the view. """ request = factory.get('/?page_size=10') response = self.view(request).render() self.assertEqual(response.data['count'], 13) - self.assertEqual(response.data['results'], self.data[:3]) + self.assertEqual(response.data['results'], self.data[:5]) def test_max_paginate_by_without_page_size_param(self): """ - If max_paginate_by is set, new kwarg should limit requests for review. + If max_paginate_by is set, but client does not specifiy page_size, + standard `paginate_by` behavior should be used. """ request = factory.get('/') response = self.view(request).render() -- cgit v1.2.3 From 97b52156cc0e96c2edb7e1b176838bfd9c22321a Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 28 Aug 2013 13:34:14 +0100 Subject: Added `.cache` attribute on throttles. Closes #1066. More localised than a new settings key, and more flexible in that different throttles can use different behavior. Thanks to @chicheng for the report! :) --- rest_framework/throttling.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py index 65b45593..8943f22c 100644 --- a/rest_framework/throttling.py +++ b/rest_framework/throttling.py @@ -2,7 +2,7 @@ Provides various throttling policies. """ from __future__ import unicode_literals -from django.core.cache import cache +from django.core.cache import cache as default_cache from django.core.exceptions import ImproperlyConfigured from rest_framework.settings import api_settings import time @@ -39,6 +39,7 @@ class SimpleRateThrottle(BaseThrottle): Previous request information used for throttling is stored in the cache. """ + cache = default_cache timer = time.time cache_format = 'throtte_%(scope)s_%(ident)s' scope = None @@ -99,7 +100,7 @@ class SimpleRateThrottle(BaseThrottle): if self.key is None: return True - self.history = cache.get(self.key, []) + self.history = self.cache.get(self.key, []) self.now = self.timer() # Drop any requests from the history which have now passed the @@ -116,7 +117,7 @@ class SimpleRateThrottle(BaseThrottle): into the cache. """ self.history.insert(0, self.now) - cache.set(self.key, self.history, self.duration) + self.cache.set(self.key, self.history, self.duration) return True def throttle_failure(self): -- cgit v1.2.3 From 2d5e14a8d39a53c8a2e6d28fb8ae7debb5fbd388 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 28 Aug 2013 15:32:41 +0100 Subject: Throttles now use HTTP_X_FORWARDED_FOR, falling back to REMOTE_ADDR to identify anonymous requests --- rest_framework/throttling.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py index 8943f22c..a946d837 100644 --- a/rest_framework/throttling.py +++ b/rest_framework/throttling.py @@ -152,7 +152,9 @@ class AnonRateThrottle(SimpleRateThrottle): if request.user.is_authenticated(): return None # Only throttle unauthenticated requests. - ident = request.META.get('REMOTE_ADDR', None) + ident = request.META.get('HTTP_X_FORWARDED_FOR') + if ident is None: + ident = request.META.get('REMOTE_ADDR') return self.cache_format % { 'scope': self.scope, -- cgit v1.2.3 From 18007d68464b0cfab970e2a60aed0d41c4de4dac Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 28 Aug 2013 21:52:56 +0100 Subject: Simplifying raw data renderering support --- rest_framework/parsers.py | 10 +++------- rest_framework/renderers.py | 10 ++++++++-- rest_framework/serializers.py | 2 -- rest_framework/tests/test_serializer.py | 2 +- 4 files changed, 12 insertions(+), 12 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py index c635505a..23387dff 100644 --- a/rest_framework/parsers.py +++ b/rest_framework/parsers.py @@ -12,7 +12,7 @@ from django.http.multipartparser import MultiPartParser as DjangoMultiPartParser from django.http.multipartparser import MultiPartParserError, parse_header, ChunkIter from rest_framework.compat import etree, six, yaml from rest_framework.exceptions import ParseError -from rest_framework.renderers import UnicodeJSONRenderer +from rest_framework import renderers import json import datetime import decimal @@ -32,8 +32,6 @@ class BaseParser(object): media_type = None - supports_html_forms = False - def parse(self, stream, media_type=None, parser_context=None): """ Given a stream to read from, return the parsed representation. @@ -49,7 +47,7 @@ class JSONParser(BaseParser): """ media_type = 'application/json' - renderer_class = UnicodeJSONRenderer + renderer_class = renderers.UnicodeJSONRenderer def parse(self, stream, media_type=None, parser_context=None): """ @@ -94,7 +92,6 @@ class FormParser(BaseParser): """ media_type = 'application/x-www-form-urlencoded' - supports_html_forms = True def parse(self, stream, media_type=None, parser_context=None): """ @@ -113,7 +110,6 @@ class MultiPartParser(BaseParser): """ media_type = 'multipart/form-data' - supports_html_forms = True def parse(self, stream, media_type=None, parser_context=None): """ @@ -134,7 +130,7 @@ class MultiPartParser(BaseParser): data, files = parser.parse() return DataAndFiles(data, files) except MultiPartParserError as exc: - raise ParseError('Multipart form parse error - %s' % six.u(exc)) + raise ParseError('Multipart form parse error - %s' % six.u(exc.strerror)) class XMLParser(BaseParser): diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index cc8de959..cd55c783 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -21,7 +21,7 @@ from rest_framework.compat import six from rest_framework.compat import smart_text from rest_framework.compat import yaml from rest_framework.settings import api_settings -from rest_framework.request import clone_request +from rest_framework.request import clone_request, is_form_media_type from rest_framework.utils import encoders from rest_framework.utils.breadcrumbs import get_breadcrumbs from rest_framework import exceptions, status, VERSION @@ -482,7 +482,7 @@ class BrowsableAPIRenderer(BaseRenderer): if method in ('DELETE', 'OPTIONS'): return True # Don't actually need to return a form - if not getattr(view, 'get_serializer', None) or not any(parser.supports_html_forms for parser in view.parser_classes): + if not getattr(view, 'get_serializer', None) or not any(is_form_media_type(parser.media_type) for parser in view.parser_classes): return serializer = view.get_serializer(instance=obj) @@ -565,11 +565,16 @@ class BrowsableAPIRenderer(BaseRenderer): obj = getattr(view, 'object', None) if getattr(view, 'get_serializer', None): serializer = view.get_serializer(instance=obj) + for field_name, field in serializer.fields.items(): + if field.read_only: + del serializer.fields[field_name] else: serializer = None parsers = [] for parser_class in view.parser_classes: + if is_form_media_type(parser_class.media_type): + continue content = None renderer_class = getattr(parser_class, 'renderer_class', None) if renderer_class and serializer: @@ -650,3 +655,4 @@ class MultiPartRenderer(BaseRenderer): def render(self, data, accepted_media_type=None, renderer_context=None): return encode_multipart(self.BOUNDARY, data) + diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 202d3a09..abff6898 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -304,8 +304,6 @@ class BaseSerializer(WritableField): ret.empty = obj is None for field_name, field in self.fields.items(): - if obj is None and field.read_only: - continue field.initialize(parent=self, field_name=field_name) key = self.get_field_key(field_name) value = field.field_to_native(obj, field_name) diff --git a/rest_framework/tests/test_serializer.py b/rest_framework/tests/test_serializer.py index 7c2a276e..c2497660 100644 --- a/rest_framework/tests/test_serializer.py +++ b/rest_framework/tests/test_serializer.py @@ -158,7 +158,7 @@ class BasicTests(TestCase): 'email': '', 'content': '', 'created': None, - #'sub_comment': '' + 'sub_comment': '' } self.assertEqual(serializer.data, expected) -- cgit v1.2.3 From 37e2720a40d39688f5e6ebb3b5c5aad68b8c25d4 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 29 Aug 2013 12:55:56 +0100 Subject: Add `override_method` context manager and cleanup. --- rest_framework/renderers.py | 165 ++++++++++++++++---------------------------- rest_framework/request.py | 23 ++++++ 2 files changed, 81 insertions(+), 107 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index cd55c783..34860f6a 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -21,7 +21,7 @@ from rest_framework.compat import six from rest_framework.compat import smart_text from rest_framework.compat import yaml from rest_framework.settings import api_settings -from rest_framework.request import clone_request, is_form_media_type +from rest_framework.request import is_form_media_type, override_method from rest_framework.utils import encoders from rest_framework.utils.breadcrumbs import get_breadcrumbs from rest_framework import exceptions, status, VERSION @@ -456,18 +456,6 @@ class BrowsableAPIRenderer(BaseRenderer): return False # Doesn't have permissions return True - def _get_rendered_html_form(self, view, method, request): - # We need to impersonate a request with the correct method, - # so that eg. any dynamic get_serializer_class methods return the - # correct form for each method. - restore = view.request - request = clone_request(request, method) - view.request = request - try: - return self.get_rendered_html_form(view, method, request) - finally: - view.request = restore - def get_rendered_html_form(self, view, method, request): """ Return a string representing a rendered HTML form, possibly bound to @@ -475,32 +463,22 @@ class BrowsableAPIRenderer(BaseRenderer): In the absence of the View having an associated form then return None. """ - obj = getattr(view, 'object', None) - if not self.show_form_for_method(view, method, request, obj): - return - - if method in ('DELETE', 'OPTIONS'): - return True # Don't actually need to return a form - - if not getattr(view, 'get_serializer', None) or not any(is_form_media_type(parser.media_type) for parser in view.parser_classes): - return - - serializer = view.get_serializer(instance=obj) - data = serializer.data - form_renderer = self.form_renderer_class() - return form_renderer.render(data, self.accepted_media_type, self.renderer_context) - - def _get_raw_data_form(self, view, method, request, media_types): - # We need to impersonate a request with the correct method, - # so that eg. any dynamic get_serializer_class methods return the - # correct form for each method. - restore = view.request - request = clone_request(request, method) - view.request = request - try: - return self.get_raw_data_form(view, method, request, media_types) - finally: - view.request = restore + with override_method(view, request, method) as request: + obj = getattr(view, 'object', None) + if not self.show_form_for_method(view, method, request, obj): + return + + if method in ('DELETE', 'OPTIONS'): + return True # Don't actually need to return a form + + if (not getattr(view, 'get_serializer', None) + or not any(is_form_media_type(parser.media_type) for parser in view.parser_classes)): + return + + serializer = view.get_serializer(instance=obj) + data = serializer.data + form_renderer = self.form_renderer_class() + return form_renderer.render(data, self.accepted_media_type, self.renderer_context) def get_raw_data_form(self, view, method, request, media_types): """ @@ -508,39 +486,39 @@ class BrowsableAPIRenderer(BaseRenderer): via standard HTML forms. (Which are typically application/x-www-form-urlencoded) """ - - # If we're not using content overloading there's no point in supplying a generic form, - # as the view won't treat the form's value as the content of the request. - if not (api_settings.FORM_CONTENT_OVERRIDE - and api_settings.FORM_CONTENTTYPE_OVERRIDE): - return None - - # Check permissions - obj = getattr(view, 'object', None) - if not self.show_form_for_method(view, method, request, obj): - return - - content_type_field = api_settings.FORM_CONTENTTYPE_OVERRIDE - content_field = api_settings.FORM_CONTENT_OVERRIDE - choices = [(media_type, media_type) for media_type in media_types] - initial = media_types[0] - - # NB. http://jacobian.org/writing/dynamic-form-generation/ - class GenericContentForm(forms.Form): - def __init__(self): - super(GenericContentForm, self).__init__() - - self.fields[content_type_field] = forms.ChoiceField( - label='Media type', - choices=choices, - initial=initial - ) - self.fields[content_field] = forms.CharField( - label='Content', - widget=forms.Textarea - ) - - return GenericContentForm() + with override_method(view, request, method) as request: + # If we're not using content overloading there's no point in supplying a generic form, + # as the view won't treat the form's value as the content of the request. + if not (api_settings.FORM_CONTENT_OVERRIDE + and api_settings.FORM_CONTENTTYPE_OVERRIDE): + return None + + # Check permissions + obj = getattr(view, 'object', None) + if not self.show_form_for_method(view, method, request, obj): + return + + content_type_field = api_settings.FORM_CONTENTTYPE_OVERRIDE + content_field = api_settings.FORM_CONTENT_OVERRIDE + choices = [(media_type, media_type) for media_type in media_types] + initial = media_types[0] + + # NB. http://jacobian.org/writing/dynamic-form-generation/ + class GenericContentForm(forms.Form): + def __init__(self): + super(GenericContentForm, self).__init__() + + self.fields[content_type_field] = forms.ChoiceField( + label='Media type', + choices=choices, + initial=initial + ) + self.fields[content_field] = forms.CharField( + label='Content', + widget=forms.Textarea + ) + + return GenericContentForm() def get_name(self, view): return view.get_view_name() @@ -562,47 +540,20 @@ class BrowsableAPIRenderer(BaseRenderer): request = renderer_context['request'] response = renderer_context['response'] - obj = getattr(view, 'object', None) - if getattr(view, 'get_serializer', None): - serializer = view.get_serializer(instance=obj) - for field_name, field in serializer.fields.items(): - if field.read_only: - del serializer.fields[field_name] - else: - serializer = None - - parsers = [] - for parser_class in view.parser_classes: - if is_form_media_type(parser_class.media_type): - continue - content = None - renderer_class = getattr(parser_class, 'renderer_class', None) - if renderer_class and serializer: - renderer = renderer_class() - context = renderer_context.copy() - context['indent'] = 4 - content = renderer.render(serializer.data, accepted_media_type, context) - print content - parsers.append({ - 'media_type': parser_class.media_type, - 'content': content - }) - - media_types = [parser.media_type for parser in view.parser_classes] renderer = self.get_default_renderer(view) content = self.get_content(renderer, data, accepted_media_type, renderer_context) - put_form = self._get_rendered_html_form(view, 'PUT', request) - post_form = self._get_rendered_html_form(view, 'POST', request) - patch_form = self._get_rendered_html_form(view, 'PATCH', request) - delete_form = self._get_rendered_html_form(view, 'DELETE', request) - options_form = self._get_rendered_html_form(view, 'OPTIONS', request) + put_form = self.get_rendered_html_form(view, 'PUT', request) + post_form = self.get_rendered_html_form(view, 'POST', request) + patch_form = self.get_rendered_html_form(view, 'PATCH', request) + delete_form = self.get_rendered_html_form(view, 'DELETE', request) + options_form = self.get_rendered_html_form(view, 'OPTIONS', request) - raw_data_put_form = self._get_raw_data_form(view, 'PUT', request, media_types) - raw_data_post_form = self._get_raw_data_form(view, 'POST', request, media_types) - raw_data_patch_form = self._get_raw_data_form(view, 'PATCH', request, media_types) + raw_data_put_form = self.get_raw_data_form(view, 'PUT', request, media_types) + raw_data_post_form = self.get_raw_data_form(view, 'POST', request, media_types) + raw_data_patch_form = self.get_raw_data_form(view, 'PATCH', request, media_types) raw_data_put_or_patch_form = raw_data_put_form or raw_data_patch_form name = self.get_name(view) diff --git a/rest_framework/request.py b/rest_framework/request.py index 919716f4..977d4d96 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -28,6 +28,29 @@ def is_form_media_type(media_type): base_media_type == 'multipart/form-data') +class override_method(object): + """ + A context manager that temporarily overrides the method on a request, + additionally setting the `view.request` attribute. + + Usage: + + with override_method(view, request, 'POST') as request: + ... # Do stuff with `view` and `request` + """ + def __init__(self, view, request, method): + self.view = view + self.request = request + self.method = method + + def __enter__(self): + self.view.request = clone_request(self.request, self.method) + return self.view.request + + def __exit__(self, *args, **kwarg): + self.view.request = self.request + + class Empty(object): """ Placeholder for unset attributes. -- cgit v1.2.3 From 11071499a777ecfee6edfb7e92ecf9a12d35eeb7 Mon Sep 17 00:00:00 2001 From: Mathieu Pillard Date: Thu, 29 Aug 2013 18:10:47 +0200 Subject: Make ChoiceField.from_native() follow IntegerField behaviour on empty values --- rest_framework/fields.py | 5 +++++ rest_framework/tests/test_fields.py | 8 ++++++++ 2 files changed, 13 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 3e0ca1a1..210c2537 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -514,6 +514,11 @@ class ChoiceField(WritableField): return True return False + def from_native(self, value): + if value in validators.EMPTY_VALUES: + return None + return super(ChoiceField, self).from_native(value) + class EmailField(CharField): type_name = 'EmailField' diff --git a/rest_framework/tests/test_fields.py b/rest_framework/tests/test_fields.py index ebccba7d..34fbab9c 100644 --- a/rest_framework/tests/test_fields.py +++ b/rest_framework/tests/test_fields.py @@ -688,6 +688,14 @@ class ChoiceFieldTests(TestCase): f = serializers.ChoiceField(required=False, choices=self.SAMPLE_CHOICES) self.assertEqual(f.choices, models.fields.BLANK_CHOICE_DASH + self.SAMPLE_CHOICES) + def test_from_native_empty(self): + """ + Make sure from_native() returns None on empty param. + """ + f = serializers.ChoiceField(choices=self.SAMPLE_CHOICES) + result = f.from_native('') + self.assertEqual(result, None) + class EmailFieldTests(TestCase): """ -- cgit v1.2.3 From c7f3b8bebef33093d4e949f797565c4cbcd2695d Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 29 Aug 2013 17:23:26 +0100 Subject: Include serialized content in raw data form. --- rest_framework/renderers.py | 43 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 34 insertions(+), 9 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 34860f6a..077d6ebe 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -480,15 +480,16 @@ class BrowsableAPIRenderer(BaseRenderer): form_renderer = self.form_renderer_class() return form_renderer.render(data, self.accepted_media_type, self.renderer_context) - def get_raw_data_form(self, view, method, request, media_types): + def get_raw_data_form(self, view, method, request): """ Returns a form that allows for arbitrary content types to be tunneled via standard HTML forms. (Which are typically application/x-www-form-urlencoded) """ with override_method(view, request, method) as request: - # If we're not using content overloading there's no point in supplying a generic form, - # as the view won't treat the form's value as the content of the request. + # If we're not using content overloading there's no point in + # supplying a generic form, as the view won't treat the form's + # value as the content of the request. if not (api_settings.FORM_CONTENT_OVERRIDE and api_settings.FORM_CONTENTTYPE_OVERRIDE): return None @@ -498,8 +499,33 @@ class BrowsableAPIRenderer(BaseRenderer): if not self.show_form_for_method(view, method, request, obj): return + # If possible, serialize the initial content for the generic form + default_parser = view.parser_classes[0] + renderer_class = getattr(default_parser, 'renderer_class', None) + if (hasattr(view, 'get_serializer') and renderer_class): + # View has a serializer defined and parser class has a + # corresponding renderer that can be used to render the data. + + # Get a read-only version of the serializer + serializer = view.get_serializer(instance=obj) + for field_name, field in serializer.fields.items(): + if field.read_only: + del serializer.fields[field_name] + + # Render the raw data content + renderer = renderer_class() + accepted = self.accepted_media_type + context = self.renderer_context.copy().update({'indent': 4}) + content = renderer.render(serializer.data, accepted, context) + else: + content = None + + # Generate a generic form that includes a content type field, + # and a content field. content_type_field = api_settings.FORM_CONTENTTYPE_OVERRIDE content_field = api_settings.FORM_CONTENT_OVERRIDE + + media_types = [parser.media_type for parser in view.parser_classes] choices = [(media_type, media_type) for media_type in media_types] initial = media_types[0] @@ -515,7 +541,8 @@ class BrowsableAPIRenderer(BaseRenderer): ) self.fields[content_field] = forms.CharField( label='Content', - widget=forms.Textarea + widget=forms.Textarea, + initial=content ) return GenericContentForm() @@ -540,8 +567,6 @@ class BrowsableAPIRenderer(BaseRenderer): request = renderer_context['request'] response = renderer_context['response'] - media_types = [parser.media_type for parser in view.parser_classes] - renderer = self.get_default_renderer(view) content = self.get_content(renderer, data, accepted_media_type, renderer_context) @@ -551,9 +576,9 @@ class BrowsableAPIRenderer(BaseRenderer): delete_form = self.get_rendered_html_form(view, 'DELETE', request) options_form = self.get_rendered_html_form(view, 'OPTIONS', request) - raw_data_put_form = self.get_raw_data_form(view, 'PUT', request, media_types) - raw_data_post_form = self.get_raw_data_form(view, 'POST', request, media_types) - raw_data_patch_form = self.get_raw_data_form(view, 'PATCH', request, media_types) + raw_data_put_form = self.get_raw_data_form(view, 'PUT', request) + raw_data_post_form = self.get_raw_data_form(view, 'POST', request) + raw_data_patch_form = self.get_raw_data_form(view, 'PATCH', request) raw_data_put_or_patch_form = raw_data_put_form or raw_data_patch_form name = self.get_name(view) -- cgit v1.2.3 From 1fa2d823cc9f2dcf301b0e3ce7f47acfcdfcb305 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 29 Aug 2013 20:35:59 +0100 Subject: Preserve tab preference in cookies. --- rest_framework/static/rest_framework/js/default.js | 45 +++++++++++++++++++++- rest_framework/templates/rest_framework/base.html | 4 +- 2 files changed, 46 insertions(+), 3 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/static/rest_framework/js/default.js b/rest_framework/static/rest_framework/js/default.js index c74829d7..a57b1cb8 100644 --- a/rest_framework/static/rest_framework/js/default.js +++ b/rest_framework/static/rest_framework/js/default.js @@ -1,13 +1,56 @@ +function getCookie(c_name) +{ + // From http://www.w3schools.com/js/js_cookies.asp + var c_value = document.cookie; + var c_start = c_value.indexOf(" " + c_name + "="); + if (c_start == -1) { + c_start = c_value.indexOf(c_name + "="); + } + if (c_start == -1) { + c_value = null; + } else { + c_start = c_value.indexOf("=", c_start) + 1; + var c_end = c_value.indexOf(";", c_start); + if (c_end == -1) { + c_end = c_value.length; + } + c_value = unescape(c_value.substring(c_start,c_end)); + } + return c_value; +} + +// JSON highlighting. prettyPrint(); +// Bootstrap tooltips. $('.js-tooltip').tooltip({ delay: 1000 }); +// Deal with rounded tab styling after tab clicks. $('a[data-toggle="tab"]:first').on('shown', function (e) { $(e.target).parents('.tabbable').addClass('first-tab-active'); }); $('a[data-toggle="tab"]:not(:first)').on('shown', function (e) { $(e.target).parents('.tabbable').removeClass('first-tab-active'); }); -$('.form-switcher a:first').tab('show'); + +$('a[data-toggle="tab"]').click(function(){ + document.cookie="tab=" + this.name; +}); + +// Store tab preference in cookies & display appropriate tab on load. +var selectedTab = null; +var selectedTabName = getCookie('tab'); + +if (selectedTabName) { + selectedTab = $('.form-switcher a[name=' + selectedTabName + ']'); +} + +if (selectedTab && selectedTab.length > 0) { + // Display whichever tab is selected. + selectedTab.tab('show'); +} else { + // If no tab selected, display rightmost tab. + $('.form-switcher a:first').tab('show'); +} diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html index 6ae47563..81697063 100644 --- a/rest_framework/templates/rest_framework/base.html +++ b/rest_framework/templates/rest_framework/base.html @@ -128,8 +128,8 @@
{% if post_form %} {% endif %}
-- cgit v1.2.3 From 44f8d1bef22d5f308fdbdfc29e6418816c3c27dd Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 29 Aug 2013 20:38:55 +0100 Subject: Fix tab preferences on PUT forms --- rest_framework/templates/rest_framework/base.html | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html index 81697063..aa90e90c 100644 --- a/rest_framework/templates/rest_framework/base.html +++ b/rest_framework/templates/rest_framework/base.html @@ -167,8 +167,8 @@
{% if put_form %} {% endif %}
-- cgit v1.2.3 From e4d2f54529bcf538be93da5770e05b88a32da1c7 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 29 Aug 2013 20:39:05 +0100 Subject: Fix indenting on raw data forms --- rest_framework/renderers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 077d6ebe..525e44d5 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -515,7 +515,8 @@ class BrowsableAPIRenderer(BaseRenderer): # Render the raw data content renderer = renderer_class() accepted = self.accepted_media_type - context = self.renderer_context.copy().update({'indent': 4}) + context = self.renderer_context.copy() + context['indent'] = 4 content = renderer.render(serializer.data, accepted, context) else: content = None -- cgit v1.2.3 From 02b6836ee88498861521dfff743467b0456ad109 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 29 Aug 2013 20:51:51 +0100 Subject: Fix breadcrumb view names --- rest_framework/utils/breadcrumbs.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/utils/breadcrumbs.py b/rest_framework/utils/breadcrumbs.py index 0384faba..e6690d17 100644 --- a/rest_framework/utils/breadcrumbs.py +++ b/rest_framework/utils/breadcrumbs.py @@ -8,8 +8,11 @@ def get_breadcrumbs(url): tuple of (name, url). """ + from rest_framework.settings import api_settings from rest_framework.views import APIView + view_name_func = api_settings.VIEW_NAME_FUNCTION + def breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen): """ Add tuples of (name, url) to the breadcrumbs list, @@ -28,8 +31,8 @@ def get_breadcrumbs(url): # Don't list the same view twice in a row. # Probably an optional trailing slash. if not seen or seen[-1] != view: - instance = view.cls() - name = instance.get_view_name() + suffix = getattr(view, 'suffix', None) + name = view_name_func(cls, suffix) breadcrumbs_list.insert(0, (name, prefix + url)) seen.append(view) -- cgit v1.2.3 From 2247fd68e9b3bbc91075a11f44db16fc40497b2a Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 29 Aug 2013 21:24:29 +0100 Subject: Fix multipart error when used via content-type overloading --- rest_framework/parsers.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py index 23387dff..98fc0341 100644 --- a/rest_framework/parsers.py +++ b/rest_framework/parsers.py @@ -122,7 +122,8 @@ class MultiPartParser(BaseParser): parser_context = parser_context or {} request = parser_context['request'] encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET) - meta = request.META + meta = request.META.copy() + meta['CONTENT_TYPE'] = media_type upload_handlers = request.upload_handlers try: @@ -130,7 +131,7 @@ class MultiPartParser(BaseParser): data, files = parser.parse() return DataAndFiles(data, files) except MultiPartParserError as exc: - raise ParseError('Multipart form parse error - %s' % six.u(exc.strerror)) + raise ParseError('Multipart form parse error - %s' % str(exc)) class XMLParser(BaseParser): -- cgit v1.2.3 From 3fba60e99c75dda4e14f7fe4f941d6fc84e4c986 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 30 Aug 2013 09:02:54 +0100 Subject: Drop broken placeholder serializations. --- rest_framework/renderers.py | 13 ++++++++++--- rest_framework/serializers.py | 3 ++- rest_framework/tests/test_serializer.py | 1 - 3 files changed, 12 insertions(+), 5 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 525e44d5..fca67eee 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -338,6 +338,11 @@ class HTMLFormRenderer(BaseRenderer): fields = {} for key, val in data.fields.items(): if getattr(val, 'read_only', True): + # Don't include read-only fields. + continue + + if getattr(val, 'fields', None): + # Nested data not supported by HTML forms. continue kwargs = {} @@ -476,6 +481,7 @@ class BrowsableAPIRenderer(BaseRenderer): return serializer = view.get_serializer(instance=obj) + data = serializer.data form_renderer = self.form_renderer_class() return form_renderer.render(data, self.accepted_media_type, self.renderer_context) @@ -508,9 +514,10 @@ class BrowsableAPIRenderer(BaseRenderer): # Get a read-only version of the serializer serializer = view.get_serializer(instance=obj) - for field_name, field in serializer.fields.items(): - if field.read_only: - del serializer.fields[field_name] + if obj is None: + for name, field in serializer.fields.items(): + if getattr(field, 'read_only', None): + del serializer.fields[name] # Render the raw data content renderer = renderer_class() diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index abff6898..a63c7f6c 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -334,13 +334,14 @@ class BaseSerializer(WritableField): if self.source == '*': return self.to_native(obj) + # Get the raw field value try: source = self.source or field_name value = obj for component in source.split('.'): if value is None: - return self.to_native(None) + break value = get_component(value, component) except ObjectDoesNotExist: return None diff --git a/rest_framework/tests/test_serializer.py b/rest_framework/tests/test_serializer.py index c2497660..957e3bd2 100644 --- a/rest_framework/tests/test_serializer.py +++ b/rest_framework/tests/test_serializer.py @@ -158,7 +158,6 @@ class BasicTests(TestCase): 'email': '', 'content': '', 'created': None, - 'sub_comment': '' } self.assertEqual(serializer.data, expected) -- cgit v1.2.3 From cba972911a90bdc0050bc48397bc70e1a062040d Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 30 Aug 2013 09:12:39 +0100 Subject: Fix failing empty serializer test --- rest_framework/tests/test_serializer.py | 1 + 1 file changed, 1 insertion(+) (limited to 'rest_framework') diff --git a/rest_framework/tests/test_serializer.py b/rest_framework/tests/test_serializer.py index 957e3bd2..c2497660 100644 --- a/rest_framework/tests/test_serializer.py +++ b/rest_framework/tests/test_serializer.py @@ -158,6 +158,7 @@ class BasicTests(TestCase): 'email': '', 'content': '', 'created': None, + 'sub_comment': '' } self.assertEqual(serializer.data, expected) -- cgit v1.2.3 From f3ab0b2b1d5734314dbe3cdd13cd7c4f0531bf7d Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 30 Aug 2013 09:20:12 +0100 Subject: Browsable API tab preferences should be site-wide --- rest_framework/static/rest_framework/js/default.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/static/rest_framework/js/default.js b/rest_framework/static/rest_framework/js/default.js index a57b1cb8..bcb1964d 100644 --- a/rest_framework/static/rest_framework/js/default.js +++ b/rest_framework/static/rest_framework/js/default.js @@ -36,12 +36,12 @@ $('a[data-toggle="tab"]:not(:first)').on('shown', function (e) { }); $('a[data-toggle="tab"]').click(function(){ - document.cookie="tab=" + this.name; + document.cookie="tabstyle=" + this.name + "; path=/"; }); // Store tab preference in cookies & display appropriate tab on load. var selectedTab = null; -var selectedTabName = getCookie('tab'); +var selectedTabName = getCookie('tabstyle'); if (selectedTabName) { selectedTab = $('.form-switcher a[name=' + selectedTabName + ']'); -- cgit v1.2.3