aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework/fields.py
diff options
context:
space:
mode:
authorTom Christie2012-10-03 12:07:34 +0100
committerTom Christie2012-10-03 12:07:34 +0100
commitcab3b2f3f8f82087cf162dd7c62f18e9d8bb208a (patch)
tree62d042adeca5da396dc91c04fa3e43116619de70 /rest_framework/fields.py
parentf1f7f5d4e3cd67730c6fb2233a5e4d6afaeae636 (diff)
downloaddjango-rest-framework-cab3b2f3f8f82087cf162dd7c62f18e9d8bb208a.tar.bz2
Split out PrimaryKeyRelatedField and ManyPrimaryKeyRelatedField
Diffstat (limited to 'rest_framework/fields.py')
-rw-r--r--rest_framework/fields.py37
1 files changed, 23 insertions, 14 deletions
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 85ee5430..7cb2950c 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -181,7 +181,6 @@ class RelatedField(Field):
Subclass this and override `convert` to define custom behaviour when
serializing related objects.
"""
-
def field_to_native(self, obj, field_name):
obj = getattr(obj, self.source or field_name)
if obj.__class__.__name__ in ('RelatedManager', 'ManyRelatedManager'):
@@ -202,7 +201,6 @@ class PrimaryKeyRelatedField(RelatedField):
"""
Serializes a model related field or related manager to a pk value.
"""
-
# Note the we use ModelRelatedField's implementation, as we want to get the
# raw database value directly, since that won't involve another
# database lookup.
@@ -225,23 +223,34 @@ class PrimaryKeyRelatedField(RelatedField):
try:
obj = obj.serializable_value(self.source or field_name)
except AttributeError:
- field = obj._meta.get_field_by_name(field_name)[0]
+ # RelatedObject (reverse relationship)
obj = getattr(obj, self.source or field_name)
- if obj.__class__.__name__ == 'RelatedManager':
- return [self.to_native(item.pk) for item in obj.all()]
- elif isinstance(field, RelatedObject):
- return self.to_native(obj.pk)
- raise
- if obj.__class__.__name__ == 'ManyRelatedManager':
- return [self.to_native(item.pk) for item in obj.all()]
+ return self.to_native(obj.pk)
+ # Forward relationship
return self.to_native(obj)
def field_from_native(self, data, field_name, into):
value = data.get(field_name)
- if hasattr(value, '__iter__'):
- into[field_name] = [self.from_native(item) for item in value]
- else:
- into[field_name + '_id'] = self.from_native(value)
+ into[field_name + '_id'] = self.from_native(value)
+
+
+class ManyPrimaryKeyRelatedField(PrimaryKeyRelatedField):
+ def field_to_native(self, obj, field_name):
+ try:
+ obj = obj.serializable_value(self.source or field_name)
+ except AttributeError:
+ # RelatedManager (reverse relationship)
+ obj = getattr(obj, self.source or field_name)
+ return [self.to_native(item.pk) for item in obj.all()]
+ # Forward relationship
+ return [self.to_native(item.pk) for item in obj.all()]
+
+ def field_from_native(self, data, field_name, into):
+ try:
+ value = data.getlist(field_name)
+ except:
+ value = data.get(field_name)
+ into[field_name] = [self.from_native(item) for item in value]
class NaturalKeyRelatedField(RelatedField):