diff options
| author | Tom Christie | 2012-10-04 14:14:56 -0700 |
|---|---|---|
| committer | Tom Christie | 2012-10-04 14:14:56 -0700 |
| commit | ad5e6eb16f4db928e1fc8d0a6af4f9f4584f7b08 (patch) | |
| tree | ae049236abc6868c0b48803a04e8dc7cd4d5040c /rest_framework/fields.py | |
| parent | 42b3fdbdc26927e55713db31548a410870d82949 (diff) | |
| parent | 693892ed0104b8ce8cd801e7bec6107feeb88782 (diff) | |
| download | django-rest-framework-ad5e6eb16f4db928e1fc8d0a6af4f9f4584f7b08.tar.bz2 | |
Merge pull request #280 from tomchristie/hyperlinked-relationships
Hyperlinked relationships
Diffstat (limited to 'rest_framework/fields.py')
| -rw-r--r-- | rest_framework/fields.py | 108 |
1 files changed, 99 insertions, 9 deletions
diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 32f2d122..b9ac3776 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -4,9 +4,9 @@ import inspect import warnings from django.core import validators -from django.core.exceptions import ValidationError +from django.core.exceptions import ObjectDoesNotExist, ValidationError +from django.core.urlresolvers import resolve from django.conf import settings -from django.db import DEFAULT_DB_ALIAS from django.utils.encoding import is_protected_type, smart_unicode from django.utils.translation import ugettext_lazy as _ from rest_framework.reverse import reverse @@ -27,6 +27,7 @@ def is_simple_callable(obj): class Field(object): creation_counter = 0 empty = '' + type_name = None def __init__(self, source=None): self.parent = None @@ -82,6 +83,10 @@ class Field(object): if is_protected_type(value): return value + + all_callable = getattr(value, 'all', None) + if is_simple_callable(all_callable): + return [self.to_native(item) for item in value.all()] elif hasattr(value, '__iter__') and not isinstance(value, (dict, basestring)): return [self.to_native(item) for item in value] return smart_unicode(value) @@ -90,7 +95,7 @@ class Field(object): """ Returns a dictionary of attributes to be used when serializing to xml. """ - if getattr(self, 'type_name', None): + if self.type_name: return {'type': self.type_name} return {} @@ -196,7 +201,7 @@ class ModelField(WritableField): value = self.model_field._get_val_from_obj(obj) if is_protected_type(value): return value - return self.model_field.value_to_string(self.obj) + return self.model_field.value_to_string(obj) def attributes(self): return { @@ -223,9 +228,9 @@ class RelatedField(WritableField): into[(self.source or field_name) + '_id'] = self.from_native(value) -class ManyRelatedField(RelatedField): +class ManyRelatedMixin(object): """ - Base class for related model managers. + Mixin to convert a related field to a many related field. """ def field_to_native(self, obj, field_name): value = getattr(obj, self.source or field_name) @@ -233,8 +238,10 @@ class ManyRelatedField(RelatedField): def field_from_native(self, data, field_name, into): try: + # Form data value = data.getlist(self.source or field_name) except: + # Non-form data value = data.get(self.source or field_name) else: if value == ['']: @@ -242,6 +249,15 @@ class ManyRelatedField(RelatedField): into[field_name] = [self.from_native(item) for item in value] +class ManyRelatedField(ManyRelatedMixin, RelatedField): + """ + Base class for related model managers. + """ + pass + + +### PrimaryKey relationships + class PrimaryKeyRelatedField(RelatedField): """ Serializes a related field or related object to a pk value. @@ -281,13 +297,87 @@ class ManyPrimaryKeyRelatedField(ManyRelatedField): return [self.to_native(item.pk) for item in queryset.all()] +### Hyperlinked relationships + +class HyperlinkedRelatedField(RelatedField): + pk_url_kwarg = 'pk' + slug_url_kwarg = 'slug' + slug_field = 'slug' + + def __init__(self, *args, **kwargs): + try: + self.view_name = kwargs.pop('view_name') + except: + raise ValueError("Hyperlinked field requires 'view_name' kwarg") + super(HyperlinkedRelatedField, self).__init__(*args, **kwargs) + + def to_native(self, obj): + view_name = self.view_name + request = self.context.get('request', None) + kwargs = {self.pk_url_kwarg: obj.pk} + try: + return reverse(view_name, kwargs=kwargs, request=request) + except: + pass + + slug = getattr(obj, self.slug_field, None) + + if not slug: + raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name) + + kwargs = {self.slug_url_kwarg: slug} + try: + return reverse(self.view_name, kwargs=kwargs, request=request) + except: + pass + + kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug} + try: + return reverse(self.view_name, kwargs=kwargs, request=request) + except: + pass + + raise ValidationError('Could not resolve URL for field using view name "%s"', view_name) + + def from_native(self, value): + # Convert URL -> model instance pk + try: + match = resolve(value) + except: + raise ValidationError('Invalid hyperlink - No URL match') + + if match.url_name != self.view_name: + raise ValidationError('Invalid hyperlink - Incorrect URL match') + + pk = match.kwargs.get(self.pk_url_kwarg, None) + slug = match.kwargs.get(self.slug_url_kwarg, None) + + # Try explicit primary key. + if pk is not None: + return pk + # Next, try looking up by slug. + elif slug is not None: + slug_field = self.get_slug_field() + queryset = self.queryset.filter(**{slug_field: slug}) + # If none of those are defined, it's an error. + else: + raise ValidationError('Invalid hyperlink') + + try: + obj = queryset.get() + except ObjectDoesNotExist: + raise ValidationError('Invalid hyperlink - object does not exist.') + return obj.pk + + +class ManyHyperlinkedRelatedField(ManyRelatedMixin, HyperlinkedRelatedField): + pass + + class HyperlinkedIdentityField(Field): """ A field that represents the model's identity using a hyperlink. """ - def __init__(self, *args, **kwargs): - pass - def field_to_native(self, obj, field_name): request = self.context.get('request', None) view_name = self.parent.opts.view_name |
