aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework/relations.py
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework/relations.py')
-rw-r--r--rest_framework/relations.py122
1 files changed, 92 insertions, 30 deletions
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index fc5054b2..edaf76d6 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -8,10 +8,11 @@ from __future__ import unicode_literals
from django.core.exceptions import ObjectDoesNotExist, ValidationError
from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch
from django import forms
+from django.db.models.fields import BLANK_CHOICE_DASH
from django.forms import widgets
from django.forms.models import ModelChoiceIterator
from django.utils.translation import ugettext_lazy as _
-from rest_framework.fields import Field, WritableField, get_component
+from rest_framework.fields import Field, WritableField, get_component, is_simple_callable
from rest_framework.reverse import reverse
from rest_framework.compat import urlparse
from rest_framework.compat import smart_text
@@ -47,7 +48,7 @@ class RelatedField(WritableField):
DeprecationWarning, stacklevel=2)
kwargs['required'] = not kwargs.pop('null')
- self.queryset = kwargs.pop('queryset', None)
+ queryset = kwargs.pop('queryset', None)
self.many = kwargs.pop('many', self.many)
if self.many:
self.widget = self.many_widget
@@ -56,6 +57,11 @@ class RelatedField(WritableField):
kwargs['read_only'] = kwargs.pop('read_only', self.read_only)
super(RelatedField, self).__init__(*args, **kwargs)
+ if not self.required:
+ self.empty_label = BLANK_CHOICE_DASH[0][1]
+
+ self.queryset = queryset
+
def initialize(self, parent, field_name):
super(RelatedField, self).initialize(parent, field_name)
if self.queryset is None and not self.read_only:
@@ -66,7 +72,6 @@ class RelatedField(WritableField):
else: # Reverse
self.queryset = manager.field.rel.to._default_manager.all()
except Exception:
- raise
msg = ('Serializer related fields must include a `queryset`' +
' argument or set `read_only=True')
raise Exception(msg)
@@ -139,7 +144,12 @@ class RelatedField(WritableField):
return None
if self.many:
- return [self.to_native(item) for item in value.all()]
+ if is_simple_callable(getattr(value, 'all', None)):
+ return [self.to_native(item) for item in value.all()]
+ else:
+ # Also support non-queryset iterables.
+ # This allows us to also support plain lists of related items.
+ return [self.to_native(item) for item in value]
return self.to_native(value)
def field_from_native(self, data, files, field_name, into):
@@ -221,15 +231,28 @@ class PrimaryKeyRelatedField(RelatedField):
def field_to_native(self, obj, field_name):
if self.many:
# To-many relationship
- try:
+
+ queryset = None
+ if not self.source:
# Prefer obj.serializable_value for performance reasons
- queryset = obj.serializable_value(self.source or field_name)
- except AttributeError:
+ try:
+ queryset = obj.serializable_value(field_name)
+ except AttributeError:
+ pass
+ if queryset is None:
# RelatedManager (reverse relationship)
- queryset = getattr(obj, self.source or field_name)
+ source = self.source or field_name
+ queryset = obj
+ for component in source.split('.'):
+ queryset = get_component(queryset, component)
# Forward relationship
- return [self.to_native(item.pk) for item in queryset.all()]
+ if is_simple_callable(getattr(queryset, 'all', None)):
+ return [self.to_native(item.pk) for item in queryset.all()]
+ else:
+ # Also support non-queryset iterables.
+ # This allows us to also support plain lists of related items.
+ return [self.to_native(item.pk) for item in queryset]
# To-one relationship
try:
@@ -434,7 +457,7 @@ class HyperlinkedRelatedField(RelatedField):
raise Exception('Writable related fields must include a `queryset` argument')
try:
- http_prefix = value.startswith('http:') or value.startswith('https:')
+ http_prefix = value.startswith(('http:', 'https:'))
except AttributeError:
msg = self.error_messages['incorrect_type']
raise ValidationError(msg % type(value).__name__)
@@ -465,17 +488,35 @@ class HyperlinkedIdentityField(Field):
"""
Represents the instance, or a property on the instance, using hyperlinking.
"""
+ lookup_field = 'pk'
+ read_only = True
+
+ # These are all pending deprecation
pk_url_kwarg = 'pk'
slug_field = 'slug'
slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
- read_only = True
def __init__(self, *args, **kwargs):
- # TODO: Make view_name mandatory, and have the
- # HyperlinkedModelSerializer set it on-the-fly
- self.view_name = kwargs.pop('view_name', None)
- # Optionally the format of the target hyperlink may be specified
+ try:
+ self.view_name = kwargs.pop('view_name')
+ except KeyError:
+ msg = "HyperlinkedIdentityField requires 'view_name' argument"
+ raise ValueError(msg)
+
self.format = kwargs.pop('format', None)
+ lookup_field = kwargs.pop('lookup_field', None)
+ self.lookup_field = lookup_field or self.lookup_field
+
+ # These are pending deprecation
+ if 'pk_url_kwarg' in kwargs:
+ msg = 'pk_url_kwarg is pending deprecation. Use lookup_field instead.'
+ warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
+ if 'slug_url_kwarg' in kwargs:
+ msg = 'slug_url_kwarg is pending deprecation. Use lookup_field instead.'
+ warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
+ if 'slug_field' in kwargs:
+ msg = 'slug_field is pending deprecation. Use lookup_field instead.'
+ warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
self.slug_field = kwargs.pop('slug_field', self.slug_field)
default_slug_kwarg = self.slug_url_kwarg or self.slug_field
@@ -487,8 +528,7 @@ class HyperlinkedIdentityField(Field):
def field_to_native(self, obj, field_name):
request = self.context.get('request', None)
format = self.context.get('format', None)
- view_name = self.view_name or self.parent.opts.view_name
- kwargs = {self.pk_url_kwarg: obj.pk}
+ view_name = self.view_name
if request is None:
warnings.warn("Using `HyperlinkedIdentityField` without including the "
@@ -508,29 +548,51 @@ class HyperlinkedIdentityField(Field):
if format and self.format and self.format != format:
format = self.format
+ # Return the hyperlink, or error if incorrectly configured.
try:
- return reverse(view_name, kwargs=kwargs, request=request, format=format)
+ return self.get_url(obj, view_name, request, format)
except NoReverseMatch:
- pass
-
- slug = getattr(obj, self.slug_field, None)
+ msg = (
+ 'Could not resolve URL for hyperlinked relationship using '
+ 'view name "%s". You may have failed to include the related '
+ 'model in your API, or incorrectly configured the '
+ '`lookup_field` attribute on this field.'
+ )
+ raise Exception(msg % view_name)
- if not slug:
- raise Exception('Could not resolve URL for field using view name "%s"' % view_name)
+ def get_url(self, obj, view_name, request, format):
+ """
+ Given an object, return the URL that hyperlinks to the object.
- kwargs = {self.slug_url_kwarg: slug}
+ May raise a `NoReverseMatch` if the `view_name` and `lookup_field`
+ attributes are not configured to correctly match the URL conf.
+ """
+ lookup_field = getattr(obj, self.lookup_field)
+ kwargs = {self.lookup_field: lookup_field}
try:
return reverse(view_name, kwargs=kwargs, request=request, format=format)
except NoReverseMatch:
pass
- kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug}
- try:
- return reverse(view_name, kwargs=kwargs, request=request, format=format)
- except NoReverseMatch:
- pass
+ if self.pk_url_kwarg != 'pk':
+ # Only try pk lookup if it has been explicitly set.
+ # Otherwise, the default `lookup_field = 'pk'` has us covered.
+ kwargs = {self.pk_url_kwarg: obj.pk}
+ try:
+ return reverse(view_name, kwargs=kwargs, request=request, format=format)
+ except NoReverseMatch:
+ pass
- raise Exception('Could not resolve URL for field using view name "%s"' % view_name)
+ slug = getattr(obj, self.slug_field, None)
+ if slug:
+ # Only use slug lookup if a slug field exists on the model
+ kwargs = {self.slug_url_kwarg: slug}
+ try:
+ return reverse(view_name, kwargs=kwargs, request=request, format=format)
+ except NoReverseMatch:
+ pass
+
+ raise NoReverseMatch()
### Old-style many classes for backwards compat