diff options
| author | Tom Christie | 2014-10-31 16:05:17 +0000 |
|---|---|---|
| committer | Tom Christie | 2014-10-31 16:05:17 +0000 |
| commit | 5e1ed0aa9578be360261d5ba8b89aec959e948c8 (patch) | |
| tree | d99facdd71a151db7c8ecb9ba679b966bd823440 | |
| parent | 0b864acd98e92425ebc148c9867b9ef0ea18a824 (diff) | |
| parent | 2dfe75c23a041493bc83514d8e9e9268b79072d9 (diff) | |
| download | django-rest-framework-5e1ed0aa9578be360261d5ba8b89aec959e948c8.tar.bz2 | |
Merge pull request #1922 from JonesChi/fix_follow
Fix follow does not work on get of APIRequestFactory
| -rw-r--r-- | rest_framework/test.py | 46 | ||||
| -rw-r--r-- | tests/test_testing.py | 47 |
2 files changed, 93 insertions, 0 deletions
diff --git a/rest_framework/test.py b/rest_framework/test.py index 9b40353a..74d2c868 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -156,6 +156,52 @@ class APIClient(APIRequestFactory, DjangoClient): kwargs.update(self._credentials) return super(APIClient, self).request(**kwargs) + def get(self, path, data=None, follow=False, **extra): + response = super(APIClient, self).get(path, data=data, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response + + def post(self, path, data=None, format=None, content_type=None, + follow=False, **extra): + response = super(APIClient, self).post( + path, data=data, format=format, content_type=content_type, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response + + def put(self, path, data=None, format=None, content_type=None, + follow=False, **extra): + response = super(APIClient, self).put( + path, data=data, format=format, content_type=content_type, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response + + def patch(self, path, data=None, format=None, content_type=None, + follow=False, **extra): + response = super(APIClient, self).patch( + path, data=data, format=format, content_type=content_type, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response + + def delete(self, path, data=None, format=None, content_type=None, + follow=False, **extra): + response = super(APIClient, self).delete( + path, data=data, format=format, content_type=content_type, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response + + def options(self, path, data=None, format=None, content_type=None, + follow=False, **extra): + response = super(APIClient, self).options( + path, data=data, format=format, content_type=content_type, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response + def logout(self): self._credentials = {} return super(APIClient, self).logout() diff --git a/tests/test_testing.py b/tests/test_testing.py index 9c472026..9fd5966e 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -5,6 +5,7 @@ from django.conf.urls import patterns, url from io import BytesIO from django.contrib.auth.models import User +from django.shortcuts import redirect from django.test import TestCase from rest_framework.decorators import api_view from rest_framework.response import Response @@ -28,10 +29,16 @@ def session_view(request): }) +@api_view(['GET', 'POST', 'PUT', 'PATCH', 'DELETE', 'OPTIONS']) +def redirect_view(request): + return redirect('/view/') + + urlpatterns = patterns( '', url(r'^view/$', view), url(r'^session-view/$', session_view), + url(r'^redirect-view/$', redirect_view), ) @@ -111,6 +118,46 @@ class TestAPITestClient(TestCase): response = self.client.get('/view/') self.assertEqual(response.data['auth'], b'') + def test_follow_redirect(self): + """ + Follow redirect by setting follow argument. + """ + response = self.client.get('/redirect-view/') + self.assertEqual(response.status_code, 302) + response = self.client.get('/redirect-view/', follow=True) + self.assertIsNotNone(response.redirect_chain) + self.assertEqual(response.status_code, 200) + + response = self.client.post('/redirect-view/') + self.assertEqual(response.status_code, 302) + response = self.client.post('/redirect-view/', follow=True) + self.assertIsNotNone(response.redirect_chain) + self.assertEqual(response.status_code, 200) + + response = self.client.put('/redirect-view/') + self.assertEqual(response.status_code, 302) + response = self.client.put('/redirect-view/', follow=True) + self.assertIsNotNone(response.redirect_chain) + self.assertEqual(response.status_code, 200) + + response = self.client.patch('/redirect-view/') + self.assertEqual(response.status_code, 302) + response = self.client.patch('/redirect-view/', follow=True) + self.assertIsNotNone(response.redirect_chain) + self.assertEqual(response.status_code, 200) + + response = self.client.delete('/redirect-view/') + self.assertEqual(response.status_code, 302) + response = self.client.delete('/redirect-view/', follow=True) + self.assertIsNotNone(response.redirect_chain) + self.assertEqual(response.status_code, 200) + + response = self.client.options('/redirect-view/') + self.assertEqual(response.status_code, 302) + response = self.client.options('/redirect-view/', follow=True) + self.assertIsNotNone(response.redirect_chain) + self.assertEqual(response.status_code, 200) + class TestAPIRequestFactory(TestCase): def test_csrf_exempt_by_default(self): |
