aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTom Christie2013-06-29 21:02:58 +0100
committerTom Christie2013-06-29 21:02:58 +0100
commit664f8c63655770cd90bdbd510b315bcd045b380a (patch)
tree2145a39de36701bc67cad67f2b303594a76d23e9
parent35022ca9213939a2f40c82facffa908a818efe0b (diff)
downloaddjango-rest-framework-664f8c63655770cd90bdbd510b315bcd045b380a.tar.bz2
Added APIClient.authenticate()
-rw-r--r--rest_framework/renderers.py2
-rw-r--r--rest_framework/request.py20
-rw-r--r--rest_framework/test.py39
-rw-r--r--rest_framework/tests/test_testing.py42
4 files changed, 95 insertions, 8 deletions
diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py
index d7a7ef29..3a03ca33 100644
--- a/rest_framework/renderers.py
+++ b/rest_framework/renderers.py
@@ -576,7 +576,7 @@ class BrowsableAPIRenderer(BaseRenderer):
class MultiPartRenderer(BaseRenderer):
media_type = 'multipart/form-data; boundary=BoUnDaRyStRiNg'
- format = 'form'
+ format = 'multipart'
charset = 'utf-8'
BOUNDARY = 'BoUnDaRyStRiNg'
diff --git a/rest_framework/request.py b/rest_framework/request.py
index 0d88ebc7..919716f4 100644
--- a/rest_framework/request.py
+++ b/rest_framework/request.py
@@ -64,6 +64,20 @@ def clone_request(request, method):
return ret
+class ForcedAuthentication(object):
+ """
+ This authentication class is used if the test client or request factory
+ forcibly authenticated the request.
+ """
+
+ def __init__(self, force_user, force_token):
+ self.force_user = force_user
+ self.force_token = force_token
+
+ def authenticate(self, request):
+ return (self.force_user, self.force_token)
+
+
class Request(object):
"""
Wrapper allowing to enhance a standard `HttpRequest` instance.
@@ -98,6 +112,12 @@ class Request(object):
self.parser_context['request'] = self
self.parser_context['encoding'] = request.encoding or settings.DEFAULT_CHARSET
+ force_user = getattr(request, '_force_auth_user', None)
+ force_token = getattr(request, '_force_auth_token', None)
+ if (force_user is not None or force_token is not None):
+ forced_auth = ForcedAuthentication(force_user, force_token)
+ self.authenticators = (forced_auth,)
+
def _default_negotiator(self):
return api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS()
diff --git a/rest_framework/test.py b/rest_framework/test.py
index 8115fa0d..08de2297 100644
--- a/rest_framework/test.py
+++ b/rest_framework/test.py
@@ -5,6 +5,7 @@
from __future__ import unicode_literals
from django.conf import settings
from django.test.client import Client as DjangoClient
+from django.test.client import ClientHandler
from rest_framework.compat import RequestFactory as DjangoRequestFactory
from rest_framework.compat import force_bytes_or_smart_bytes, six
from rest_framework.renderers import JSONRenderer, MultiPartRenderer
@@ -13,9 +14,9 @@ from rest_framework.renderers import JSONRenderer, MultiPartRenderer
class APIRequestFactory(DjangoRequestFactory):
renderer_classes = {
'json': JSONRenderer,
- 'form': MultiPartRenderer
+ 'multipart': MultiPartRenderer
}
- default_format = 'form'
+ default_format = 'multipart'
def _encode_data(self, data, format=None, content_type=None):
"""
@@ -74,14 +75,44 @@ class APIRequestFactory(DjangoRequestFactory):
return self.generic('OPTIONS', path, data, content_type, **extra)
-class APIClient(APIRequestFactory, DjangoClient):
+class ForceAuthClientHandler(ClientHandler):
+ """
+ A patched version of ClientHandler that can enforce authentication
+ on the outgoing requests.
+ """
+
def __init__(self, *args, **kwargs):
+ self._force_auth_user = None
+ self._force_auth_token = None
+ super(ForceAuthClientHandler, self).__init__(*args, **kwargs)
+
+ def force_authenticate(self, user=None, token=None):
+ self._force_auth_user = user
+ self._force_auth_token = token
+
+ def get_response(self, request):
+ # This is the simplest place we can hook into to patch the
+ # request object.
+ request._force_auth_user = self._force_auth_user
+ request._force_auth_token = self._force_auth_token
+ return super(ForceAuthClientHandler, self).get_response(request)
+
+
+class APIClient(APIRequestFactory, DjangoClient):
+ def __init__(self, enforce_csrf_checks=False, **defaults):
+ # Note that our super call skips Client.__init__
+ # since we don't need to instantiate a regular ClientHandler
+ super(DjangoClient, self).__init__(**defaults)
+ self.handler = ForceAuthClientHandler(enforce_csrf_checks)
+ self.exc_info = None
self._credentials = {}
- super(APIClient, self).__init__(*args, **kwargs)
def credentials(self, **kwargs):
self._credentials = kwargs
+ def authenticate(self, user=None, token=None):
+ self.handler.force_authenticate(user, token)
+
def get(self, path, data={}, follow=False, **extra):
extra.update(self._credentials)
response = super(APIClient, self).get(path, data=data, **extra)
diff --git a/rest_framework/tests/test_testing.py b/rest_framework/tests/test_testing.py
index 71dacd38..a8398b9a 100644
--- a/rest_framework/tests/test_testing.py
+++ b/rest_framework/tests/test_testing.py
@@ -1,6 +1,7 @@
# -- coding: utf-8 --
from __future__ import unicode_literals
+from django.contrib.auth.models import User
from django.test import TestCase
from rest_framework.compat import patterns, url
from rest_framework.decorators import api_view
@@ -8,10 +9,11 @@ from rest_framework.response import Response
from rest_framework.test import APIClient
-@api_view(['GET'])
+@api_view(['GET', 'POST'])
def mirror(request):
return Response({
- 'auth': request.META.get('HTTP_AUTHORIZATION', b'')
+ 'auth': request.META.get('HTTP_AUTHORIZATION', b''),
+ 'user': request.user.username
})
@@ -27,6 +29,40 @@ class CheckTestClient(TestCase):
self.client = APIClient()
def test_credentials(self):
+ """
+ Setting `.credentials()` adds the required headers to each request.
+ """
self.client.credentials(HTTP_AUTHORIZATION='example')
+ for _ in range(0, 3):
+ response = self.client.get('/view/')
+ self.assertEqual(response.data['auth'], 'example')
+
+ def test_authenticate(self):
+ """
+ Setting `.authenticate()` forcibly authenticates each request.
+ """
+ user = User.objects.create_user('example', 'example@example.com')
+ self.client.authenticate(user)
response = self.client.get('/view/')
- self.assertEqual(response.data['auth'], 'example')
+ self.assertEqual(response.data['user'], 'example')
+
+ def test_csrf_exempt_by_default(self):
+ """
+ By default, the test client is CSRF exempt.
+ """
+ User.objects.create_user('example', 'example@example.com', 'password')
+ self.client.login(username='example', password='password')
+ response = self.client.post('/view/')
+ self.assertEqual(response.status_code, 200)
+
+ def test_explicitly_enforce_csrf_checks(self):
+ """
+ The test client can enforce CSRF checks.
+ """
+ client = APIClient(enforce_csrf_checks=True)
+ User.objects.create_user('example', 'example@example.com', 'password')
+ client.login(username='example', password='password')
+ response = client.post('/view/')
+ expected = {'detail': 'CSRF Failed: CSRF cookie not set.'}
+ self.assertEqual(response.status_code, 403)
+ self.assertEqual(response.data, expected)