From 0a722de171b0e80ac26d8c77b8051a4170bdb4c6 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 1 Jul 2013 13:59:05 +0100 Subject: Complete testing docs --- rest_framework/test.py | 70 +++++++++++++++++++++++++++++++------------------- 1 file changed, 43 insertions(+), 27 deletions(-) (limited to 'rest_framework/test.py') diff --git a/rest_framework/test.py b/rest_framework/test.py index 2f658a56..29d017ee 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -1,22 +1,31 @@ # -- coding: utf-8 -- -# Note that we use `DjangoRequestFactory` and `DjangoClient` names in order +# Note that we import as `DjangoRequestFactory` and `DjangoClient` in order # to make it harder for the user to import the wrong thing without realizing. from __future__ import unicode_literals from django.conf import settings from django.test.client import Client as DjangoClient from django.test.client import ClientHandler +from rest_framework.settings import api_settings from rest_framework.compat import RequestFactory as DjangoRequestFactory from rest_framework.compat import force_bytes_or_smart_bytes, six -from rest_framework.renderers import JSONRenderer, MultiPartRenderer + + +def force_authenticate(request, user=None, token=None): + request._force_auth_user = user + request._force_auth_token = token class APIRequestFactory(DjangoRequestFactory): - renderer_classes = { - 'json': JSONRenderer, - 'multipart': MultiPartRenderer - } - default_format = 'multipart' + renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES + default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT + + def __init__(self, enforce_csrf_checks=False, **defaults): + self.enforce_csrf_checks = enforce_csrf_checks + self.renderer_classes = {} + for cls in self.renderer_classes_list: + self.renderer_classes[cls.format] = cls + super(APIRequestFactory, self).__init__(**defaults) def _encode_data(self, data, format=None, content_type=None): """ @@ -35,18 +44,24 @@ class APIRequestFactory(DjangoRequestFactory): ret = force_bytes_or_smart_bytes(data, settings.DEFAULT_CHARSET) else: - # Use format and render the data into a bytestring format = format or self.default_format + + assert format in self.renderer_classes, ("Invalid format '{0}'. " + "Available formats are {1}. Set TEST_REQUEST_RENDERER_CLASSES " + "to enable extra request formats.".format( + format, + ', '.join(["'" + fmt + "'" for fmt in self.renderer_classes.keys()]) + ) + ) + + # Use format and render the data into a bytestring renderer = self.renderer_classes[format]() ret = renderer.render(data) # Determine the content-type header from the renderer - if ';' in renderer.media_type: - content_type = renderer.media_type - else: - content_type = "{0}; charset={1}".format( - renderer.media_type, renderer.charset - ) + content_type = "{0}; charset={1}".format( + renderer.media_type, renderer.charset + ) # Coerce text to bytes if required. if isinstance(ret, six.text_type): @@ -74,6 +89,11 @@ class APIRequestFactory(DjangoRequestFactory): data, content_type = self._encode_data(data, format, content_type) return self.generic('OPTIONS', path, data, content_type, **extra) + def request(self, **kwargs): + request = super(APIRequestFactory, self).request(**kwargs) + request._dont_enforce_csrf_checks = not self.enforce_csrf_checks + return request + class ForceAuthClientHandler(ClientHandler): """ @@ -82,25 +102,21 @@ class ForceAuthClientHandler(ClientHandler): """ def __init__(self, *args, **kwargs): - self._force_auth_user = None - self._force_auth_token = None + self._force_user = None + self._force_token = None super(ForceAuthClientHandler, self).__init__(*args, **kwargs) def get_response(self, request): # This is the simplest place we can hook into to patch the # request object. - request._force_auth_user = self._force_auth_user - request._force_auth_token = self._force_auth_token + force_authenticate(request, self._force_user, self._force_token) return super(ForceAuthClientHandler, self).get_response(request) class APIClient(APIRequestFactory, DjangoClient): def __init__(self, enforce_csrf_checks=False, **defaults): - # Note that our super call skips Client.__init__ - # since we don't need to instantiate a regular ClientHandler - super(DjangoClient, self).__init__(**defaults) + super(APIClient, self).__init__(**defaults) self.handler = ForceAuthClientHandler(enforce_csrf_checks) - self.exc_info = None self._credentials = {} def credentials(self, **kwargs): @@ -114,10 +130,10 @@ class APIClient(APIRequestFactory, DjangoClient): Forcibly authenticates outgoing requests with the given user and/or token. """ - self.handler._force_auth_user = user - self.handler._force_auth_token = token + self.handler._force_user = user + self.handler._force_token = token - def request(self, **request): + def request(self, **kwargs): # Ensure that any credentials set get added to every request. - request.update(self._credentials) - return super(APIClient, self).request(**request) + kwargs.update(self._credentials) + return super(APIClient, self).request(**kwargs) -- cgit v1.2.3