aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework
diff options
context:
space:
mode:
authorTom Christie2013-07-01 13:59:05 +0100
committerTom Christie2013-07-01 13:59:05 +0100
commit0a722de171b0e80ac26d8c77b8051a4170bdb4c6 (patch)
tree2ab02448965b7cc288fced2d3a1185d70050fac9 /rest_framework
parentd31d7c18676b6292e8dc688b61913d572eccde91 (diff)
downloaddjango-rest-framework-0a722de171b0e80ac26d8c77b8051a4170bdb4c6.tar.bz2
Complete testing docs
Diffstat (limited to 'rest_framework')
-rw-r--r--rest_framework/response.py2
-rw-r--r--rest_framework/settings.py8
-rw-r--r--rest_framework/test.py70
-rw-r--r--rest_framework/tests/test_testing.py55
4 files changed, 103 insertions, 32 deletions
diff --git a/rest_framework/response.py b/rest_framework/response.py
index c4b2aaa6..5877c8a3 100644
--- a/rest_framework/response.py
+++ b/rest_framework/response.py
@@ -50,7 +50,7 @@ class Response(SimpleTemplateResponse):
charset = renderer.charset
content_type = self.content_type
- if content_type is None and charset is not None and ';' not in media_type:
+ if content_type is None and charset is not None:
content_type = "{0}; charset={1}".format(media_type, charset)
elif content_type is None:
content_type = media_type
diff --git a/rest_framework/settings.py b/rest_framework/settings.py
index beb511ac..8fd177d5 100644
--- a/rest_framework/settings.py
+++ b/rest_framework/settings.py
@@ -73,6 +73,13 @@ DEFAULTS = {
'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',
'UNAUTHENTICATED_TOKEN': None,
+ # Testing
+ 'TEST_REQUEST_RENDERER_CLASSES': (
+ 'rest_framework.renderers.MultiPartRenderer',
+ 'rest_framework.renderers.JSONRenderer'
+ ),
+ 'TEST_REQUEST_DEFAULT_FORMAT': 'multipart',
+
# Browser enhancements
'FORM_METHOD_OVERRIDE': '_method',
'FORM_CONTENT_OVERRIDE': '_content',
@@ -115,6 +122,7 @@ IMPORT_STRINGS = (
'DEFAULT_PAGINATION_SERIALIZER_CLASS',
'DEFAULT_FILTER_BACKENDS',
'FILTER_BACKEND',
+ 'TEST_REQUEST_RENDERER_CLASSES',
'UNAUTHENTICATED_USER',
'UNAUTHENTICATED_TOKEN',
)
diff --git a/rest_framework/test.py b/rest_framework/test.py
index 2f658a56..29d017ee 100644
--- a/rest_framework/test.py
+++ b/rest_framework/test.py
@@ -1,22 +1,31 @@
# -- coding: utf-8 --
-# Note that we use `DjangoRequestFactory` and `DjangoClient` names in order
+# Note that we import as `DjangoRequestFactory` and `DjangoClient` in order
# to make it harder for the user to import the wrong thing without realizing.
from __future__ import unicode_literals
from django.conf import settings
from django.test.client import Client as DjangoClient
from django.test.client import ClientHandler
+from rest_framework.settings import api_settings
from rest_framework.compat import RequestFactory as DjangoRequestFactory
from rest_framework.compat import force_bytes_or_smart_bytes, six
-from rest_framework.renderers import JSONRenderer, MultiPartRenderer
+
+
+def force_authenticate(request, user=None, token=None):
+ request._force_auth_user = user
+ request._force_auth_token = token
class APIRequestFactory(DjangoRequestFactory):
- renderer_classes = {
- 'json': JSONRenderer,
- 'multipart': MultiPartRenderer
- }
- default_format = 'multipart'
+ renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES
+ default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT
+
+ def __init__(self, enforce_csrf_checks=False, **defaults):
+ self.enforce_csrf_checks = enforce_csrf_checks
+ self.renderer_classes = {}
+ for cls in self.renderer_classes_list:
+ self.renderer_classes[cls.format] = cls
+ super(APIRequestFactory, self).__init__(**defaults)
def _encode_data(self, data, format=None, content_type=None):
"""
@@ -35,18 +44,24 @@ class APIRequestFactory(DjangoRequestFactory):
ret = force_bytes_or_smart_bytes(data, settings.DEFAULT_CHARSET)
else:
- # Use format and render the data into a bytestring
format = format or self.default_format
+
+ assert format in self.renderer_classes, ("Invalid format '{0}'. "
+ "Available formats are {1}. Set TEST_REQUEST_RENDERER_CLASSES "
+ "to enable extra request formats.".format(
+ format,
+ ', '.join(["'" + fmt + "'" for fmt in self.renderer_classes.keys()])
+ )
+ )
+
+ # Use format and render the data into a bytestring
renderer = self.renderer_classes[format]()
ret = renderer.render(data)
# Determine the content-type header from the renderer
- if ';' in renderer.media_type:
- content_type = renderer.media_type
- else:
- content_type = "{0}; charset={1}".format(
- renderer.media_type, renderer.charset
- )
+ content_type = "{0}; charset={1}".format(
+ renderer.media_type, renderer.charset
+ )
# Coerce text to bytes if required.
if isinstance(ret, six.text_type):
@@ -74,6 +89,11 @@ class APIRequestFactory(DjangoRequestFactory):
data, content_type = self._encode_data(data, format, content_type)
return self.generic('OPTIONS', path, data, content_type, **extra)
+ def request(self, **kwargs):
+ request = super(APIRequestFactory, self).request(**kwargs)
+ request._dont_enforce_csrf_checks = not self.enforce_csrf_checks
+ return request
+
class ForceAuthClientHandler(ClientHandler):
"""
@@ -82,25 +102,21 @@ class ForceAuthClientHandler(ClientHandler):
"""
def __init__(self, *args, **kwargs):
- self._force_auth_user = None
- self._force_auth_token = None
+ self._force_user = None
+ self._force_token = None
super(ForceAuthClientHandler, self).__init__(*args, **kwargs)
def get_response(self, request):
# This is the simplest place we can hook into to patch the
# request object.
- request._force_auth_user = self._force_auth_user
- request._force_auth_token = self._force_auth_token
+ force_authenticate(request, self._force_user, self._force_token)
return super(ForceAuthClientHandler, self).get_response(request)
class APIClient(APIRequestFactory, DjangoClient):
def __init__(self, enforce_csrf_checks=False, **defaults):
- # Note that our super call skips Client.__init__
- # since we don't need to instantiate a regular ClientHandler
- super(DjangoClient, self).__init__(**defaults)
+ super(APIClient, self).__init__(**defaults)
self.handler = ForceAuthClientHandler(enforce_csrf_checks)
- self.exc_info = None
self._credentials = {}
def credentials(self, **kwargs):
@@ -114,10 +130,10 @@ class APIClient(APIRequestFactory, DjangoClient):
Forcibly authenticates outgoing requests with the given
user and/or token.
"""
- self.handler._force_auth_user = user
- self.handler._force_auth_token = token
+ self.handler._force_user = user
+ self.handler._force_token = token
- def request(self, **request):
+ def request(self, **kwargs):
# Ensure that any credentials set get added to every request.
- request.update(self._credentials)
- return super(APIClient, self).request(**request)
+ kwargs.update(self._credentials)
+ return super(APIClient, self).request(**kwargs)
diff --git a/rest_framework/tests/test_testing.py b/rest_framework/tests/test_testing.py
index 3706f38c..49d45fc2 100644
--- a/rest_framework/tests/test_testing.py
+++ b/rest_framework/tests/test_testing.py
@@ -6,11 +6,11 @@ from django.test import TestCase
from rest_framework.compat import patterns, url
from rest_framework.decorators import api_view
from rest_framework.response import Response
-from rest_framework.test import APIClient
+from rest_framework.test import APIClient, APIRequestFactory, force_authenticate
@api_view(['GET', 'POST'])
-def mirror(request):
+def view(request):
return Response({
'auth': request.META.get('HTTP_AUTHORIZATION', b''),
'user': request.user.username
@@ -18,11 +18,11 @@ def mirror(request):
urlpatterns = patterns('',
- url(r'^view/$', mirror),
+ url(r'^view/$', view),
)
-class CheckTestClient(TestCase):
+class TestAPITestClient(TestCase):
urls = 'rest_framework.tests.test_testing'
def setUp(self):
@@ -66,3 +66,50 @@ class CheckTestClient(TestCase):
expected = {'detail': 'CSRF Failed: CSRF cookie not set.'}
self.assertEqual(response.status_code, 403)
self.assertEqual(response.data, expected)
+
+
+class TestAPIRequestFactory(TestCase):
+ def test_csrf_exempt_by_default(self):
+ """
+ By default, the test client is CSRF exempt.
+ """
+ user = User.objects.create_user('example', 'example@example.com', 'password')
+ factory = APIRequestFactory()
+ request = factory.post('/view/')
+ request.user = user
+ response = view(request)
+ self.assertEqual(response.status_code, 200)
+
+ def test_explicitly_enforce_csrf_checks(self):
+ """
+ The test client can enforce CSRF checks.
+ """
+ user = User.objects.create_user('example', 'example@example.com', 'password')
+ factory = APIRequestFactory(enforce_csrf_checks=True)
+ request = factory.post('/view/')
+ request.user = user
+ response = view(request)
+ expected = {'detail': 'CSRF Failed: CSRF cookie not set.'}
+ self.assertEqual(response.status_code, 403)
+ self.assertEqual(response.data, expected)
+
+ def test_invalid_format(self):
+ """
+ Attempting to use a format that is not configured will raise an
+ assertion error.
+ """
+ factory = APIRequestFactory()
+ self.assertRaises(AssertionError, factory.post,
+ path='/view/', data={'example': 1}, format='xml'
+ )
+
+ def test_force_authenticate(self):
+ """
+ Setting `force_authenticate()` forcibly authenticates the request.
+ """
+ user = User.objects.create_user('example', 'example@example.com')
+ factory = APIRequestFactory()
+ request = factory.get('/view')
+ force_authenticate(request, user=user)
+ response = view(request)
+ self.assertEqual(response.data['user'], 'example')