aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework/test.py
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework/test.py')
-rw-r--r--rest_framework/test.py84
1 files changed, 78 insertions, 6 deletions
diff --git a/rest_framework/test.py b/rest_framework/test.py
index a18f5a29..a83d082a 100644
--- a/rest_framework/test.py
+++ b/rest_framework/test.py
@@ -8,9 +8,11 @@ from django.conf import settings
from django.test.client import Client as DjangoClient
from django.test.client import ClientHandler
from django.test import testcases
+from django.utils import six
+from django.utils.http import urlencode
from rest_framework.settings import api_settings
from rest_framework.compat import RequestFactory as DjangoRequestFactory
-from rest_framework.compat import force_bytes_or_smart_bytes, six
+from rest_framework.compat import force_bytes_or_smart_bytes
def force_authenticate(request, user=None, token=None):
@@ -34,8 +36,8 @@ class APIRequestFactory(DjangoRequestFactory):
Encode the data returning a two tuple of (bytes, content_type)
"""
- if not data:
- return ('', None)
+ if data is None:
+ return ('', content_type)
assert format is None or content_type is None, (
'You may not set both `format` and `content_type`.'
@@ -48,9 +50,10 @@ class APIRequestFactory(DjangoRequestFactory):
else:
format = format or self.default_format
- assert format in self.renderer_classes, ("Invalid format '{0}'. "
- "Available formats are {1}. Set TEST_REQUEST_RENDERER_CLASSES "
- "to enable extra request formats.".format(
+ assert format in self.renderer_classes, (
+ "Invalid format '{0}'. Available formats are {1}. "
+ "Set TEST_REQUEST_RENDERER_CLASSES to enable "
+ "extra request formats.".format(
format,
', '.join(["'" + fmt + "'" for fmt in self.renderer_classes.keys()])
)
@@ -71,6 +74,17 @@ class APIRequestFactory(DjangoRequestFactory):
return ret, content_type
+ def get(self, path, data=None, **extra):
+ r = {
+ 'QUERY_STRING': urlencode(data or {}, doseq=True),
+ }
+ # Fix to support old behavior where you have the arguments in the url
+ # See #1461
+ if not data and '?' in path:
+ r['QUERY_STRING'] = path.split('?')[1]
+ r.update(extra)
+ return self.generic('GET', path, **r)
+
def post(self, path, data=None, format=None, content_type=None, **extra):
data, content_type = self._encode_data(data, format, content_type)
return self.generic('POST', path, data, content_type, **extra)
@@ -134,12 +148,70 @@ class APIClient(APIRequestFactory, DjangoClient):
"""
self.handler._force_user = user
self.handler._force_token = token
+ if user is None:
+ self.logout() # Also clear any possible session info if required
def request(self, **kwargs):
# Ensure that any credentials set get added to every request.
kwargs.update(self._credentials)
return super(APIClient, self).request(**kwargs)
+ def get(self, path, data=None, follow=False, **extra):
+ response = super(APIClient, self).get(path, data=data, **extra)
+ if follow:
+ response = self._handle_redirects(response, **extra)
+ return response
+
+ def post(self, path, data=None, format=None, content_type=None,
+ follow=False, **extra):
+ response = super(APIClient, self).post(
+ path, data=data, format=format, content_type=content_type, **extra)
+ if follow:
+ response = self._handle_redirects(response, **extra)
+ return response
+
+ def put(self, path, data=None, format=None, content_type=None,
+ follow=False, **extra):
+ response = super(APIClient, self).put(
+ path, data=data, format=format, content_type=content_type, **extra)
+ if follow:
+ response = self._handle_redirects(response, **extra)
+ return response
+
+ def patch(self, path, data=None, format=None, content_type=None,
+ follow=False, **extra):
+ response = super(APIClient, self).patch(
+ path, data=data, format=format, content_type=content_type, **extra)
+ if follow:
+ response = self._handle_redirects(response, **extra)
+ return response
+
+ def delete(self, path, data=None, format=None, content_type=None,
+ follow=False, **extra):
+ response = super(APIClient, self).delete(
+ path, data=data, format=format, content_type=content_type, **extra)
+ if follow:
+ response = self._handle_redirects(response, **extra)
+ return response
+
+ def options(self, path, data=None, format=None, content_type=None,
+ follow=False, **extra):
+ response = super(APIClient, self).options(
+ path, data=data, format=format, content_type=content_type, **extra)
+ if follow:
+ response = self._handle_redirects(response, **extra)
+ return response
+
+ def logout(self):
+ self._credentials = {}
+
+ # Also clear any `force_authenticate`
+ self.handler._force_user = None
+ self.handler._force_token = None
+
+ if self.session:
+ super(APIClient, self).logout()
+
class APITransactionTestCase(testcases.TransactionTestCase):
client_class = APIClient