diff options
Diffstat (limited to 'rest_framework/test.py')
| -rw-r--r-- | rest_framework/test.py | 84 | 
1 files changed, 78 insertions, 6 deletions
| diff --git a/rest_framework/test.py b/rest_framework/test.py index a18f5a29..a83d082a 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -8,9 +8,11 @@ from django.conf import settings  from django.test.client import Client as DjangoClient  from django.test.client import ClientHandler  from django.test import testcases +from django.utils import six +from django.utils.http import urlencode  from rest_framework.settings import api_settings  from rest_framework.compat import RequestFactory as DjangoRequestFactory -from rest_framework.compat import force_bytes_or_smart_bytes, six +from rest_framework.compat import force_bytes_or_smart_bytes  def force_authenticate(request, user=None, token=None): @@ -34,8 +36,8 @@ class APIRequestFactory(DjangoRequestFactory):          Encode the data returning a two tuple of (bytes, content_type)          """ -        if not data: -            return ('', None) +        if data is None: +            return ('', content_type)          assert format is None or content_type is None, (              'You may not set both `format` and `content_type`.' @@ -48,9 +50,10 @@ class APIRequestFactory(DjangoRequestFactory):          else:              format = format or self.default_format -            assert format in self.renderer_classes, ("Invalid format '{0}'. " -                "Available formats are {1}.  Set TEST_REQUEST_RENDERER_CLASSES " -                "to enable extra request formats.".format( +            assert format in self.renderer_classes, ( +                "Invalid format '{0}'. Available formats are {1}. " +                "Set TEST_REQUEST_RENDERER_CLASSES to enable " +                "extra request formats.".format(                      format,                      ', '.join(["'" + fmt + "'" for fmt in self.renderer_classes.keys()])                  ) @@ -71,6 +74,17 @@ class APIRequestFactory(DjangoRequestFactory):          return ret, content_type +    def get(self, path, data=None, **extra): +        r = { +            'QUERY_STRING': urlencode(data or {}, doseq=True), +        } +        # Fix to support old behavior where you have the arguments in the url +        # See #1461 +        if not data and '?' in path: +            r['QUERY_STRING'] = path.split('?')[1] +        r.update(extra) +        return self.generic('GET', path, **r) +      def post(self, path, data=None, format=None, content_type=None, **extra):          data, content_type = self._encode_data(data, format, content_type)          return self.generic('POST', path, data, content_type, **extra) @@ -134,12 +148,70 @@ class APIClient(APIRequestFactory, DjangoClient):          """          self.handler._force_user = user          self.handler._force_token = token +        if user is None: +            self.logout()  # Also clear any possible session info if required      def request(self, **kwargs):          # Ensure that any credentials set get added to every request.          kwargs.update(self._credentials)          return super(APIClient, self).request(**kwargs) +    def get(self, path, data=None, follow=False, **extra): +        response = super(APIClient, self).get(path, data=data, **extra) +        if follow: +            response = self._handle_redirects(response, **extra) +        return response + +    def post(self, path, data=None, format=None, content_type=None, +             follow=False, **extra): +        response = super(APIClient, self).post( +            path, data=data, format=format, content_type=content_type, **extra) +        if follow: +            response = self._handle_redirects(response, **extra) +        return response + +    def put(self, path, data=None, format=None, content_type=None, +            follow=False, **extra): +        response = super(APIClient, self).put( +            path, data=data, format=format, content_type=content_type, **extra) +        if follow: +            response = self._handle_redirects(response, **extra) +        return response + +    def patch(self, path, data=None, format=None, content_type=None, +              follow=False, **extra): +        response = super(APIClient, self).patch( +            path, data=data, format=format, content_type=content_type, **extra) +        if follow: +            response = self._handle_redirects(response, **extra) +        return response + +    def delete(self, path, data=None, format=None, content_type=None, +               follow=False, **extra): +        response = super(APIClient, self).delete( +            path, data=data, format=format, content_type=content_type, **extra) +        if follow: +            response = self._handle_redirects(response, **extra) +        return response + +    def options(self, path, data=None, format=None, content_type=None, +                follow=False, **extra): +        response = super(APIClient, self).options( +            path, data=data, format=format, content_type=content_type, **extra) +        if follow: +            response = self._handle_redirects(response, **extra) +        return response + +    def logout(self): +        self._credentials = {} + +        # Also clear any `force_authenticate` +        self.handler._force_user = None +        self.handler._force_token = None + +        if self.session: +            super(APIClient, self).logout() +  class APITransactionTestCase(testcases.TransactionTestCase):      client_class = APIClient | 
