diff options
| author | Tom Christie | 2013-07-01 13:59:05 +0100 | 
|---|---|---|
| committer | Tom Christie | 2013-07-01 13:59:05 +0100 | 
| commit | 0a722de171b0e80ac26d8c77b8051a4170bdb4c6 (patch) | |
| tree | 2ab02448965b7cc288fced2d3a1185d70050fac9 /rest_framework/test.py | |
| parent | d31d7c18676b6292e8dc688b61913d572eccde91 (diff) | |
| download | django-rest-framework-0a722de171b0e80ac26d8c77b8051a4170bdb4c6.tar.bz2 | |
Complete testing docs
Diffstat (limited to 'rest_framework/test.py')
| -rw-r--r-- | rest_framework/test.py | 70 | 
1 files changed, 43 insertions, 27 deletions
| diff --git a/rest_framework/test.py b/rest_framework/test.py index 2f658a56..29d017ee 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -1,22 +1,31 @@  # -- coding: utf-8 -- -# Note that we use `DjangoRequestFactory` and `DjangoClient` names in order +# Note that we import as `DjangoRequestFactory` and `DjangoClient` in order  # to make it harder for the user to import the wrong thing without realizing.  from __future__ import unicode_literals  from django.conf import settings  from django.test.client import Client as DjangoClient  from django.test.client import ClientHandler +from rest_framework.settings import api_settings  from rest_framework.compat import RequestFactory as DjangoRequestFactory  from rest_framework.compat import force_bytes_or_smart_bytes, six -from rest_framework.renderers import JSONRenderer, MultiPartRenderer + + +def force_authenticate(request, user=None, token=None): +    request._force_auth_user = user +    request._force_auth_token = token  class APIRequestFactory(DjangoRequestFactory): -    renderer_classes = { -        'json': JSONRenderer, -        'multipart': MultiPartRenderer -    } -    default_format = 'multipart' +    renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES +    default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT + +    def __init__(self, enforce_csrf_checks=False, **defaults): +        self.enforce_csrf_checks = enforce_csrf_checks +        self.renderer_classes = {} +        for cls in self.renderer_classes_list: +            self.renderer_classes[cls.format] = cls +        super(APIRequestFactory, self).__init__(**defaults)      def _encode_data(self, data, format=None, content_type=None):          """ @@ -35,18 +44,24 @@ class APIRequestFactory(DjangoRequestFactory):              ret = force_bytes_or_smart_bytes(data, settings.DEFAULT_CHARSET)          else: -            # Use format and render the data into a bytestring              format = format or self.default_format + +            assert format in self.renderer_classes, ("Invalid format '{0}'. " +                "Available formats are {1}.  Set TEST_REQUEST_RENDERER_CLASSES " +                "to enable extra request formats.".format( +                    format, +                    ', '.join(["'" + fmt + "'" for fmt in self.renderer_classes.keys()]) +                ) +            ) + +            # Use format and render the data into a bytestring              renderer = self.renderer_classes[format]()              ret = renderer.render(data)              # Determine the content-type header from the renderer -            if ';' in renderer.media_type: -                content_type = renderer.media_type -            else: -                content_type = "{0}; charset={1}".format( -                    renderer.media_type, renderer.charset -                ) +            content_type = "{0}; charset={1}".format( +                renderer.media_type, renderer.charset +            )              # Coerce text to bytes if required.              if isinstance(ret, six.text_type): @@ -74,6 +89,11 @@ class APIRequestFactory(DjangoRequestFactory):          data, content_type = self._encode_data(data, format, content_type)          return self.generic('OPTIONS', path, data, content_type, **extra) +    def request(self, **kwargs): +        request = super(APIRequestFactory, self).request(**kwargs) +        request._dont_enforce_csrf_checks = not self.enforce_csrf_checks +        return request +  class ForceAuthClientHandler(ClientHandler):      """ @@ -82,25 +102,21 @@ class ForceAuthClientHandler(ClientHandler):      """      def __init__(self, *args, **kwargs): -        self._force_auth_user = None -        self._force_auth_token = None +        self._force_user = None +        self._force_token = None          super(ForceAuthClientHandler, self).__init__(*args, **kwargs)      def get_response(self, request):          # This is the simplest place we can hook into to patch the          # request object. -        request._force_auth_user = self._force_auth_user -        request._force_auth_token = self._force_auth_token +        force_authenticate(request, self._force_user, self._force_token)          return super(ForceAuthClientHandler, self).get_response(request)  class APIClient(APIRequestFactory, DjangoClient):      def __init__(self, enforce_csrf_checks=False, **defaults): -        # Note that our super call skips Client.__init__ -        # since we don't need to instantiate a regular ClientHandler -        super(DjangoClient, self).__init__(**defaults) +        super(APIClient, self).__init__(**defaults)          self.handler = ForceAuthClientHandler(enforce_csrf_checks) -        self.exc_info = None          self._credentials = {}      def credentials(self, **kwargs): @@ -114,10 +130,10 @@ class APIClient(APIRequestFactory, DjangoClient):          Forcibly authenticates outgoing requests with the given          user and/or token.          """ -        self.handler._force_auth_user = user -        self.handler._force_auth_token = token +        self.handler._force_user = user +        self.handler._force_token = token -    def request(self, **request): +    def request(self, **kwargs):          # Ensure that any credentials set get added to every request. -        request.update(self._credentials) -        return super(APIClient, self).request(**request) +        kwargs.update(self._credentials) +        return super(APIClient, self).request(**kwargs) | 
