aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/relations.py16
-rw-r--r--rest_framework/tests/test_serializer.py44
2 files changed, 57 insertions, 3 deletions
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index e3675b51..edaf76d6 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -12,7 +12,7 @@ from django.db.models.fields import BLANK_CHOICE_DASH
from django.forms import widgets
from django.forms.models import ModelChoiceIterator
from django.utils.translation import ugettext_lazy as _
-from rest_framework.fields import Field, WritableField, get_component
+from rest_framework.fields import Field, WritableField, get_component, is_simple_callable
from rest_framework.reverse import reverse
from rest_framework.compat import urlparse
from rest_framework.compat import smart_text
@@ -144,7 +144,12 @@ class RelatedField(WritableField):
return None
if self.many:
- return [self.to_native(item) for item in value.all()]
+ if is_simple_callable(getattr(value, 'all', None)):
+ return [self.to_native(item) for item in value.all()]
+ else:
+ # Also support non-queryset iterables.
+ # This allows us to also support plain lists of related items.
+ return [self.to_native(item) for item in value]
return self.to_native(value)
def field_from_native(self, data, files, field_name, into):
@@ -242,7 +247,12 @@ class PrimaryKeyRelatedField(RelatedField):
queryset = get_component(queryset, component)
# Forward relationship
- return [self.to_native(item.pk) for item in queryset.all()]
+ if is_simple_callable(getattr(queryset, 'all', None)):
+ return [self.to_native(item.pk) for item in queryset.all()]
+ else:
+ # Also support non-queryset iterables.
+ # This allows us to also support plain lists of related items.
+ return [self.to_native(item.pk) for item in queryset]
# To-one relationship
try:
diff --git a/rest_framework/tests/test_serializer.py b/rest_framework/tests/test_serializer.py
index ecaa7255..8b87a084 100644
--- a/rest_framework/tests/test_serializer.py
+++ b/rest_framework/tests/test_serializer.py
@@ -1558,6 +1558,8 @@ class MetadataSerializerTestCase(TestCase):
self.assertEqual(expected, metadata)
+### Regression test for #840
+
class SimpleModel(models.Model):
text = models.CharField(max_length=100)
@@ -1573,6 +1575,7 @@ class SimpleModelSerializer(serializers.ModelSerializer):
del attrs['other']
return attrs
+
class FieldValidationRemovingAttr(TestCase):
def test_removing_non_model_field_in_validation(self):
"""
@@ -1587,3 +1590,44 @@ class FieldValidationRemovingAttr(TestCase):
self.assertTrue(serializer.is_valid())
serializer.save()
self.assertEqual(serializer.object.text, 'foo')
+
+
+### Regression test for #878
+
+class SimpleTargetModel(models.Model):
+ text = models.CharField(max_length=100)
+
+
+class SimplePKSourceModelSerializer(serializers.Serializer):
+ targets = serializers.PrimaryKeyRelatedField(queryset=SimpleTargetModel.objects.all(), many=True)
+ text = serializers.CharField()
+
+
+class SimpleSlugSourceModelSerializer(serializers.Serializer):
+ targets = serializers.SlugRelatedField(queryset=SimpleTargetModel.objects.all(), many=True, slug_field='pk')
+ text = serializers.CharField()
+
+
+class SerializerSupportsManyRelationships(TestCase):
+ def setUp(self):
+ SimpleTargetModel.objects.create(text='foo')
+ SimpleTargetModel.objects.create(text='bar')
+
+ def test_serializer_supports_pk_many_relationships(self):
+ """
+ Regression test for #878.
+
+ Note that pk behavior has a different code path to usual cases,
+ for performance reasons.
+ """
+ serializer = SimplePKSourceModelSerializer(data={'text': 'foo', 'targets': [1, 2]})
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, {'text': 'foo', 'targets': [1, 2]})
+
+ def test_serializer_supports_slug_many_relationships(self):
+ """
+ Regression test for #878.
+ """
+ serializer = SimpleSlugSourceModelSerializer(data={'text': 'foo', 'targets': [1, 2]})
+ self.assertTrue(serializer.is_valid())
+ self.assertEqual(serializer.data, {'text': 'foo', 'targets': [1, 2]})