From 7224b20d58ceee22abc987980ab646ab8cb2d8dc Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 28 Jun 2013 17:17:39 +0100 Subject: Added APIRequestFactory --- rest_framework/test.py | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 rest_framework/test.py (limited to 'rest_framework/test.py') diff --git a/rest_framework/test.py b/rest_framework/test.py new file mode 100644 index 00000000..92281caf --- /dev/null +++ b/rest_framework/test.py @@ -0,0 +1,48 @@ +from rest_framework.compat import six, RequestFactory +from rest_framework.renderers import JSONRenderer, MultiPartRenderer + + +class APIRequestFactory(RequestFactory): + renderer_classes = { + 'json': JSONRenderer, + 'form': MultiPartRenderer + } + default_format = 'form' + + def __init__(self, format=None, **defaults): + self.format = format or self.default_format + super(APIRequestFactory, self).__init__(**defaults) + + def _encode_data(self, data, format, content_type): + if not data: + return ('', None) + + format = format or self.format + + if content_type is None and data is not None: + renderer = self.renderer_classes[format]() + data = renderer.render(data) + # Determine the content-type header + if ';' in renderer.media_type: + content_type = renderer.media_type + else: + content_type = "{0}; charset={1}".format( + renderer.media_type, renderer.charset + ) + # Coerce text to bytes if required. + if isinstance(data, six.text_type): + data = bytes(data.encode(renderer.charset)) + + return data, content_type + + def post(self, path, data=None, format=None, content_type=None, **extra): + data, content_type = self._encode_data(data, format, content_type) + return self.generic('POST', path, data, content_type, **extra) + + def put(self, path, data=None, format=None, content_type=None, **extra): + data, content_type = self._encode_data(data, format, content_type) + return self.generic('PUT', path, data, content_type, **extra) + + def patch(self, path, data=None, format=None, content_type=None, **extra): + data, content_type = self._encode_data(data, format, content_type) + return self.generic('PATCH', path, data, content_type, **extra) -- cgit v1.2.3