diff options
| author | Tom Christie | 2013-07-04 05:50:04 -0700 | 
|---|---|---|
| committer | Tom Christie | 2013-07-04 05:50:04 -0700 | 
| commit | 99794773cf6b865b5b860b35db31dea92968c605 (patch) | |
| tree | 157b09d6b19ee5583d6d32123b3671c1c75adbc9 /rest_framework | |
| parent | a890116ab31e57af3bd1382c1f17259fa368f988 (diff) | |
| parent | 7398464b397d37dbcfda13eb6142039fed3e9a19 (diff) | |
| download | django-rest-framework-99794773cf6b865b5b860b35db31dea92968c605.tar.bz2 | |
Merge pull request #962 from tomchristie/test-client
APIClient and APIRequestFactory
Diffstat (limited to 'rest_framework')
25 files changed, 452 insertions, 188 deletions
diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index 10298027..cf001a24 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -26,6 +26,12 @@ def get_authorization_header(request):      return auth +class CSRFCheck(CsrfViewMiddleware): +    def _reject(self, request, reason): +        # Return the failure reason instead of an HttpResponse +        return reason + +  class BaseAuthentication(object):      """      All authentication classes should extend BaseAuthentication. @@ -103,27 +109,27 @@ class SessionAuthentication(BaseAuthentication):          """          # Get the underlying HttpRequest object -        http_request = request._request -        user = getattr(http_request, 'user', None) +        request = request._request +        user = getattr(request, 'user', None)          # Unauthenticated, CSRF validation not required          if not user or not user.is_active:              return None -        # Enforce CSRF validation for session based authentication. -        class CSRFCheck(CsrfViewMiddleware): -            def _reject(self, request, reason): -                # Return the failure reason instead of an HttpResponse -                return reason +        self.enforce_csrf(request) + +        # CSRF passed with authenticated user +        return (user, None) -        reason = CSRFCheck().process_view(http_request, None, (), {}) +    def enforce_csrf(self, request): +        """ +        Enforce CSRF validation for session based authentication. +        """ +        reason = CSRFCheck().process_view(request, None, (), {})          if reason:              # CSRF failed, bail with explicit error message              raise exceptions.AuthenticationFailed('CSRF Failed: %s' % reason) -        # CSRF passed with authenticated user -        return (user, None) -  class TokenAuthentication(BaseAuthentication):      """ diff --git a/rest_framework/compat.py b/rest_framework/compat.py index cb122846..6f7447ad 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -8,6 +8,7 @@ from __future__ import unicode_literals  import django  from django.core.exceptions import ImproperlyConfigured +from django.conf import settings  # Try to import six from Django, fallback to included `six`.  try: @@ -83,7 +84,6 @@ def get_concrete_model(model_cls):  # Django 1.5 add support for custom auth user model  if django.VERSION >= (1, 5): -    from django.conf import settings      AUTH_USER_MODEL = settings.AUTH_USER_MODEL  else:      AUTH_USER_MODEL = 'auth.User' @@ -436,6 +436,42 @@ except ImportError:          return force_text(url) +# RequestFactory only provide `generic` from 1.5 onwards + +from django.test.client import RequestFactory as DjangoRequestFactory +from django.test.client import FakePayload +try: +    # In 1.5 the test client uses force_bytes +    from django.utils.encoding import force_bytes_or_smart_bytes +except ImportError: +    # In 1.3 and 1.4 the test client just uses smart_str +    from django.utils.encoding import smart_str as force_bytes_or_smart_bytes + + +class RequestFactory(DjangoRequestFactory): +    def generic(self, method, path, +            data='', content_type='application/octet-stream', **extra): +        parsed = urlparse.urlparse(path) +        data = force_bytes_or_smart_bytes(data, settings.DEFAULT_CHARSET) +        r = { +            'PATH_INFO':      self._get_path(parsed), +            'QUERY_STRING':   force_text(parsed[4]), +            'REQUEST_METHOD': str(method), +        } +        if data: +            r.update({ +                'CONTENT_LENGTH': len(data), +                'CONTENT_TYPE':   str(content_type), +                'wsgi.input':     FakePayload(data), +            }) +        elif django.VERSION <= (1, 4): +            # For 1.3 we need an empty WSGI payload +            r.update({ +                'wsgi.input': FakePayload('') +            }) +        r.update(extra) +        return self.request(**r) +  # Markdown is optional  try:      import markdown diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 8b2428ad..3a03ca33 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -14,6 +14,7 @@ from django import forms  from django.core.exceptions import ImproperlyConfigured  from django.http.multipartparser import parse_header  from django.template import RequestContext, loader, Template +from django.test.client import encode_multipart  from django.utils.xmlutils import SimplerXMLGenerator  from rest_framework.compat import StringIO  from rest_framework.compat import six @@ -571,3 +572,13 @@ class BrowsableAPIRenderer(BaseRenderer):              response.status_code = status.HTTP_200_OK          return ret + + +class MultiPartRenderer(BaseRenderer): +    media_type = 'multipart/form-data; boundary=BoUnDaRyStRiNg' +    format = 'multipart' +    charset = 'utf-8' +    BOUNDARY = 'BoUnDaRyStRiNg' + +    def render(self, data, accepted_media_type=None, renderer_context=None): +        return encode_multipart(self.BOUNDARY, data) diff --git a/rest_framework/request.py b/rest_framework/request.py index 0d88ebc7..919716f4 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -64,6 +64,20 @@ def clone_request(request, method):      return ret +class ForcedAuthentication(object): +    """ +    This authentication class is used if the test client or request factory +    forcibly authenticated the request. +    """ + +    def __init__(self, force_user, force_token): +        self.force_user = force_user +        self.force_token = force_token + +    def authenticate(self, request): +        return (self.force_user, self.force_token) + +  class Request(object):      """      Wrapper allowing to enhance a standard `HttpRequest` instance. @@ -98,6 +112,12 @@ class Request(object):          self.parser_context['request'] = self          self.parser_context['encoding'] = request.encoding or settings.DEFAULT_CHARSET +        force_user = getattr(request, '_force_auth_user', None) +        force_token = getattr(request, '_force_auth_token', None) +        if (force_user is not None or force_token is not None): +            forced_auth = ForcedAuthentication(force_user, force_token) +            self.authenticators = (forced_auth,) +      def _default_negotiator(self):          return api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS() diff --git a/rest_framework/settings.py b/rest_framework/settings.py index beb511ac..8fd177d5 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -73,6 +73,13 @@ DEFAULTS = {      'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',      'UNAUTHENTICATED_TOKEN': None, +    # Testing +    'TEST_REQUEST_RENDERER_CLASSES': ( +        'rest_framework.renderers.MultiPartRenderer', +        'rest_framework.renderers.JSONRenderer' +    ), +    'TEST_REQUEST_DEFAULT_FORMAT': 'multipart', +      # Browser enhancements      'FORM_METHOD_OVERRIDE': '_method',      'FORM_CONTENT_OVERRIDE': '_content', @@ -115,6 +122,7 @@ IMPORT_STRINGS = (      'DEFAULT_PAGINATION_SERIALIZER_CLASS',      'DEFAULT_FILTER_BACKENDS',      'FILTER_BACKEND', +    'TEST_REQUEST_RENDERER_CLASSES',      'UNAUTHENTICATED_USER',      'UNAUTHENTICATED_TOKEN',  ) diff --git a/rest_framework/test.py b/rest_framework/test.py new file mode 100644 index 00000000..29d017ee --- /dev/null +++ b/rest_framework/test.py @@ -0,0 +1,139 @@ +# -- coding: utf-8 -- + +# Note that we import as `DjangoRequestFactory` and `DjangoClient` in order +# to make it harder for the user to import the wrong thing without realizing. +from __future__ import unicode_literals +from django.conf import settings +from django.test.client import Client as DjangoClient +from django.test.client import ClientHandler +from rest_framework.settings import api_settings +from rest_framework.compat import RequestFactory as DjangoRequestFactory +from rest_framework.compat import force_bytes_or_smart_bytes, six + + +def force_authenticate(request, user=None, token=None): +    request._force_auth_user = user +    request._force_auth_token = token + + +class APIRequestFactory(DjangoRequestFactory): +    renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES +    default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT + +    def __init__(self, enforce_csrf_checks=False, **defaults): +        self.enforce_csrf_checks = enforce_csrf_checks +        self.renderer_classes = {} +        for cls in self.renderer_classes_list: +            self.renderer_classes[cls.format] = cls +        super(APIRequestFactory, self).__init__(**defaults) + +    def _encode_data(self, data, format=None, content_type=None): +        """ +        Encode the data returning a two tuple of (bytes, content_type) +        """ + +        if not data: +            return ('', None) + +        assert format is None or content_type is None, ( +            'You may not set both `format` and `content_type`.' +        ) + +        if content_type: +            # Content type specified explicitly, treat data as a raw bytestring +            ret = force_bytes_or_smart_bytes(data, settings.DEFAULT_CHARSET) + +        else: +            format = format or self.default_format + +            assert format in self.renderer_classes, ("Invalid format '{0}'. " +                "Available formats are {1}.  Set TEST_REQUEST_RENDERER_CLASSES " +                "to enable extra request formats.".format( +                    format, +                    ', '.join(["'" + fmt + "'" for fmt in self.renderer_classes.keys()]) +                ) +            ) + +            # Use format and render the data into a bytestring +            renderer = self.renderer_classes[format]() +            ret = renderer.render(data) + +            # Determine the content-type header from the renderer +            content_type = "{0}; charset={1}".format( +                renderer.media_type, renderer.charset +            ) + +            # Coerce text to bytes if required. +            if isinstance(ret, six.text_type): +                ret = bytes(ret.encode(renderer.charset)) + +        return ret, content_type + +    def post(self, path, data=None, format=None, content_type=None, **extra): +        data, content_type = self._encode_data(data, format, content_type) +        return self.generic('POST', path, data, content_type, **extra) + +    def put(self, path, data=None, format=None, content_type=None, **extra): +        data, content_type = self._encode_data(data, format, content_type) +        return self.generic('PUT', path, data, content_type, **extra) + +    def patch(self, path, data=None, format=None, content_type=None, **extra): +        data, content_type = self._encode_data(data, format, content_type) +        return self.generic('PATCH', path, data, content_type, **extra) + +    def delete(self, path, data=None, format=None, content_type=None, **extra): +        data, content_type = self._encode_data(data, format, content_type) +        return self.generic('DELETE', path, data, content_type, **extra) + +    def options(self, path, data=None, format=None, content_type=None, **extra): +        data, content_type = self._encode_data(data, format, content_type) +        return self.generic('OPTIONS', path, data, content_type, **extra) + +    def request(self, **kwargs): +        request = super(APIRequestFactory, self).request(**kwargs) +        request._dont_enforce_csrf_checks = not self.enforce_csrf_checks +        return request + + +class ForceAuthClientHandler(ClientHandler): +    """ +    A patched version of ClientHandler that can enforce authentication +    on the outgoing requests. +    """ + +    def __init__(self, *args, **kwargs): +        self._force_user = None +        self._force_token = None +        super(ForceAuthClientHandler, self).__init__(*args, **kwargs) + +    def get_response(self, request): +        # This is the simplest place we can hook into to patch the +        # request object. +        force_authenticate(request, self._force_user, self._force_token) +        return super(ForceAuthClientHandler, self).get_response(request) + + +class APIClient(APIRequestFactory, DjangoClient): +    def __init__(self, enforce_csrf_checks=False, **defaults): +        super(APIClient, self).__init__(**defaults) +        self.handler = ForceAuthClientHandler(enforce_csrf_checks) +        self._credentials = {} + +    def credentials(self, **kwargs): +        """ +        Sets headers that will be used on every outgoing request. +        """ +        self._credentials = kwargs + +    def force_authenticate(self, user=None, token=None): +        """ +        Forcibly authenticates outgoing requests with the given +        user and/or token. +        """ +        self.handler._force_user = user +        self.handler._force_token = token + +    def request(self, **kwargs): +        # Ensure that any credentials set get added to every request. +        kwargs.update(self._credentials) +        return super(APIClient, self).request(**kwargs) diff --git a/rest_framework/tests/test_authentication.py b/rest_framework/tests/test_authentication.py index 6a50be06..a44813b6 100644 --- a/rest_framework/tests/test_authentication.py +++ b/rest_framework/tests/test_authentication.py @@ -1,7 +1,7 @@  from __future__ import unicode_literals  from django.contrib.auth.models import User  from django.http import HttpResponse -from django.test import Client, TestCase +from django.test import TestCase  from django.utils import unittest  from rest_framework import HTTP_HEADER_ENCODING  from rest_framework import exceptions @@ -21,14 +21,13 @@ from rest_framework.authtoken.models import Token  from rest_framework.compat import patterns, url, include  from rest_framework.compat import oauth2_provider, oauth2_provider_models, oauth2_provider_scope  from rest_framework.compat import oauth, oauth_provider -from rest_framework.tests.utils import RequestFactory +from rest_framework.test import APIRequestFactory, APIClient  from rest_framework.views import APIView -import json  import base64  import time  import datetime -factory = RequestFactory() +factory = APIRequestFactory()  class MockView(APIView): @@ -68,7 +67,7 @@ class BasicAuthTests(TestCase):      urls = 'rest_framework.tests.test_authentication'      def setUp(self): -        self.csrf_client = Client(enforce_csrf_checks=True) +        self.csrf_client = APIClient(enforce_csrf_checks=True)          self.username = 'john'          self.email = 'lennon@thebeatles.com'          self.password = 'password' @@ -87,7 +86,7 @@ class BasicAuthTests(TestCase):          credentials = ('%s:%s' % (self.username, self.password))          base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING)          auth = 'Basic %s' % base64_credentials -        response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) +        response = self.csrf_client.post('/basic/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)          self.assertEqual(response.status_code, status.HTTP_200_OK)      def test_post_form_failing_basic_auth(self): @@ -97,7 +96,7 @@ class BasicAuthTests(TestCase):      def test_post_json_failing_basic_auth(self):          """Ensure POSTing json over basic auth without correct credentials fails""" -        response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json') +        response = self.csrf_client.post('/basic/', {'example': 'example'}, format='json')          self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)          self.assertEqual(response['WWW-Authenticate'], 'Basic realm="api"') @@ -107,8 +106,8 @@ class SessionAuthTests(TestCase):      urls = 'rest_framework.tests.test_authentication'      def setUp(self): -        self.csrf_client = Client(enforce_csrf_checks=True) -        self.non_csrf_client = Client(enforce_csrf_checks=False) +        self.csrf_client = APIClient(enforce_csrf_checks=True) +        self.non_csrf_client = APIClient(enforce_csrf_checks=False)          self.username = 'john'          self.email = 'lennon@thebeatles.com'          self.password = 'password' @@ -154,7 +153,7 @@ class TokenAuthTests(TestCase):      urls = 'rest_framework.tests.test_authentication'      def setUp(self): -        self.csrf_client = Client(enforce_csrf_checks=True) +        self.csrf_client = APIClient(enforce_csrf_checks=True)          self.username = 'john'          self.email = 'lennon@thebeatles.com'          self.password = 'password' @@ -172,7 +171,7 @@ class TokenAuthTests(TestCase):      def test_post_json_passing_token_auth(self):          """Ensure POSTing form over token auth with correct credentials passes and does not require CSRF"""          auth = "Token " + self.key -        response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) +        response = self.csrf_client.post('/token/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)          self.assertEqual(response.status_code, status.HTTP_200_OK)      def test_post_form_failing_token_auth(self): @@ -182,7 +181,7 @@ class TokenAuthTests(TestCase):      def test_post_json_failing_token_auth(self):          """Ensure POSTing json over token auth without correct credentials fails""" -        response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json') +        response = self.csrf_client.post('/token/', {'example': 'example'}, format='json')          self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)      def test_token_has_auto_assigned_key_if_none_provided(self): @@ -193,33 +192,33 @@ class TokenAuthTests(TestCase):      def test_token_login_json(self):          """Ensure token login view using JSON POST works.""" -        client = Client(enforce_csrf_checks=True) +        client = APIClient(enforce_csrf_checks=True)          response = client.post('/auth-token/', -                               json.dumps({'username': self.username, 'password': self.password}), 'application/json') +                               {'username': self.username, 'password': self.password}, format='json')          self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(json.loads(response.content.decode('ascii'))['token'], self.key) +        self.assertEqual(response.data['token'], self.key)      def test_token_login_json_bad_creds(self):          """Ensure token login view using JSON POST fails if bad credentials are used.""" -        client = Client(enforce_csrf_checks=True) +        client = APIClient(enforce_csrf_checks=True)          response = client.post('/auth-token/', -                               json.dumps({'username': self.username, 'password': "badpass"}), 'application/json') +                               {'username': self.username, 'password': "badpass"}, format='json')          self.assertEqual(response.status_code, 400)      def test_token_login_json_missing_fields(self):          """Ensure token login view using JSON POST fails if missing fields.""" -        client = Client(enforce_csrf_checks=True) +        client = APIClient(enforce_csrf_checks=True)          response = client.post('/auth-token/', -                               json.dumps({'username': self.username}), 'application/json') +                               {'username': self.username}, format='json')          self.assertEqual(response.status_code, 400)      def test_token_login_form(self):          """Ensure token login view using form POST works.""" -        client = Client(enforce_csrf_checks=True) +        client = APIClient(enforce_csrf_checks=True)          response = client.post('/auth-token/',                                 {'username': self.username, 'password': self.password})          self.assertEqual(response.status_code, status.HTTP_200_OK) -        self.assertEqual(json.loads(response.content.decode('ascii'))['token'], self.key) +        self.assertEqual(response.data['token'], self.key)  class IncorrectCredentialsTests(TestCase): @@ -256,7 +255,7 @@ class OAuthTests(TestCase):          self.consts = consts -        self.csrf_client = Client(enforce_csrf_checks=True) +        self.csrf_client = APIClient(enforce_csrf_checks=True)          self.username = 'john'          self.email = 'lennon@thebeatles.com'          self.password = 'password' @@ -470,12 +469,13 @@ class OAuthTests(TestCase):          response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth)          self.assertEqual(response.status_code, 401) +  class OAuth2Tests(TestCase):      """OAuth 2.0 authentication"""      urls = 'rest_framework.tests.test_authentication'      def setUp(self): -        self.csrf_client = Client(enforce_csrf_checks=True) +        self.csrf_client = APIClient(enforce_csrf_checks=True)          self.username = 'john'          self.email = 'lennon@thebeatles.com'          self.password = 'password' diff --git a/rest_framework/tests/test_decorators.py b/rest_framework/tests/test_decorators.py index 1016fed3..195f0ba3 100644 --- a/rest_framework/tests/test_decorators.py +++ b/rest_framework/tests/test_decorators.py @@ -1,12 +1,13 @@  from __future__ import unicode_literals  from django.test import TestCase  from rest_framework import status +from rest_framework.authentication import BasicAuthentication +from rest_framework.parsers import JSONParser +from rest_framework.permissions import IsAuthenticated  from rest_framework.response import Response  from rest_framework.renderers import JSONRenderer -from rest_framework.parsers import JSONParser -from rest_framework.authentication import BasicAuthentication +from rest_framework.test import APIRequestFactory  from rest_framework.throttling import UserRateThrottle -from rest_framework.permissions import IsAuthenticated  from rest_framework.views import APIView  from rest_framework.decorators import (      api_view, @@ -17,13 +18,11 @@ from rest_framework.decorators import (      permission_classes,  ) -from rest_framework.tests.utils import RequestFactory -  class DecoratorTestCase(TestCase):      def setUp(self): -        self.factory = RequestFactory() +        self.factory = APIRequestFactory()      def _finalize_response(self, request, response, *args, **kwargs):          response.request = request diff --git a/rest_framework/tests/test_filters.py b/rest_framework/tests/test_filters.py index aaed6247..c9d9e7ff 100644 --- a/rest_framework/tests/test_filters.py +++ b/rest_framework/tests/test_filters.py @@ -4,13 +4,13 @@ from decimal import Decimal  from django.db import models  from django.core.urlresolvers import reverse  from django.test import TestCase -from django.test.client import RequestFactory  from django.utils import unittest  from rest_framework import generics, serializers, status, filters  from rest_framework.compat import django_filters, patterns, url +from rest_framework.test import APIRequestFactory  from rest_framework.tests.models import BasicModel -factory = RequestFactory() +factory = APIRequestFactory()  class FilterableItem(models.Model): diff --git a/rest_framework/tests/test_generics.py b/rest_framework/tests/test_generics.py index 37734195..1550880b 100644 --- a/rest_framework/tests/test_generics.py +++ b/rest_framework/tests/test_generics.py @@ -3,12 +3,11 @@ from django.db import models  from django.shortcuts import get_object_or_404  from django.test import TestCase  from rest_framework import generics, renderers, serializers, status -from rest_framework.tests.utils import RequestFactory +from rest_framework.test import APIRequestFactory  from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel  from rest_framework.compat import six -import json -factory = RequestFactory() +factory = APIRequestFactory()  class RootView(generics.ListCreateAPIView): @@ -71,9 +70,8 @@ class TestRootView(TestCase):          """          POST requests to ListCreateAPIView should create a new object.          """ -        content = {'text': 'foobar'} -        request = factory.post('/', json.dumps(content), -                               content_type='application/json') +        data = {'text': 'foobar'} +        request = factory.post('/', data, format='json')          with self.assertNumQueries(1):              response = self.view(request).render()          self.assertEqual(response.status_code, status.HTTP_201_CREATED) @@ -85,9 +83,8 @@ class TestRootView(TestCase):          """          PUT requests to ListCreateAPIView should not be allowed          """ -        content = {'text': 'foobar'} -        request = factory.put('/', json.dumps(content), -                              content_type='application/json') +        data = {'text': 'foobar'} +        request = factory.put('/', data, format='json')          with self.assertNumQueries(0):              response = self.view(request).render()          self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) @@ -148,9 +145,8 @@ class TestRootView(TestCase):          """          POST requests to create a new object should not be able to set the id.          """ -        content = {'id': 999, 'text': 'foobar'} -        request = factory.post('/', json.dumps(content), -                               content_type='application/json') +        data = {'id': 999, 'text': 'foobar'} +        request = factory.post('/', data, format='json')          with self.assertNumQueries(1):              response = self.view(request).render()          self.assertEqual(response.status_code, status.HTTP_201_CREATED) @@ -189,9 +185,8 @@ class TestInstanceView(TestCase):          """          POST requests to RetrieveUpdateDestroyAPIView should not be allowed          """ -        content = {'text': 'foobar'} -        request = factory.post('/', json.dumps(content), -                               content_type='application/json') +        data = {'text': 'foobar'} +        request = factory.post('/', data, format='json')          with self.assertNumQueries(0):              response = self.view(request).render()          self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) @@ -201,9 +196,8 @@ class TestInstanceView(TestCase):          """          PUT requests to RetrieveUpdateDestroyAPIView should update an object.          """ -        content = {'text': 'foobar'} -        request = factory.put('/1', json.dumps(content), -                              content_type='application/json') +        data = {'text': 'foobar'} +        request = factory.put('/1', data, format='json')          with self.assertNumQueries(2):              response = self.view(request, pk='1').render()          self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -215,9 +209,8 @@ class TestInstanceView(TestCase):          """          PATCH requests to RetrieveUpdateDestroyAPIView should update an object.          """ -        content = {'text': 'foobar'} -        request = factory.patch('/1', json.dumps(content), -                              content_type='application/json') +        data = {'text': 'foobar'} +        request = factory.patch('/1', data, format='json')          with self.assertNumQueries(2):              response = self.view(request, pk=1).render() @@ -293,9 +286,8 @@ class TestInstanceView(TestCase):          """          PUT requests to create a new object should not be able to set the id.          """ -        content = {'id': 999, 'text': 'foobar'} -        request = factory.put('/1', json.dumps(content), -                              content_type='application/json') +        data = {'id': 999, 'text': 'foobar'} +        request = factory.put('/1', data, format='json')          with self.assertNumQueries(2):              response = self.view(request, pk=1).render()          self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -309,9 +301,8 @@ class TestInstanceView(TestCase):          if it does not currently exist.          """          self.objects.get(id=1).delete() -        content = {'text': 'foobar'} -        request = factory.put('/1', json.dumps(content), -                              content_type='application/json') +        data = {'text': 'foobar'} +        request = factory.put('/1', data, format='json')          with self.assertNumQueries(3):              response = self.view(request, pk=1).render()          self.assertEqual(response.status_code, status.HTTP_201_CREATED) @@ -324,10 +315,9 @@ class TestInstanceView(TestCase):          PUT requests to RetrieveUpdateDestroyAPIView should create an object          at the requested url if it doesn't exist.          """ -        content = {'text': 'foobar'} +        data = {'text': 'foobar'}          # pk fields can not be created on demand, only the database can set the pk for a new object -        request = factory.put('/5', json.dumps(content), -                              content_type='application/json') +        request = factory.put('/5', data, format='json')          with self.assertNumQueries(3):              response = self.view(request, pk=5).render()          self.assertEqual(response.status_code, status.HTTP_201_CREATED) @@ -339,9 +329,8 @@ class TestInstanceView(TestCase):          PUT requests to RetrieveUpdateDestroyAPIView should create an object          at the requested url if possible, else return HTTP_403_FORBIDDEN error-response.          """ -        content = {'text': 'foobar'} -        request = factory.put('/test_slug', json.dumps(content), -                              content_type='application/json') +        data = {'text': 'foobar'} +        request = factory.put('/test_slug', data, format='json')          with self.assertNumQueries(2):              response = self.slug_based_view(request, slug='test_slug').render()          self.assertEqual(response.status_code, status.HTTP_201_CREATED) @@ -415,9 +404,8 @@ class TestCreateModelWithAutoNowAddField(TestCase):          https://github.com/tomchristie/django-rest-framework/issues/285          """ -        content = {'email': 'foobar@example.com', 'content': 'foobar'} -        request = factory.post('/', json.dumps(content), -                               content_type='application/json') +        data = {'email': 'foobar@example.com', 'content': 'foobar'} +        request = factory.post('/', data, format='json')          response = self.view(request).render()          self.assertEqual(response.status_code, status.HTTP_201_CREATED)          created = self.objects.get(id=1) diff --git a/rest_framework/tests/test_hyperlinkedserializers.py b/rest_framework/tests/test_hyperlinkedserializers.py index 129600cb..61e613d7 100644 --- a/rest_framework/tests/test_hyperlinkedserializers.py +++ b/rest_framework/tests/test_hyperlinkedserializers.py @@ -1,12 +1,15 @@  from __future__ import unicode_literals  import json  from django.test import TestCase -from django.test.client import RequestFactory  from rest_framework import generics, status, serializers  from rest_framework.compat import patterns, url -from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, Album, Photo, OptionalRelationModel +from rest_framework.test import APIRequestFactory +from rest_framework.tests.models import ( +    Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, +    Album, Photo, OptionalRelationModel +) -factory = RequestFactory() +factory = APIRequestFactory()  class BlogPostCommentSerializer(serializers.ModelSerializer): @@ -21,7 +24,7 @@ class BlogPostCommentSerializer(serializers.ModelSerializer):  class PhotoSerializer(serializers.Serializer):      description = serializers.CharField() -    album_url = serializers.HyperlinkedRelatedField(source='album', view_name='album-detail', queryset=Album.objects.all(), slug_field='title', slug_url_kwarg='title') +    album_url = serializers.HyperlinkedRelatedField(source='album', view_name='album-detail', queryset=Album.objects.all(), lookup_field='title', slug_url_kwarg='title')      def restore_object(self, attrs, instance=None):          return Photo(**attrs) diff --git a/rest_framework/tests/test_negotiation.py b/rest_framework/tests/test_negotiation.py index 7f84827f..04b89eb6 100644 --- a/rest_framework/tests/test_negotiation.py +++ b/rest_framework/tests/test_negotiation.py @@ -1,12 +1,12 @@  from __future__ import unicode_literals  from django.test import TestCase -from django.test.client import RequestFactory  from rest_framework.negotiation import DefaultContentNegotiation  from rest_framework.request import Request  from rest_framework.renderers import BaseRenderer +from rest_framework.test import APIRequestFactory -factory = RequestFactory() +factory = APIRequestFactory()  class MockJSONRenderer(BaseRenderer): diff --git a/rest_framework/tests/test_pagination.py b/rest_framework/tests/test_pagination.py index e538a78e..85d4640e 100644 --- a/rest_framework/tests/test_pagination.py +++ b/rest_framework/tests/test_pagination.py @@ -4,13 +4,13 @@ from decimal import Decimal  from django.db import models  from django.core.paginator import Paginator  from django.test import TestCase -from django.test.client import RequestFactory  from django.utils import unittest  from rest_framework import generics, status, pagination, filters, serializers  from rest_framework.compat import django_filters +from rest_framework.test import APIRequestFactory  from rest_framework.tests.models import BasicModel -factory = RequestFactory() +factory = APIRequestFactory()  class FilterableItem(models.Model): @@ -369,7 +369,7 @@ class TestCustomPaginationSerializer(TestCase):          self.page = paginator.page(1)      def test_custom_pagination_serializer(self): -        request = RequestFactory().get('/foobar') +        request = APIRequestFactory().get('/foobar')          serializer = CustomPaginationSerializer(              instance=self.page,              context={'request': request} diff --git a/rest_framework/tests/test_permissions.py b/rest_framework/tests/test_permissions.py index 6caaf65b..e2cca380 100644 --- a/rest_framework/tests/test_permissions.py +++ b/rest_framework/tests/test_permissions.py @@ -3,11 +3,10 @@ from django.contrib.auth.models import User, Permission  from django.db import models  from django.test import TestCase  from rest_framework import generics, status, permissions, authentication, HTTP_HEADER_ENCODING -from rest_framework.tests.utils import RequestFactory +from rest_framework.test import APIRequestFactory  import base64 -import json -factory = RequestFactory() +factory = APIRequestFactory()  class BasicModel(models.Model): @@ -56,15 +55,13 @@ class ModelPermissionsIntegrationTests(TestCase):          BasicModel(text='foo').save()      def test_has_create_permissions(self): -        request = factory.post('/', json.dumps({'text': 'foobar'}), -                               content_type='application/json', +        request = factory.post('/', {'text': 'foobar'}, format='json',                                 HTTP_AUTHORIZATION=self.permitted_credentials)          response = root_view(request, pk=1)          self.assertEqual(response.status_code, status.HTTP_201_CREATED)      def test_has_put_permissions(self): -        request = factory.put('/1', json.dumps({'text': 'foobar'}), -                              content_type='application/json', +        request = factory.put('/1', {'text': 'foobar'}, format='json',                                HTTP_AUTHORIZATION=self.permitted_credentials)          response = instance_view(request, pk='1')          self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -75,15 +72,13 @@ class ModelPermissionsIntegrationTests(TestCase):          self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)      def test_does_not_have_create_permissions(self): -        request = factory.post('/', json.dumps({'text': 'foobar'}), -                               content_type='application/json', +        request = factory.post('/', {'text': 'foobar'}, format='json',                                 HTTP_AUTHORIZATION=self.disallowed_credentials)          response = root_view(request, pk=1)          self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)      def test_does_not_have_put_permissions(self): -        request = factory.put('/1', json.dumps({'text': 'foobar'}), -                              content_type='application/json', +        request = factory.put('/1', {'text': 'foobar'}, format='json',                                HTTP_AUTHORIZATION=self.disallowed_credentials)          response = instance_view(request, pk='1')          self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) @@ -95,28 +90,26 @@ class ModelPermissionsIntegrationTests(TestCase):      def test_has_put_as_create_permissions(self):          # User only has update permissions - should be able to update an entity. -        request = factory.put('/1', json.dumps({'text': 'foobar'}), -                              content_type='application/json', +        request = factory.put('/1', {'text': 'foobar'}, format='json',                                HTTP_AUTHORIZATION=self.updateonly_credentials)          response = instance_view(request, pk='1')          self.assertEqual(response.status_code, status.HTTP_200_OK)          # But if PUTing to a new entity, permission should be denied. -        request = factory.put('/2', json.dumps({'text': 'foobar'}), -                              content_type='application/json', +        request = factory.put('/2', {'text': 'foobar'}, format='json',                                HTTP_AUTHORIZATION=self.updateonly_credentials)          response = instance_view(request, pk='2')          self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)      def test_options_permitted(self): -        request = factory.options('/', content_type='application/json', +        request = factory.options('/',                                 HTTP_AUTHORIZATION=self.permitted_credentials)          response = root_view(request, pk='1')          self.assertEqual(response.status_code, status.HTTP_200_OK)          self.assertIn('actions', response.data)          self.assertEqual(list(response.data['actions'].keys()), ['POST']) -        request = factory.options('/1', content_type='application/json', +        request = factory.options('/1',                                 HTTP_AUTHORIZATION=self.permitted_credentials)          response = instance_view(request, pk='1')          self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -124,26 +117,26 @@ class ModelPermissionsIntegrationTests(TestCase):          self.assertEqual(list(response.data['actions'].keys()), ['PUT'])      def test_options_disallowed(self): -        request = factory.options('/', content_type='application/json', +        request = factory.options('/',                                 HTTP_AUTHORIZATION=self.disallowed_credentials)          response = root_view(request, pk='1')          self.assertEqual(response.status_code, status.HTTP_200_OK)          self.assertNotIn('actions', response.data) -        request = factory.options('/1', content_type='application/json', +        request = factory.options('/1',                                 HTTP_AUTHORIZATION=self.disallowed_credentials)          response = instance_view(request, pk='1')          self.assertEqual(response.status_code, status.HTTP_200_OK)          self.assertNotIn('actions', response.data)      def test_options_updateonly(self): -        request = factory.options('/', content_type='application/json', +        request = factory.options('/',                                 HTTP_AUTHORIZATION=self.updateonly_credentials)          response = root_view(request, pk='1')          self.assertEqual(response.status_code, status.HTTP_200_OK)          self.assertNotIn('actions', response.data) -        request = factory.options('/1', content_type='application/json', +        request = factory.options('/1',                                 HTTP_AUTHORIZATION=self.updateonly_credentials)          response = instance_view(request, pk='1')          self.assertEqual(response.status_code, status.HTTP_200_OK) diff --git a/rest_framework/tests/test_relations_hyperlink.py b/rest_framework/tests/test_relations_hyperlink.py index 2ca7f4f2..3c4d39af 100644 --- a/rest_framework/tests/test_relations_hyperlink.py +++ b/rest_framework/tests/test_relations_hyperlink.py @@ -1,15 +1,15 @@  from __future__ import unicode_literals  from django.test import TestCase -from django.test.client import RequestFactory  from rest_framework import serializers  from rest_framework.compat import patterns, url +from rest_framework.test import APIRequestFactory  from rest_framework.tests.models import (      BlogPost,      ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource,      NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource  ) -factory = RequestFactory() +factory = APIRequestFactory()  request = factory.get('/')  # Just to ensure we have a request in the serializer context diff --git a/rest_framework/tests/test_renderers.py b/rest_framework/tests/test_renderers.py index 95b59741..df6f4aa6 100644 --- a/rest_framework/tests/test_renderers.py +++ b/rest_framework/tests/test_renderers.py @@ -4,19 +4,17 @@ from __future__ import unicode_literals  from decimal import Decimal  from django.core.cache import cache  from django.test import TestCase -from django.test.client import RequestFactory  from django.utils import unittest  from django.utils.translation import ugettext_lazy as _  from rest_framework import status, permissions -from rest_framework.compat import yaml, etree, patterns, url, include +from rest_framework.compat import yaml, etree, patterns, url, include, six, StringIO  from rest_framework.response import Response  from rest_framework.views import APIView  from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \      XMLRenderer, JSONPRenderer, BrowsableAPIRenderer, UnicodeJSONRenderer  from rest_framework.parsers import YAMLParser, XMLParser  from rest_framework.settings import api_settings -from rest_framework.compat import StringIO -from rest_framework.compat import six +from rest_framework.test import APIRequestFactory  import datetime  import pickle  import re @@ -121,7 +119,7 @@ class POSTDeniedView(APIView):  class DocumentingRendererTests(TestCase):      def test_only_permitted_forms_are_displayed(self):          view = POSTDeniedView.as_view() -        request = RequestFactory().get('/') +        request = APIRequestFactory().get('/')          response = view(request).render()          self.assertNotContains(response, '>POST<')          self.assertContains(response, '>PUT<') diff --git a/rest_framework/tests/test_request.py b/rest_framework/tests/test_request.py index a5c5e84c..969d8024 100644 --- a/rest_framework/tests/test_request.py +++ b/rest_framework/tests/test_request.py @@ -5,8 +5,7 @@ from __future__ import unicode_literals  from django.contrib.auth.models import User  from django.contrib.auth import authenticate, login, logout  from django.contrib.sessions.middleware import SessionMiddleware -from django.test import TestCase, Client -from django.test.client import RequestFactory +from django.test import TestCase  from rest_framework import status  from rest_framework.authentication import SessionAuthentication  from rest_framework.compat import patterns @@ -19,12 +18,13 @@ from rest_framework.parsers import (  from rest_framework.request import Request  from rest_framework.response import Response  from rest_framework.settings import api_settings +from rest_framework.test import APIRequestFactory, APIClient  from rest_framework.views import APIView  from rest_framework.compat import six  import json -factory = RequestFactory() +factory = APIRequestFactory()  class PlainTextParser(BaseParser): @@ -116,16 +116,7 @@ class TestContentParsing(TestCase):          Ensure request.DATA returns content for PUT request with form content.          """          data = {'qwerty': 'uiop'} - -        from django import VERSION - -        if VERSION >= (1, 5): -            from django.test.client import MULTIPART_CONTENT, BOUNDARY, encode_multipart -            request = Request(factory.put('/', encode_multipart(BOUNDARY, data), -                                  content_type=MULTIPART_CONTENT)) -        else: -            request = Request(factory.put('/', data)) - +        request = Request(factory.put('/', data))          request.parsers = (FormParser(), MultiPartParser())          self.assertEqual(list(request.DATA.items()), list(data.items())) @@ -257,7 +248,7 @@ class TestContentParsingWithAuthentication(TestCase):      urls = 'rest_framework.tests.test_request'      def setUp(self): -        self.csrf_client = Client(enforce_csrf_checks=True) +        self.csrf_client = APIClient(enforce_csrf_checks=True)          self.username = 'john'          self.email = 'lennon@thebeatles.com'          self.password = 'password' diff --git a/rest_framework/tests/test_reverse.py b/rest_framework/tests/test_reverse.py index 93ef5637..690a30b1 100644 --- a/rest_framework/tests/test_reverse.py +++ b/rest_framework/tests/test_reverse.py @@ -1,10 +1,10 @@  from __future__ import unicode_literals  from django.test import TestCase -from django.test.client import RequestFactory  from rest_framework.compat import patterns, url  from rest_framework.reverse import reverse +from rest_framework.test import APIRequestFactory -factory = RequestFactory() +factory = APIRequestFactory()  def null_view(request): diff --git a/rest_framework/tests/test_routers.py b/rest_framework/tests/test_routers.py index d375f4a8..5fcccb74 100644 --- a/rest_framework/tests/test_routers.py +++ b/rest_framework/tests/test_routers.py @@ -1,15 +1,15 @@  from __future__ import unicode_literals  from django.db import models  from django.test import TestCase -from django.test.client import RequestFactory  from django.core.exceptions import ImproperlyConfigured  from rest_framework import serializers, viewsets, permissions  from rest_framework.compat import include, patterns, url  from rest_framework.decorators import link, action  from rest_framework.response import Response  from rest_framework.routers import SimpleRouter, DefaultRouter +from rest_framework.test import APIRequestFactory -factory = RequestFactory() +factory = APIRequestFactory()  urlpatterns = patterns('',) @@ -193,6 +193,7 @@ class TestActionKeywordArgs(TestCase):              {'permission_classes': [permissions.AllowAny]}          ) +  class TestActionAppliedToExistingRoute(TestCase):      """      Ensure `@action` decorator raises an except when applied diff --git a/rest_framework/tests/test_testing.py b/rest_framework/tests/test_testing.py new file mode 100644 index 00000000..49d45fc2 --- /dev/null +++ b/rest_framework/tests/test_testing.py @@ -0,0 +1,115 @@ +# -- coding: utf-8 -- + +from __future__ import unicode_literals +from django.contrib.auth.models import User +from django.test import TestCase +from rest_framework.compat import patterns, url +from rest_framework.decorators import api_view +from rest_framework.response import Response +from rest_framework.test import APIClient, APIRequestFactory, force_authenticate + + +@api_view(['GET', 'POST']) +def view(request): +    return Response({ +        'auth': request.META.get('HTTP_AUTHORIZATION', b''), +        'user': request.user.username +    }) + + +urlpatterns = patterns('', +    url(r'^view/$', view), +) + + +class TestAPITestClient(TestCase): +    urls = 'rest_framework.tests.test_testing' + +    def setUp(self): +        self.client = APIClient() + +    def test_credentials(self): +        """ +        Setting `.credentials()` adds the required headers to each request. +        """ +        self.client.credentials(HTTP_AUTHORIZATION='example') +        for _ in range(0, 3): +            response = self.client.get('/view/') +            self.assertEqual(response.data['auth'], 'example') + +    def test_force_authenticate(self): +        """ +        Setting `.force_authenticate()` forcibly authenticates each request. +        """ +        user = User.objects.create_user('example', 'example@example.com') +        self.client.force_authenticate(user) +        response = self.client.get('/view/') +        self.assertEqual(response.data['user'], 'example') + +    def test_csrf_exempt_by_default(self): +        """ +        By default, the test client is CSRF exempt. +        """ +        User.objects.create_user('example', 'example@example.com', 'password') +        self.client.login(username='example', password='password') +        response = self.client.post('/view/') +        self.assertEqual(response.status_code, 200) + +    def test_explicitly_enforce_csrf_checks(self): +        """ +        The test client can enforce CSRF checks. +        """ +        client = APIClient(enforce_csrf_checks=True) +        User.objects.create_user('example', 'example@example.com', 'password') +        client.login(username='example', password='password') +        response = client.post('/view/') +        expected = {'detail': 'CSRF Failed: CSRF cookie not set.'} +        self.assertEqual(response.status_code, 403) +        self.assertEqual(response.data, expected) + + +class TestAPIRequestFactory(TestCase): +    def test_csrf_exempt_by_default(self): +        """ +        By default, the test client is CSRF exempt. +        """ +        user = User.objects.create_user('example', 'example@example.com', 'password') +        factory = APIRequestFactory() +        request = factory.post('/view/') +        request.user = user +        response = view(request) +        self.assertEqual(response.status_code, 200) + +    def test_explicitly_enforce_csrf_checks(self): +        """ +        The test client can enforce CSRF checks. +        """ +        user = User.objects.create_user('example', 'example@example.com', 'password') +        factory = APIRequestFactory(enforce_csrf_checks=True) +        request = factory.post('/view/') +        request.user = user +        response = view(request) +        expected = {'detail': 'CSRF Failed: CSRF cookie not set.'} +        self.assertEqual(response.status_code, 403) +        self.assertEqual(response.data, expected) + +    def test_invalid_format(self): +        """ +        Attempting to use a format that is not configured will raise an +        assertion error. +        """ +        factory = APIRequestFactory() +        self.assertRaises(AssertionError, factory.post, +            path='/view/', data={'example': 1}, format='xml' +        ) + +    def test_force_authenticate(self): +        """ +        Setting `force_authenticate()` forcibly authenticates the request. +        """ +        user = User.objects.create_user('example', 'example@example.com') +        factory = APIRequestFactory() +        request = factory.get('/view') +        force_authenticate(request, user=user) +        response = view(request) +        self.assertEqual(response.data['user'], 'example') diff --git a/rest_framework/tests/test_throttling.py b/rest_framework/tests/test_throttling.py index d35d3709..19bc691a 100644 --- a/rest_framework/tests/test_throttling.py +++ b/rest_framework/tests/test_throttling.py @@ -5,7 +5,7 @@ from __future__ import unicode_literals  from django.test import TestCase  from django.contrib.auth.models import User  from django.core.cache import cache -from django.test.client import RequestFactory +from rest_framework.test import APIRequestFactory  from rest_framework.views import APIView  from rest_framework.throttling import UserRateThrottle, ScopedRateThrottle  from rest_framework.response import Response @@ -41,7 +41,7 @@ class ThrottlingTests(TestCase):          Reset the cache so that no throttles will be active          """          cache.clear() -        self.factory = RequestFactory() +        self.factory = APIRequestFactory()      def test_requests_are_throttled(self):          """ @@ -173,7 +173,7 @@ class ScopedRateThrottleTests(TestCase):                  return Response('y')          self.throttle_class = XYScopedRateThrottle -        self.factory = RequestFactory() +        self.factory = APIRequestFactory()          self.x_view = XView.as_view()          self.y_view = YView.as_view()          self.unscoped_view = UnscopedView.as_view() diff --git a/rest_framework/tests/test_urlpatterns.py b/rest_framework/tests/test_urlpatterns.py index 29ed4a96..8132ec4c 100644 --- a/rest_framework/tests/test_urlpatterns.py +++ b/rest_framework/tests/test_urlpatterns.py @@ -2,7 +2,7 @@ from __future__ import unicode_literals  from collections import namedtuple  from django.core import urlresolvers  from django.test import TestCase -from django.test.client import RequestFactory +from rest_framework.test import APIRequestFactory  from rest_framework.compat import patterns, url, include  from rest_framework.urlpatterns import format_suffix_patterns @@ -20,7 +20,7 @@ class FormatSuffixTests(TestCase):      Tests `format_suffix_patterns` against different URLPatterns to ensure the URLs still resolve properly, including any captured parameters.      """      def _resolve_urlpatterns(self, urlpatterns, test_paths): -        factory = RequestFactory() +        factory = APIRequestFactory()          try:              urlpatterns = format_suffix_patterns(urlpatterns)          except Exception: diff --git a/rest_framework/tests/test_validation.py b/rest_framework/tests/test_validation.py index a6ec0e99..ebfdff9c 100644 --- a/rest_framework/tests/test_validation.py +++ b/rest_framework/tests/test_validation.py @@ -2,10 +2,9 @@ from __future__ import unicode_literals  from django.db import models  from django.test import TestCase  from rest_framework import generics, serializers, status -from rest_framework.tests.utils import RequestFactory -import json +from rest_framework.test import APIRequestFactory -factory = RequestFactory() +factory = APIRequestFactory()  # Regression for #666 @@ -33,8 +32,7 @@ class TestPreSaveValidationExclusions(TestCase):          validation on read only fields.          """          obj = ValidationModel.objects.create(blank_validated_field='') -        request = factory.put('/', json.dumps({}), -                              content_type='application/json') +        request = factory.put('/', {}, format='json')          view = UpdateValidationModel().as_view()          response = view(request, pk=obj.pk).render()          self.assertEqual(response.status_code, status.HTTP_200_OK) diff --git a/rest_framework/tests/test_views.py b/rest_framework/tests/test_views.py index 2767d24c..c0bec5ae 100644 --- a/rest_framework/tests/test_views.py +++ b/rest_framework/tests/test_views.py @@ -1,17 +1,15 @@  from __future__ import unicode_literals  import copy -  from django.test import TestCase -from django.test.client import RequestFactory -  from rest_framework import status  from rest_framework.decorators import api_view  from rest_framework.response import Response  from rest_framework.settings import api_settings +from rest_framework.test import APIRequestFactory  from rest_framework.views import APIView -factory = RequestFactory() +factory = APIRequestFactory()  class BasicView(APIView): diff --git a/rest_framework/tests/utils.py b/rest_framework/tests/utils.py deleted file mode 100644 index 8c87917d..00000000 --- a/rest_framework/tests/utils.py +++ /dev/null @@ -1,40 +0,0 @@ -from __future__ import unicode_literals -from django.test.client import FakePayload, Client as _Client, RequestFactory as _RequestFactory -from django.test.client import MULTIPART_CONTENT -from rest_framework.compat import urlparse - - -class RequestFactory(_RequestFactory): - -    def __init__(self, **defaults): -        super(RequestFactory, self).__init__(**defaults) - -    def patch(self, path, data={}, content_type=MULTIPART_CONTENT, -            **extra): -        "Construct a PATCH request." - -        patch_data = self._encode_data(data, content_type) - -        parsed = urlparse.urlparse(path) -        r = { -            'CONTENT_LENGTH': len(patch_data), -            'CONTENT_TYPE':   content_type, -            'PATH_INFO':      self._get_path(parsed), -            'QUERY_STRING':   parsed[4], -            'REQUEST_METHOD': 'PATCH', -            'wsgi.input':     FakePayload(patch_data), -        } -        r.update(extra) -        return self.request(**r) - - -class Client(_Client, RequestFactory): -    def patch(self, path, data={}, content_type=MULTIPART_CONTENT, -              follow=False, **extra): -        """ -        Send a resource to the server using PATCH. -        """ -        response = super(Client, self).patch(path, data=data, content_type=content_type, **extra) -        if follow: -            response = self._handle_redirects(response, **extra) -        return response  | 
