aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework/test.py
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework/test.py')
-rw-r--r--rest_framework/test.py139
1 files changed, 139 insertions, 0 deletions
diff --git a/rest_framework/test.py b/rest_framework/test.py
new file mode 100644
index 00000000..29d017ee
--- /dev/null
+++ b/rest_framework/test.py
@@ -0,0 +1,139 @@
+# -- coding: utf-8 --
+
+# Note that we import as `DjangoRequestFactory` and `DjangoClient` in order
+# to make it harder for the user to import the wrong thing without realizing.
+from __future__ import unicode_literals
+from django.conf import settings
+from django.test.client import Client as DjangoClient
+from django.test.client import ClientHandler
+from rest_framework.settings import api_settings
+from rest_framework.compat import RequestFactory as DjangoRequestFactory
+from rest_framework.compat import force_bytes_or_smart_bytes, six
+
+
+def force_authenticate(request, user=None, token=None):
+ request._force_auth_user = user
+ request._force_auth_token = token
+
+
+class APIRequestFactory(DjangoRequestFactory):
+ renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES
+ default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT
+
+ def __init__(self, enforce_csrf_checks=False, **defaults):
+ self.enforce_csrf_checks = enforce_csrf_checks
+ self.renderer_classes = {}
+ for cls in self.renderer_classes_list:
+ self.renderer_classes[cls.format] = cls
+ super(APIRequestFactory, self).__init__(**defaults)
+
+ def _encode_data(self, data, format=None, content_type=None):
+ """
+ Encode the data returning a two tuple of (bytes, content_type)
+ """
+
+ if not data:
+ return ('', None)
+
+ assert format is None or content_type is None, (
+ 'You may not set both `format` and `content_type`.'
+ )
+
+ if content_type:
+ # Content type specified explicitly, treat data as a raw bytestring
+ ret = force_bytes_or_smart_bytes(data, settings.DEFAULT_CHARSET)
+
+ else:
+ format = format or self.default_format
+
+ assert format in self.renderer_classes, ("Invalid format '{0}'. "
+ "Available formats are {1}. Set TEST_REQUEST_RENDERER_CLASSES "
+ "to enable extra request formats.".format(
+ format,
+ ', '.join(["'" + fmt + "'" for fmt in self.renderer_classes.keys()])
+ )
+ )
+
+ # Use format and render the data into a bytestring
+ renderer = self.renderer_classes[format]()
+ ret = renderer.render(data)
+
+ # Determine the content-type header from the renderer
+ content_type = "{0}; charset={1}".format(
+ renderer.media_type, renderer.charset
+ )
+
+ # Coerce text to bytes if required.
+ if isinstance(ret, six.text_type):
+ ret = bytes(ret.encode(renderer.charset))
+
+ return ret, content_type
+
+ def post(self, path, data=None, format=None, content_type=None, **extra):
+ data, content_type = self._encode_data(data, format, content_type)
+ return self.generic('POST', path, data, content_type, **extra)
+
+ def put(self, path, data=None, format=None, content_type=None, **extra):
+ data, content_type = self._encode_data(data, format, content_type)
+ return self.generic('PUT', path, data, content_type, **extra)
+
+ def patch(self, path, data=None, format=None, content_type=None, **extra):
+ data, content_type = self._encode_data(data, format, content_type)
+ return self.generic('PATCH', path, data, content_type, **extra)
+
+ def delete(self, path, data=None, format=None, content_type=None, **extra):
+ data, content_type = self._encode_data(data, format, content_type)
+ return self.generic('DELETE', path, data, content_type, **extra)
+
+ def options(self, path, data=None, format=None, content_type=None, **extra):
+ data, content_type = self._encode_data(data, format, content_type)
+ return self.generic('OPTIONS', path, data, content_type, **extra)
+
+ def request(self, **kwargs):
+ request = super(APIRequestFactory, self).request(**kwargs)
+ request._dont_enforce_csrf_checks = not self.enforce_csrf_checks
+ return request
+
+
+class ForceAuthClientHandler(ClientHandler):
+ """
+ A patched version of ClientHandler that can enforce authentication
+ on the outgoing requests.
+ """
+
+ def __init__(self, *args, **kwargs):
+ self._force_user = None
+ self._force_token = None
+ super(ForceAuthClientHandler, self).__init__(*args, **kwargs)
+
+ def get_response(self, request):
+ # This is the simplest place we can hook into to patch the
+ # request object.
+ force_authenticate(request, self._force_user, self._force_token)
+ return super(ForceAuthClientHandler, self).get_response(request)
+
+
+class APIClient(APIRequestFactory, DjangoClient):
+ def __init__(self, enforce_csrf_checks=False, **defaults):
+ super(APIClient, self).__init__(**defaults)
+ self.handler = ForceAuthClientHandler(enforce_csrf_checks)
+ self._credentials = {}
+
+ def credentials(self, **kwargs):
+ """
+ Sets headers that will be used on every outgoing request.
+ """
+ self._credentials = kwargs
+
+ def force_authenticate(self, user=None, token=None):
+ """
+ Forcibly authenticates outgoing requests with the given
+ user and/or token.
+ """
+ self.handler._force_user = user
+ self.handler._force_token = token
+
+ def request(self, **kwargs):
+ # Ensure that any credentials set get added to every request.
+ kwargs.update(self._credentials)
+ return super(APIClient, self).request(**kwargs)