diff options
| author | Tom Christie | 2013-06-28 17:50:30 +0100 | 
|---|---|---|
| committer | Tom Christie | 2013-06-28 17:50:30 +0100 | 
| commit | f585480ee10f4b5e61db4ac343b1d2af25d2de97 (patch) | |
| tree | 01f3e7e1b77ab42fe09334e9629840c518de2773 /rest_framework/test.py | |
| parent | 7224b20d58ceee22abc987980ab646ab8cb2d8dc (diff) | |
| download | django-rest-framework-f585480ee10f4b5e61db4ac343b1d2af25d2de97.tar.bz2 | |
Added APIClient
Diffstat (limited to 'rest_framework/test.py')
| -rw-r--r-- | rest_framework/test.py | 81 | 
1 files changed, 68 insertions, 13 deletions
| diff --git a/rest_framework/test.py b/rest_framework/test.py index 92281caf..9fce2c08 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -1,39 +1,54 @@ -from rest_framework.compat import six, RequestFactory +# Note that we use `DjangoRequestFactory` and `DjangoClient` names in order +# to make it harder for the user to import the wrong thing without realizing. +from django.conf import settings +from django.test.client import Client as DjangoClient +from rest_framework.compat import RequestFactory as DjangoRequestFactory +from rest_framework.compat import force_bytes_or_smart_bytes, six  from rest_framework.renderers import JSONRenderer, MultiPartRenderer -class APIRequestFactory(RequestFactory): +class APIRequestFactory(DjangoRequestFactory):      renderer_classes = {          'json': JSONRenderer,          'form': MultiPartRenderer      }      default_format = 'form' -    def __init__(self, format=None, **defaults): -        self.format = format or self.default_format -        super(APIRequestFactory, self).__init__(**defaults) +    def _encode_data(self, data, format=None, content_type=None): +        """ +        Encode the data returning a two tuple of (bytes, content_type) +        """ -    def _encode_data(self, data, format, content_type):          if not data:              return ('', None) -        format = format or self.format +        assert format is None or content_type is None, ( +            'You may not set both `format` and `content_type`.' +        ) -        if content_type is None and data is not None: +        if content_type: +            # Content type specified explicitly, treat data as a raw bytestring +            ret = force_bytes_or_smart_bytes(data, settings.DEFAULT_CHARSET) + +        else: +            # Use format and render the data into a bytestring +            format = format or self.default_format              renderer = self.renderer_classes[format]() -            data = renderer.render(data) -            # Determine the content-type header +            ret = renderer.render(data) + +            # Determine the content-type header from the renderer              if ';' in renderer.media_type:                  content_type = renderer.media_type              else:                  content_type = "{0}; charset={1}".format(                      renderer.media_type, renderer.charset                  ) +              # Coerce text to bytes if required. -            if isinstance(data, six.text_type): -                data = bytes(data.encode(renderer.charset)) +            if isinstance(ret, six.text_type): +                ret = bytes(ret.encode(renderer.charset)) -        return data, content_type +        return ret, content_type      def post(self, path, data=None, format=None, content_type=None, **extra):          data, content_type = self._encode_data(data, format, content_type) @@ -46,3 +61,43 @@ class APIRequestFactory(RequestFactory):      def patch(self, path, data=None, format=None, content_type=None, **extra):          data, content_type = self._encode_data(data, format, content_type)          return self.generic('PATCH', path, data, content_type, **extra) + +    def delete(self, path, data=None, format=None, content_type=None, **extra): +        data, content_type = self._encode_data(data, format, content_type) +        return self.generic('DELETE', path, data, content_type, **extra) + +    def options(self, path, data=None, format=None, content_type=None, **extra): +        data, content_type = self._encode_data(data, format, content_type) +        return self.generic('OPTIONS', path, data, content_type, **extra) + + +class APIClient(APIRequestFactory, DjangoClient): +    def post(self, path, data=None, format=None, content_type=None, follow=False, **extra): +        response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) +        if follow: +            response = self._handle_redirects(response, **extra) +        return response + +    def put(self, path, data=None, format=None, content_type=None, follow=False, **extra): +        response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) +        if follow: +            response = self._handle_redirects(response, **extra) +        return response + +    def patch(self, path, data=None, format=None, content_type=None, follow=False, **extra): +        response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) +        if follow: +            response = self._handle_redirects(response, **extra) +        return response + +    def delete(self, path, data=None, format=None, content_type=None, follow=False, **extra): +        response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) +        if follow: +            response = self._handle_redirects(response, **extra) +        return response + +    def options(self, path, data=None, format=None, content_type=None, follow=False, **extra): +        response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) +        if follow: +            response = self._handle_redirects(response, **extra) +        return response | 
