aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/__init__.py2
-rw-r--r--rest_framework/authentication.py12
-rw-r--r--rest_framework/compat.py10
-rw-r--r--rest_framework/generics.py21
-rw-r--r--rest_framework/permissions.py2
-rw-r--r--rest_framework/serializers.py103
-rw-r--r--rest_framework/templates/rest_framework/base.html2
-rw-r--r--rest_framework/templatetags/rest_framework.py13
-rw-r--r--rest_framework/tests/generics.py42
-rw-r--r--rest_framework/tests/relations_hyperlink.py16
-rw-r--r--rest_framework/tests/relations_nested.py24
-rw-r--r--rest_framework/tests/relations_pk.py17
-rw-r--r--rest_framework/tests/serializer.py37
-rw-r--r--rest_framework/views.py8
14 files changed, 229 insertions, 80 deletions
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py
index c86403d8..856badc6 100644
--- a/rest_framework/__init__.py
+++ b/rest_framework/__init__.py
@@ -1,4 +1,4 @@
-__version__ = '2.2.5'
+__version__ = '2.2.7'
VERSION = __version__ # synonym
diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py
index 145d4295..1eebb5b9 100644
--- a/rest_framework/authentication.py
+++ b/rest_framework/authentication.py
@@ -10,7 +10,7 @@ from django.core.exceptions import ImproperlyConfigured
from rest_framework import exceptions, HTTP_HEADER_ENCODING
from rest_framework.compat import CsrfViewMiddleware
from rest_framework.compat import oauth, oauth_provider, oauth_provider_store
-from rest_framework.compat import oauth2_provider, oauth2_provider_forms
+from rest_framework.compat import oauth2_provider
from rest_framework.authtoken.models import Token
@@ -230,7 +230,7 @@ class OAuthAuthentication(BaseAuthentication):
try:
consumer_key = oauth_request.get_parameter('oauth_consumer_key')
consumer = oauth_provider_store.get_consumer(request, oauth_request, consumer_key)
- except oauth_provider_store.InvalidConsumerError as err:
+ except oauth_provider.store.InvalidConsumerError as err:
raise exceptions.AuthenticationFailed(err)
if consumer.status != oauth_provider.consts.ACCEPTED:
@@ -240,7 +240,7 @@ class OAuthAuthentication(BaseAuthentication):
try:
token_param = oauth_request.get_parameter('oauth_token')
token = oauth_provider_store.get_access_token(request, oauth_request, consumer, token_param)
- except oauth_provider_store.InvalidTokenError:
+ except oauth_provider.store.InvalidTokenError:
msg = 'Invalid access token: %s' % oauth_request.get_parameter('oauth_token')
raise exceptions.AuthenticationFailed(msg)
@@ -325,11 +325,13 @@ class OAuth2Authentication(BaseAuthentication):
except oauth2_provider.models.AccessToken.DoesNotExist:
raise exceptions.AuthenticationFailed('Invalid token')
- if not token.user.is_active:
+ user = token.user
+
+ if not user.is_active:
msg = 'User inactive or deleted: %s' % user.username
raise exceptions.AuthenticationFailed(msg)
- return (token.user, token)
+ return (user, token)
def authenticate_header(self, request):
"""
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index 8bfebe68..3828555b 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -404,19 +404,23 @@ except ImportError:
try:
from django.utils.html import smart_urlquote
except ImportError:
+ import re
+ from django.utils.encoding import smart_str
try:
from urllib.parse import quote, urlsplit, urlunsplit
except ImportError: # Python 2
from urllib import quote
from urlparse import urlsplit, urlunsplit
+ unquoted_percents_re = re.compile(r'%(?![0-9A-Fa-f]{2})')
+
def smart_urlquote(url):
"Quotes a URL if it isn't already quoted."
# Handle IDN before quoting.
scheme, netloc, path, query, fragment = urlsplit(url)
try:
- netloc = netloc.encode('idna').decode('ascii') # IDN -> ACE
- except UnicodeError: # invalid domain part
+ netloc = netloc.encode('idna').decode('ascii') # IDN -> ACE
+ except UnicodeError: # invalid domain part
pass
else:
url = urlunsplit((scheme, netloc, path, query, fragment))
@@ -425,7 +429,7 @@ except ImportError:
# contains a % not followed by two hexadecimal digits. See #9655.
if '%' not in url or unquoted_percents_re.search(url):
# See http://bugs.python.org/issue2637
- url = quote(force_bytes(url), safe=b'!*\'();:@&=+$,/?#[]~')
+ url = quote(smart_str(url), safe=b'!*\'();:@&=+$,/?#[]~')
return force_text(url)
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index ea62123d..ae03060b 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -21,19 +21,21 @@ class GenericAPIView(views.APIView):
queryset = None
serializer_class = None
+ # Shortcut which may be used in place of `queryset`/`serializer_class`
+ model = None
+
filter_backend = api_settings.FILTER_BACKEND
paginate_by = api_settings.PAGINATE_BY
paginate_by_param = api_settings.PAGINATE_BY_PARAM
pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS
+ model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS
page_kwarg = 'page'
- lookup_kwarg = 'pk'
+ lookup_field = 'pk'
allow_empty = True
######################################
- # These are all pending deprecation...
+ # These are pending deprecation...
- model = None
- model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS
pk_url_kwarg = 'pk'
slug_url_kwarg = 'slug'
slug_field = 'slug'
@@ -116,6 +118,11 @@ class GenericAPIView(views.APIView):
def filter_queryset(self, queryset):
"""
Given a queryset, filter it with whichever filter backend is in use.
+
+ You are unlikely to want to override this method, although you may need
+ to call it either from a list view, or from a custom `get_object`
+ method if you want to apply the configured filtering backend to the
+ default queryset.
"""
if not self.filter_backend:
return queryset
@@ -164,7 +171,7 @@ class GenericAPIView(views.APIView):
# TODO: Deprecation warning
return self.model._default_manager.all()
- raise ImproperlyConfigured("'%s' must define 'queryset'"
+ raise ImproperlyConfigured("'%s' must define 'queryset' or 'model'"
% self.__class__.__name__)
def get_object(self, queryset=None):
@@ -184,10 +191,10 @@ class GenericAPIView(views.APIView):
# Perform the lookup filtering.
pk = self.kwargs.get(self.pk_url_kwarg, None)
slug = self.kwargs.get(self.slug_url_kwarg, None)
- lookup = self.kwargs.get(self.lookup_kwarg, None)
+ lookup = self.kwargs.get(self.lookup_field, None)
if lookup is not None:
- filter_kwargs = {self.lookup_kwarg: lookup}
+ filter_kwargs = {self.lookup_field: lookup}
elif pk is not None:
# TODO: Deprecation warning
filter_kwargs = {'pk': pk}
diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py
index ae895f39..2aa45c71 100644
--- a/rest_framework/permissions.py
+++ b/rest_framework/permissions.py
@@ -25,7 +25,7 @@ class BasePermission(object):
"""
Return `True` if permission is granted, `False` otherwise.
"""
- if len(inspect.getargspec(self.has_permission)[0]) == 4:
+ if len(inspect.getargspec(self.has_permission).args) == 4:
warnings.warn('The `obj` argument in `has_permission` is due to be deprecated. '
'Use `has_object_permission()` instead for object permissions.',
PendingDeprecationWarning, stacklevel=2)
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 1b2b0821..b4327af1 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -568,36 +568,73 @@ class ModelSerializer(Serializer):
assert cls is not None, \
"Serializer class '%s' is missing 'model' Meta option" % self.__class__.__name__
opts = get_concrete_model(cls)._meta
- pk_field = opts.pk
+ ret = SortedDict()
+ nested = bool(self.opts.depth)
- # If model is a child via multitable inheritance, use parent's pk
+ # Deal with adding the primary key field
+ pk_field = opts.pk
while pk_field.rel and pk_field.rel.parent_link:
+ # If model is a child via multitable inheritance, use parent's pk
pk_field = pk_field.rel.to._meta.pk
- fields = [pk_field]
- fields += [field for field in opts.fields if field.serialize]
- fields += [field for field in opts.many_to_many if field.serialize]
+ field = self.get_pk_field(pk_field)
+ if field:
+ ret[pk_field.name] = field
- ret = SortedDict()
- nested = bool(self.opts.depth)
- is_pk = True # First field in the list is the pk
-
- for model_field in fields:
- if is_pk:
- field = self.get_pk_field(model_field)
- is_pk = False
- elif model_field.rel and nested:
- field = self.get_nested_field(model_field)
- elif model_field.rel:
+ # Deal with forward relationships
+ forward_rels = [field for field in opts.fields if field.serialize]
+ forward_rels += [field for field in opts.many_to_many if field.serialize]
+
+ for model_field in forward_rels:
+ if model_field.rel:
to_many = isinstance(model_field,
models.fields.related.ManyToManyField)
- field = self.get_related_field(model_field, to_many=to_many)
+ related_model = model_field.rel.to
+
+ if model_field.rel and nested:
+ if len(inspect.getargspec(self.get_nested_field).args) == 2:
+ # TODO: deprecation warning
+ field = self.get_nested_field(model_field)
+ else:
+ field = self.get_nested_field(model_field, related_model, to_many)
+ elif model_field.rel:
+ if len(inspect.getargspec(self.get_nested_field).args) == 3:
+ # TODO: deprecation warning
+ field = self.get_related_field(model_field, to_many=to_many)
+ else:
+ field = self.get_related_field(model_field, related_model, to_many)
else:
field = self.get_field(model_field)
if field:
ret[model_field.name] = field
+ # Deal with reverse relationships
+ if not self.opts.fields:
+ reverse_rels = []
+ else:
+ # Reverse relationships are only included if they are explicitly
+ # present in the `fields` option on the serializer
+ reverse_rels = opts.get_all_related_objects()
+ reverse_rels += opts.get_all_related_many_to_many_objects()
+
+ for relation in reverse_rels:
+ accessor_name = relation.get_accessor_name()
+ if accessor_name not in self.opts.fields:
+ continue
+ related_model = relation.model
+ to_many = relation.field.rel.multiple
+
+ if nested:
+ field = self.get_nested_field(None, related_model, to_many)
+ else:
+ field = self.get_related_field(None, related_model, to_many)
+
+ if field:
+ ret[accessor_name] = field
+
+ # Add the `read_only` flag to any fields that have bee specified
+ # in the `read_only_fields` option
for field_name in self.opts.read_only_fields:
assert field_name in ret, \
"read_only_fields on '%s' included invalid item '%s'" % \
@@ -612,27 +649,30 @@ class ModelSerializer(Serializer):
"""
return self.get_field(model_field)
- def get_nested_field(self, model_field):
+ def get_nested_field(self, model_field, related_model, to_many):
"""
Creates a default instance of a nested relational field.
"""
class NestedModelSerializer(ModelSerializer):
class Meta:
- model = model_field.rel.to
- return NestedModelSerializer()
+ model = related_model
+ return NestedModelSerializer(many=to_many)
- def get_related_field(self, model_field, to_many=False):
+ def get_related_field(self, model_field, related_model, to_many):
"""
Creates a default instance of a flat relational field.
"""
# TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to)
+
kwargs = {
- 'required': not(model_field.null or model_field.blank),
- 'queryset': model_field.rel.to._default_manager,
+ 'queryset': related_model._default_manager,
'many': to_many
}
+ if model_field:
+ kwargs['required'] = not(model_field.null or model_field.blank)
+
return PrimaryKeyRelatedField(**kwargs)
def get_field(self, model_field):
@@ -741,7 +781,7 @@ class ModelSerializer(Serializer):
Override the default method to also include model field validation.
"""
instance = super(ModelSerializer, self).from_native(data, files)
- if instance:
+ if not self._errors:
return self.full_clean(instance)
def save_object(self, obj, **kwargs):
@@ -797,21 +837,24 @@ class HyperlinkedModelSerializer(ModelSerializer):
return self._default_view_name % format_kwargs
def get_pk_field(self, model_field):
- return None
+ if self.opts.fields and model_field.name in self.opts.fields:
+ return self.get_field(model_field)
- def get_related_field(self, model_field, to_many):
+ def get_related_field(self, model_field, related_model, to_many):
"""
Creates a default instance of a flat relational field.
"""
# TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to)
- rel = model_field.rel.to
kwargs = {
- 'required': not(model_field.null or model_field.blank),
- 'queryset': rel._default_manager,
- 'view_name': self._get_default_view_name(rel),
+ 'queryset': related_model._default_manager,
+ 'view_name': self._get_default_view_name(related_model),
'many': to_many
}
+
+ if model_field:
+ kwargs['required'] = not(model_field.null or model_field.blank)
+
return HyperlinkedRelatedField(**kwargs)
def get_identity(self, data):
diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html
index 44633f5a..4410f285 100644
--- a/rest_framework/templates/rest_framework/base.html
+++ b/rest_framework/templates/rest_framework/base.html
@@ -115,7 +115,7 @@
</div>
<div class="response-info">
<pre class="prettyprint"><div class="meta nocode"><b>HTTP {{ response.status_code }} {{ response.status_text }}</b>{% autoescape off %}
-{% for key, val in response.items %}<b>{{ key }}:</b> <span class="lit">{{ val|urlize_quoted_links }}</span>
+{% for key, val in response.items %}<b>{{ key }}:</b> <span class="lit">{{ val|break_long_headers|urlize_quoted_links }}</span>
{% endfor %}
</div>{{ content|urlize_quoted_links }}</pre>{% endautoescape %}
</div>
diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py
index b6ab2de3..c86b6456 100644
--- a/rest_framework/templatetags/rest_framework.py
+++ b/rest_framework/templatetags/rest_framework.py
@@ -181,7 +181,7 @@ TRAILING_PUNCTUATION = ['.', ',', ':', ';', '.)', '"', "'"]
WRAPPING_PUNCTUATION = [('(', ')'), ('<', '>'), ('[', ']'), ('&lt;', '&gt;'),
('"', '"'), ("'", "'")]
word_split_re = re.compile(r'(\s+)')
-simple_url_re = re.compile(r'^https?://\w', re.IGNORECASE)
+simple_url_re = re.compile(r'^https?://\[?\w', re.IGNORECASE)
simple_url_2_re = re.compile(r'^www\.|^(?!http)\w[^@]+\.(com|edu|gov|int|mil|net|org)$', re.IGNORECASE)
simple_email_re = re.compile(r'^\S+@\S+\.\S+$')
@@ -260,3 +260,14 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
elif autoescape:
words[i] = escape(word)
return ''.join(words)
+
+
+@register.filter
+def break_long_headers(header):
+ """
+ Breaks headers longer than 160 characters (~page length)
+ when possible (are comma separated)
+ """
+ if len(header) > 160 and ',' in header:
+ header = mark_safe('<br> ' + ', <br>'.join(header.split(',')))
+ return header
diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py
index f564890c..4a13389a 100644
--- a/rest_framework/tests/generics.py
+++ b/rest_framework/tests/generics.py
@@ -1,5 +1,6 @@
from __future__ import unicode_literals
from django.db import models
+from django.shortcuts import get_object_or_404
from django.test import TestCase
from rest_framework import generics, serializers, status
from rest_framework.tests.utils import RequestFactory
@@ -302,6 +303,47 @@ class TestInstanceView(TestCase):
self.assertEqual(new_obj.text, 'foobar')
+class TestOverriddenGetObject(TestCase):
+ """
+ Test cases for a RetrieveUpdateDestroyAPIView that does NOT use the
+ queryset/model mechanism but instead overrides get_object()
+ """
+ def setUp(self):
+ """
+ Create 3 BasicModel intances.
+ """
+ items = ['foo', 'bar', 'baz']
+ for item in items:
+ BasicModel(text=item).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+
+ class OverriddenGetObjectView(generics.RetrieveUpdateDestroyAPIView):
+ """
+ Example detail view for override of get_object().
+ """
+ model = BasicModel
+
+ def get_object(self):
+ pk = int(self.kwargs['pk'])
+ return get_object_or_404(BasicModel.objects.all(), id=pk)
+
+ self.view = OverriddenGetObjectView.as_view()
+
+ def test_overridden_get_object_view(self):
+ """
+ GET requests to RetrieveUpdateDestroyAPIView should return a single object.
+ """
+ request = factory.get('/1')
+ with self.assertNumQueries(1):
+ response = self.view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data[0])
+
+
# Regression test for #285
class CommentSerializer(serializers.ModelSerializer):
diff --git a/rest_framework/tests/relations_hyperlink.py b/rest_framework/tests/relations_hyperlink.py
index b5702a48..b1eed9a7 100644
--- a/rest_framework/tests/relations_hyperlink.py
+++ b/rest_framework/tests/relations_hyperlink.py
@@ -26,42 +26,44 @@ urlpatterns = patterns('',
)
+# ManyToMany
class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer):
- sources = serializers.HyperlinkedRelatedField(many=True, view_name='manytomanysource-detail')
-
class Meta:
model = ManyToManyTarget
+ fields = ('url', 'name', 'sources')
class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = ManyToManySource
+ fields = ('url', 'name', 'targets')
+# ForeignKey
class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer):
- sources = serializers.HyperlinkedRelatedField(many=True, view_name='foreignkeysource-detail')
-
class Meta:
model = ForeignKeyTarget
+ fields = ('url', 'name', 'sources')
class ForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = ForeignKeySource
+ fields = ('url', 'name', 'target')
# Nullable ForeignKey
class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = NullableForeignKeySource
+ fields = ('url', 'name', 'target')
-# OneToOne
+# Nullable OneToOne
class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer):
- nullable_source = serializers.HyperlinkedRelatedField(view_name='nullableonetoonesource-detail')
-
class Meta:
model = OneToOneTarget
+ fields = ('url', 'name', 'nullable_source')
# TODO: Add test that .data cannot be accessed prior to .is_valid
diff --git a/rest_framework/tests/relations_nested.py b/rest_framework/tests/relations_nested.py
index a125ba65..f6d006b3 100644
--- a/rest_framework/tests/relations_nested.py
+++ b/rest_framework/tests/relations_nested.py
@@ -6,38 +6,30 @@ from rest_framework.tests.models import ForeignKeyTarget, ForeignKeySource, Null
class ForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
- depth = 1
- model = ForeignKeySource
-
-
-class FlatForeignKeySourceSerializer(serializers.ModelSerializer):
- class Meta:
model = ForeignKeySource
+ fields = ('id', 'name', 'target')
+ depth = 1
class ForeignKeyTargetSerializer(serializers.ModelSerializer):
- sources = FlatForeignKeySourceSerializer(many=True)
-
class Meta:
model = ForeignKeyTarget
+ fields = ('id', 'name', 'sources')
+ depth = 1
class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
- depth = 1
model = NullableForeignKeySource
-
-
-class NullableOneToOneSourceSerializer(serializers.ModelSerializer):
- class Meta:
- model = NullableOneToOneSource
+ fields = ('id', 'name', 'target')
+ depth = 1
class NullableOneToOneTargetSerializer(serializers.ModelSerializer):
- nullable_source = NullableOneToOneSourceSerializer()
-
class Meta:
model = OneToOneTarget
+ fields = ('id', 'name', 'nullable_source')
+ depth = 1
class ReverseForeignKeyTests(TestCase):
diff --git a/rest_framework/tests/relations_pk.py b/rest_framework/tests/relations_pk.py
index f08e1808..5ce8b567 100644
--- a/rest_framework/tests/relations_pk.py
+++ b/rest_framework/tests/relations_pk.py
@@ -5,41 +5,44 @@ from rest_framework.tests.models import ManyToManyTarget, ManyToManySource, Fore
from rest_framework.compat import six
+# ManyToMany
class ManyToManyTargetSerializer(serializers.ModelSerializer):
- sources = serializers.PrimaryKeyRelatedField(many=True)
-
class Meta:
model = ManyToManyTarget
+ fields = ('id', 'name', 'sources')
class ManyToManySourceSerializer(serializers.ModelSerializer):
class Meta:
model = ManyToManySource
+ fields = ('id', 'name', 'targets')
+# ForeignKey
class ForeignKeyTargetSerializer(serializers.ModelSerializer):
- sources = serializers.PrimaryKeyRelatedField(many=True)
-
class Meta:
model = ForeignKeyTarget
+ fields = ('id', 'name', 'sources')
class ForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
model = ForeignKeySource
+ fields = ('id', 'name', 'target')
+# Nullable ForeignKey
class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
model = NullableForeignKeySource
+ fields = ('id', 'name', 'target')
-# OneToOne
+# Nullable OneToOne
class NullableOneToOneTargetSerializer(serializers.ModelSerializer):
- nullable_source = serializers.PrimaryKeyRelatedField()
-
class Meta:
model = OneToOneTarget
+ fields = ('id', 'name', 'nullable_source')
# TODO: Add test that .data cannot be accessed prior to .is_valid
diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py
index 05217f35..3a94fad5 100644
--- a/rest_framework/tests/serializer.py
+++ b/rest_framework/tests/serializer.py
@@ -738,6 +738,43 @@ class ManyRelatedTests(TestCase):
self.assertEqual(serializer.data, expected)
+ def test_include_reverse_relations(self):
+ post = BlogPost.objects.create(title="Test blog post")
+ post.blogpostcomment_set.create(text="I hate this blog post")
+ post.blogpostcomment_set.create(text="I love this blog post")
+
+ class BlogPostSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BlogPost
+ fields = ('id', 'title', 'blogpostcomment_set')
+
+ serializer = BlogPostSerializer(instance=post)
+ expected = {
+ 'id': 1, 'title': 'Test blog post', 'blogpostcomment_set': [1, 2]
+ }
+ self.assertEqual(serializer.data, expected)
+
+ def test_depth_include_reverse_relations(self):
+ post = BlogPost.objects.create(title="Test blog post")
+ post.blogpostcomment_set.create(text="I hate this blog post")
+ post.blogpostcomment_set.create(text="I love this blog post")
+
+ class BlogPostSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BlogPost
+ fields = ('id', 'title', 'blogpostcomment_set')
+ depth = 1
+
+ serializer = BlogPostSerializer(instance=post)
+ expected = {
+ 'id': 1, 'title': 'Test blog post',
+ 'blogpostcomment_set': [
+ {'id': 1, 'text': 'I hate this blog post', 'blog_post': 1},
+ {'id': 2, 'text': 'I love this blog post', 'blog_post': 1}
+ ]
+ }
+ self.assertEqual(serializer.data, expected)
+
def test_callable_source(self):
post = BlogPost.objects.create(title="Test blog post")
post.blogpostcomment_set.create(text="I love this blog post")
diff --git a/rest_framework/views.py b/rest_framework/views.py
index d7d3a2e2..b8e948e0 100644
--- a/rest_framework/views.py
+++ b/rest_framework/views.py
@@ -3,7 +3,7 @@ Provides an APIView class that is used as the base of all class-based views.
"""
from __future__ import unicode_literals
from django.core.exceptions import PermissionDenied
-from django.http import Http404
+from django.http import Http404, HttpResponse
from django.views.decorators.csrf import csrf_exempt
from rest_framework import status, exceptions
from rest_framework.compat import View
@@ -256,6 +256,12 @@ class APIView(View):
"""
Returns the final response object.
"""
+ # Make the error obvious if a proper response is not returned
+ assert isinstance(response, HttpResponse), (
+ 'Expected a `Response` to be returned from the view, '
+ 'but received a `%s`' % type(response)
+ )
+
if isinstance(response, Response):
if not getattr(request, 'accepted_renderer', None):
neg = self.perform_content_negotiation(request, force=True)