diff options
| author | Tom Christie | 2013-06-29 21:02:58 +0100 |
|---|---|---|
| committer | Tom Christie | 2013-06-29 21:02:58 +0100 |
| commit | 664f8c63655770cd90bdbd510b315bcd045b380a (patch) | |
| tree | 2145a39de36701bc67cad67f2b303594a76d23e9 /rest_framework/test.py | |
| parent | 35022ca9213939a2f40c82facffa908a818efe0b (diff) | |
| download | django-rest-framework-664f8c63655770cd90bdbd510b315bcd045b380a.tar.bz2 | |
Added APIClient.authenticate()
Diffstat (limited to 'rest_framework/test.py')
| -rw-r--r-- | rest_framework/test.py | 39 |
1 files changed, 35 insertions, 4 deletions
diff --git a/rest_framework/test.py b/rest_framework/test.py index 8115fa0d..08de2297 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -5,6 +5,7 @@ from __future__ import unicode_literals from django.conf import settings from django.test.client import Client as DjangoClient +from django.test.client import ClientHandler from rest_framework.compat import RequestFactory as DjangoRequestFactory from rest_framework.compat import force_bytes_or_smart_bytes, six from rest_framework.renderers import JSONRenderer, MultiPartRenderer @@ -13,9 +14,9 @@ from rest_framework.renderers import JSONRenderer, MultiPartRenderer class APIRequestFactory(DjangoRequestFactory): renderer_classes = { 'json': JSONRenderer, - 'form': MultiPartRenderer + 'multipart': MultiPartRenderer } - default_format = 'form' + default_format = 'multipart' def _encode_data(self, data, format=None, content_type=None): """ @@ -74,14 +75,44 @@ class APIRequestFactory(DjangoRequestFactory): return self.generic('OPTIONS', path, data, content_type, **extra) -class APIClient(APIRequestFactory, DjangoClient): +class ForceAuthClientHandler(ClientHandler): + """ + A patched version of ClientHandler that can enforce authentication + on the outgoing requests. + """ + def __init__(self, *args, **kwargs): + self._force_auth_user = None + self._force_auth_token = None + super(ForceAuthClientHandler, self).__init__(*args, **kwargs) + + def force_authenticate(self, user=None, token=None): + self._force_auth_user = user + self._force_auth_token = token + + def get_response(self, request): + # This is the simplest place we can hook into to patch the + # request object. + request._force_auth_user = self._force_auth_user + request._force_auth_token = self._force_auth_token + return super(ForceAuthClientHandler, self).get_response(request) + + +class APIClient(APIRequestFactory, DjangoClient): + def __init__(self, enforce_csrf_checks=False, **defaults): + # Note that our super call skips Client.__init__ + # since we don't need to instantiate a regular ClientHandler + super(DjangoClient, self).__init__(**defaults) + self.handler = ForceAuthClientHandler(enforce_csrf_checks) + self.exc_info = None self._credentials = {} - super(APIClient, self).__init__(*args, **kwargs) def credentials(self, **kwargs): self._credentials = kwargs + def authenticate(self, user=None, token=None): + self.handler.force_authenticate(user, token) + def get(self, path, data={}, follow=False, **extra): extra.update(self._credentials) response = super(APIClient, self).get(path, data=data, **extra) |
