aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework/relations.py
diff options
context:
space:
mode:
authorJosé Padilla2014-11-28 12:14:40 -0400
committerJosé Padilla2014-11-28 12:14:40 -0400
commit0cc990792c63caa8fa8fea62cea53b0d28157b55 (patch)
tree7ea80a203cc8718150cd55e4403f3f4771160281 /rest_framework/relations.py
parent1aa77830955dcdf829f65a9001b6b8900dfc8755 (diff)
parent3a5b3772fefc3c2f2c0899947cbc07bfe6e6b5d2 (diff)
downloaddjango-rest-framework-0cc990792c63caa8fa8fea62cea53b0d28157b55.tar.bz2
Merge branch 'version-3.1' into oauth_as_package
Conflicts: requirements-test.txt rest_framework/compat.py tests/settings.py tox.ini
Diffstat (limited to 'rest_framework/relations.py')
-rw-r--r--rest_framework/relations.py154
1 files changed, 133 insertions, 21 deletions
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index 5aa1f8bd..d1ea497a 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -1,12 +1,32 @@
from rest_framework.compat import smart_text, urlparse
-from rest_framework.fields import Field
+from rest_framework.fields import get_attribute, empty, Field
from rest_framework.reverse import reverse
+from rest_framework.utils import html
from django.core.exceptions import ObjectDoesNotExist, ImproperlyConfigured
from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch, Resolver404
from django.db.models.query import QuerySet
+from django.utils import six
from django.utils.translation import ugettext_lazy as _
+class PKOnlyObject(object):
+ """
+ This is a mock object, used for when we only need the pk of the object
+ instance, but still want to return an object with a .pk attribute,
+ in order to keep the same interface as a regular model instance.
+ """
+ def __init__(self, pk):
+ self.pk = pk
+
+
+# We assume that 'validators' are intended for the child serializer,
+# rather than the parent serializer.
+MANY_RELATION_KWARGS = (
+ 'read_only', 'write_only', 'required', 'default', 'initial', 'source',
+ 'label', 'help_text', 'style', 'error_messages'
+)
+
+
class RelatedField(Field):
def __init__(self, **kwargs):
self.queryset = kwargs.pop('queryset', None)
@@ -22,14 +42,40 @@ class RelatedField(Field):
def __new__(cls, *args, **kwargs):
# We override this method in order to automagically create
- # `ManyRelation` classes instead when `many=True` is set.
+ # `ManyRelatedField` classes instead when `many=True` is set.
if kwargs.pop('many', False):
- return ManyRelation(
- child_relation=cls(*args, **kwargs),
- read_only=kwargs.get('read_only', False)
- )
+ return cls.many_init(*args, **kwargs)
return super(RelatedField, cls).__new__(cls, *args, **kwargs)
+ @classmethod
+ def many_init(cls, *args, **kwargs):
+ """
+ This method handles creating a parent `ManyRelatedField` instance
+ when the `many=True` keyword argument is passed.
+
+ Typically you won't need to override this method.
+
+ Note that we're over-cautious in passing most arguments to both parent
+ and child classes in order to try to cover the general case. If you're
+ overriding this method you'll probably want something much simpler, eg:
+
+ @classmethod
+ def many_init(cls, *args, **kwargs):
+ kwargs['child'] = cls()
+ return CustomManyRelatedField(*args, **kwargs)
+ """
+ list_kwargs = {'child_relation': cls(*args, **kwargs)}
+ for key in kwargs.keys():
+ if key in MANY_RELATION_KWARGS:
+ list_kwargs[key] = kwargs[key]
+ return ManyRelatedField(**list_kwargs)
+
+ def run_validation(self, data=empty):
+ # We force empty strings to None values for relational fields.
+ if data == '':
+ data = None
+ return super(RelatedField, self).run_validation(data)
+
def get_queryset(self):
queryset = self.queryset
if isinstance(queryset, QuerySet):
@@ -37,8 +83,22 @@ class RelatedField(Field):
queryset = queryset.all()
return queryset
+ def get_iterable(self, instance, source_attrs):
+ relationship = get_attribute(instance, source_attrs)
+ return relationship.all() if (hasattr(relationship, 'all')) else relationship
+
+ @property
+ def choices(self):
+ return dict([
+ (
+ str(self.to_representation(item)),
+ str(item)
+ )
+ for item in self.queryset.all()
+ ])
+
-class StringRelatedField(Field):
+class StringRelatedField(RelatedField):
"""
A read only field that represents its targets using their
plain string representation.
@@ -49,7 +109,7 @@ class StringRelatedField(Field):
super(StringRelatedField, self).__init__(**kwargs)
def to_representation(self, value):
- return str(value)
+ return six.text_type(value)
class PrimaryKeyRelatedField(RelatedField):
@@ -67,6 +127,32 @@ class PrimaryKeyRelatedField(RelatedField):
except (TypeError, ValueError):
self.fail('incorrect_type', data_type=type(data).__name__)
+ def get_attribute(self, instance):
+ # We customize `get_attribute` here for performance reasons.
+ # For relationships the instance will already have the pk of
+ # the related object. We return this directly instead of returning the
+ # object itself, which would require a database lookup.
+ try:
+ instance = get_attribute(instance, self.source_attrs[:-1])
+ return PKOnlyObject(pk=instance.serializable_value(self.source_attrs[-1]))
+ except AttributeError:
+ return get_attribute(instance, self.source_attrs)
+
+ def get_iterable(self, instance, source_attrs):
+ # For consistency with `get_attribute` we're using `serializable_value()`
+ # here. Typically there won't be any difference, but some custom field
+ # types might return a non-primative value for the pk otherwise.
+ #
+ # We could try to get smart with `values_list('pk', flat=True)`, which
+ # would be better in some case, but would actually end up with *more*
+ # queries if the developer is using `prefetch_related` across the
+ # relationship.
+ relationship = super(PrimaryKeyRelatedField, self).get_iterable(instance, source_attrs)
+ return [
+ PKOnlyObject(pk=item.serializable_value('pk'))
+ for item in relationship
+ ]
+
def to_representation(self, value):
return value.pk
@@ -89,9 +175,9 @@ class HyperlinkedRelatedField(RelatedField):
self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field)
self.format = kwargs.pop('format', None)
- # We include these simply for dependancy injection in tests.
+ # We include these simply for dependency injection in tests.
# We can't add them as class attributes or they would expect an
- # implict `self` argument to be passed.
+ # implicit `self` argument to be passed.
self.reverse = reverse
self.resolve = resolve
@@ -227,26 +313,33 @@ class SlugRelatedField(RelatedField):
return getattr(obj, self.slug_field)
-class ManyRelation(Field):
+class ManyRelatedField(Field):
"""
Relationships with `many=True` transparently get coerced into instead being
- a ManyRelation with a child relationship.
+ a ManyRelatedField with a child relationship.
- The `ManyRelation` class is responsible for handling iterating through
+ The `ManyRelatedField` class is responsible for handling iterating through
the values and passing each one to the child relationship.
- You shouldn't need to be using this class directly yourself.
+ This class is treated as private API.
+ You shouldn't generally need to be using this class directly yourself,
+ and should instead simply set 'many=True' on the relationship.
"""
+ initial = []
+ default_empty_html = []
def __init__(self, child_relation=None, *args, **kwargs):
self.child_relation = child_relation
assert child_relation is not None, '`child_relation` is a required argument.'
- super(ManyRelation, self).__init__(*args, **kwargs)
+ super(ManyRelatedField, self).__init__(*args, **kwargs)
+ self.child_relation.bind(field_name='', parent=self)
- def bind(self, field_name, parent, root):
- # ManyRelation needs to provide the current context to the child relation.
- super(ManyRelation, self).bind(field_name, parent, root)
- self.child_relation.bind(field_name, parent, root)
+ def get_value(self, dictionary):
+ # We override the default field access in order to support
+ # lists in HTML forms.
+ if html.is_html_input(dictionary):
+ return dictionary.getlist(self.field_name)
+ return dictionary.get(self.field_name, empty)
def to_internal_value(self, data):
return [
@@ -254,8 +347,27 @@ class ManyRelation(Field):
for item in data
]
- def to_representation(self, obj):
+ def get_attribute(self, instance):
+ return self.child_relation.get_iterable(instance, self.source_attrs)
+
+ def to_representation(self, iterable):
return [
self.child_relation.to_representation(value)
- for value in obj.all()
+ for value in iterable
]
+
+ @property
+ def choices(self):
+ queryset = self.child_relation.queryset
+ iterable = queryset.all() if (hasattr(queryset, 'all')) else queryset
+ items_and_representations = [
+ (item, self.child_relation.to_representation(item))
+ for item in iterable
+ ]
+ return dict([
+ (
+ str(item_representation),
+ str(item) + ' - ' + str(item_representation)
+ )
+ for item, item_representation in items_and_representations
+ ])