aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework/test.py
diff options
context:
space:
mode:
authorTom Christie2013-07-01 13:59:05 +0100
committerTom Christie2013-07-01 13:59:05 +0100
commit0a722de171b0e80ac26d8c77b8051a4170bdb4c6 (patch)
tree2ab02448965b7cc288fced2d3a1185d70050fac9 /rest_framework/test.py
parentd31d7c18676b6292e8dc688b61913d572eccde91 (diff)
downloaddjango-rest-framework-0a722de171b0e80ac26d8c77b8051a4170bdb4c6.tar.bz2
Complete testing docs
Diffstat (limited to 'rest_framework/test.py')
-rw-r--r--rest_framework/test.py70
1 files changed, 43 insertions, 27 deletions
diff --git a/rest_framework/test.py b/rest_framework/test.py
index 2f658a56..29d017ee 100644
--- a/rest_framework/test.py
+++ b/rest_framework/test.py
@@ -1,22 +1,31 @@
# -- coding: utf-8 --
-# Note that we use `DjangoRequestFactory` and `DjangoClient` names in order
+# Note that we import as `DjangoRequestFactory` and `DjangoClient` in order
# to make it harder for the user to import the wrong thing without realizing.
from __future__ import unicode_literals
from django.conf import settings
from django.test.client import Client as DjangoClient
from django.test.client import ClientHandler
+from rest_framework.settings import api_settings
from rest_framework.compat import RequestFactory as DjangoRequestFactory
from rest_framework.compat import force_bytes_or_smart_bytes, six
-from rest_framework.renderers import JSONRenderer, MultiPartRenderer
+
+
+def force_authenticate(request, user=None, token=None):
+ request._force_auth_user = user
+ request._force_auth_token = token
class APIRequestFactory(DjangoRequestFactory):
- renderer_classes = {
- 'json': JSONRenderer,
- 'multipart': MultiPartRenderer
- }
- default_format = 'multipart'
+ renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES
+ default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT
+
+ def __init__(self, enforce_csrf_checks=False, **defaults):
+ self.enforce_csrf_checks = enforce_csrf_checks
+ self.renderer_classes = {}
+ for cls in self.renderer_classes_list:
+ self.renderer_classes[cls.format] = cls
+ super(APIRequestFactory, self).__init__(**defaults)
def _encode_data(self, data, format=None, content_type=None):
"""
@@ -35,18 +44,24 @@ class APIRequestFactory(DjangoRequestFactory):
ret = force_bytes_or_smart_bytes(data, settings.DEFAULT_CHARSET)
else:
- # Use format and render the data into a bytestring
format = format or self.default_format
+
+ assert format in self.renderer_classes, ("Invalid format '{0}'. "
+ "Available formats are {1}. Set TEST_REQUEST_RENDERER_CLASSES "
+ "to enable extra request formats.".format(
+ format,
+ ', '.join(["'" + fmt + "'" for fmt in self.renderer_classes.keys()])
+ )
+ )
+
+ # Use format and render the data into a bytestring
renderer = self.renderer_classes[format]()
ret = renderer.render(data)
# Determine the content-type header from the renderer
- if ';' in renderer.media_type:
- content_type = renderer.media_type
- else:
- content_type = "{0}; charset={1}".format(
- renderer.media_type, renderer.charset
- )
+ content_type = "{0}; charset={1}".format(
+ renderer.media_type, renderer.charset
+ )
# Coerce text to bytes if required.
if isinstance(ret, six.text_type):
@@ -74,6 +89,11 @@ class APIRequestFactory(DjangoRequestFactory):
data, content_type = self._encode_data(data, format, content_type)
return self.generic('OPTIONS', path, data, content_type, **extra)
+ def request(self, **kwargs):
+ request = super(APIRequestFactory, self).request(**kwargs)
+ request._dont_enforce_csrf_checks = not self.enforce_csrf_checks
+ return request
+
class ForceAuthClientHandler(ClientHandler):
"""
@@ -82,25 +102,21 @@ class ForceAuthClientHandler(ClientHandler):
"""
def __init__(self, *args, **kwargs):
- self._force_auth_user = None
- self._force_auth_token = None
+ self._force_user = None
+ self._force_token = None
super(ForceAuthClientHandler, self).__init__(*args, **kwargs)
def get_response(self, request):
# This is the simplest place we can hook into to patch the
# request object.
- request._force_auth_user = self._force_auth_user
- request._force_auth_token = self._force_auth_token
+ force_authenticate(request, self._force_user, self._force_token)
return super(ForceAuthClientHandler, self).get_response(request)
class APIClient(APIRequestFactory, DjangoClient):
def __init__(self, enforce_csrf_checks=False, **defaults):
- # Note that our super call skips Client.__init__
- # since we don't need to instantiate a regular ClientHandler
- super(DjangoClient, self).__init__(**defaults)
+ super(APIClient, self).__init__(**defaults)
self.handler = ForceAuthClientHandler(enforce_csrf_checks)
- self.exc_info = None
self._credentials = {}
def credentials(self, **kwargs):
@@ -114,10 +130,10 @@ class APIClient(APIRequestFactory, DjangoClient):
Forcibly authenticates outgoing requests with the given
user and/or token.
"""
- self.handler._force_auth_user = user
- self.handler._force_auth_token = token
+ self.handler._force_user = user
+ self.handler._force_token = token
- def request(self, **request):
+ def request(self, **kwargs):
# Ensure that any credentials set get added to every request.
- request.update(self._credentials)
- return super(APIClient, self).request(**request)
+ kwargs.update(self._credentials)
+ return super(APIClient, self).request(**kwargs)