aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework/relations.py
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework/relations.py')
-rw-r--r--rest_framework/relations.py446
1 files changed, 446 insertions, 0 deletions
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
new file mode 100644
index 00000000..9b3a7790
--- /dev/null
+++ b/rest_framework/relations.py
@@ -0,0 +1,446 @@
+from django.core.exceptions import ObjectDoesNotExist, ValidationError
+from django.core.urlresolvers import resolve, get_script_prefix
+from django import forms
+from django.forms import widgets
+from django.forms.models import ModelChoiceIterator
+from django.utils.encoding import smart_unicode
+from rest_framework.fields import Field, WritableField
+from rest_framework.reverse import reverse
+from urlparse import urlparse
+
+##### Relational fields #####
+
+
+# Not actually Writable, but subclasses may need to be.
+class RelatedField(WritableField):
+ """
+ Base class for related model fields.
+
+ If not overridden, this represents a to-one relationship, using the unicode
+ representation of the target.
+ """
+ widget = widgets.Select
+ cache_choices = False
+ empty_label = None
+ default_read_only = True # TODO: Remove this
+
+ def __init__(self, *args, **kwargs):
+ self.queryset = kwargs.pop('queryset', None)
+ self.null = kwargs.pop('null', False)
+ super(RelatedField, self).__init__(*args, **kwargs)
+ self.read_only = kwargs.pop('read_only', self.default_read_only)
+
+ def initialize(self, parent, field_name):
+ super(RelatedField, self).initialize(parent, field_name)
+ if self.queryset is None and not self.read_only:
+ try:
+ manager = getattr(self.parent.opts.model, self.source or field_name)
+ if hasattr(manager, 'related'): # Forward
+ self.queryset = manager.related.model._default_manager.all()
+ else: # Reverse
+ self.queryset = manager.field.rel.to._default_manager.all()
+ except:
+ raise
+ msg = ('Serializer related fields must include a `queryset`' +
+ ' argument or set `read_only=True')
+ raise Exception(msg)
+
+ ### We need this stuff to make form choices work...
+
+ # def __deepcopy__(self, memo):
+ # result = super(RelatedField, self).__deepcopy__(memo)
+ # result.queryset = result.queryset
+ # return result
+
+ def prepare_value(self, obj):
+ return self.to_native(obj)
+
+ def label_from_instance(self, obj):
+ """
+ Return a readable representation for use with eg. select widgets.
+ """
+ desc = smart_unicode(obj)
+ ident = smart_unicode(self.to_native(obj))
+ if desc == ident:
+ return desc
+ return "%s - %s" % (desc, ident)
+
+ def _get_queryset(self):
+ return self._queryset
+
+ def _set_queryset(self, queryset):
+ self._queryset = queryset
+ self.widget.choices = self.choices
+
+ queryset = property(_get_queryset, _set_queryset)
+
+ def _get_choices(self):
+ # If self._choices is set, then somebody must have manually set
+ # the property self.choices. In this case, just return self._choices.
+ if hasattr(self, '_choices'):
+ return self._choices
+
+ # Otherwise, execute the QuerySet in self.queryset to determine the
+ # choices dynamically. Return a fresh ModelChoiceIterator that has not been
+ # consumed. Note that we're instantiating a new ModelChoiceIterator *each*
+ # time _get_choices() is called (and, thus, each time self.choices is
+ # accessed) so that we can ensure the QuerySet has not been consumed. This
+ # construct might look complicated but it allows for lazy evaluation of
+ # the queryset.
+ return ModelChoiceIterator(self)
+
+ def _set_choices(self, value):
+ # Setting choices also sets the choices on the widget.
+ # choices can be any iterable, but we call list() on it because
+ # it will be consumed more than once.
+ self._choices = self.widget.choices = list(value)
+
+ choices = property(_get_choices, _set_choices)
+
+ ### Regular serializer stuff...
+
+ def field_to_native(self, obj, field_name):
+ value = getattr(obj, self.source or field_name)
+ return self.to_native(value)
+
+ def field_from_native(self, data, files, field_name, into):
+ if self.read_only:
+ return
+
+ try:
+ value = data[field_name]
+ except KeyError:
+ if self.required:
+ raise ValidationError(self.error_messages['required'])
+ return
+
+ if value in (None, '') and not self.null:
+ raise ValidationError('Value may not be null')
+ elif value in (None, '') and self.null:
+ into[(self.source or field_name)] = None
+ else:
+ into[(self.source or field_name)] = self.from_native(value)
+
+
+class ManyRelatedMixin(object):
+ """
+ Mixin to convert a related field to a many related field.
+ """
+ widget = widgets.SelectMultiple
+
+ def field_to_native(self, obj, field_name):
+ value = getattr(obj, self.source or field_name)
+ return [self.to_native(item) for item in value.all()]
+
+ def field_from_native(self, data, files, field_name, into):
+ if self.read_only:
+ return
+
+ try:
+ # Form data
+ value = data.getlist(self.source or field_name)
+ except:
+ # Non-form data
+ value = data.get(self.source or field_name)
+ else:
+ if value == ['']:
+ value = []
+
+ into[field_name] = [self.from_native(item) for item in value]
+
+
+class ManyRelatedField(ManyRelatedMixin, RelatedField):
+ """
+ Base class for related model managers.
+
+ If not overridden, this represents a to-many relationship, using the unicode
+ representations of the target, and is read-only.
+ """
+ pass
+
+
+### PrimaryKey relationships
+
+class PrimaryKeyRelatedField(RelatedField):
+ """
+ Represents a to-one relationship as a pk value.
+ """
+ default_read_only = False
+ form_field_class = forms.ChoiceField
+
+ # TODO: Remove these field hacks...
+ def prepare_value(self, obj):
+ return self.to_native(obj.pk)
+
+ def label_from_instance(self, obj):
+ """
+ Return a readable representation for use with eg. select widgets.
+ """
+ desc = smart_unicode(obj)
+ ident = smart_unicode(self.to_native(obj.pk))
+ if desc == ident:
+ return desc
+ return "%s - %s" % (desc, ident)
+
+ # TODO: Possibly change this to just take `obj`, through prob less performant
+ def to_native(self, pk):
+ return pk
+
+ def from_native(self, data):
+ if self.queryset is None:
+ raise Exception('Writable related fields must include a `queryset` argument')
+
+ try:
+ return self.queryset.get(pk=data)
+ except ObjectDoesNotExist:
+ msg = "Invalid pk '%s' - object does not exist." % smart_unicode(data)
+ raise ValidationError(msg)
+
+ def field_to_native(self, obj, field_name):
+ try:
+ # Prefer obj.serializable_value for performance reasons
+ pk = obj.serializable_value(self.source or field_name)
+ except AttributeError:
+ # RelatedObject (reverse relationship)
+ obj = getattr(obj, self.source or field_name)
+ return self.to_native(obj.pk)
+ # Forward relationship
+ return self.to_native(pk)
+
+
+class ManyPrimaryKeyRelatedField(ManyRelatedField):
+ """
+ Represents a to-many relationship as a pk value.
+ """
+ default_read_only = False
+ form_field_class = forms.MultipleChoiceField
+
+ def prepare_value(self, obj):
+ return self.to_native(obj.pk)
+
+ def label_from_instance(self, obj):
+ """
+ Return a readable representation for use with eg. select widgets.
+ """
+ desc = smart_unicode(obj)
+ ident = smart_unicode(self.to_native(obj.pk))
+ if desc == ident:
+ return desc
+ return "%s - %s" % (desc, ident)
+
+ def to_native(self, pk):
+ return pk
+
+ def field_to_native(self, obj, field_name):
+ try:
+ # Prefer obj.serializable_value for performance reasons
+ queryset = obj.serializable_value(self.source or field_name)
+ except AttributeError:
+ # RelatedManager (reverse relationship)
+ queryset = getattr(obj, self.source or field_name)
+ return [self.to_native(item.pk) for item in queryset.all()]
+ # Forward relationship
+ return [self.to_native(item.pk) for item in queryset.all()]
+
+ def from_native(self, data):
+ if self.queryset is None:
+ raise Exception('Writable related fields must include a `queryset` argument')
+
+ try:
+ return self.queryset.get(pk=data)
+ except ObjectDoesNotExist:
+ msg = "Invalid pk '%s' - object does not exist." % smart_unicode(data)
+ raise ValidationError(msg)
+
+### Slug relationships
+
+
+class SlugRelatedField(RelatedField):
+ default_read_only = False
+ form_field_class = forms.ChoiceField
+
+ def __init__(self, *args, **kwargs):
+ self.slug_field = kwargs.pop('slug_field', None)
+ assert self.slug_field, 'slug_field is required'
+ super(SlugRelatedField, self).__init__(*args, **kwargs)
+
+ def to_native(self, obj):
+ return getattr(obj, self.slug_field)
+
+ def from_native(self, data):
+ if self.queryset is None:
+ raise Exception('Writable related fields must include a `queryset` argument')
+
+ try:
+ return self.queryset.get(**{self.slug_field: data})
+ except ObjectDoesNotExist:
+ raise ValidationError('Object with %s=%s does not exist.' %
+ (self.slug_field, unicode(data)))
+
+
+class ManySlugRelatedField(ManyRelatedMixin, SlugRelatedField):
+ form_field_class = forms.MultipleChoiceField
+
+
+### Hyperlinked relationships
+
+class HyperlinkedRelatedField(RelatedField):
+ """
+ Represents a to-one relationship, using hyperlinking.
+ """
+ pk_url_kwarg = 'pk'
+ slug_field = 'slug'
+ slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
+ default_read_only = False
+ form_field_class = forms.ChoiceField
+
+ def __init__(self, *args, **kwargs):
+ try:
+ self.view_name = kwargs.pop('view_name')
+ except:
+ raise ValueError("Hyperlinked field requires 'view_name' kwarg")
+
+ self.slug_field = kwargs.pop('slug_field', self.slug_field)
+ default_slug_kwarg = self.slug_url_kwarg or self.slug_field
+ self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg)
+ self.slug_url_kwarg = kwargs.pop('slug_url_kwarg', default_slug_kwarg)
+
+ self.format = kwargs.pop('format', None)
+ super(HyperlinkedRelatedField, self).__init__(*args, **kwargs)
+
+ def get_slug_field(self):
+ """
+ Get the name of a slug field to be used to look up by slug.
+ """
+ return self.slug_field
+
+ def to_native(self, obj):
+ view_name = self.view_name
+ request = self.context.get('request', None)
+ format = self.format or self.context.get('format', None)
+ pk = getattr(obj, 'pk', None)
+ if pk is None:
+ return
+ kwargs = {self.pk_url_kwarg: pk}
+ try:
+ return reverse(view_name, kwargs=kwargs, request=request, format=format)
+ except:
+ pass
+
+ slug = getattr(obj, self.slug_field, None)
+
+ if not slug:
+ raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name)
+
+ kwargs = {self.slug_url_kwarg: slug}
+ try:
+ return reverse(self.view_name, kwargs=kwargs, request=request, format=format)
+ except:
+ pass
+
+ kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug}
+ try:
+ return reverse(self.view_name, kwargs=kwargs, request=request, format=format)
+ except:
+ pass
+
+ raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name)
+
+ def from_native(self, value):
+ # Convert URL -> model instance pk
+ # TODO: Use values_list
+ if self.queryset is None:
+ raise Exception('Writable related fields must include a `queryset` argument')
+
+ if value.startswith('http:') or value.startswith('https:'):
+ # If needed convert absolute URLs to relative path
+ value = urlparse(value).path
+ prefix = get_script_prefix()
+ if value.startswith(prefix):
+ value = '/' + value[len(prefix):]
+
+ try:
+ match = resolve(value)
+ except:
+ raise ValidationError('Invalid hyperlink - No URL match')
+
+ if match.url_name != self.view_name:
+ raise ValidationError('Invalid hyperlink - Incorrect URL match')
+
+ pk = match.kwargs.get(self.pk_url_kwarg, None)
+ slug = match.kwargs.get(self.slug_url_kwarg, None)
+
+ # Try explicit primary key.
+ if pk is not None:
+ queryset = self.queryset.filter(pk=pk)
+ # Next, try looking up by slug.
+ elif slug is not None:
+ slug_field = self.get_slug_field()
+ queryset = self.queryset.filter(**{slug_field: slug})
+ # If none of those are defined, it's an error.
+ else:
+ raise ValidationError('Invalid hyperlink')
+
+ try:
+ obj = queryset.get()
+ except ObjectDoesNotExist:
+ raise ValidationError('Invalid hyperlink - object does not exist.')
+ return obj
+
+
+class ManyHyperlinkedRelatedField(ManyRelatedMixin, HyperlinkedRelatedField):
+ """
+ Represents a to-many relationship, using hyperlinking.
+ """
+ form_field_class = forms.MultipleChoiceField
+
+
+class HyperlinkedIdentityField(Field):
+ """
+ Represents the instance, or a property on the instance, using hyperlinking.
+ """
+ pk_url_kwarg = 'pk'
+ slug_field = 'slug'
+ slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
+
+ def __init__(self, *args, **kwargs):
+ # TODO: Make view_name mandatory, and have the
+ # HyperlinkedModelSerializer set it on-the-fly
+ self.view_name = kwargs.pop('view_name', None)
+ self.format = kwargs.pop('format', None)
+
+ self.slug_field = kwargs.pop('slug_field', self.slug_field)
+ default_slug_kwarg = self.slug_url_kwarg or self.slug_field
+ self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg)
+ self.slug_url_kwarg = kwargs.pop('slug_url_kwarg', default_slug_kwarg)
+
+ super(HyperlinkedIdentityField, self).__init__(*args, **kwargs)
+
+ def field_to_native(self, obj, field_name):
+ request = self.context.get('request', None)
+ format = self.format or self.context.get('format', None)
+ view_name = self.view_name or self.parent.opts.view_name
+ kwargs = {self.pk_url_kwarg: obj.pk}
+ try:
+ return reverse(view_name, kwargs=kwargs, request=request, format=format)
+ except:
+ pass
+
+ slug = getattr(obj, self.slug_field, None)
+
+ if not slug:
+ raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name)
+
+ kwargs = {self.slug_url_kwarg: slug}
+ try:
+ return reverse(self.view_name, kwargs=kwargs, request=request, format=format)
+ except:
+ pass
+
+ kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug}
+ try:
+ return reverse(self.view_name, kwargs=kwargs, request=request, format=format)
+ except:
+ pass
+
+ raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name)