diff options
Diffstat (limited to 'rest_framework/tests')
| -rw-r--r-- | rest_framework/tests/test_authentication.py | 46 | ||||
| -rw-r--r-- | rest_framework/tests/test_decorators.py | 11 | ||||
| -rw-r--r-- | rest_framework/tests/test_filters.py | 4 | ||||
| -rw-r--r-- | rest_framework/tests/test_generics.py | 60 | ||||
| -rw-r--r-- | rest_framework/tests/test_hyperlinkedserializers.py | 11 | ||||
| -rw-r--r-- | rest_framework/tests/test_negotiation.py | 4 | ||||
| -rw-r--r-- | rest_framework/tests/test_pagination.py | 6 | ||||
| -rw-r--r-- | rest_framework/tests/test_permissions.py | 35 | ||||
| -rw-r--r-- | rest_framework/tests/test_relations_hyperlink.py | 4 | ||||
| -rw-r--r-- | rest_framework/tests/test_renderers.py | 8 | ||||
| -rw-r--r-- | rest_framework/tests/test_request.py | 19 | ||||
| -rw-r--r-- | rest_framework/tests/test_reverse.py | 4 | ||||
| -rw-r--r-- | rest_framework/tests/test_routers.py | 5 | ||||
| -rw-r--r-- | rest_framework/tests/test_testing.py | 115 | ||||
| -rw-r--r-- | rest_framework/tests/test_throttling.py | 6 | ||||
| -rw-r--r-- | rest_framework/tests/test_urlpatterns.py | 4 | ||||
| -rw-r--r-- | rest_framework/tests/test_validation.py | 8 | ||||
| -rw-r--r-- | rest_framework/tests/test_views.py | 6 | ||||
| -rw-r--r-- | rest_framework/tests/utils.py | 40 |
19 files changed, 220 insertions, 176 deletions
diff --git a/rest_framework/tests/test_authentication.py b/rest_framework/tests/test_authentication.py index 6a50be06..a44813b6 100644 --- a/rest_framework/tests/test_authentication.py +++ b/rest_framework/tests/test_authentication.py @@ -1,7 +1,7 @@ from __future__ import unicode_literals from django.contrib.auth.models import User from django.http import HttpResponse -from django.test import Client, TestCase +from django.test import TestCase from django.utils import unittest from rest_framework import HTTP_HEADER_ENCODING from rest_framework import exceptions @@ -21,14 +21,13 @@ from rest_framework.authtoken.models import Token from rest_framework.compat import patterns, url, include from rest_framework.compat import oauth2_provider, oauth2_provider_models, oauth2_provider_scope from rest_framework.compat import oauth, oauth_provider -from rest_framework.tests.utils import RequestFactory +from rest_framework.test import APIRequestFactory, APIClient from rest_framework.views import APIView -import json import base64 import time import datetime -factory = RequestFactory() +factory = APIRequestFactory() class MockView(APIView): @@ -68,7 +67,7 @@ class BasicAuthTests(TestCase): urls = 'rest_framework.tests.test_authentication' def setUp(self): - self.csrf_client = Client(enforce_csrf_checks=True) + self.csrf_client = APIClient(enforce_csrf_checks=True) self.username = 'john' self.email = 'lennon@thebeatles.com' self.password = 'password' @@ -87,7 +86,7 @@ class BasicAuthTests(TestCase): credentials = ('%s:%s' % (self.username, self.password)) base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING) auth = 'Basic %s' % base64_credentials - response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/basic/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, status.HTTP_200_OK) def test_post_form_failing_basic_auth(self): @@ -97,7 +96,7 @@ class BasicAuthTests(TestCase): def test_post_json_failing_basic_auth(self): """Ensure POSTing json over basic auth without correct credentials fails""" - response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json') + response = self.csrf_client.post('/basic/', {'example': 'example'}, format='json') self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) self.assertEqual(response['WWW-Authenticate'], 'Basic realm="api"') @@ -107,8 +106,8 @@ class SessionAuthTests(TestCase): urls = 'rest_framework.tests.test_authentication' def setUp(self): - self.csrf_client = Client(enforce_csrf_checks=True) - self.non_csrf_client = Client(enforce_csrf_checks=False) + self.csrf_client = APIClient(enforce_csrf_checks=True) + self.non_csrf_client = APIClient(enforce_csrf_checks=False) self.username = 'john' self.email = 'lennon@thebeatles.com' self.password = 'password' @@ -154,7 +153,7 @@ class TokenAuthTests(TestCase): urls = 'rest_framework.tests.test_authentication' def setUp(self): - self.csrf_client = Client(enforce_csrf_checks=True) + self.csrf_client = APIClient(enforce_csrf_checks=True) self.username = 'john' self.email = 'lennon@thebeatles.com' self.password = 'password' @@ -172,7 +171,7 @@ class TokenAuthTests(TestCase): def test_post_json_passing_token_auth(self): """Ensure POSTing form over token auth with correct credentials passes and does not require CSRF""" auth = "Token " + self.key - response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/token/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, status.HTTP_200_OK) def test_post_form_failing_token_auth(self): @@ -182,7 +181,7 @@ class TokenAuthTests(TestCase): def test_post_json_failing_token_auth(self): """Ensure POSTing json over token auth without correct credentials fails""" - response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json') + response = self.csrf_client.post('/token/', {'example': 'example'}, format='json') self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) def test_token_has_auto_assigned_key_if_none_provided(self): @@ -193,33 +192,33 @@ class TokenAuthTests(TestCase): def test_token_login_json(self): """Ensure token login view using JSON POST works.""" - client = Client(enforce_csrf_checks=True) + client = APIClient(enforce_csrf_checks=True) response = client.post('/auth-token/', - json.dumps({'username': self.username, 'password': self.password}), 'application/json') + {'username': self.username, 'password': self.password}, format='json') self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(json.loads(response.content.decode('ascii'))['token'], self.key) + self.assertEqual(response.data['token'], self.key) def test_token_login_json_bad_creds(self): """Ensure token login view using JSON POST fails if bad credentials are used.""" - client = Client(enforce_csrf_checks=True) + client = APIClient(enforce_csrf_checks=True) response = client.post('/auth-token/', - json.dumps({'username': self.username, 'password': "badpass"}), 'application/json') + {'username': self.username, 'password': "badpass"}, format='json') self.assertEqual(response.status_code, 400) def test_token_login_json_missing_fields(self): """Ensure token login view using JSON POST fails if missing fields.""" - client = Client(enforce_csrf_checks=True) + client = APIClient(enforce_csrf_checks=True) response = client.post('/auth-token/', - json.dumps({'username': self.username}), 'application/json') + {'username': self.username}, format='json') self.assertEqual(response.status_code, 400) def test_token_login_form(self): """Ensure token login view using form POST works.""" - client = Client(enforce_csrf_checks=True) + client = APIClient(enforce_csrf_checks=True) response = client.post('/auth-token/', {'username': self.username, 'password': self.password}) self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(json.loads(response.content.decode('ascii'))['token'], self.key) + self.assertEqual(response.data['token'], self.key) class IncorrectCredentialsTests(TestCase): @@ -256,7 +255,7 @@ class OAuthTests(TestCase): self.consts = consts - self.csrf_client = Client(enforce_csrf_checks=True) + self.csrf_client = APIClient(enforce_csrf_checks=True) self.username = 'john' self.email = 'lennon@thebeatles.com' self.password = 'password' @@ -470,12 +469,13 @@ class OAuthTests(TestCase): response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 401) + class OAuth2Tests(TestCase): """OAuth 2.0 authentication""" urls = 'rest_framework.tests.test_authentication' def setUp(self): - self.csrf_client = Client(enforce_csrf_checks=True) + self.csrf_client = APIClient(enforce_csrf_checks=True) self.username = 'john' self.email = 'lennon@thebeatles.com' self.password = 'password' diff --git a/rest_framework/tests/test_decorators.py b/rest_framework/tests/test_decorators.py index 1016fed3..195f0ba3 100644 --- a/rest_framework/tests/test_decorators.py +++ b/rest_framework/tests/test_decorators.py @@ -1,12 +1,13 @@ from __future__ import unicode_literals from django.test import TestCase from rest_framework import status +from rest_framework.authentication import BasicAuthentication +from rest_framework.parsers import JSONParser +from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response from rest_framework.renderers import JSONRenderer -from rest_framework.parsers import JSONParser -from rest_framework.authentication import BasicAuthentication +from rest_framework.test import APIRequestFactory from rest_framework.throttling import UserRateThrottle -from rest_framework.permissions import IsAuthenticated from rest_framework.views import APIView from rest_framework.decorators import ( api_view, @@ -17,13 +18,11 @@ from rest_framework.decorators import ( permission_classes, ) -from rest_framework.tests.utils import RequestFactory - class DecoratorTestCase(TestCase): def setUp(self): - self.factory = RequestFactory() + self.factory = APIRequestFactory() def _finalize_response(self, request, response, *args, **kwargs): response.request = request diff --git a/rest_framework/tests/test_filters.py b/rest_framework/tests/test_filters.py index aaed6247..c9d9e7ff 100644 --- a/rest_framework/tests/test_filters.py +++ b/rest_framework/tests/test_filters.py @@ -4,13 +4,13 @@ from decimal import Decimal from django.db import models from django.core.urlresolvers import reverse from django.test import TestCase -from django.test.client import RequestFactory from django.utils import unittest from rest_framework import generics, serializers, status, filters from rest_framework.compat import django_filters, patterns, url +from rest_framework.test import APIRequestFactory from rest_framework.tests.models import BasicModel -factory = RequestFactory() +factory = APIRequestFactory() class FilterableItem(models.Model): diff --git a/rest_framework/tests/test_generics.py b/rest_framework/tests/test_generics.py index 37734195..1550880b 100644 --- a/rest_framework/tests/test_generics.py +++ b/rest_framework/tests/test_generics.py @@ -3,12 +3,11 @@ from django.db import models from django.shortcuts import get_object_or_404 from django.test import TestCase from rest_framework import generics, renderers, serializers, status -from rest_framework.tests.utils import RequestFactory +from rest_framework.test import APIRequestFactory from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel from rest_framework.compat import six -import json -factory = RequestFactory() +factory = APIRequestFactory() class RootView(generics.ListCreateAPIView): @@ -71,9 +70,8 @@ class TestRootView(TestCase): """ POST requests to ListCreateAPIView should create a new object. """ - content = {'text': 'foobar'} - request = factory.post('/', json.dumps(content), - content_type='application/json') + data = {'text': 'foobar'} + request = factory.post('/', data, format='json') with self.assertNumQueries(1): response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_201_CREATED) @@ -85,9 +83,8 @@ class TestRootView(TestCase): """ PUT requests to ListCreateAPIView should not be allowed """ - content = {'text': 'foobar'} - request = factory.put('/', json.dumps(content), - content_type='application/json') + data = {'text': 'foobar'} + request = factory.put('/', data, format='json') with self.assertNumQueries(0): response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) @@ -148,9 +145,8 @@ class TestRootView(TestCase): """ POST requests to create a new object should not be able to set the id. """ - content = {'id': 999, 'text': 'foobar'} - request = factory.post('/', json.dumps(content), - content_type='application/json') + data = {'id': 999, 'text': 'foobar'} + request = factory.post('/', data, format='json') with self.assertNumQueries(1): response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_201_CREATED) @@ -189,9 +185,8 @@ class TestInstanceView(TestCase): """ POST requests to RetrieveUpdateDestroyAPIView should not be allowed """ - content = {'text': 'foobar'} - request = factory.post('/', json.dumps(content), - content_type='application/json') + data = {'text': 'foobar'} + request = factory.post('/', data, format='json') with self.assertNumQueries(0): response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) @@ -201,9 +196,8 @@ class TestInstanceView(TestCase): """ PUT requests to RetrieveUpdateDestroyAPIView should update an object. """ - content = {'text': 'foobar'} - request = factory.put('/1', json.dumps(content), - content_type='application/json') + data = {'text': 'foobar'} + request = factory.put('/1', data, format='json') with self.assertNumQueries(2): response = self.view(request, pk='1').render() self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -215,9 +209,8 @@ class TestInstanceView(TestCase): """ PATCH requests to RetrieveUpdateDestroyAPIView should update an object. """ - content = {'text': 'foobar'} - request = factory.patch('/1', json.dumps(content), - content_type='application/json') + data = {'text': 'foobar'} + request = factory.patch('/1', data, format='json') with self.assertNumQueries(2): response = self.view(request, pk=1).render() @@ -293,9 +286,8 @@ class TestInstanceView(TestCase): """ PUT requests to create a new object should not be able to set the id. """ - content = {'id': 999, 'text': 'foobar'} - request = factory.put('/1', json.dumps(content), - content_type='application/json') + data = {'id': 999, 'text': 'foobar'} + request = factory.put('/1', data, format='json') with self.assertNumQueries(2): response = self.view(request, pk=1).render() self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -309,9 +301,8 @@ class TestInstanceView(TestCase): if it does not currently exist. """ self.objects.get(id=1).delete() - content = {'text': 'foobar'} - request = factory.put('/1', json.dumps(content), - content_type='application/json') + data = {'text': 'foobar'} + request = factory.put('/1', data, format='json') with self.assertNumQueries(3): response = self.view(request, pk=1).render() self.assertEqual(response.status_code, status.HTTP_201_CREATED) @@ -324,10 +315,9 @@ class TestInstanceView(TestCase): PUT requests to RetrieveUpdateDestroyAPIView should create an object at the requested url if it doesn't exist. """ - content = {'text': 'foobar'} + data = {'text': 'foobar'} # pk fields can not be created on demand, only the database can set the pk for a new object - request = factory.put('/5', json.dumps(content), - content_type='application/json') + request = factory.put('/5', data, format='json') with self.assertNumQueries(3): response = self.view(request, pk=5).render() self.assertEqual(response.status_code, status.HTTP_201_CREATED) @@ -339,9 +329,8 @@ class TestInstanceView(TestCase): PUT requests to RetrieveUpdateDestroyAPIView should create an object at the requested url if possible, else return HTTP_403_FORBIDDEN error-response. """ - content = {'text': 'foobar'} - request = factory.put('/test_slug', json.dumps(content), - content_type='application/json') + data = {'text': 'foobar'} + request = factory.put('/test_slug', data, format='json') with self.assertNumQueries(2): response = self.slug_based_view(request, slug='test_slug').render() self.assertEqual(response.status_code, status.HTTP_201_CREATED) @@ -415,9 +404,8 @@ class TestCreateModelWithAutoNowAddField(TestCase): https://github.com/tomchristie/django-rest-framework/issues/285 """ - content = {'email': 'foobar@example.com', 'content': 'foobar'} - request = factory.post('/', json.dumps(content), - content_type='application/json') + data = {'email': 'foobar@example.com', 'content': 'foobar'} + request = factory.post('/', data, format='json') response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_201_CREATED) created = self.objects.get(id=1) diff --git a/rest_framework/tests/test_hyperlinkedserializers.py b/rest_framework/tests/test_hyperlinkedserializers.py index 129600cb..61e613d7 100644 --- a/rest_framework/tests/test_hyperlinkedserializers.py +++ b/rest_framework/tests/test_hyperlinkedserializers.py @@ -1,12 +1,15 @@ from __future__ import unicode_literals import json from django.test import TestCase -from django.test.client import RequestFactory from rest_framework import generics, status, serializers from rest_framework.compat import patterns, url -from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, Album, Photo, OptionalRelationModel +from rest_framework.test import APIRequestFactory +from rest_framework.tests.models import ( + Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, + Album, Photo, OptionalRelationModel +) -factory = RequestFactory() +factory = APIRequestFactory() class BlogPostCommentSerializer(serializers.ModelSerializer): @@ -21,7 +24,7 @@ class BlogPostCommentSerializer(serializers.ModelSerializer): class PhotoSerializer(serializers.Serializer): description = serializers.CharField() - album_url = serializers.HyperlinkedRelatedField(source='album', view_name='album-detail', queryset=Album.objects.all(), slug_field='title', slug_url_kwarg='title') + album_url = serializers.HyperlinkedRelatedField(source='album', view_name='album-detail', queryset=Album.objects.all(), lookup_field='title', slug_url_kwarg='title') def restore_object(self, attrs, instance=None): return Photo(**attrs) diff --git a/rest_framework/tests/test_negotiation.py b/rest_framework/tests/test_negotiation.py index 7f84827f..04b89eb6 100644 --- a/rest_framework/tests/test_negotiation.py +++ b/rest_framework/tests/test_negotiation.py @@ -1,12 +1,12 @@ from __future__ import unicode_literals from django.test import TestCase -from django.test.client import RequestFactory from rest_framework.negotiation import DefaultContentNegotiation from rest_framework.request import Request from rest_framework.renderers import BaseRenderer +from rest_framework.test import APIRequestFactory -factory = RequestFactory() +factory = APIRequestFactory() class MockJSONRenderer(BaseRenderer): diff --git a/rest_framework/tests/test_pagination.py b/rest_framework/tests/test_pagination.py index e538a78e..85d4640e 100644 --- a/rest_framework/tests/test_pagination.py +++ b/rest_framework/tests/test_pagination.py @@ -4,13 +4,13 @@ from decimal import Decimal from django.db import models from django.core.paginator import Paginator from django.test import TestCase -from django.test.client import RequestFactory from django.utils import unittest from rest_framework import generics, status, pagination, filters, serializers from rest_framework.compat import django_filters +from rest_framework.test import APIRequestFactory from rest_framework.tests.models import BasicModel -factory = RequestFactory() +factory = APIRequestFactory() class FilterableItem(models.Model): @@ -369,7 +369,7 @@ class TestCustomPaginationSerializer(TestCase): self.page = paginator.page(1) def test_custom_pagination_serializer(self): - request = RequestFactory().get('/foobar') + request = APIRequestFactory().get('/foobar') serializer = CustomPaginationSerializer( instance=self.page, context={'request': request} diff --git a/rest_framework/tests/test_permissions.py b/rest_framework/tests/test_permissions.py index 6caaf65b..e2cca380 100644 --- a/rest_framework/tests/test_permissions.py +++ b/rest_framework/tests/test_permissions.py @@ -3,11 +3,10 @@ from django.contrib.auth.models import User, Permission from django.db import models from django.test import TestCase from rest_framework import generics, status, permissions, authentication, HTTP_HEADER_ENCODING -from rest_framework.tests.utils import RequestFactory +from rest_framework.test import APIRequestFactory import base64 -import json -factory = RequestFactory() +factory = APIRequestFactory() class BasicModel(models.Model): @@ -56,15 +55,13 @@ class ModelPermissionsIntegrationTests(TestCase): BasicModel(text='foo').save() def test_has_create_permissions(self): - request = factory.post('/', json.dumps({'text': 'foobar'}), - content_type='application/json', + request = factory.post('/', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.permitted_credentials) response = root_view(request, pk=1) self.assertEqual(response.status_code, status.HTTP_201_CREATED) def test_has_put_permissions(self): - request = factory.put('/1', json.dumps({'text': 'foobar'}), - content_type='application/json', + request = factory.put('/1', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.permitted_credentials) response = instance_view(request, pk='1') self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -75,15 +72,13 @@ class ModelPermissionsIntegrationTests(TestCase): self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) def test_does_not_have_create_permissions(self): - request = factory.post('/', json.dumps({'text': 'foobar'}), - content_type='application/json', + request = factory.post('/', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.disallowed_credentials) response = root_view(request, pk=1) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) def test_does_not_have_put_permissions(self): - request = factory.put('/1', json.dumps({'text': 'foobar'}), - content_type='application/json', + request = factory.put('/1', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.disallowed_credentials) response = instance_view(request, pk='1') self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) @@ -95,28 +90,26 @@ class ModelPermissionsIntegrationTests(TestCase): def test_has_put_as_create_permissions(self): # User only has update permissions - should be able to update an entity. - request = factory.put('/1', json.dumps({'text': 'foobar'}), - content_type='application/json', + request = factory.put('/1', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.updateonly_credentials) response = instance_view(request, pk='1') self.assertEqual(response.status_code, status.HTTP_200_OK) # But if PUTing to a new entity, permission should be denied. - request = factory.put('/2', json.dumps({'text': 'foobar'}), - content_type='application/json', + request = factory.put('/2', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.updateonly_credentials) response = instance_view(request, pk='2') self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) def test_options_permitted(self): - request = factory.options('/', content_type='application/json', + request = factory.options('/', HTTP_AUTHORIZATION=self.permitted_credentials) response = root_view(request, pk='1') self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertIn('actions', response.data) self.assertEqual(list(response.data['actions'].keys()), ['POST']) - request = factory.options('/1', content_type='application/json', + request = factory.options('/1', HTTP_AUTHORIZATION=self.permitted_credentials) response = instance_view(request, pk='1') self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -124,26 +117,26 @@ class ModelPermissionsIntegrationTests(TestCase): self.assertEqual(list(response.data['actions'].keys()), ['PUT']) def test_options_disallowed(self): - request = factory.options('/', content_type='application/json', + request = factory.options('/', HTTP_AUTHORIZATION=self.disallowed_credentials) response = root_view(request, pk='1') self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertNotIn('actions', response.data) - request = factory.options('/1', content_type='application/json', + request = factory.options('/1', HTTP_AUTHORIZATION=self.disallowed_credentials) response = instance_view(request, pk='1') self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertNotIn('actions', response.data) def test_options_updateonly(self): - request = factory.options('/', content_type='application/json', + request = factory.options('/', HTTP_AUTHORIZATION=self.updateonly_credentials) response = root_view(request, pk='1') self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertNotIn('actions', response.data) - request = factory.options('/1', content_type='application/json', + request = factory.options('/1', HTTP_AUTHORIZATION=self.updateonly_credentials) response = instance_view(request, pk='1') self.assertEqual(response.status_code, status.HTTP_200_OK) diff --git a/rest_framework/tests/test_relations_hyperlink.py b/rest_framework/tests/test_relations_hyperlink.py index 2ca7f4f2..3c4d39af 100644 --- a/rest_framework/tests/test_relations_hyperlink.py +++ b/rest_framework/tests/test_relations_hyperlink.py @@ -1,15 +1,15 @@ from __future__ import unicode_literals from django.test import TestCase -from django.test.client import RequestFactory from rest_framework import serializers from rest_framework.compat import patterns, url +from rest_framework.test import APIRequestFactory from rest_framework.tests.models import ( BlogPost, ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource ) -factory = RequestFactory() +factory = APIRequestFactory() request = factory.get('/') # Just to ensure we have a request in the serializer context diff --git a/rest_framework/tests/test_renderers.py b/rest_framework/tests/test_renderers.py index 95b59741..df6f4aa6 100644 --- a/rest_framework/tests/test_renderers.py +++ b/rest_framework/tests/test_renderers.py @@ -4,19 +4,17 @@ from __future__ import unicode_literals from decimal import Decimal from django.core.cache import cache from django.test import TestCase -from django.test.client import RequestFactory from django.utils import unittest from django.utils.translation import ugettext_lazy as _ from rest_framework import status, permissions -from rest_framework.compat import yaml, etree, patterns, url, include +from rest_framework.compat import yaml, etree, patterns, url, include, six, StringIO from rest_framework.response import Response from rest_framework.views import APIView from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \ XMLRenderer, JSONPRenderer, BrowsableAPIRenderer, UnicodeJSONRenderer from rest_framework.parsers import YAMLParser, XMLParser from rest_framework.settings import api_settings -from rest_framework.compat import StringIO -from rest_framework.compat import six +from rest_framework.test import APIRequestFactory import datetime import pickle import re @@ -121,7 +119,7 @@ class POSTDeniedView(APIView): class DocumentingRendererTests(TestCase): def test_only_permitted_forms_are_displayed(self): view = POSTDeniedView.as_view() - request = RequestFactory().get('/') + request = APIRequestFactory().get('/') response = view(request).render() self.assertNotContains(response, '>POST<') self.assertContains(response, '>PUT<') diff --git a/rest_framework/tests/test_request.py b/rest_framework/tests/test_request.py index a5c5e84c..969d8024 100644 --- a/rest_framework/tests/test_request.py +++ b/rest_framework/tests/test_request.py @@ -5,8 +5,7 @@ from __future__ import unicode_literals from django.contrib.auth.models import User from django.contrib.auth import authenticate, login, logout from django.contrib.sessions.middleware import SessionMiddleware -from django.test import TestCase, Client -from django.test.client import RequestFactory +from django.test import TestCase from rest_framework import status from rest_framework.authentication import SessionAuthentication from rest_framework.compat import patterns @@ -19,12 +18,13 @@ from rest_framework.parsers import ( from rest_framework.request import Request from rest_framework.response import Response from rest_framework.settings import api_settings +from rest_framework.test import APIRequestFactory, APIClient from rest_framework.views import APIView from rest_framework.compat import six import json -factory = RequestFactory() +factory = APIRequestFactory() class PlainTextParser(BaseParser): @@ -116,16 +116,7 @@ class TestContentParsing(TestCase): Ensure request.DATA returns content for PUT request with form content. """ data = {'qwerty': 'uiop'} - - from django import VERSION - - if VERSION >= (1, 5): - from django.test.client import MULTIPART_CONTENT, BOUNDARY, encode_multipart - request = Request(factory.put('/', encode_multipart(BOUNDARY, data), - content_type=MULTIPART_CONTENT)) - else: - request = Request(factory.put('/', data)) - + request = Request(factory.put('/', data)) request.parsers = (FormParser(), MultiPartParser()) self.assertEqual(list(request.DATA.items()), list(data.items())) @@ -257,7 +248,7 @@ class TestContentParsingWithAuthentication(TestCase): urls = 'rest_framework.tests.test_request' def setUp(self): - self.csrf_client = Client(enforce_csrf_checks=True) + self.csrf_client = APIClient(enforce_csrf_checks=True) self.username = 'john' self.email = 'lennon@thebeatles.com' self.password = 'password' diff --git a/rest_framework/tests/test_reverse.py b/rest_framework/tests/test_reverse.py index 93ef5637..690a30b1 100644 --- a/rest_framework/tests/test_reverse.py +++ b/rest_framework/tests/test_reverse.py @@ -1,10 +1,10 @@ from __future__ import unicode_literals from django.test import TestCase -from django.test.client import RequestFactory from rest_framework.compat import patterns, url from rest_framework.reverse import reverse +from rest_framework.test import APIRequestFactory -factory = RequestFactory() +factory = APIRequestFactory() def null_view(request): diff --git a/rest_framework/tests/test_routers.py b/rest_framework/tests/test_routers.py index d375f4a8..5fcccb74 100644 --- a/rest_framework/tests/test_routers.py +++ b/rest_framework/tests/test_routers.py @@ -1,15 +1,15 @@ from __future__ import unicode_literals from django.db import models from django.test import TestCase -from django.test.client import RequestFactory from django.core.exceptions import ImproperlyConfigured from rest_framework import serializers, viewsets, permissions from rest_framework.compat import include, patterns, url from rest_framework.decorators import link, action from rest_framework.response import Response from rest_framework.routers import SimpleRouter, DefaultRouter +from rest_framework.test import APIRequestFactory -factory = RequestFactory() +factory = APIRequestFactory() urlpatterns = patterns('',) @@ -193,6 +193,7 @@ class TestActionKeywordArgs(TestCase): {'permission_classes': [permissions.AllowAny]} ) + class TestActionAppliedToExistingRoute(TestCase): """ Ensure `@action` decorator raises an except when applied diff --git a/rest_framework/tests/test_testing.py b/rest_framework/tests/test_testing.py new file mode 100644 index 00000000..49d45fc2 --- /dev/null +++ b/rest_framework/tests/test_testing.py @@ -0,0 +1,115 @@ +# -- coding: utf-8 -- + +from __future__ import unicode_literals +from django.contrib.auth.models import User +from django.test import TestCase +from rest_framework.compat import patterns, url +from rest_framework.decorators import api_view +from rest_framework.response import Response +from rest_framework.test import APIClient, APIRequestFactory, force_authenticate + + +@api_view(['GET', 'POST']) +def view(request): + return Response({ + 'auth': request.META.get('HTTP_AUTHORIZATION', b''), + 'user': request.user.username + }) + + +urlpatterns = patterns('', + url(r'^view/$', view), +) + + +class TestAPITestClient(TestCase): + urls = 'rest_framework.tests.test_testing' + + def setUp(self): + self.client = APIClient() + + def test_credentials(self): + """ + Setting `.credentials()` adds the required headers to each request. + """ + self.client.credentials(HTTP_AUTHORIZATION='example') + for _ in range(0, 3): + response = self.client.get('/view/') + self.assertEqual(response.data['auth'], 'example') + + def test_force_authenticate(self): + """ + Setting `.force_authenticate()` forcibly authenticates each request. + """ + user = User.objects.create_user('example', 'example@example.com') + self.client.force_authenticate(user) + response = self.client.get('/view/') + self.assertEqual(response.data['user'], 'example') + + def test_csrf_exempt_by_default(self): + """ + By default, the test client is CSRF exempt. + """ + User.objects.create_user('example', 'example@example.com', 'password') + self.client.login(username='example', password='password') + response = self.client.post('/view/') + self.assertEqual(response.status_code, 200) + + def test_explicitly_enforce_csrf_checks(self): + """ + The test client can enforce CSRF checks. + """ + client = APIClient(enforce_csrf_checks=True) + User.objects.create_user('example', 'example@example.com', 'password') + client.login(username='example', password='password') + response = client.post('/view/') + expected = {'detail': 'CSRF Failed: CSRF cookie not set.'} + self.assertEqual(response.status_code, 403) + self.assertEqual(response.data, expected) + + +class TestAPIRequestFactory(TestCase): + def test_csrf_exempt_by_default(self): + """ + By default, the test client is CSRF exempt. + """ + user = User.objects.create_user('example', 'example@example.com', 'password') + factory = APIRequestFactory() + request = factory.post('/view/') + request.user = user + response = view(request) + self.assertEqual(response.status_code, 200) + + def test_explicitly_enforce_csrf_checks(self): + """ + The test client can enforce CSRF checks. + """ + user = User.objects.create_user('example', 'example@example.com', 'password') + factory = APIRequestFactory(enforce_csrf_checks=True) + request = factory.post('/view/') + request.user = user + response = view(request) + expected = {'detail': 'CSRF Failed: CSRF cookie not set.'} + self.assertEqual(response.status_code, 403) + self.assertEqual(response.data, expected) + + def test_invalid_format(self): + """ + Attempting to use a format that is not configured will raise an + assertion error. + """ + factory = APIRequestFactory() + self.assertRaises(AssertionError, factory.post, + path='/view/', data={'example': 1}, format='xml' + ) + + def test_force_authenticate(self): + """ + Setting `force_authenticate()` forcibly authenticates the request. + """ + user = User.objects.create_user('example', 'example@example.com') + factory = APIRequestFactory() + request = factory.get('/view') + force_authenticate(request, user=user) + response = view(request) + self.assertEqual(response.data['user'], 'example') diff --git a/rest_framework/tests/test_throttling.py b/rest_framework/tests/test_throttling.py index d35d3709..19bc691a 100644 --- a/rest_framework/tests/test_throttling.py +++ b/rest_framework/tests/test_throttling.py @@ -5,7 +5,7 @@ from __future__ import unicode_literals from django.test import TestCase from django.contrib.auth.models import User from django.core.cache import cache -from django.test.client import RequestFactory +from rest_framework.test import APIRequestFactory from rest_framework.views import APIView from rest_framework.throttling import UserRateThrottle, ScopedRateThrottle from rest_framework.response import Response @@ -41,7 +41,7 @@ class ThrottlingTests(TestCase): Reset the cache so that no throttles will be active """ cache.clear() - self.factory = RequestFactory() + self.factory = APIRequestFactory() def test_requests_are_throttled(self): """ @@ -173,7 +173,7 @@ class ScopedRateThrottleTests(TestCase): return Response('y') self.throttle_class = XYScopedRateThrottle - self.factory = RequestFactory() + self.factory = APIRequestFactory() self.x_view = XView.as_view() self.y_view = YView.as_view() self.unscoped_view = UnscopedView.as_view() diff --git a/rest_framework/tests/test_urlpatterns.py b/rest_framework/tests/test_urlpatterns.py index 29ed4a96..8132ec4c 100644 --- a/rest_framework/tests/test_urlpatterns.py +++ b/rest_framework/tests/test_urlpatterns.py @@ -2,7 +2,7 @@ from __future__ import unicode_literals from collections import namedtuple from django.core import urlresolvers from django.test import TestCase -from django.test.client import RequestFactory +from rest_framework.test import APIRequestFactory from rest_framework.compat import patterns, url, include from rest_framework.urlpatterns import format_suffix_patterns @@ -20,7 +20,7 @@ class FormatSuffixTests(TestCase): Tests `format_suffix_patterns` against different URLPatterns to ensure the URLs still resolve properly, including any captured parameters. """ def _resolve_urlpatterns(self, urlpatterns, test_paths): - factory = RequestFactory() + factory = APIRequestFactory() try: urlpatterns = format_suffix_patterns(urlpatterns) except Exception: diff --git a/rest_framework/tests/test_validation.py b/rest_framework/tests/test_validation.py index a6ec0e99..ebfdff9c 100644 --- a/rest_framework/tests/test_validation.py +++ b/rest_framework/tests/test_validation.py @@ -2,10 +2,9 @@ from __future__ import unicode_literals from django.db import models from django.test import TestCase from rest_framework import generics, serializers, status -from rest_framework.tests.utils import RequestFactory -import json +from rest_framework.test import APIRequestFactory -factory = RequestFactory() +factory = APIRequestFactory() # Regression for #666 @@ -33,8 +32,7 @@ class TestPreSaveValidationExclusions(TestCase): validation on read only fields. """ obj = ValidationModel.objects.create(blank_validated_field='') - request = factory.put('/', json.dumps({}), - content_type='application/json') + request = factory.put('/', {}, format='json') view = UpdateValidationModel().as_view() response = view(request, pk=obj.pk).render() self.assertEqual(response.status_code, status.HTTP_200_OK) diff --git a/rest_framework/tests/test_views.py b/rest_framework/tests/test_views.py index 2767d24c..c0bec5ae 100644 --- a/rest_framework/tests/test_views.py +++ b/rest_framework/tests/test_views.py @@ -1,17 +1,15 @@ from __future__ import unicode_literals import copy - from django.test import TestCase -from django.test.client import RequestFactory - from rest_framework import status from rest_framework.decorators import api_view from rest_framework.response import Response from rest_framework.settings import api_settings +from rest_framework.test import APIRequestFactory from rest_framework.views import APIView -factory = RequestFactory() +factory = APIRequestFactory() class BasicView(APIView): diff --git a/rest_framework/tests/utils.py b/rest_framework/tests/utils.py deleted file mode 100644 index 8c87917d..00000000 --- a/rest_framework/tests/utils.py +++ /dev/null @@ -1,40 +0,0 @@ -from __future__ import unicode_literals -from django.test.client import FakePayload, Client as _Client, RequestFactory as _RequestFactory -from django.test.client import MULTIPART_CONTENT -from rest_framework.compat import urlparse - - -class RequestFactory(_RequestFactory): - - def __init__(self, **defaults): - super(RequestFactory, self).__init__(**defaults) - - def patch(self, path, data={}, content_type=MULTIPART_CONTENT, - **extra): - "Construct a PATCH request." - - patch_data = self._encode_data(data, content_type) - - parsed = urlparse.urlparse(path) - r = { - 'CONTENT_LENGTH': len(patch_data), - 'CONTENT_TYPE': content_type, - 'PATH_INFO': self._get_path(parsed), - 'QUERY_STRING': parsed[4], - 'REQUEST_METHOD': 'PATCH', - 'wsgi.input': FakePayload(patch_data), - } - r.update(extra) - return self.request(**r) - - -class Client(_Client, RequestFactory): - def patch(self, path, data={}, content_type=MULTIPART_CONTENT, - follow=False, **extra): - """ - Send a resource to the server using PATCH. - """ - response = super(Client, self).patch(path, data=data, content_type=content_type, **extra) - if follow: - response = self._handle_redirects(response, **extra) - return response |
