diff options
| -rw-r--r-- | rest_framework/test.py | 46 | ||||
| -rw-r--r-- | tests/test_testing.py | 47 | 
2 files changed, 93 insertions, 0 deletions
| diff --git a/rest_framework/test.py b/rest_framework/test.py index 9b40353a..74d2c868 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -156,6 +156,52 @@ class APIClient(APIRequestFactory, DjangoClient):          kwargs.update(self._credentials)          return super(APIClient, self).request(**kwargs) +    def get(self, path, data=None, follow=False, **extra): +        response = super(APIClient, self).get(path, data=data, **extra) +        if follow: +            response = self._handle_redirects(response, **extra) +        return response + +    def post(self, path, data=None, format=None, content_type=None, +             follow=False, **extra): +        response = super(APIClient, self).post( +            path, data=data, format=format, content_type=content_type, **extra) +        if follow: +            response = self._handle_redirects(response, **extra) +        return response + +    def put(self, path, data=None, format=None, content_type=None, +            follow=False, **extra): +        response = super(APIClient, self).put( +            path, data=data, format=format, content_type=content_type, **extra) +        if follow: +            response = self._handle_redirects(response, **extra) +        return response + +    def patch(self, path, data=None, format=None, content_type=None, +              follow=False, **extra): +        response = super(APIClient, self).patch( +            path, data=data, format=format, content_type=content_type, **extra) +        if follow: +            response = self._handle_redirects(response, **extra) +        return response + +    def delete(self, path, data=None, format=None, content_type=None, +               follow=False, **extra): +        response = super(APIClient, self).delete( +            path, data=data, format=format, content_type=content_type, **extra) +        if follow: +            response = self._handle_redirects(response, **extra) +        return response + +    def options(self, path, data=None, format=None, content_type=None, +                follow=False, **extra): +        response = super(APIClient, self).options( +            path, data=data, format=format, content_type=content_type, **extra) +        if follow: +            response = self._handle_redirects(response, **extra) +        return response +      def logout(self):          self._credentials = {}          return super(APIClient, self).logout() diff --git a/tests/test_testing.py b/tests/test_testing.py index 9c472026..9fd5966e 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -5,6 +5,7 @@ from django.conf.urls import patterns, url  from io import BytesIO  from django.contrib.auth.models import User +from django.shortcuts import redirect  from django.test import TestCase  from rest_framework.decorators import api_view  from rest_framework.response import Response @@ -28,10 +29,16 @@ def session_view(request):      }) +@api_view(['GET', 'POST', 'PUT', 'PATCH', 'DELETE', 'OPTIONS']) +def redirect_view(request): +    return redirect('/view/') + +  urlpatterns = patterns(      '',      url(r'^view/$', view),      url(r'^session-view/$', session_view), +    url(r'^redirect-view/$', redirect_view),  ) @@ -111,6 +118,46 @@ class TestAPITestClient(TestCase):          response = self.client.get('/view/')          self.assertEqual(response.data['auth'], b'') +    def test_follow_redirect(self): +        """ +        Follow redirect by setting follow argument. +        """ +        response = self.client.get('/redirect-view/') +        self.assertEqual(response.status_code, 302) +        response = self.client.get('/redirect-view/', follow=True) +        self.assertIsNotNone(response.redirect_chain) +        self.assertEqual(response.status_code, 200) + +        response = self.client.post('/redirect-view/') +        self.assertEqual(response.status_code, 302) +        response = self.client.post('/redirect-view/', follow=True) +        self.assertIsNotNone(response.redirect_chain) +        self.assertEqual(response.status_code, 200) + +        response = self.client.put('/redirect-view/') +        self.assertEqual(response.status_code, 302) +        response = self.client.put('/redirect-view/', follow=True) +        self.assertIsNotNone(response.redirect_chain) +        self.assertEqual(response.status_code, 200) + +        response = self.client.patch('/redirect-view/') +        self.assertEqual(response.status_code, 302) +        response = self.client.patch('/redirect-view/', follow=True) +        self.assertIsNotNone(response.redirect_chain) +        self.assertEqual(response.status_code, 200) + +        response = self.client.delete('/redirect-view/') +        self.assertEqual(response.status_code, 302) +        response = self.client.delete('/redirect-view/', follow=True) +        self.assertIsNotNone(response.redirect_chain) +        self.assertEqual(response.status_code, 200) + +        response = self.client.options('/redirect-view/') +        self.assertEqual(response.status_code, 302) +        response = self.client.options('/redirect-view/', follow=True) +        self.assertIsNotNone(response.redirect_chain) +        self.assertEqual(response.status_code, 200) +  class TestAPIRequestFactory(TestCase):      def test_csrf_exempt_by_default(self): | 
