diff options
| author | Tom Christie | 2013-06-29 08:05:08 +0100 |
|---|---|---|
| committer | Tom Christie | 2013-06-29 08:05:08 +0100 |
| commit | 90bc07f3f160485001ea329e5f69f7e521d14ec9 (patch) | |
| tree | b872a8c913e24e1138900855d602f4aeee994aa1 | |
| parent | f585480ee10f4b5e61db4ac343b1d2af25d2de97 (diff) | |
| download | django-rest-framework-90bc07f3f160485001ea329e5f69f7e521d14ec9.tar.bz2 | |
Addeded 'APITestClient.credentials()'
| -rw-r--r-- | rest_framework/test.py | 29 | ||||
| -rw-r--r-- | rest_framework/tests/test_testing.py | 32 |
2 files changed, 61 insertions, 0 deletions
diff --git a/rest_framework/test.py b/rest_framework/test.py index 9fce2c08..8115fa0d 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -1,5 +1,8 @@ +# -- coding: utf-8 -- + # Note that we use `DjangoRequestFactory` and `DjangoClient` names in order # to make it harder for the user to import the wrong thing without realizing. +from __future__ import unicode_literals from django.conf import settings from django.test.client import Client as DjangoClient from rest_framework.compat import RequestFactory as DjangoRequestFactory @@ -72,31 +75,57 @@ class APIRequestFactory(DjangoRequestFactory): class APIClient(APIRequestFactory, DjangoClient): + def __init__(self, *args, **kwargs): + self._credentials = {} + super(APIClient, self).__init__(*args, **kwargs) + + def credentials(self, **kwargs): + self._credentials = kwargs + + def get(self, path, data={}, follow=False, **extra): + extra.update(self._credentials) + response = super(APIClient, self).get(path, data=data, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response + + def head(self, path, data={}, follow=False, **extra): + extra.update(self._credentials) + response = super(APIClient, self).head(path, data=data, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response + def post(self, path, data=None, format=None, content_type=None, follow=False, **extra): + extra.update(self._credentials) response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) if follow: response = self._handle_redirects(response, **extra) return response def put(self, path, data=None, format=None, content_type=None, follow=False, **extra): + extra.update(self._credentials) response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) if follow: response = self._handle_redirects(response, **extra) return response def patch(self, path, data=None, format=None, content_type=None, follow=False, **extra): + extra.update(self._credentials) response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) if follow: response = self._handle_redirects(response, **extra) return response def delete(self, path, data=None, format=None, content_type=None, follow=False, **extra): + extra.update(self._credentials) response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) if follow: response = self._handle_redirects(response, **extra) return response def options(self, path, data=None, format=None, content_type=None, follow=False, **extra): + extra.update(self._credentials) response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) if follow: response = self._handle_redirects(response, **extra) diff --git a/rest_framework/tests/test_testing.py b/rest_framework/tests/test_testing.py new file mode 100644 index 00000000..71dacd38 --- /dev/null +++ b/rest_framework/tests/test_testing.py @@ -0,0 +1,32 @@ +# -- coding: utf-8 -- + +from __future__ import unicode_literals +from django.test import TestCase +from rest_framework.compat import patterns, url +from rest_framework.decorators import api_view +from rest_framework.response import Response +from rest_framework.test import APIClient + + +@api_view(['GET']) +def mirror(request): + return Response({ + 'auth': request.META.get('HTTP_AUTHORIZATION', b'') + }) + + +urlpatterns = patterns('', + url(r'^view/$', mirror), +) + + +class CheckTestClient(TestCase): + urls = 'rest_framework.tests.test_testing' + + def setUp(self): + self.client = APIClient() + + def test_credentials(self): + self.client.credentials(HTTP_AUTHORIZATION='example') + response = self.client.get('/view/') + self.assertEqual(response.data['auth'], 'example') |
