diff options
Diffstat (limited to 'rest_framework/test.py')
| -rw-r--r-- | rest_framework/test.py | 139 |
1 files changed, 139 insertions, 0 deletions
diff --git a/rest_framework/test.py b/rest_framework/test.py new file mode 100644 index 00000000..29d017ee --- /dev/null +++ b/rest_framework/test.py @@ -0,0 +1,139 @@ +# -- coding: utf-8 -- + +# Note that we import as `DjangoRequestFactory` and `DjangoClient` in order +# to make it harder for the user to import the wrong thing without realizing. +from __future__ import unicode_literals +from django.conf import settings +from django.test.client import Client as DjangoClient +from django.test.client import ClientHandler +from rest_framework.settings import api_settings +from rest_framework.compat import RequestFactory as DjangoRequestFactory +from rest_framework.compat import force_bytes_or_smart_bytes, six + + +def force_authenticate(request, user=None, token=None): + request._force_auth_user = user + request._force_auth_token = token + + +class APIRequestFactory(DjangoRequestFactory): + renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES + default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT + + def __init__(self, enforce_csrf_checks=False, **defaults): + self.enforce_csrf_checks = enforce_csrf_checks + self.renderer_classes = {} + for cls in self.renderer_classes_list: + self.renderer_classes[cls.format] = cls + super(APIRequestFactory, self).__init__(**defaults) + + def _encode_data(self, data, format=None, content_type=None): + """ + Encode the data returning a two tuple of (bytes, content_type) + """ + + if not data: + return ('', None) + + assert format is None or content_type is None, ( + 'You may not set both `format` and `content_type`.' + ) + + if content_type: + # Content type specified explicitly, treat data as a raw bytestring + ret = force_bytes_or_smart_bytes(data, settings.DEFAULT_CHARSET) + + else: + format = format or self.default_format + + assert format in self.renderer_classes, ("Invalid format '{0}'. " + "Available formats are {1}. Set TEST_REQUEST_RENDERER_CLASSES " + "to enable extra request formats.".format( + format, + ', '.join(["'" + fmt + "'" for fmt in self.renderer_classes.keys()]) + ) + ) + + # Use format and render the data into a bytestring + renderer = self.renderer_classes[format]() + ret = renderer.render(data) + + # Determine the content-type header from the renderer + content_type = "{0}; charset={1}".format( + renderer.media_type, renderer.charset + ) + + # Coerce text to bytes if required. + if isinstance(ret, six.text_type): + ret = bytes(ret.encode(renderer.charset)) + + return ret, content_type + + def post(self, path, data=None, format=None, content_type=None, **extra): + data, content_type = self._encode_data(data, format, content_type) + return self.generic('POST', path, data, content_type, **extra) + + def put(self, path, data=None, format=None, content_type=None, **extra): + data, content_type = self._encode_data(data, format, content_type) + return self.generic('PUT', path, data, content_type, **extra) + + def patch(self, path, data=None, format=None, content_type=None, **extra): + data, content_type = self._encode_data(data, format, content_type) + return self.generic('PATCH', path, data, content_type, **extra) + + def delete(self, path, data=None, format=None, content_type=None, **extra): + data, content_type = self._encode_data(data, format, content_type) + return self.generic('DELETE', path, data, content_type, **extra) + + def options(self, path, data=None, format=None, content_type=None, **extra): + data, content_type = self._encode_data(data, format, content_type) + return self.generic('OPTIONS', path, data, content_type, **extra) + + def request(self, **kwargs): + request = super(APIRequestFactory, self).request(**kwargs) + request._dont_enforce_csrf_checks = not self.enforce_csrf_checks + return request + + +class ForceAuthClientHandler(ClientHandler): + """ + A patched version of ClientHandler that can enforce authentication + on the outgoing requests. + """ + + def __init__(self, *args, **kwargs): + self._force_user = None + self._force_token = None + super(ForceAuthClientHandler, self).__init__(*args, **kwargs) + + def get_response(self, request): + # This is the simplest place we can hook into to patch the + # request object. + force_authenticate(request, self._force_user, self._force_token) + return super(ForceAuthClientHandler, self).get_response(request) + + +class APIClient(APIRequestFactory, DjangoClient): + def __init__(self, enforce_csrf_checks=False, **defaults): + super(APIClient, self).__init__(**defaults) + self.handler = ForceAuthClientHandler(enforce_csrf_checks) + self._credentials = {} + + def credentials(self, **kwargs): + """ + Sets headers that will be used on every outgoing request. + """ + self._credentials = kwargs + + def force_authenticate(self, user=None, token=None): + """ + Forcibly authenticates outgoing requests with the given + user and/or token. + """ + self.handler._force_user = user + self.handler._force_token = token + + def request(self, **kwargs): + # Ensure that any credentials set get added to every request. + kwargs.update(self._credentials) + return super(APIClient, self).request(**kwargs) |
