aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTom Christie2013-06-29 08:05:08 +0100
committerTom Christie2013-06-29 08:05:08 +0100
commit90bc07f3f160485001ea329e5f69f7e521d14ec9 (patch)
treeb872a8c913e24e1138900855d602f4aeee994aa1
parentf585480ee10f4b5e61db4ac343b1d2af25d2de97 (diff)
downloaddjango-rest-framework-90bc07f3f160485001ea329e5f69f7e521d14ec9.tar.bz2
Addeded 'APITestClient.credentials()'
-rw-r--r--rest_framework/test.py29
-rw-r--r--rest_framework/tests/test_testing.py32
2 files changed, 61 insertions, 0 deletions
diff --git a/rest_framework/test.py b/rest_framework/test.py
index 9fce2c08..8115fa0d 100644
--- a/rest_framework/test.py
+++ b/rest_framework/test.py
@@ -1,5 +1,8 @@
+# -- coding: utf-8 --
+
# Note that we use `DjangoRequestFactory` and `DjangoClient` names in order
# to make it harder for the user to import the wrong thing without realizing.
+from __future__ import unicode_literals
from django.conf import settings
from django.test.client import Client as DjangoClient
from rest_framework.compat import RequestFactory as DjangoRequestFactory
@@ -72,31 +75,57 @@ class APIRequestFactory(DjangoRequestFactory):
class APIClient(APIRequestFactory, DjangoClient):
+ def __init__(self, *args, **kwargs):
+ self._credentials = {}
+ super(APIClient, self).__init__(*args, **kwargs)
+
+ def credentials(self, **kwargs):
+ self._credentials = kwargs
+
+ def get(self, path, data={}, follow=False, **extra):
+ extra.update(self._credentials)
+ response = super(APIClient, self).get(path, data=data, **extra)
+ if follow:
+ response = self._handle_redirects(response, **extra)
+ return response
+
+ def head(self, path, data={}, follow=False, **extra):
+ extra.update(self._credentials)
+ response = super(APIClient, self).head(path, data=data, **extra)
+ if follow:
+ response = self._handle_redirects(response, **extra)
+ return response
+
def post(self, path, data=None, format=None, content_type=None, follow=False, **extra):
+ extra.update(self._credentials)
response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra)
if follow:
response = self._handle_redirects(response, **extra)
return response
def put(self, path, data=None, format=None, content_type=None, follow=False, **extra):
+ extra.update(self._credentials)
response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra)
if follow:
response = self._handle_redirects(response, **extra)
return response
def patch(self, path, data=None, format=None, content_type=None, follow=False, **extra):
+ extra.update(self._credentials)
response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra)
if follow:
response = self._handle_redirects(response, **extra)
return response
def delete(self, path, data=None, format=None, content_type=None, follow=False, **extra):
+ extra.update(self._credentials)
response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra)
if follow:
response = self._handle_redirects(response, **extra)
return response
def options(self, path, data=None, format=None, content_type=None, follow=False, **extra):
+ extra.update(self._credentials)
response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra)
if follow:
response = self._handle_redirects(response, **extra)
diff --git a/rest_framework/tests/test_testing.py b/rest_framework/tests/test_testing.py
new file mode 100644
index 00000000..71dacd38
--- /dev/null
+++ b/rest_framework/tests/test_testing.py
@@ -0,0 +1,32 @@
+# -- coding: utf-8 --
+
+from __future__ import unicode_literals
+from django.test import TestCase
+from rest_framework.compat import patterns, url
+from rest_framework.decorators import api_view
+from rest_framework.response import Response
+from rest_framework.test import APIClient
+
+
+@api_view(['GET'])
+def mirror(request):
+ return Response({
+ 'auth': request.META.get('HTTP_AUTHORIZATION', b'')
+ })
+
+
+urlpatterns = patterns('',
+ url(r'^view/$', mirror),
+)
+
+
+class CheckTestClient(TestCase):
+ urls = 'rest_framework.tests.test_testing'
+
+ def setUp(self):
+ self.client = APIClient()
+
+ def test_credentials(self):
+ self.client.credentials(HTTP_AUTHORIZATION='example')
+ response = self.client.get('/view/')
+ self.assertEqual(response.data['auth'], 'example')