diff options
| -rw-r--r-- | rest_framework/renderers.py | 2 | ||||
| -rw-r--r-- | rest_framework/request.py | 20 | ||||
| -rw-r--r-- | rest_framework/test.py | 39 | ||||
| -rw-r--r-- | rest_framework/tests/test_testing.py | 42 |
4 files changed, 95 insertions, 8 deletions
diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index d7a7ef29..3a03ca33 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -576,7 +576,7 @@ class BrowsableAPIRenderer(BaseRenderer): class MultiPartRenderer(BaseRenderer): media_type = 'multipart/form-data; boundary=BoUnDaRyStRiNg' - format = 'form' + format = 'multipart' charset = 'utf-8' BOUNDARY = 'BoUnDaRyStRiNg' diff --git a/rest_framework/request.py b/rest_framework/request.py index 0d88ebc7..919716f4 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -64,6 +64,20 @@ def clone_request(request, method): return ret +class ForcedAuthentication(object): + """ + This authentication class is used if the test client or request factory + forcibly authenticated the request. + """ + + def __init__(self, force_user, force_token): + self.force_user = force_user + self.force_token = force_token + + def authenticate(self, request): + return (self.force_user, self.force_token) + + class Request(object): """ Wrapper allowing to enhance a standard `HttpRequest` instance. @@ -98,6 +112,12 @@ class Request(object): self.parser_context['request'] = self self.parser_context['encoding'] = request.encoding or settings.DEFAULT_CHARSET + force_user = getattr(request, '_force_auth_user', None) + force_token = getattr(request, '_force_auth_token', None) + if (force_user is not None or force_token is not None): + forced_auth = ForcedAuthentication(force_user, force_token) + self.authenticators = (forced_auth,) + def _default_negotiator(self): return api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS() diff --git a/rest_framework/test.py b/rest_framework/test.py index 8115fa0d..08de2297 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -5,6 +5,7 @@ from __future__ import unicode_literals from django.conf import settings from django.test.client import Client as DjangoClient +from django.test.client import ClientHandler from rest_framework.compat import RequestFactory as DjangoRequestFactory from rest_framework.compat import force_bytes_or_smart_bytes, six from rest_framework.renderers import JSONRenderer, MultiPartRenderer @@ -13,9 +14,9 @@ from rest_framework.renderers import JSONRenderer, MultiPartRenderer class APIRequestFactory(DjangoRequestFactory): renderer_classes = { 'json': JSONRenderer, - 'form': MultiPartRenderer + 'multipart': MultiPartRenderer } - default_format = 'form' + default_format = 'multipart' def _encode_data(self, data, format=None, content_type=None): """ @@ -74,14 +75,44 @@ class APIRequestFactory(DjangoRequestFactory): return self.generic('OPTIONS', path, data, content_type, **extra) -class APIClient(APIRequestFactory, DjangoClient): +class ForceAuthClientHandler(ClientHandler): + """ + A patched version of ClientHandler that can enforce authentication + on the outgoing requests. + """ + def __init__(self, *args, **kwargs): + self._force_auth_user = None + self._force_auth_token = None + super(ForceAuthClientHandler, self).__init__(*args, **kwargs) + + def force_authenticate(self, user=None, token=None): + self._force_auth_user = user + self._force_auth_token = token + + def get_response(self, request): + # This is the simplest place we can hook into to patch the + # request object. + request._force_auth_user = self._force_auth_user + request._force_auth_token = self._force_auth_token + return super(ForceAuthClientHandler, self).get_response(request) + + +class APIClient(APIRequestFactory, DjangoClient): + def __init__(self, enforce_csrf_checks=False, **defaults): + # Note that our super call skips Client.__init__ + # since we don't need to instantiate a regular ClientHandler + super(DjangoClient, self).__init__(**defaults) + self.handler = ForceAuthClientHandler(enforce_csrf_checks) + self.exc_info = None self._credentials = {} - super(APIClient, self).__init__(*args, **kwargs) def credentials(self, **kwargs): self._credentials = kwargs + def authenticate(self, user=None, token=None): + self.handler.force_authenticate(user, token) + def get(self, path, data={}, follow=False, **extra): extra.update(self._credentials) response = super(APIClient, self).get(path, data=data, **extra) diff --git a/rest_framework/tests/test_testing.py b/rest_framework/tests/test_testing.py index 71dacd38..a8398b9a 100644 --- a/rest_framework/tests/test_testing.py +++ b/rest_framework/tests/test_testing.py @@ -1,6 +1,7 @@ # -- coding: utf-8 -- from __future__ import unicode_literals +from django.contrib.auth.models import User from django.test import TestCase from rest_framework.compat import patterns, url from rest_framework.decorators import api_view @@ -8,10 +9,11 @@ from rest_framework.response import Response from rest_framework.test import APIClient -@api_view(['GET']) +@api_view(['GET', 'POST']) def mirror(request): return Response({ - 'auth': request.META.get('HTTP_AUTHORIZATION', b'') + 'auth': request.META.get('HTTP_AUTHORIZATION', b''), + 'user': request.user.username }) @@ -27,6 +29,40 @@ class CheckTestClient(TestCase): self.client = APIClient() def test_credentials(self): + """ + Setting `.credentials()` adds the required headers to each request. + """ self.client.credentials(HTTP_AUTHORIZATION='example') + for _ in range(0, 3): + response = self.client.get('/view/') + self.assertEqual(response.data['auth'], 'example') + + def test_authenticate(self): + """ + Setting `.authenticate()` forcibly authenticates each request. + """ + user = User.objects.create_user('example', 'example@example.com') + self.client.authenticate(user) response = self.client.get('/view/') - self.assertEqual(response.data['auth'], 'example') + self.assertEqual(response.data['user'], 'example') + + def test_csrf_exempt_by_default(self): + """ + By default, the test client is CSRF exempt. + """ + User.objects.create_user('example', 'example@example.com', 'password') + self.client.login(username='example', password='password') + response = self.client.post('/view/') + self.assertEqual(response.status_code, 200) + + def test_explicitly_enforce_csrf_checks(self): + """ + The test client can enforce CSRF checks. + """ + client = APIClient(enforce_csrf_checks=True) + User.objects.create_user('example', 'example@example.com', 'password') + client.login(username='example', password='password') + response = client.post('/view/') + expected = {'detail': 'CSRF Failed: CSRF cookie not set.'} + self.assertEqual(response.status_code, 403) + self.assertEqual(response.data, expected) |
