aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTom Christie2012-10-03 12:16:30 +0100
committerTom Christie2012-10-03 12:16:30 +0100
commit58c1263267e5947f8243568edb33273effdc2787 (patch)
treee6c6d0c213d05e32240eabe07f408ab672514781
parentcab3b2f3f8f82087cf162dd7c62f18e9d8bb208a (diff)
downloaddjango-rest-framework-58c1263267e5947f8243568edb33273effdc2787.tar.bz2
Use either PrimaryKeyRelatedField or ManyPrimaryKeyRelatedField as appropriate (fixes test)
-rw-r--r--rest_framework/fields.py12
-rw-r--r--rest_framework/serializers.py2
-rw-r--r--rest_framework/tests/serializer.py1
3 files changed, 9 insertions, 6 deletions
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 7cb2950c..b51d70a8 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -221,13 +221,13 @@ class PrimaryKeyRelatedField(RelatedField):
def field_to_native(self, obj, field_name):
try:
- obj = obj.serializable_value(self.source or field_name)
+ pk = obj.serializable_value(self.source or field_name)
except AttributeError:
# RelatedObject (reverse relationship)
obj = getattr(obj, self.source or field_name)
return self.to_native(obj.pk)
# Forward relationship
- return self.to_native(obj)
+ return self.to_native(pk)
def field_from_native(self, data, field_name, into):
value = data.get(field_name)
@@ -237,13 +237,13 @@ class PrimaryKeyRelatedField(RelatedField):
class ManyPrimaryKeyRelatedField(PrimaryKeyRelatedField):
def field_to_native(self, obj, field_name):
try:
- obj = obj.serializable_value(self.source or field_name)
+ queryset = obj.serializable_value(self.source or field_name)
except AttributeError:
# RelatedManager (reverse relationship)
- obj = getattr(obj, self.source or field_name)
- return [self.to_native(item.pk) for item in obj.all()]
+ queryset = getattr(obj, self.source or field_name)
+ return [self.to_native(item.pk) for item in queryset.all()]
# Forward relationship
- return [self.to_native(item.pk) for item in obj.all()]
+ return [self.to_native(item.pk) for item in queryset.all()]
def field_from_native(self, data, field_name, into):
try:
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 683b9efc..d3ae9b8a 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -351,6 +351,8 @@ class ModelSerializer(RelatedField, Serializer):
"""
Creates a default instance of a flat relational field.
"""
+ if isinstance(model_field, models.fields.related.ManyToManyField):
+ return ManyPrimaryKeyRelatedField()
return PrimaryKeyRelatedField()
def get_field(self, model_field):
diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py
index f90dce16..db342c9e 100644
--- a/rest_framework/tests/serializer.py
+++ b/rest_framework/tests/serializer.py
@@ -201,6 +201,7 @@ class ManyToManyTests(TestCase):
self.assertEquals(len(ManyToManyModel.objects.all()), 2)
self.assertEquals(instance.pk, 2)
self.assertEquals(list(instance.rel.all()), [])
+
# def test_deserialization_for_update(self):
# serializer = self.serializer_class(self.data, instance=self.instance)
# expected = self.instance