aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/fields.py108
-rw-r--r--rest_framework/renderers.py2
-rw-r--r--rest_framework/serializers.py26
-rw-r--r--rest_framework/tests/genericrelations.py33
-rw-r--r--rest_framework/tests/hyperlinkedserializers.py74
-rw-r--r--rest_framework/tests/models.py24
6 files changed, 249 insertions, 18 deletions
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 32f2d122..b9ac3776 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -4,9 +4,9 @@ import inspect
import warnings
from django.core import validators
-from django.core.exceptions import ValidationError
+from django.core.exceptions import ObjectDoesNotExist, ValidationError
+from django.core.urlresolvers import resolve
from django.conf import settings
-from django.db import DEFAULT_DB_ALIAS
from django.utils.encoding import is_protected_type, smart_unicode
from django.utils.translation import ugettext_lazy as _
from rest_framework.reverse import reverse
@@ -27,6 +27,7 @@ def is_simple_callable(obj):
class Field(object):
creation_counter = 0
empty = ''
+ type_name = None
def __init__(self, source=None):
self.parent = None
@@ -82,6 +83,10 @@ class Field(object):
if is_protected_type(value):
return value
+
+ all_callable = getattr(value, 'all', None)
+ if is_simple_callable(all_callable):
+ return [self.to_native(item) for item in value.all()]
elif hasattr(value, '__iter__') and not isinstance(value, (dict, basestring)):
return [self.to_native(item) for item in value]
return smart_unicode(value)
@@ -90,7 +95,7 @@ class Field(object):
"""
Returns a dictionary of attributes to be used when serializing to xml.
"""
- if getattr(self, 'type_name', None):
+ if self.type_name:
return {'type': self.type_name}
return {}
@@ -196,7 +201,7 @@ class ModelField(WritableField):
value = self.model_field._get_val_from_obj(obj)
if is_protected_type(value):
return value
- return self.model_field.value_to_string(self.obj)
+ return self.model_field.value_to_string(obj)
def attributes(self):
return {
@@ -223,9 +228,9 @@ class RelatedField(WritableField):
into[(self.source or field_name) + '_id'] = self.from_native(value)
-class ManyRelatedField(RelatedField):
+class ManyRelatedMixin(object):
"""
- Base class for related model managers.
+ Mixin to convert a related field to a many related field.
"""
def field_to_native(self, obj, field_name):
value = getattr(obj, self.source or field_name)
@@ -233,8 +238,10 @@ class ManyRelatedField(RelatedField):
def field_from_native(self, data, field_name, into):
try:
+ # Form data
value = data.getlist(self.source or field_name)
except:
+ # Non-form data
value = data.get(self.source or field_name)
else:
if value == ['']:
@@ -242,6 +249,15 @@ class ManyRelatedField(RelatedField):
into[field_name] = [self.from_native(item) for item in value]
+class ManyRelatedField(ManyRelatedMixin, RelatedField):
+ """
+ Base class for related model managers.
+ """
+ pass
+
+
+### PrimaryKey relationships
+
class PrimaryKeyRelatedField(RelatedField):
"""
Serializes a related field or related object to a pk value.
@@ -281,13 +297,87 @@ class ManyPrimaryKeyRelatedField(ManyRelatedField):
return [self.to_native(item.pk) for item in queryset.all()]
+### Hyperlinked relationships
+
+class HyperlinkedRelatedField(RelatedField):
+ pk_url_kwarg = 'pk'
+ slug_url_kwarg = 'slug'
+ slug_field = 'slug'
+
+ def __init__(self, *args, **kwargs):
+ try:
+ self.view_name = kwargs.pop('view_name')
+ except:
+ raise ValueError("Hyperlinked field requires 'view_name' kwarg")
+ super(HyperlinkedRelatedField, self).__init__(*args, **kwargs)
+
+ def to_native(self, obj):
+ view_name = self.view_name
+ request = self.context.get('request', None)
+ kwargs = {self.pk_url_kwarg: obj.pk}
+ try:
+ return reverse(view_name, kwargs=kwargs, request=request)
+ except:
+ pass
+
+ slug = getattr(obj, self.slug_field, None)
+
+ if not slug:
+ raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name)
+
+ kwargs = {self.slug_url_kwarg: slug}
+ try:
+ return reverse(self.view_name, kwargs=kwargs, request=request)
+ except:
+ pass
+
+ kwargs = {self.pk_url_kwarg: obj.pk, self.slug_url_kwarg: slug}
+ try:
+ return reverse(self.view_name, kwargs=kwargs, request=request)
+ except:
+ pass
+
+ raise ValidationError('Could not resolve URL for field using view name "%s"', view_name)
+
+ def from_native(self, value):
+ # Convert URL -> model instance pk
+ try:
+ match = resolve(value)
+ except:
+ raise ValidationError('Invalid hyperlink - No URL match')
+
+ if match.url_name != self.view_name:
+ raise ValidationError('Invalid hyperlink - Incorrect URL match')
+
+ pk = match.kwargs.get(self.pk_url_kwarg, None)
+ slug = match.kwargs.get(self.slug_url_kwarg, None)
+
+ # Try explicit primary key.
+ if pk is not None:
+ return pk
+ # Next, try looking up by slug.
+ elif slug is not None:
+ slug_field = self.get_slug_field()
+ queryset = self.queryset.filter(**{slug_field: slug})
+ # If none of those are defined, it's an error.
+ else:
+ raise ValidationError('Invalid hyperlink')
+
+ try:
+ obj = queryset.get()
+ except ObjectDoesNotExist:
+ raise ValidationError('Invalid hyperlink - object does not exist.')
+ return obj.pk
+
+
+class ManyHyperlinkedRelatedField(ManyRelatedMixin, HyperlinkedRelatedField):
+ pass
+
+
class HyperlinkedIdentityField(Field):
"""
A field that represents the model's identity using a hyperlink.
"""
- def __init__(self, *args, **kwargs):
- pass
-
def field_to_native(self, obj, field_name):
request = self.context.get('request', None)
view_name = self.parent.opts.view_name
diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py
index 5bc5d5f8..e33fa30e 100644
--- a/rest_framework/renderers.py
+++ b/rest_framework/renderers.py
@@ -260,7 +260,7 @@ class DocumentingHTMLRenderer(BaseRenderer):
serializer = view.get_serializer(instance=obj)
for k, v in serializer.get_fields(True).items():
print k, v
- if v.readonly:
+ if getattr(v, 'readonly', True):
continue
kwargs = {}
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index ae0b3cdf..ba8bf8ad 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -353,7 +353,9 @@ class ModelSerializer(Serializer):
"""
Creates a default instance of a flat relational field.
"""
- queryset = model_field.rel.to._default_manager # .using(db).complex_filter(self.rel.limit_choices_to)
+ # TODO: filter queryset using:
+ # .using(db).complex_filter(self.rel.limit_choices_to)
+ queryset = model_field.rel.to._default_manager
if isinstance(model_field, models.fields.related.ManyToManyField):
return ManyPrimaryKeyRelatedField(queryset=queryset)
return PrimaryKeyRelatedField(queryset=queryset)
@@ -420,13 +422,13 @@ class HyperlinkedModelSerializer(ModelSerializer):
def __init__(self, *args, **kwargs):
super(HyperlinkedModelSerializer, self).__init__(*args, **kwargs)
if self.opts.view_name is None:
- self.opts.view_name = self._get_default_view_name()
+ self.opts.view_name = self._get_default_view_name(self.opts.model)
- def _get_default_view_name(self):
+ def _get_default_view_name(self, model):
"""
Return the view name to use if 'view_name' is not specified in 'Meta'
"""
- model_meta = self.opts.model._meta
+ model_meta = model._meta
format_kwargs = {
'app_label': model_meta.app_label,
'model_name': model_meta.object_name.lower()
@@ -435,3 +437,19 @@ class HyperlinkedModelSerializer(ModelSerializer):
def get_pk_field(self, model_field):
return None
+
+ def get_related_field(self, model_field):
+ """
+ Creates a default instance of a flat relational field.
+ """
+ # TODO: filter queryset using:
+ # .using(db).complex_filter(self.rel.limit_choices_to)
+ rel = model_field.rel.to
+ queryset = rel._default_manager
+ kwargs = {
+ 'queryset': queryset,
+ 'view_name': self._get_default_view_name(rel)
+ }
+ if isinstance(model_field, models.fields.related.ManyToManyField):
+ return ManyHyperlinkedRelatedField(**kwargs)
+ return HyperlinkedRelatedField(**kwargs)
diff --git a/rest_framework/tests/genericrelations.py b/rest_framework/tests/genericrelations.py
new file mode 100644
index 00000000..d88a6c06
--- /dev/null
+++ b/rest_framework/tests/genericrelations.py
@@ -0,0 +1,33 @@
+from django.test import TestCase
+from rest_framework import serializers
+from rest_framework.tests.models import *
+
+
+class TestGenericRelations(TestCase):
+ def setUp(self):
+ bookmark = Bookmark(url='https://www.djangoproject.com/')
+ bookmark.save()
+ django = Tag(tag_name='django')
+ django.save()
+ python = Tag(tag_name='python')
+ python.save()
+ t1 = TaggedItem(content_object=bookmark, tag=django)
+ t1.save()
+ t2 = TaggedItem(content_object=bookmark, tag=python)
+ t2.save()
+ self.bookmark = bookmark
+
+ def test_reverse_generic_relation(self):
+ class BookmarkSerializer(serializers.ModelSerializer):
+ tags = serializers.Field(source='tags')
+
+ class Meta:
+ model = Bookmark
+ exclude = ('id',)
+
+ serializer = BookmarkSerializer(instance=self.bookmark)
+ expected = {
+ 'tags': [u'django', u'python'],
+ 'url': u'https://www.djangoproject.com/'
+ }
+ self.assertEquals(serializer.data, expected)
diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/hyperlinkedserializers.py
index 4f9393aa..5532a8ee 100644
--- a/rest_framework/tests/hyperlinkedserializers.py
+++ b/rest_framework/tests/hyperlinkedserializers.py
@@ -2,7 +2,7 @@ from django.conf.urls.defaults import patterns, url
from django.test import TestCase
from django.test.client import RequestFactory
from rest_framework import generics, status, serializers
-from rest_framework.tests.models import BasicModel
+from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel
factory = RequestFactory()
@@ -17,13 +17,31 @@ class BasicDetail(generics.RetrieveUpdateDestroyAPIView):
model_serializer_class = serializers.HyperlinkedModelSerializer
+class AnchorDetail(generics.RetrieveAPIView):
+ model = Anchor
+ model_serializer_class = serializers.HyperlinkedModelSerializer
+
+
+class ManyToManyList(generics.ListAPIView):
+ model = ManyToManyModel
+ model_serializer_class = serializers.HyperlinkedModelSerializer
+
+
+class ManyToManyDetail(generics.RetrieveAPIView):
+ model = ManyToManyModel
+ model_serializer_class = serializers.HyperlinkedModelSerializer
+
+
urlpatterns = patterns('',
url(r'^basic/$', BasicList.as_view(), name='basicmodel-list'),
url(r'^basic/(?P<pk>\d+)/$', BasicDetail.as_view(), name='basicmodel-detail'),
+ url(r'^anchor/(?P<pk>\d+)/$', AnchorDetail.as_view(), name='anchor-detail'),
+ url(r'^manytomany/$', ManyToManyList.as_view(), name='manytomanymodel-list'),
+ url(r'^manytomany/(?P<pk>\d+)/$', ManyToManyDetail.as_view(), name='manytomanymodel-detail'),
)
-class TestHyperlinkedView(TestCase):
+class TestBasicHyperlinkedView(TestCase):
urls = 'rest_framework.tests.hyperlinkedserializers'
def setUp(self):
@@ -45,7 +63,55 @@ class TestHyperlinkedView(TestCase):
"""
GET requests to ListCreateAPIView should return list of objects.
"""
- request = factory.get('/')
+ request = factory.get('/basic/')
+ response = self.list_view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data, self.data)
+
+ def test_get_detail_view(self):
+ """
+ GET requests to ListCreateAPIView should return list of objects.
+ """
+ request = factory.get('/basic/1')
+ response = self.detail_view(request, pk=1).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data, self.data[0])
+
+
+class TestManyToManyHyperlinkedView(TestCase):
+ urls = 'rest_framework.tests.hyperlinkedserializers'
+
+ def setUp(self):
+ """
+ Create 3 BasicModel intances.
+ """
+ items = ['foo', 'bar', 'baz']
+ anchors = []
+ for item in items:
+ anchor = Anchor(text=item)
+ anchor.save()
+ anchors.append(anchor)
+
+ manytomany = ManyToManyModel()
+ manytomany.save()
+ manytomany.rel.add(*anchors)
+
+ self.data = [{
+ 'url': 'http://testserver/manytomany/1/',
+ 'rel': [
+ 'http://testserver/anchor/1/',
+ 'http://testserver/anchor/2/',
+ 'http://testserver/anchor/3/',
+ ]
+ }]
+ self.list_view = ManyToManyList.as_view()
+ self.detail_view = ManyToManyDetail.as_view()
+
+ def test_get_list_view(self):
+ """
+ GET requests to ListCreateAPIView should return list of objects.
+ """
+ request = factory.get('/manytomany/')
response = self.list_view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, self.data)
@@ -54,7 +120,7 @@ class TestHyperlinkedView(TestCase):
"""
GET requests to ListCreateAPIView should return list of objects.
"""
- request = factory.get('/1')
+ request = factory.get('/manytomany/1/')
response = self.detail_view(request, pk=1).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, self.data[0])
diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py
index 969c8297..7c7f485b 100644
--- a/rest_framework/tests/models.py
+++ b/rest_framework/tests/models.py
@@ -1,4 +1,7 @@
from django.db import models
+from django.contrib.contenttypes.models import ContentType
+from django.contrib.contenttypes.generic import GenericForeignKey, GenericRelation
+
# from django.contrib.auth.models import Group
@@ -59,3 +62,24 @@ class CallableDefaultValueModel(RESTFrameworkModel):
class ManyToManyModel(RESTFrameworkModel):
rel = models.ManyToManyField(Anchor)
+
+# Models to test generic relations
+
+
+class Tag(RESTFrameworkModel):
+ tag_name = models.SlugField()
+
+
+class TaggedItem(RESTFrameworkModel):
+ tag = models.ForeignKey(Tag, related_name='items')
+ content_type = models.ForeignKey(ContentType)
+ object_id = models.PositiveIntegerField()
+ content_object = GenericForeignKey('content_type', 'object_id')
+
+ def __unicode__(self):
+ return self.tag.tag_name
+
+
+class Bookmark(RESTFrameworkModel):
+ url = models.URLField()
+ tags = GenericRelation(TaggedItem)