From f585480ee10f4b5e61db4ac343b1d2af25d2de97 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 28 Jun 2013 17:50:30 +0100 Subject: Added APIClient --- rest_framework/test.py | 81 ++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 68 insertions(+), 13 deletions(-) (limited to 'rest_framework/test.py') diff --git a/rest_framework/test.py b/rest_framework/test.py index 92281caf..9fce2c08 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -1,39 +1,54 @@ -from rest_framework.compat import six, RequestFactory +# Note that we use `DjangoRequestFactory` and `DjangoClient` names in order +# to make it harder for the user to import the wrong thing without realizing. +from django.conf import settings +from django.test.client import Client as DjangoClient +from rest_framework.compat import RequestFactory as DjangoRequestFactory +from rest_framework.compat import force_bytes_or_smart_bytes, six from rest_framework.renderers import JSONRenderer, MultiPartRenderer -class APIRequestFactory(RequestFactory): +class APIRequestFactory(DjangoRequestFactory): renderer_classes = { 'json': JSONRenderer, 'form': MultiPartRenderer } default_format = 'form' - def __init__(self, format=None, **defaults): - self.format = format or self.default_format - super(APIRequestFactory, self).__init__(**defaults) + def _encode_data(self, data, format=None, content_type=None): + """ + Encode the data returning a two tuple of (bytes, content_type) + """ - def _encode_data(self, data, format, content_type): if not data: return ('', None) - format = format or self.format + assert format is None or content_type is None, ( + 'You may not set both `format` and `content_type`.' + ) - if content_type is None and data is not None: + if content_type: + # Content type specified explicitly, treat data as a raw bytestring + ret = force_bytes_or_smart_bytes(data, settings.DEFAULT_CHARSET) + + else: + # Use format and render the data into a bytestring + format = format or self.default_format renderer = self.renderer_classes[format]() - data = renderer.render(data) - # Determine the content-type header + ret = renderer.render(data) + + # Determine the content-type header from the renderer if ';' in renderer.media_type: content_type = renderer.media_type else: content_type = "{0}; charset={1}".format( renderer.media_type, renderer.charset ) + # Coerce text to bytes if required. - if isinstance(data, six.text_type): - data = bytes(data.encode(renderer.charset)) + if isinstance(ret, six.text_type): + ret = bytes(ret.encode(renderer.charset)) - return data, content_type + return ret, content_type def post(self, path, data=None, format=None, content_type=None, **extra): data, content_type = self._encode_data(data, format, content_type) @@ -46,3 +61,43 @@ class APIRequestFactory(RequestFactory): def patch(self, path, data=None, format=None, content_type=None, **extra): data, content_type = self._encode_data(data, format, content_type) return self.generic('PATCH', path, data, content_type, **extra) + + def delete(self, path, data=None, format=None, content_type=None, **extra): + data, content_type = self._encode_data(data, format, content_type) + return self.generic('DELETE', path, data, content_type, **extra) + + def options(self, path, data=None, format=None, content_type=None, **extra): + data, content_type = self._encode_data(data, format, content_type) + return self.generic('OPTIONS', path, data, content_type, **extra) + + +class APIClient(APIRequestFactory, DjangoClient): + def post(self, path, data=None, format=None, content_type=None, follow=False, **extra): + response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response + + def put(self, path, data=None, format=None, content_type=None, follow=False, **extra): + response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response + + def patch(self, path, data=None, format=None, content_type=None, follow=False, **extra): + response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response + + def delete(self, path, data=None, format=None, content_type=None, follow=False, **extra): + response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response + + def options(self, path, data=None, format=None, content_type=None, follow=False, **extra): + response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response -- cgit v1.2.3