diff options
| author | Tom Christie | 2013-07-01 13:59:05 +0100 |
|---|---|---|
| committer | Tom Christie | 2013-07-01 13:59:05 +0100 |
| commit | 0a722de171b0e80ac26d8c77b8051a4170bdb4c6 (patch) | |
| tree | 2ab02448965b7cc288fced2d3a1185d70050fac9 /rest_framework | |
| parent | d31d7c18676b6292e8dc688b61913d572eccde91 (diff) | |
| download | django-rest-framework-0a722de171b0e80ac26d8c77b8051a4170bdb4c6.tar.bz2 | |
Complete testing docs
Diffstat (limited to 'rest_framework')
| -rw-r--r-- | rest_framework/response.py | 2 | ||||
| -rw-r--r-- | rest_framework/settings.py | 8 | ||||
| -rw-r--r-- | rest_framework/test.py | 70 | ||||
| -rw-r--r-- | rest_framework/tests/test_testing.py | 55 |
4 files changed, 103 insertions, 32 deletions
diff --git a/rest_framework/response.py b/rest_framework/response.py index c4b2aaa6..5877c8a3 100644 --- a/rest_framework/response.py +++ b/rest_framework/response.py @@ -50,7 +50,7 @@ class Response(SimpleTemplateResponse): charset = renderer.charset content_type = self.content_type - if content_type is None and charset is not None and ';' not in media_type: + if content_type is None and charset is not None: content_type = "{0}; charset={1}".format(media_type, charset) elif content_type is None: content_type = media_type diff --git a/rest_framework/settings.py b/rest_framework/settings.py index beb511ac..8fd177d5 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -73,6 +73,13 @@ DEFAULTS = { 'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser', 'UNAUTHENTICATED_TOKEN': None, + # Testing + 'TEST_REQUEST_RENDERER_CLASSES': ( + 'rest_framework.renderers.MultiPartRenderer', + 'rest_framework.renderers.JSONRenderer' + ), + 'TEST_REQUEST_DEFAULT_FORMAT': 'multipart', + # Browser enhancements 'FORM_METHOD_OVERRIDE': '_method', 'FORM_CONTENT_OVERRIDE': '_content', @@ -115,6 +122,7 @@ IMPORT_STRINGS = ( 'DEFAULT_PAGINATION_SERIALIZER_CLASS', 'DEFAULT_FILTER_BACKENDS', 'FILTER_BACKEND', + 'TEST_REQUEST_RENDERER_CLASSES', 'UNAUTHENTICATED_USER', 'UNAUTHENTICATED_TOKEN', ) diff --git a/rest_framework/test.py b/rest_framework/test.py index 2f658a56..29d017ee 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -1,22 +1,31 @@ # -- coding: utf-8 -- -# Note that we use `DjangoRequestFactory` and `DjangoClient` names in order +# Note that we import as `DjangoRequestFactory` and `DjangoClient` in order # to make it harder for the user to import the wrong thing without realizing. from __future__ import unicode_literals from django.conf import settings from django.test.client import Client as DjangoClient from django.test.client import ClientHandler +from rest_framework.settings import api_settings from rest_framework.compat import RequestFactory as DjangoRequestFactory from rest_framework.compat import force_bytes_or_smart_bytes, six -from rest_framework.renderers import JSONRenderer, MultiPartRenderer + + +def force_authenticate(request, user=None, token=None): + request._force_auth_user = user + request._force_auth_token = token class APIRequestFactory(DjangoRequestFactory): - renderer_classes = { - 'json': JSONRenderer, - 'multipart': MultiPartRenderer - } - default_format = 'multipart' + renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES + default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT + + def __init__(self, enforce_csrf_checks=False, **defaults): + self.enforce_csrf_checks = enforce_csrf_checks + self.renderer_classes = {} + for cls in self.renderer_classes_list: + self.renderer_classes[cls.format] = cls + super(APIRequestFactory, self).__init__(**defaults) def _encode_data(self, data, format=None, content_type=None): """ @@ -35,18 +44,24 @@ class APIRequestFactory(DjangoRequestFactory): ret = force_bytes_or_smart_bytes(data, settings.DEFAULT_CHARSET) else: - # Use format and render the data into a bytestring format = format or self.default_format + + assert format in self.renderer_classes, ("Invalid format '{0}'. " + "Available formats are {1}. Set TEST_REQUEST_RENDERER_CLASSES " + "to enable extra request formats.".format( + format, + ', '.join(["'" + fmt + "'" for fmt in self.renderer_classes.keys()]) + ) + ) + + # Use format and render the data into a bytestring renderer = self.renderer_classes[format]() ret = renderer.render(data) # Determine the content-type header from the renderer - if ';' in renderer.media_type: - content_type = renderer.media_type - else: - content_type = "{0}; charset={1}".format( - renderer.media_type, renderer.charset - ) + content_type = "{0}; charset={1}".format( + renderer.media_type, renderer.charset + ) # Coerce text to bytes if required. if isinstance(ret, six.text_type): @@ -74,6 +89,11 @@ class APIRequestFactory(DjangoRequestFactory): data, content_type = self._encode_data(data, format, content_type) return self.generic('OPTIONS', path, data, content_type, **extra) + def request(self, **kwargs): + request = super(APIRequestFactory, self).request(**kwargs) + request._dont_enforce_csrf_checks = not self.enforce_csrf_checks + return request + class ForceAuthClientHandler(ClientHandler): """ @@ -82,25 +102,21 @@ class ForceAuthClientHandler(ClientHandler): """ def __init__(self, *args, **kwargs): - self._force_auth_user = None - self._force_auth_token = None + self._force_user = None + self._force_token = None super(ForceAuthClientHandler, self).__init__(*args, **kwargs) def get_response(self, request): # This is the simplest place we can hook into to patch the # request object. - request._force_auth_user = self._force_auth_user - request._force_auth_token = self._force_auth_token + force_authenticate(request, self._force_user, self._force_token) return super(ForceAuthClientHandler, self).get_response(request) class APIClient(APIRequestFactory, DjangoClient): def __init__(self, enforce_csrf_checks=False, **defaults): - # Note that our super call skips Client.__init__ - # since we don't need to instantiate a regular ClientHandler - super(DjangoClient, self).__init__(**defaults) + super(APIClient, self).__init__(**defaults) self.handler = ForceAuthClientHandler(enforce_csrf_checks) - self.exc_info = None self._credentials = {} def credentials(self, **kwargs): @@ -114,10 +130,10 @@ class APIClient(APIRequestFactory, DjangoClient): Forcibly authenticates outgoing requests with the given user and/or token. """ - self.handler._force_auth_user = user - self.handler._force_auth_token = token + self.handler._force_user = user + self.handler._force_token = token - def request(self, **request): + def request(self, **kwargs): # Ensure that any credentials set get added to every request. - request.update(self._credentials) - return super(APIClient, self).request(**request) + kwargs.update(self._credentials) + return super(APIClient, self).request(**kwargs) diff --git a/rest_framework/tests/test_testing.py b/rest_framework/tests/test_testing.py index 3706f38c..49d45fc2 100644 --- a/rest_framework/tests/test_testing.py +++ b/rest_framework/tests/test_testing.py @@ -6,11 +6,11 @@ from django.test import TestCase from rest_framework.compat import patterns, url from rest_framework.decorators import api_view from rest_framework.response import Response -from rest_framework.test import APIClient +from rest_framework.test import APIClient, APIRequestFactory, force_authenticate @api_view(['GET', 'POST']) -def mirror(request): +def view(request): return Response({ 'auth': request.META.get('HTTP_AUTHORIZATION', b''), 'user': request.user.username @@ -18,11 +18,11 @@ def mirror(request): urlpatterns = patterns('', - url(r'^view/$', mirror), + url(r'^view/$', view), ) -class CheckTestClient(TestCase): +class TestAPITestClient(TestCase): urls = 'rest_framework.tests.test_testing' def setUp(self): @@ -66,3 +66,50 @@ class CheckTestClient(TestCase): expected = {'detail': 'CSRF Failed: CSRF cookie not set.'} self.assertEqual(response.status_code, 403) self.assertEqual(response.data, expected) + + +class TestAPIRequestFactory(TestCase): + def test_csrf_exempt_by_default(self): + """ + By default, the test client is CSRF exempt. + """ + user = User.objects.create_user('example', 'example@example.com', 'password') + factory = APIRequestFactory() + request = factory.post('/view/') + request.user = user + response = view(request) + self.assertEqual(response.status_code, 200) + + def test_explicitly_enforce_csrf_checks(self): + """ + The test client can enforce CSRF checks. + """ + user = User.objects.create_user('example', 'example@example.com', 'password') + factory = APIRequestFactory(enforce_csrf_checks=True) + request = factory.post('/view/') + request.user = user + response = view(request) + expected = {'detail': 'CSRF Failed: CSRF cookie not set.'} + self.assertEqual(response.status_code, 403) + self.assertEqual(response.data, expected) + + def test_invalid_format(self): + """ + Attempting to use a format that is not configured will raise an + assertion error. + """ + factory = APIRequestFactory() + self.assertRaises(AssertionError, factory.post, + path='/view/', data={'example': 1}, format='xml' + ) + + def test_force_authenticate(self): + """ + Setting `force_authenticate()` forcibly authenticates the request. + """ + user = User.objects.create_user('example', 'example@example.com') + factory = APIRequestFactory() + request = factory.get('/view') + force_authenticate(request, user=user) + response = view(request) + self.assertEqual(response.data['user'], 'example') |
