diff options
| author | Tom Christie | 2013-06-29 21:02:58 +0100 | 
|---|---|---|
| committer | Tom Christie | 2013-06-29 21:02:58 +0100 | 
| commit | 664f8c63655770cd90bdbd510b315bcd045b380a (patch) | |
| tree | 2145a39de36701bc67cad67f2b303594a76d23e9 /rest_framework/test.py | |
| parent | 35022ca9213939a2f40c82facffa908a818efe0b (diff) | |
| download | django-rest-framework-664f8c63655770cd90bdbd510b315bcd045b380a.tar.bz2 | |
Added APIClient.authenticate()
Diffstat (limited to 'rest_framework/test.py')
| -rw-r--r-- | rest_framework/test.py | 39 | 
1 files changed, 35 insertions, 4 deletions
| diff --git a/rest_framework/test.py b/rest_framework/test.py index 8115fa0d..08de2297 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -5,6 +5,7 @@  from __future__ import unicode_literals  from django.conf import settings  from django.test.client import Client as DjangoClient +from django.test.client import ClientHandler  from rest_framework.compat import RequestFactory as DjangoRequestFactory  from rest_framework.compat import force_bytes_or_smart_bytes, six  from rest_framework.renderers import JSONRenderer, MultiPartRenderer @@ -13,9 +14,9 @@ from rest_framework.renderers import JSONRenderer, MultiPartRenderer  class APIRequestFactory(DjangoRequestFactory):      renderer_classes = {          'json': JSONRenderer, -        'form': MultiPartRenderer +        'multipart': MultiPartRenderer      } -    default_format = 'form' +    default_format = 'multipart'      def _encode_data(self, data, format=None, content_type=None):          """ @@ -74,14 +75,44 @@ class APIRequestFactory(DjangoRequestFactory):          return self.generic('OPTIONS', path, data, content_type, **extra) -class APIClient(APIRequestFactory, DjangoClient): +class ForceAuthClientHandler(ClientHandler): +    """ +    A patched version of ClientHandler that can enforce authentication +    on the outgoing requests. +    """ +      def __init__(self, *args, **kwargs): +        self._force_auth_user = None +        self._force_auth_token = None +        super(ForceAuthClientHandler, self).__init__(*args, **kwargs) + +    def force_authenticate(self, user=None, token=None): +        self._force_auth_user = user +        self._force_auth_token = token + +    def get_response(self, request): +        # This is the simplest place we can hook into to patch the +        # request object. +        request._force_auth_user = self._force_auth_user +        request._force_auth_token = self._force_auth_token +        return super(ForceAuthClientHandler, self).get_response(request) + + +class APIClient(APIRequestFactory, DjangoClient): +    def __init__(self, enforce_csrf_checks=False, **defaults): +        # Note that our super call skips Client.__init__ +        # since we don't need to instantiate a regular ClientHandler +        super(DjangoClient, self).__init__(**defaults) +        self.handler = ForceAuthClientHandler(enforce_csrf_checks) +        self.exc_info = None          self._credentials = {} -        super(APIClient, self).__init__(*args, **kwargs)      def credentials(self, **kwargs):          self._credentials = kwargs +    def authenticate(self, user=None, token=None): +        self.handler.force_authenticate(user, token) +      def get(self, path, data={}, follow=False, **extra):          extra.update(self._credentials)          response = super(APIClient, self).get(path, data=data, **extra) | 
