aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework/fields.py
diff options
context:
space:
mode:
authorTom Christie2012-10-03 13:28:22 -0700
committerTom Christie2012-10-03 13:28:22 -0700
commit0a769f261e79272cf1be6add1bf96aaeec59fb05 (patch)
treedd7f32d5ce01e88a2e435c406bbd120799340799 /rest_framework/fields.py
parent89ec0b275039868668080be740c46ebef92cff1e (diff)
parenta02707e12f750fd0d325e528f7b0fbcd7079db73 (diff)
downloaddjango-rest-framework-0a769f261e79272cf1be6add1bf96aaeec59fb05.tar.bz2
Merge pull request #277 from tomchristie/related-field-fixes
Related field fixes
Diffstat (limited to 'rest_framework/fields.py')
-rw-r--r--rest_framework/fields.py71
1 files changed, 43 insertions, 28 deletions
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 85ee5430..edc77e1a 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -7,7 +7,6 @@ from django.core import validators
from django.core.exceptions import ValidationError
from django.conf import settings
from django.db import DEFAULT_DB_ALIAS
-from django.db.models.related import RelatedObject
from django.utils.encoding import is_protected_type, smart_unicode
from django.utils.translation import ugettext_lazy as _
from rest_framework.compat import parse_date, parse_datetime
@@ -181,6 +180,9 @@ class RelatedField(Field):
Subclass this and override `convert` to define custom behaviour when
serializing related objects.
"""
+ def __init__(self, *args, **kwargs):
+ self.queryset = kwargs.pop('queryset', None)
+ super(RelatedField, self).__init__(*args, **kwargs)
def field_to_native(self, obj, field_name):
obj = getattr(obj, self.source or field_name)
@@ -200,48 +202,61 @@ class RelatedField(Field):
class PrimaryKeyRelatedField(RelatedField):
"""
- Serializes a model related field or related manager to a pk value.
+ Serializes a related field or related object to a pk value.
"""
- # Note the we use ModelRelatedField's implementation, as we want to get the
- # raw database value directly, since that won't involve another
- # database lookup.
- #
- # An alternative implementation would simply be this...
- #
- # class PrimaryKeyRelatedField(RelatedField):
- # def to_native(self, obj):
- # return obj.pk
-
def to_native(self, pk):
"""
- Simply returns the object's pk. You can subclass this method to
- provide different serialization behavior of the pk.
- (For example returning a URL based on the model's pk.)
+ You can subclass this method to provide different serialization
+ behavior based on the pk.
"""
return pk
def field_to_native(self, obj, field_name):
+ # This is only implemented for performance reasons
+ #
+ # We could leave the default `RelatedField.field_to_native()` in place,
+ # and inside just implement `to_native()` as `return obj.pk`
+ #
+ # That would involve an extra database lookup.
try:
- obj = obj.serializable_value(self.source or field_name)
+ pk = obj.serializable_value(self.source or field_name)
except AttributeError:
- field = obj._meta.get_field_by_name(field_name)[0]
+ # RelatedObject (reverse relationship)
obj = getattr(obj, self.source or field_name)
- if obj.__class__.__name__ == 'RelatedManager':
- return [self.to_native(item.pk) for item in obj.all()]
- elif isinstance(field, RelatedObject):
- return self.to_native(obj.pk)
- raise
- if obj.__class__.__name__ == 'ManyRelatedManager':
- return [self.to_native(item.pk) for item in obj.all()]
- return self.to_native(obj)
+ return self.to_native(obj.pk)
+ # Forward relationship
+ return self.to_native(pk)
def field_from_native(self, data, field_name, into):
value = data.get(field_name)
- if hasattr(value, '__iter__'):
- into[field_name] = [self.from_native(item) for item in value]
+ into[field_name + '_id'] = self.from_native(value)
+
+
+class ManyPrimaryKeyRelatedField(PrimaryKeyRelatedField):
+ """
+ Serializes a to-many related field or related manager to a pk value.
+ """
+
+ def field_to_native(self, obj, field_name):
+ try:
+ queryset = obj.serializable_value(self.source or field_name)
+ except AttributeError:
+ # RelatedManager (reverse relationship)
+ queryset = getattr(obj, self.source or field_name)
+ return [self.to_native(item.pk) for item in queryset.all()]
+ # Forward relationship
+ return [self.to_native(item.pk) for item in queryset.all()]
+
+ def field_from_native(self, data, field_name, into):
+ try:
+ value = data.getlist(field_name)
+ except:
+ value = data.get(field_name)
else:
- into[field_name + '_id'] = self.from_native(value)
+ if value == ['']:
+ value = []
+ into[field_name] = [self.from_native(item) for item in value]
class NaturalKeyRelatedField(RelatedField):