From 7224b20d58ceee22abc987980ab646ab8cb2d8dc Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 28 Jun 2013 17:17:39 +0100 Subject: Added APIRequestFactory --- rest_framework/compat.py | 38 +++++++++++++- rest_framework/renderers.py | 11 ++++ rest_framework/response.py | 2 +- rest_framework/test.py | 48 +++++++++++++++++ rest_framework/tests/test_authentication.py | 6 +-- rest_framework/tests/test_decorators.py | 11 ++-- rest_framework/tests/test_filters.py | 4 +- rest_framework/tests/test_generics.py | 60 +++++++++------------- .../tests/test_hyperlinkedserializers.py | 11 ++-- rest_framework/tests/test_negotiation.py | 4 +- rest_framework/tests/test_pagination.py | 6 +-- rest_framework/tests/test_permissions.py | 35 +++++-------- rest_framework/tests/test_relations_hyperlink.py | 4 +- rest_framework/tests/test_renderers.py | 8 ++- rest_framework/tests/test_request.py | 15 ++---- rest_framework/tests/test_reverse.py | 4 +- rest_framework/tests/test_routers.py | 5 +- rest_framework/tests/test_throttling.py | 6 +-- rest_framework/tests/test_urlpatterns.py | 4 +- rest_framework/tests/test_validation.py | 8 ++- rest_framework/tests/test_views.py | 6 +-- 21 files changed, 180 insertions(+), 116 deletions(-) create mode 100644 rest_framework/test.py (limited to 'rest_framework') diff --git a/rest_framework/compat.py b/rest_framework/compat.py index cb122846..6f7447ad 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -8,6 +8,7 @@ from __future__ import unicode_literals import django from django.core.exceptions import ImproperlyConfigured +from django.conf import settings # Try to import six from Django, fallback to included `six`. try: @@ -83,7 +84,6 @@ def get_concrete_model(model_cls): # Django 1.5 add support for custom auth user model if django.VERSION >= (1, 5): - from django.conf import settings AUTH_USER_MODEL = settings.AUTH_USER_MODEL else: AUTH_USER_MODEL = 'auth.User' @@ -436,6 +436,42 @@ except ImportError: return force_text(url) +# RequestFactory only provide `generic` from 1.5 onwards + +from django.test.client import RequestFactory as DjangoRequestFactory +from django.test.client import FakePayload +try: + # In 1.5 the test client uses force_bytes + from django.utils.encoding import force_bytes_or_smart_bytes +except ImportError: + # In 1.3 and 1.4 the test client just uses smart_str + from django.utils.encoding import smart_str as force_bytes_or_smart_bytes + + +class RequestFactory(DjangoRequestFactory): + def generic(self, method, path, + data='', content_type='application/octet-stream', **extra): + parsed = urlparse.urlparse(path) + data = force_bytes_or_smart_bytes(data, settings.DEFAULT_CHARSET) + r = { + 'PATH_INFO': self._get_path(parsed), + 'QUERY_STRING': force_text(parsed[4]), + 'REQUEST_METHOD': str(method), + } + if data: + r.update({ + 'CONTENT_LENGTH': len(data), + 'CONTENT_TYPE': str(content_type), + 'wsgi.input': FakePayload(data), + }) + elif django.VERSION <= (1, 4): + # For 1.3 we need an empty WSGI payload + r.update({ + 'wsgi.input': FakePayload('') + }) + r.update(extra) + return self.request(**r) + # Markdown is optional try: import markdown diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 8b2428ad..d7a7ef29 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -14,6 +14,7 @@ from django import forms from django.core.exceptions import ImproperlyConfigured from django.http.multipartparser import parse_header from django.template import RequestContext, loader, Template +from django.test.client import encode_multipart from django.utils.xmlutils import SimplerXMLGenerator from rest_framework.compat import StringIO from rest_framework.compat import six @@ -571,3 +572,13 @@ class BrowsableAPIRenderer(BaseRenderer): response.status_code = status.HTTP_200_OK return ret + + +class MultiPartRenderer(BaseRenderer): + media_type = 'multipart/form-data; boundary=BoUnDaRyStRiNg' + format = 'form' + charset = 'utf-8' + BOUNDARY = 'BoUnDaRyStRiNg' + + def render(self, data, accepted_media_type=None, renderer_context=None): + return encode_multipart(self.BOUNDARY, data) diff --git a/rest_framework/response.py b/rest_framework/response.py index 5877c8a3..c4b2aaa6 100644 --- a/rest_framework/response.py +++ b/rest_framework/response.py @@ -50,7 +50,7 @@ class Response(SimpleTemplateResponse): charset = renderer.charset content_type = self.content_type - if content_type is None and charset is not None: + if content_type is None and charset is not None and ';' not in media_type: content_type = "{0}; charset={1}".format(media_type, charset) elif content_type is None: content_type = media_type diff --git a/rest_framework/test.py b/rest_framework/test.py new file mode 100644 index 00000000..92281caf --- /dev/null +++ b/rest_framework/test.py @@ -0,0 +1,48 @@ +from rest_framework.compat import six, RequestFactory +from rest_framework.renderers import JSONRenderer, MultiPartRenderer + + +class APIRequestFactory(RequestFactory): + renderer_classes = { + 'json': JSONRenderer, + 'form': MultiPartRenderer + } + default_format = 'form' + + def __init__(self, format=None, **defaults): + self.format = format or self.default_format + super(APIRequestFactory, self).__init__(**defaults) + + def _encode_data(self, data, format, content_type): + if not data: + return ('', None) + + format = format or self.format + + if content_type is None and data is not None: + renderer = self.renderer_classes[format]() + data = renderer.render(data) + # Determine the content-type header + if ';' in renderer.media_type: + content_type = renderer.media_type + else: + content_type = "{0}; charset={1}".format( + renderer.media_type, renderer.charset + ) + # Coerce text to bytes if required. + if isinstance(data, six.text_type): + data = bytes(data.encode(renderer.charset)) + + return data, content_type + + def post(self, path, data=None, format=None, content_type=None, **extra): + data, content_type = self._encode_data(data, format, content_type) + return self.generic('POST', path, data, content_type, **extra) + + def put(self, path, data=None, format=None, content_type=None, **extra): + data, content_type = self._encode_data(data, format, content_type) + return self.generic('PUT', path, data, content_type, **extra) + + def patch(self, path, data=None, format=None, content_type=None, **extra): + data, content_type = self._encode_data(data, format, content_type) + return self.generic('PATCH', path, data, content_type, **extra) diff --git a/rest_framework/tests/test_authentication.py b/rest_framework/tests/test_authentication.py index 6a50be06..f2c51c68 100644 --- a/rest_framework/tests/test_authentication.py +++ b/rest_framework/tests/test_authentication.py @@ -21,14 +21,14 @@ from rest_framework.authtoken.models import Token from rest_framework.compat import patterns, url, include from rest_framework.compat import oauth2_provider, oauth2_provider_models, oauth2_provider_scope from rest_framework.compat import oauth, oauth_provider -from rest_framework.tests.utils import RequestFactory +from rest_framework.test import APIRequestFactory from rest_framework.views import APIView -import json import base64 import time import datetime +import json -factory = RequestFactory() +factory = APIRequestFactory() class MockView(APIView): diff --git a/rest_framework/tests/test_decorators.py b/rest_framework/tests/test_decorators.py index 1016fed3..195f0ba3 100644 --- a/rest_framework/tests/test_decorators.py +++ b/rest_framework/tests/test_decorators.py @@ -1,12 +1,13 @@ from __future__ import unicode_literals from django.test import TestCase from rest_framework import status +from rest_framework.authentication import BasicAuthentication +from rest_framework.parsers import JSONParser +from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response from rest_framework.renderers import JSONRenderer -from rest_framework.parsers import JSONParser -from rest_framework.authentication import BasicAuthentication +from rest_framework.test import APIRequestFactory from rest_framework.throttling import UserRateThrottle -from rest_framework.permissions import IsAuthenticated from rest_framework.views import APIView from rest_framework.decorators import ( api_view, @@ -17,13 +18,11 @@ from rest_framework.decorators import ( permission_classes, ) -from rest_framework.tests.utils import RequestFactory - class DecoratorTestCase(TestCase): def setUp(self): - self.factory = RequestFactory() + self.factory = APIRequestFactory() def _finalize_response(self, request, response, *args, **kwargs): response.request = request diff --git a/rest_framework/tests/test_filters.py b/rest_framework/tests/test_filters.py index aaed6247..c9d9e7ff 100644 --- a/rest_framework/tests/test_filters.py +++ b/rest_framework/tests/test_filters.py @@ -4,13 +4,13 @@ from decimal import Decimal from django.db import models from django.core.urlresolvers import reverse from django.test import TestCase -from django.test.client import RequestFactory from django.utils import unittest from rest_framework import generics, serializers, status, filters from rest_framework.compat import django_filters, patterns, url +from rest_framework.test import APIRequestFactory from rest_framework.tests.models import BasicModel -factory = RequestFactory() +factory = APIRequestFactory() class FilterableItem(models.Model): diff --git a/rest_framework/tests/test_generics.py b/rest_framework/tests/test_generics.py index 37734195..1550880b 100644 --- a/rest_framework/tests/test_generics.py +++ b/rest_framework/tests/test_generics.py @@ -3,12 +3,11 @@ from django.db import models from django.shortcuts import get_object_or_404 from django.test import TestCase from rest_framework import generics, renderers, serializers, status -from rest_framework.tests.utils import RequestFactory +from rest_framework.test import APIRequestFactory from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel from rest_framework.compat import six -import json -factory = RequestFactory() +factory = APIRequestFactory() class RootView(generics.ListCreateAPIView): @@ -71,9 +70,8 @@ class TestRootView(TestCase): """ POST requests to ListCreateAPIView should create a new object. """ - content = {'text': 'foobar'} - request = factory.post('/', json.dumps(content), - content_type='application/json') + data = {'text': 'foobar'} + request = factory.post('/', data, format='json') with self.assertNumQueries(1): response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_201_CREATED) @@ -85,9 +83,8 @@ class TestRootView(TestCase): """ PUT requests to ListCreateAPIView should not be allowed """ - content = {'text': 'foobar'} - request = factory.put('/', json.dumps(content), - content_type='application/json') + data = {'text': 'foobar'} + request = factory.put('/', data, format='json') with self.assertNumQueries(0): response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) @@ -148,9 +145,8 @@ class TestRootView(TestCase): """ POST requests to create a new object should not be able to set the id. """ - content = {'id': 999, 'text': 'foobar'} - request = factory.post('/', json.dumps(content), - content_type='application/json') + data = {'id': 999, 'text': 'foobar'} + request = factory.post('/', data, format='json') with self.assertNumQueries(1): response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_201_CREATED) @@ -189,9 +185,8 @@ class TestInstanceView(TestCase): """ POST requests to RetrieveUpdateDestroyAPIView should not be allowed """ - content = {'text': 'foobar'} - request = factory.post('/', json.dumps(content), - content_type='application/json') + data = {'text': 'foobar'} + request = factory.post('/', data, format='json') with self.assertNumQueries(0): response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) @@ -201,9 +196,8 @@ class TestInstanceView(TestCase): """ PUT requests to RetrieveUpdateDestroyAPIView should update an object. """ - content = {'text': 'foobar'} - request = factory.put('/1', json.dumps(content), - content_type='application/json') + data = {'text': 'foobar'} + request = factory.put('/1', data, format='json') with self.assertNumQueries(2): response = self.view(request, pk='1').render() self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -215,9 +209,8 @@ class TestInstanceView(TestCase): """ PATCH requests to RetrieveUpdateDestroyAPIView should update an object. """ - content = {'text': 'foobar'} - request = factory.patch('/1', json.dumps(content), - content_type='application/json') + data = {'text': 'foobar'} + request = factory.patch('/1', data, format='json') with self.assertNumQueries(2): response = self.view(request, pk=1).render() @@ -293,9 +286,8 @@ class TestInstanceView(TestCase): """ PUT requests to create a new object should not be able to set the id. """ - content = {'id': 999, 'text': 'foobar'} - request = factory.put('/1', json.dumps(content), - content_type='application/json') + data = {'id': 999, 'text': 'foobar'} + request = factory.put('/1', data, format='json') with self.assertNumQueries(2): response = self.view(request, pk=1).render() self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -309,9 +301,8 @@ class TestInstanceView(TestCase): if it does not currently exist. """ self.objects.get(id=1).delete() - content = {'text': 'foobar'} - request = factory.put('/1', json.dumps(content), - content_type='application/json') + data = {'text': 'foobar'} + request = factory.put('/1', data, format='json') with self.assertNumQueries(3): response = self.view(request, pk=1).render() self.assertEqual(response.status_code, status.HTTP_201_CREATED) @@ -324,10 +315,9 @@ class TestInstanceView(TestCase): PUT requests to RetrieveUpdateDestroyAPIView should create an object at the requested url if it doesn't exist. """ - content = {'text': 'foobar'} + data = {'text': 'foobar'} # pk fields can not be created on demand, only the database can set the pk for a new object - request = factory.put('/5', json.dumps(content), - content_type='application/json') + request = factory.put('/5', data, format='json') with self.assertNumQueries(3): response = self.view(request, pk=5).render() self.assertEqual(response.status_code, status.HTTP_201_CREATED) @@ -339,9 +329,8 @@ class TestInstanceView(TestCase): PUT requests to RetrieveUpdateDestroyAPIView should create an object at the requested url if possible, else return HTTP_403_FORBIDDEN error-response. """ - content = {'text': 'foobar'} - request = factory.put('/test_slug', json.dumps(content), - content_type='application/json') + data = {'text': 'foobar'} + request = factory.put('/test_slug', data, format='json') with self.assertNumQueries(2): response = self.slug_based_view(request, slug='test_slug').render() self.assertEqual(response.status_code, status.HTTP_201_CREATED) @@ -415,9 +404,8 @@ class TestCreateModelWithAutoNowAddField(TestCase): https://github.com/tomchristie/django-rest-framework/issues/285 """ - content = {'email': 'foobar@example.com', 'content': 'foobar'} - request = factory.post('/', json.dumps(content), - content_type='application/json') + data = {'email': 'foobar@example.com', 'content': 'foobar'} + request = factory.post('/', data, format='json') response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_201_CREATED) created = self.objects.get(id=1) diff --git a/rest_framework/tests/test_hyperlinkedserializers.py b/rest_framework/tests/test_hyperlinkedserializers.py index 129600cb..61e613d7 100644 --- a/rest_framework/tests/test_hyperlinkedserializers.py +++ b/rest_framework/tests/test_hyperlinkedserializers.py @@ -1,12 +1,15 @@ from __future__ import unicode_literals import json from django.test import TestCase -from django.test.client import RequestFactory from rest_framework import generics, status, serializers from rest_framework.compat import patterns, url -from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, Album, Photo, OptionalRelationModel +from rest_framework.test import APIRequestFactory +from rest_framework.tests.models import ( + Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, + Album, Photo, OptionalRelationModel +) -factory = RequestFactory() +factory = APIRequestFactory() class BlogPostCommentSerializer(serializers.ModelSerializer): @@ -21,7 +24,7 @@ class BlogPostCommentSerializer(serializers.ModelSerializer): class PhotoSerializer(serializers.Serializer): description = serializers.CharField() - album_url = serializers.HyperlinkedRelatedField(source='album', view_name='album-detail', queryset=Album.objects.all(), slug_field='title', slug_url_kwarg='title') + album_url = serializers.HyperlinkedRelatedField(source='album', view_name='album-detail', queryset=Album.objects.all(), lookup_field='title', slug_url_kwarg='title') def restore_object(self, attrs, instance=None): return Photo(**attrs) diff --git a/rest_framework/tests/test_negotiation.py b/rest_framework/tests/test_negotiation.py index 7f84827f..04b89eb6 100644 --- a/rest_framework/tests/test_negotiation.py +++ b/rest_framework/tests/test_negotiation.py @@ -1,12 +1,12 @@ from __future__ import unicode_literals from django.test import TestCase -from django.test.client import RequestFactory from rest_framework.negotiation import DefaultContentNegotiation from rest_framework.request import Request from rest_framework.renderers import BaseRenderer +from rest_framework.test import APIRequestFactory -factory = RequestFactory() +factory = APIRequestFactory() class MockJSONRenderer(BaseRenderer): diff --git a/rest_framework/tests/test_pagination.py b/rest_framework/tests/test_pagination.py index e538a78e..85d4640e 100644 --- a/rest_framework/tests/test_pagination.py +++ b/rest_framework/tests/test_pagination.py @@ -4,13 +4,13 @@ from decimal import Decimal from django.db import models from django.core.paginator import Paginator from django.test import TestCase -from django.test.client import RequestFactory from django.utils import unittest from rest_framework import generics, status, pagination, filters, serializers from rest_framework.compat import django_filters +from rest_framework.test import APIRequestFactory from rest_framework.tests.models import BasicModel -factory = RequestFactory() +factory = APIRequestFactory() class FilterableItem(models.Model): @@ -369,7 +369,7 @@ class TestCustomPaginationSerializer(TestCase): self.page = paginator.page(1) def test_custom_pagination_serializer(self): - request = RequestFactory().get('/foobar') + request = APIRequestFactory().get('/foobar') serializer = CustomPaginationSerializer( instance=self.page, context={'request': request} diff --git a/rest_framework/tests/test_permissions.py b/rest_framework/tests/test_permissions.py index 6caaf65b..e2cca380 100644 --- a/rest_framework/tests/test_permissions.py +++ b/rest_framework/tests/test_permissions.py @@ -3,11 +3,10 @@ from django.contrib.auth.models import User, Permission from django.db import models from django.test import TestCase from rest_framework import generics, status, permissions, authentication, HTTP_HEADER_ENCODING -from rest_framework.tests.utils import RequestFactory +from rest_framework.test import APIRequestFactory import base64 -import json -factory = RequestFactory() +factory = APIRequestFactory() class BasicModel(models.Model): @@ -56,15 +55,13 @@ class ModelPermissionsIntegrationTests(TestCase): BasicModel(text='foo').save() def test_has_create_permissions(self): - request = factory.post('/', json.dumps({'text': 'foobar'}), - content_type='application/json', + request = factory.post('/', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.permitted_credentials) response = root_view(request, pk=1) self.assertEqual(response.status_code, status.HTTP_201_CREATED) def test_has_put_permissions(self): - request = factory.put('/1', json.dumps({'text': 'foobar'}), - content_type='application/json', + request = factory.put('/1', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.permitted_credentials) response = instance_view(request, pk='1') self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -75,15 +72,13 @@ class ModelPermissionsIntegrationTests(TestCase): self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) def test_does_not_have_create_permissions(self): - request = factory.post('/', json.dumps({'text': 'foobar'}), - content_type='application/json', + request = factory.post('/', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.disallowed_credentials) response = root_view(request, pk=1) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) def test_does_not_have_put_permissions(self): - request = factory.put('/1', json.dumps({'text': 'foobar'}), - content_type='application/json', + request = factory.put('/1', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.disallowed_credentials) response = instance_view(request, pk='1') self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) @@ -95,28 +90,26 @@ class ModelPermissionsIntegrationTests(TestCase): def test_has_put_as_create_permissions(self): # User only has update permissions - should be able to update an entity. - request = factory.put('/1', json.dumps({'text': 'foobar'}), - content_type='application/json', + request = factory.put('/1', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.updateonly_credentials) response = instance_view(request, pk='1') self.assertEqual(response.status_code, status.HTTP_200_OK) # But if PUTing to a new entity, permission should be denied. - request = factory.put('/2', json.dumps({'text': 'foobar'}), - content_type='application/json', + request = factory.put('/2', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.updateonly_credentials) response = instance_view(request, pk='2') self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) def test_options_permitted(self): - request = factory.options('/', content_type='application/json', + request = factory.options('/', HTTP_AUTHORIZATION=self.permitted_credentials) response = root_view(request, pk='1') self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertIn('actions', response.data) self.assertEqual(list(response.data['actions'].keys()), ['POST']) - request = factory.options('/1', content_type='application/json', + request = factory.options('/1', HTTP_AUTHORIZATION=self.permitted_credentials) response = instance_view(request, pk='1') self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -124,26 +117,26 @@ class ModelPermissionsIntegrationTests(TestCase): self.assertEqual(list(response.data['actions'].keys()), ['PUT']) def test_options_disallowed(self): - request = factory.options('/', content_type='application/json', + request = factory.options('/', HTTP_AUTHORIZATION=self.disallowed_credentials) response = root_view(request, pk='1') self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertNotIn('actions', response.data) - request = factory.options('/1', content_type='application/json', + request = factory.options('/1', HTTP_AUTHORIZATION=self.disallowed_credentials) response = instance_view(request, pk='1') self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertNotIn('actions', response.data) def test_options_updateonly(self): - request = factory.options('/', content_type='application/json', + request = factory.options('/', HTTP_AUTHORIZATION=self.updateonly_credentials) response = root_view(request, pk='1') self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertNotIn('actions', response.data) - request = factory.options('/1', content_type='application/json', + request = factory.options('/1', HTTP_AUTHORIZATION=self.updateonly_credentials) response = instance_view(request, pk='1') self.assertEqual(response.status_code, status.HTTP_200_OK) diff --git a/rest_framework/tests/test_relations_hyperlink.py b/rest_framework/tests/test_relations_hyperlink.py index 2ca7f4f2..3c4d39af 100644 --- a/rest_framework/tests/test_relations_hyperlink.py +++ b/rest_framework/tests/test_relations_hyperlink.py @@ -1,15 +1,15 @@ from __future__ import unicode_literals from django.test import TestCase -from django.test.client import RequestFactory from rest_framework import serializers from rest_framework.compat import patterns, url +from rest_framework.test import APIRequestFactory from rest_framework.tests.models import ( BlogPost, ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource ) -factory = RequestFactory() +factory = APIRequestFactory() request = factory.get('/') # Just to ensure we have a request in the serializer context diff --git a/rest_framework/tests/test_renderers.py b/rest_framework/tests/test_renderers.py index 95b59741..df6f4aa6 100644 --- a/rest_framework/tests/test_renderers.py +++ b/rest_framework/tests/test_renderers.py @@ -4,19 +4,17 @@ from __future__ import unicode_literals from decimal import Decimal from django.core.cache import cache from django.test import TestCase -from django.test.client import RequestFactory from django.utils import unittest from django.utils.translation import ugettext_lazy as _ from rest_framework import status, permissions -from rest_framework.compat import yaml, etree, patterns, url, include +from rest_framework.compat import yaml, etree, patterns, url, include, six, StringIO from rest_framework.response import Response from rest_framework.views import APIView from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \ XMLRenderer, JSONPRenderer, BrowsableAPIRenderer, UnicodeJSONRenderer from rest_framework.parsers import YAMLParser, XMLParser from rest_framework.settings import api_settings -from rest_framework.compat import StringIO -from rest_framework.compat import six +from rest_framework.test import APIRequestFactory import datetime import pickle import re @@ -121,7 +119,7 @@ class POSTDeniedView(APIView): class DocumentingRendererTests(TestCase): def test_only_permitted_forms_are_displayed(self): view = POSTDeniedView.as_view() - request = RequestFactory().get('/') + request = APIRequestFactory().get('/') response = view(request).render() self.assertNotContains(response, '>POST<') self.assertContains(response, '>PUT<') diff --git a/rest_framework/tests/test_request.py b/rest_framework/tests/test_request.py index a5c5e84c..8d64d79f 100644 --- a/rest_framework/tests/test_request.py +++ b/rest_framework/tests/test_request.py @@ -6,7 +6,6 @@ from django.contrib.auth.models import User from django.contrib.auth import authenticate, login, logout from django.contrib.sessions.middleware import SessionMiddleware from django.test import TestCase, Client -from django.test.client import RequestFactory from rest_framework import status from rest_framework.authentication import SessionAuthentication from rest_framework.compat import patterns @@ -19,12 +18,13 @@ from rest_framework.parsers import ( from rest_framework.request import Request from rest_framework.response import Response from rest_framework.settings import api_settings +from rest_framework.test import APIRequestFactory from rest_framework.views import APIView from rest_framework.compat import six import json -factory = RequestFactory() +factory = APIRequestFactory() class PlainTextParser(BaseParser): @@ -116,16 +116,7 @@ class TestContentParsing(TestCase): Ensure request.DATA returns content for PUT request with form content. """ data = {'qwerty': 'uiop'} - - from django import VERSION - - if VERSION >= (1, 5): - from django.test.client import MULTIPART_CONTENT, BOUNDARY, encode_multipart - request = Request(factory.put('/', encode_multipart(BOUNDARY, data), - content_type=MULTIPART_CONTENT)) - else: - request = Request(factory.put('/', data)) - + request = Request(factory.put('/', data)) request.parsers = (FormParser(), MultiPartParser()) self.assertEqual(list(request.DATA.items()), list(data.items())) diff --git a/rest_framework/tests/test_reverse.py b/rest_framework/tests/test_reverse.py index 93ef5637..690a30b1 100644 --- a/rest_framework/tests/test_reverse.py +++ b/rest_framework/tests/test_reverse.py @@ -1,10 +1,10 @@ from __future__ import unicode_literals from django.test import TestCase -from django.test.client import RequestFactory from rest_framework.compat import patterns, url from rest_framework.reverse import reverse +from rest_framework.test import APIRequestFactory -factory = RequestFactory() +factory = APIRequestFactory() def null_view(request): diff --git a/rest_framework/tests/test_routers.py b/rest_framework/tests/test_routers.py index d375f4a8..5fcccb74 100644 --- a/rest_framework/tests/test_routers.py +++ b/rest_framework/tests/test_routers.py @@ -1,15 +1,15 @@ from __future__ import unicode_literals from django.db import models from django.test import TestCase -from django.test.client import RequestFactory from django.core.exceptions import ImproperlyConfigured from rest_framework import serializers, viewsets, permissions from rest_framework.compat import include, patterns, url from rest_framework.decorators import link, action from rest_framework.response import Response from rest_framework.routers import SimpleRouter, DefaultRouter +from rest_framework.test import APIRequestFactory -factory = RequestFactory() +factory = APIRequestFactory() urlpatterns = patterns('',) @@ -193,6 +193,7 @@ class TestActionKeywordArgs(TestCase): {'permission_classes': [permissions.AllowAny]} ) + class TestActionAppliedToExistingRoute(TestCase): """ Ensure `@action` decorator raises an except when applied diff --git a/rest_framework/tests/test_throttling.py b/rest_framework/tests/test_throttling.py index d35d3709..19bc691a 100644 --- a/rest_framework/tests/test_throttling.py +++ b/rest_framework/tests/test_throttling.py @@ -5,7 +5,7 @@ from __future__ import unicode_literals from django.test import TestCase from django.contrib.auth.models import User from django.core.cache import cache -from django.test.client import RequestFactory +from rest_framework.test import APIRequestFactory from rest_framework.views import APIView from rest_framework.throttling import UserRateThrottle, ScopedRateThrottle from rest_framework.response import Response @@ -41,7 +41,7 @@ class ThrottlingTests(TestCase): Reset the cache so that no throttles will be active """ cache.clear() - self.factory = RequestFactory() + self.factory = APIRequestFactory() def test_requests_are_throttled(self): """ @@ -173,7 +173,7 @@ class ScopedRateThrottleTests(TestCase): return Response('y') self.throttle_class = XYScopedRateThrottle - self.factory = RequestFactory() + self.factory = APIRequestFactory() self.x_view = XView.as_view() self.y_view = YView.as_view() self.unscoped_view = UnscopedView.as_view() diff --git a/rest_framework/tests/test_urlpatterns.py b/rest_framework/tests/test_urlpatterns.py index 29ed4a96..8132ec4c 100644 --- a/rest_framework/tests/test_urlpatterns.py +++ b/rest_framework/tests/test_urlpatterns.py @@ -2,7 +2,7 @@ from __future__ import unicode_literals from collections import namedtuple from django.core import urlresolvers from django.test import TestCase -from django.test.client import RequestFactory +from rest_framework.test import APIRequestFactory from rest_framework.compat import patterns, url, include from rest_framework.urlpatterns import format_suffix_patterns @@ -20,7 +20,7 @@ class FormatSuffixTests(TestCase): Tests `format_suffix_patterns` against different URLPatterns to ensure the URLs still resolve properly, including any captured parameters. """ def _resolve_urlpatterns(self, urlpatterns, test_paths): - factory = RequestFactory() + factory = APIRequestFactory() try: urlpatterns = format_suffix_patterns(urlpatterns) except Exception: diff --git a/rest_framework/tests/test_validation.py b/rest_framework/tests/test_validation.py index a6ec0e99..ebfdff9c 100644 --- a/rest_framework/tests/test_validation.py +++ b/rest_framework/tests/test_validation.py @@ -2,10 +2,9 @@ from __future__ import unicode_literals from django.db import models from django.test import TestCase from rest_framework import generics, serializers, status -from rest_framework.tests.utils import RequestFactory -import json +from rest_framework.test import APIRequestFactory -factory = RequestFactory() +factory = APIRequestFactory() # Regression for #666 @@ -33,8 +32,7 @@ class TestPreSaveValidationExclusions(TestCase): validation on read only fields. """ obj = ValidationModel.objects.create(blank_validated_field='') - request = factory.put('/', json.dumps({}), - content_type='application/json') + request = factory.put('/', {}, format='json') view = UpdateValidationModel().as_view() response = view(request, pk=obj.pk).render() self.assertEqual(response.status_code, status.HTTP_200_OK) diff --git a/rest_framework/tests/test_views.py b/rest_framework/tests/test_views.py index 2767d24c..c0bec5ae 100644 --- a/rest_framework/tests/test_views.py +++ b/rest_framework/tests/test_views.py @@ -1,17 +1,15 @@ from __future__ import unicode_literals import copy - from django.test import TestCase -from django.test.client import RequestFactory - from rest_framework import status from rest_framework.decorators import api_view from rest_framework.response import Response from rest_framework.settings import api_settings +from rest_framework.test import APIRequestFactory from rest_framework.views import APIView -factory = RequestFactory() +factory = APIRequestFactory() class BasicView(APIView): -- cgit v1.2.3 From f585480ee10f4b5e61db4ac343b1d2af25d2de97 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 28 Jun 2013 17:50:30 +0100 Subject: Added APIClient --- rest_framework/test.py | 81 ++++++++++++++++++++++++----- rest_framework/tests/test_authentication.py | 44 ++++++++-------- rest_framework/tests/test_request.py | 6 +-- 3 files changed, 93 insertions(+), 38 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/test.py b/rest_framework/test.py index 92281caf..9fce2c08 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -1,39 +1,54 @@ -from rest_framework.compat import six, RequestFactory +# Note that we use `DjangoRequestFactory` and `DjangoClient` names in order +# to make it harder for the user to import the wrong thing without realizing. +from django.conf import settings +from django.test.client import Client as DjangoClient +from rest_framework.compat import RequestFactory as DjangoRequestFactory +from rest_framework.compat import force_bytes_or_smart_bytes, six from rest_framework.renderers import JSONRenderer, MultiPartRenderer -class APIRequestFactory(RequestFactory): +class APIRequestFactory(DjangoRequestFactory): renderer_classes = { 'json': JSONRenderer, 'form': MultiPartRenderer } default_format = 'form' - def __init__(self, format=None, **defaults): - self.format = format or self.default_format - super(APIRequestFactory, self).__init__(**defaults) + def _encode_data(self, data, format=None, content_type=None): + """ + Encode the data returning a two tuple of (bytes, content_type) + """ - def _encode_data(self, data, format, content_type): if not data: return ('', None) - format = format or self.format + assert format is None or content_type is None, ( + 'You may not set both `format` and `content_type`.' + ) - if content_type is None and data is not None: + if content_type: + # Content type specified explicitly, treat data as a raw bytestring + ret = force_bytes_or_smart_bytes(data, settings.DEFAULT_CHARSET) + + else: + # Use format and render the data into a bytestring + format = format or self.default_format renderer = self.renderer_classes[format]() - data = renderer.render(data) - # Determine the content-type header + ret = renderer.render(data) + + # Determine the content-type header from the renderer if ';' in renderer.media_type: content_type = renderer.media_type else: content_type = "{0}; charset={1}".format( renderer.media_type, renderer.charset ) + # Coerce text to bytes if required. - if isinstance(data, six.text_type): - data = bytes(data.encode(renderer.charset)) + if isinstance(ret, six.text_type): + ret = bytes(ret.encode(renderer.charset)) - return data, content_type + return ret, content_type def post(self, path, data=None, format=None, content_type=None, **extra): data, content_type = self._encode_data(data, format, content_type) @@ -46,3 +61,43 @@ class APIRequestFactory(RequestFactory): def patch(self, path, data=None, format=None, content_type=None, **extra): data, content_type = self._encode_data(data, format, content_type) return self.generic('PATCH', path, data, content_type, **extra) + + def delete(self, path, data=None, format=None, content_type=None, **extra): + data, content_type = self._encode_data(data, format, content_type) + return self.generic('DELETE', path, data, content_type, **extra) + + def options(self, path, data=None, format=None, content_type=None, **extra): + data, content_type = self._encode_data(data, format, content_type) + return self.generic('OPTIONS', path, data, content_type, **extra) + + +class APIClient(APIRequestFactory, DjangoClient): + def post(self, path, data=None, format=None, content_type=None, follow=False, **extra): + response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response + + def put(self, path, data=None, format=None, content_type=None, follow=False, **extra): + response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response + + def patch(self, path, data=None, format=None, content_type=None, follow=False, **extra): + response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response + + def delete(self, path, data=None, format=None, content_type=None, follow=False, **extra): + response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response + + def options(self, path, data=None, format=None, content_type=None, follow=False, **extra): + response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response diff --git a/rest_framework/tests/test_authentication.py b/rest_framework/tests/test_authentication.py index f2c51c68..a44813b6 100644 --- a/rest_framework/tests/test_authentication.py +++ b/rest_framework/tests/test_authentication.py @@ -1,7 +1,7 @@ from __future__ import unicode_literals from django.contrib.auth.models import User from django.http import HttpResponse -from django.test import Client, TestCase +from django.test import TestCase from django.utils import unittest from rest_framework import HTTP_HEADER_ENCODING from rest_framework import exceptions @@ -21,12 +21,11 @@ from rest_framework.authtoken.models import Token from rest_framework.compat import patterns, url, include from rest_framework.compat import oauth2_provider, oauth2_provider_models, oauth2_provider_scope from rest_framework.compat import oauth, oauth_provider -from rest_framework.test import APIRequestFactory +from rest_framework.test import APIRequestFactory, APIClient from rest_framework.views import APIView import base64 import time import datetime -import json factory = APIRequestFactory() @@ -68,7 +67,7 @@ class BasicAuthTests(TestCase): urls = 'rest_framework.tests.test_authentication' def setUp(self): - self.csrf_client = Client(enforce_csrf_checks=True) + self.csrf_client = APIClient(enforce_csrf_checks=True) self.username = 'john' self.email = 'lennon@thebeatles.com' self.password = 'password' @@ -87,7 +86,7 @@ class BasicAuthTests(TestCase): credentials = ('%s:%s' % (self.username, self.password)) base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING) auth = 'Basic %s' % base64_credentials - response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/basic/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, status.HTTP_200_OK) def test_post_form_failing_basic_auth(self): @@ -97,7 +96,7 @@ class BasicAuthTests(TestCase): def test_post_json_failing_basic_auth(self): """Ensure POSTing json over basic auth without correct credentials fails""" - response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json') + response = self.csrf_client.post('/basic/', {'example': 'example'}, format='json') self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) self.assertEqual(response['WWW-Authenticate'], 'Basic realm="api"') @@ -107,8 +106,8 @@ class SessionAuthTests(TestCase): urls = 'rest_framework.tests.test_authentication' def setUp(self): - self.csrf_client = Client(enforce_csrf_checks=True) - self.non_csrf_client = Client(enforce_csrf_checks=False) + self.csrf_client = APIClient(enforce_csrf_checks=True) + self.non_csrf_client = APIClient(enforce_csrf_checks=False) self.username = 'john' self.email = 'lennon@thebeatles.com' self.password = 'password' @@ -154,7 +153,7 @@ class TokenAuthTests(TestCase): urls = 'rest_framework.tests.test_authentication' def setUp(self): - self.csrf_client = Client(enforce_csrf_checks=True) + self.csrf_client = APIClient(enforce_csrf_checks=True) self.username = 'john' self.email = 'lennon@thebeatles.com' self.password = 'password' @@ -172,7 +171,7 @@ class TokenAuthTests(TestCase): def test_post_json_passing_token_auth(self): """Ensure POSTing form over token auth with correct credentials passes and does not require CSRF""" auth = "Token " + self.key - response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/token/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, status.HTTP_200_OK) def test_post_form_failing_token_auth(self): @@ -182,7 +181,7 @@ class TokenAuthTests(TestCase): def test_post_json_failing_token_auth(self): """Ensure POSTing json over token auth without correct credentials fails""" - response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json') + response = self.csrf_client.post('/token/', {'example': 'example'}, format='json') self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) def test_token_has_auto_assigned_key_if_none_provided(self): @@ -193,33 +192,33 @@ class TokenAuthTests(TestCase): def test_token_login_json(self): """Ensure token login view using JSON POST works.""" - client = Client(enforce_csrf_checks=True) + client = APIClient(enforce_csrf_checks=True) response = client.post('/auth-token/', - json.dumps({'username': self.username, 'password': self.password}), 'application/json') + {'username': self.username, 'password': self.password}, format='json') self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(json.loads(response.content.decode('ascii'))['token'], self.key) + self.assertEqual(response.data['token'], self.key) def test_token_login_json_bad_creds(self): """Ensure token login view using JSON POST fails if bad credentials are used.""" - client = Client(enforce_csrf_checks=True) + client = APIClient(enforce_csrf_checks=True) response = client.post('/auth-token/', - json.dumps({'username': self.username, 'password': "badpass"}), 'application/json') + {'username': self.username, 'password': "badpass"}, format='json') self.assertEqual(response.status_code, 400) def test_token_login_json_missing_fields(self): """Ensure token login view using JSON POST fails if missing fields.""" - client = Client(enforce_csrf_checks=True) + client = APIClient(enforce_csrf_checks=True) response = client.post('/auth-token/', - json.dumps({'username': self.username}), 'application/json') + {'username': self.username}, format='json') self.assertEqual(response.status_code, 400) def test_token_login_form(self): """Ensure token login view using form POST works.""" - client = Client(enforce_csrf_checks=True) + client = APIClient(enforce_csrf_checks=True) response = client.post('/auth-token/', {'username': self.username, 'password': self.password}) self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(json.loads(response.content.decode('ascii'))['token'], self.key) + self.assertEqual(response.data['token'], self.key) class IncorrectCredentialsTests(TestCase): @@ -256,7 +255,7 @@ class OAuthTests(TestCase): self.consts = consts - self.csrf_client = Client(enforce_csrf_checks=True) + self.csrf_client = APIClient(enforce_csrf_checks=True) self.username = 'john' self.email = 'lennon@thebeatles.com' self.password = 'password' @@ -470,12 +469,13 @@ class OAuthTests(TestCase): response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 401) + class OAuth2Tests(TestCase): """OAuth 2.0 authentication""" urls = 'rest_framework.tests.test_authentication' def setUp(self): - self.csrf_client = Client(enforce_csrf_checks=True) + self.csrf_client = APIClient(enforce_csrf_checks=True) self.username = 'john' self.email = 'lennon@thebeatles.com' self.password = 'password' diff --git a/rest_framework/tests/test_request.py b/rest_framework/tests/test_request.py index 8d64d79f..969d8024 100644 --- a/rest_framework/tests/test_request.py +++ b/rest_framework/tests/test_request.py @@ -5,7 +5,7 @@ from __future__ import unicode_literals from django.contrib.auth.models import User from django.contrib.auth import authenticate, login, logout from django.contrib.sessions.middleware import SessionMiddleware -from django.test import TestCase, Client +from django.test import TestCase from rest_framework import status from rest_framework.authentication import SessionAuthentication from rest_framework.compat import patterns @@ -18,7 +18,7 @@ from rest_framework.parsers import ( from rest_framework.request import Request from rest_framework.response import Response from rest_framework.settings import api_settings -from rest_framework.test import APIRequestFactory +from rest_framework.test import APIRequestFactory, APIClient from rest_framework.views import APIView from rest_framework.compat import six import json @@ -248,7 +248,7 @@ class TestContentParsingWithAuthentication(TestCase): urls = 'rest_framework.tests.test_request' def setUp(self): - self.csrf_client = Client(enforce_csrf_checks=True) + self.csrf_client = APIClient(enforce_csrf_checks=True) self.username = 'john' self.email = 'lennon@thebeatles.com' self.password = 'password' -- cgit v1.2.3 From 90bc07f3f160485001ea329e5f69f7e521d14ec9 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Sat, 29 Jun 2013 08:05:08 +0100 Subject: Addeded 'APITestClient.credentials()' --- rest_framework/test.py | 29 +++++++++++++++++++++++++++++ rest_framework/tests/test_testing.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+) create mode 100644 rest_framework/tests/test_testing.py (limited to 'rest_framework') diff --git a/rest_framework/test.py b/rest_framework/test.py index 9fce2c08..8115fa0d 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -1,5 +1,8 @@ +# -- coding: utf-8 -- + # Note that we use `DjangoRequestFactory` and `DjangoClient` names in order # to make it harder for the user to import the wrong thing without realizing. +from __future__ import unicode_literals from django.conf import settings from django.test.client import Client as DjangoClient from rest_framework.compat import RequestFactory as DjangoRequestFactory @@ -72,31 +75,57 @@ class APIRequestFactory(DjangoRequestFactory): class APIClient(APIRequestFactory, DjangoClient): + def __init__(self, *args, **kwargs): + self._credentials = {} + super(APIClient, self).__init__(*args, **kwargs) + + def credentials(self, **kwargs): + self._credentials = kwargs + + def get(self, path, data={}, follow=False, **extra): + extra.update(self._credentials) + response = super(APIClient, self).get(path, data=data, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response + + def head(self, path, data={}, follow=False, **extra): + extra.update(self._credentials) + response = super(APIClient, self).head(path, data=data, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response + def post(self, path, data=None, format=None, content_type=None, follow=False, **extra): + extra.update(self._credentials) response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) if follow: response = self._handle_redirects(response, **extra) return response def put(self, path, data=None, format=None, content_type=None, follow=False, **extra): + extra.update(self._credentials) response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) if follow: response = self._handle_redirects(response, **extra) return response def patch(self, path, data=None, format=None, content_type=None, follow=False, **extra): + extra.update(self._credentials) response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) if follow: response = self._handle_redirects(response, **extra) return response def delete(self, path, data=None, format=None, content_type=None, follow=False, **extra): + extra.update(self._credentials) response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) if follow: response = self._handle_redirects(response, **extra) return response def options(self, path, data=None, format=None, content_type=None, follow=False, **extra): + extra.update(self._credentials) response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) if follow: response = self._handle_redirects(response, **extra) diff --git a/rest_framework/tests/test_testing.py b/rest_framework/tests/test_testing.py new file mode 100644 index 00000000..71dacd38 --- /dev/null +++ b/rest_framework/tests/test_testing.py @@ -0,0 +1,32 @@ +# -- coding: utf-8 -- + +from __future__ import unicode_literals +from django.test import TestCase +from rest_framework.compat import patterns, url +from rest_framework.decorators import api_view +from rest_framework.response import Response +from rest_framework.test import APIClient + + +@api_view(['GET']) +def mirror(request): + return Response({ + 'auth': request.META.get('HTTP_AUTHORIZATION', b'') + }) + + +urlpatterns = patterns('', + url(r'^view/$', mirror), +) + + +class CheckTestClient(TestCase): + urls = 'rest_framework.tests.test_testing' + + def setUp(self): + self.client = APIClient() + + def test_credentials(self): + self.client.credentials(HTTP_AUTHORIZATION='example') + response = self.client.get('/view/') + self.assertEqual(response.data['auth'], 'example') -- cgit v1.2.3 From f7db06953bd8ad7f5e0211f49a04e8d5bb634380 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Sat, 29 Jun 2013 08:06:11 +0100 Subject: Remove unneeded tests.utils, superseeded by APIRequestFactory, APIClient --- rest_framework/tests/utils.py | 40 ---------------------------------------- 1 file changed, 40 deletions(-) delete mode 100644 rest_framework/tests/utils.py (limited to 'rest_framework') diff --git a/rest_framework/tests/utils.py b/rest_framework/tests/utils.py deleted file mode 100644 index 8c87917d..00000000 --- a/rest_framework/tests/utils.py +++ /dev/null @@ -1,40 +0,0 @@ -from __future__ import unicode_literals -from django.test.client import FakePayload, Client as _Client, RequestFactory as _RequestFactory -from django.test.client import MULTIPART_CONTENT -from rest_framework.compat import urlparse - - -class RequestFactory(_RequestFactory): - - def __init__(self, **defaults): - super(RequestFactory, self).__init__(**defaults) - - def patch(self, path, data={}, content_type=MULTIPART_CONTENT, - **extra): - "Construct a PATCH request." - - patch_data = self._encode_data(data, content_type) - - parsed = urlparse.urlparse(path) - r = { - 'CONTENT_LENGTH': len(patch_data), - 'CONTENT_TYPE': content_type, - 'PATH_INFO': self._get_path(parsed), - 'QUERY_STRING': parsed[4], - 'REQUEST_METHOD': 'PATCH', - 'wsgi.input': FakePayload(patch_data), - } - r.update(extra) - return self.request(**r) - - -class Client(_Client, RequestFactory): - def patch(self, path, data={}, content_type=MULTIPART_CONTENT, - follow=False, **extra): - """ - Send a resource to the server using PATCH. - """ - response = super(Client, self).patch(path, data=data, content_type=content_type, **extra) - if follow: - response = self._handle_redirects(response, **extra) - return response -- cgit v1.2.3 From 35022ca9213939a2f40c82facffa908a818efe0b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Sat, 29 Jun 2013 08:14:05 +0100 Subject: Refactor SessionAuthentication slightly --- rest_framework/authentication.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index 10298027..b42162dd 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -26,6 +26,12 @@ def get_authorization_header(request): return auth +class CSRFCheck(CsrfViewMiddleware): + def _reject(self, request, reason): + # Return the failure reason instead of an HttpResponse + return reason + + class BaseAuthentication(object): """ All authentication classes should extend BaseAuthentication. @@ -110,20 +116,20 @@ class SessionAuthentication(BaseAuthentication): if not user or not user.is_active: return None - # Enforce CSRF validation for session based authentication. - class CSRFCheck(CsrfViewMiddleware): - def _reject(self, request, reason): - # Return the failure reason instead of an HttpResponse - return reason + self.enforce_csrf(http_request) + + # CSRF passed with authenticated user + return (user, None) - reason = CSRFCheck().process_view(http_request, None, (), {}) + def enforce_csrf(self, request): + """ + Enforce CSRF validation for session based authentication. + """ + reason = CSRFCheck().process_view(request, None, (), {}) if reason: # CSRF failed, bail with explicit error message raise exceptions.AuthenticationFailed('CSRF Failed: %s' % reason) - # CSRF passed with authenticated user - return (user, None) - class TokenAuthentication(BaseAuthentication): """ -- cgit v1.2.3 From 664f8c63655770cd90bdbd510b315bcd045b380a Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Sat, 29 Jun 2013 21:02:58 +0100 Subject: Added APIClient.authenticate() --- rest_framework/renderers.py | 2 +- rest_framework/request.py | 20 +++++++++++++++++ rest_framework/test.py | 39 +++++++++++++++++++++++++++++---- rest_framework/tests/test_testing.py | 42 +++++++++++++++++++++++++++++++++--- 4 files changed, 95 insertions(+), 8 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index d7a7ef29..3a03ca33 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -576,7 +576,7 @@ class BrowsableAPIRenderer(BaseRenderer): class MultiPartRenderer(BaseRenderer): media_type = 'multipart/form-data; boundary=BoUnDaRyStRiNg' - format = 'form' + format = 'multipart' charset = 'utf-8' BOUNDARY = 'BoUnDaRyStRiNg' diff --git a/rest_framework/request.py b/rest_framework/request.py index 0d88ebc7..919716f4 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -64,6 +64,20 @@ def clone_request(request, method): return ret +class ForcedAuthentication(object): + """ + This authentication class is used if the test client or request factory + forcibly authenticated the request. + """ + + def __init__(self, force_user, force_token): + self.force_user = force_user + self.force_token = force_token + + def authenticate(self, request): + return (self.force_user, self.force_token) + + class Request(object): """ Wrapper allowing to enhance a standard `HttpRequest` instance. @@ -98,6 +112,12 @@ class Request(object): self.parser_context['request'] = self self.parser_context['encoding'] = request.encoding or settings.DEFAULT_CHARSET + force_user = getattr(request, '_force_auth_user', None) + force_token = getattr(request, '_force_auth_token', None) + if (force_user is not None or force_token is not None): + forced_auth = ForcedAuthentication(force_user, force_token) + self.authenticators = (forced_auth,) + def _default_negotiator(self): return api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS() diff --git a/rest_framework/test.py b/rest_framework/test.py index 8115fa0d..08de2297 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -5,6 +5,7 @@ from __future__ import unicode_literals from django.conf import settings from django.test.client import Client as DjangoClient +from django.test.client import ClientHandler from rest_framework.compat import RequestFactory as DjangoRequestFactory from rest_framework.compat import force_bytes_or_smart_bytes, six from rest_framework.renderers import JSONRenderer, MultiPartRenderer @@ -13,9 +14,9 @@ from rest_framework.renderers import JSONRenderer, MultiPartRenderer class APIRequestFactory(DjangoRequestFactory): renderer_classes = { 'json': JSONRenderer, - 'form': MultiPartRenderer + 'multipart': MultiPartRenderer } - default_format = 'form' + default_format = 'multipart' def _encode_data(self, data, format=None, content_type=None): """ @@ -74,14 +75,44 @@ class APIRequestFactory(DjangoRequestFactory): return self.generic('OPTIONS', path, data, content_type, **extra) -class APIClient(APIRequestFactory, DjangoClient): +class ForceAuthClientHandler(ClientHandler): + """ + A patched version of ClientHandler that can enforce authentication + on the outgoing requests. + """ + def __init__(self, *args, **kwargs): + self._force_auth_user = None + self._force_auth_token = None + super(ForceAuthClientHandler, self).__init__(*args, **kwargs) + + def force_authenticate(self, user=None, token=None): + self._force_auth_user = user + self._force_auth_token = token + + def get_response(self, request): + # This is the simplest place we can hook into to patch the + # request object. + request._force_auth_user = self._force_auth_user + request._force_auth_token = self._force_auth_token + return super(ForceAuthClientHandler, self).get_response(request) + + +class APIClient(APIRequestFactory, DjangoClient): + def __init__(self, enforce_csrf_checks=False, **defaults): + # Note that our super call skips Client.__init__ + # since we don't need to instantiate a regular ClientHandler + super(DjangoClient, self).__init__(**defaults) + self.handler = ForceAuthClientHandler(enforce_csrf_checks) + self.exc_info = None self._credentials = {} - super(APIClient, self).__init__(*args, **kwargs) def credentials(self, **kwargs): self._credentials = kwargs + def authenticate(self, user=None, token=None): + self.handler.force_authenticate(user, token) + def get(self, path, data={}, follow=False, **extra): extra.update(self._credentials) response = super(APIClient, self).get(path, data=data, **extra) diff --git a/rest_framework/tests/test_testing.py b/rest_framework/tests/test_testing.py index 71dacd38..a8398b9a 100644 --- a/rest_framework/tests/test_testing.py +++ b/rest_framework/tests/test_testing.py @@ -1,6 +1,7 @@ # -- coding: utf-8 -- from __future__ import unicode_literals +from django.contrib.auth.models import User from django.test import TestCase from rest_framework.compat import patterns, url from rest_framework.decorators import api_view @@ -8,10 +9,11 @@ from rest_framework.response import Response from rest_framework.test import APIClient -@api_view(['GET']) +@api_view(['GET', 'POST']) def mirror(request): return Response({ - 'auth': request.META.get('HTTP_AUTHORIZATION', b'') + 'auth': request.META.get('HTTP_AUTHORIZATION', b''), + 'user': request.user.username }) @@ -27,6 +29,40 @@ class CheckTestClient(TestCase): self.client = APIClient() def test_credentials(self): + """ + Setting `.credentials()` adds the required headers to each request. + """ self.client.credentials(HTTP_AUTHORIZATION='example') + for _ in range(0, 3): + response = self.client.get('/view/') + self.assertEqual(response.data['auth'], 'example') + + def test_authenticate(self): + """ + Setting `.authenticate()` forcibly authenticates each request. + """ + user = User.objects.create_user('example', 'example@example.com') + self.client.authenticate(user) response = self.client.get('/view/') - self.assertEqual(response.data['auth'], 'example') + self.assertEqual(response.data['user'], 'example') + + def test_csrf_exempt_by_default(self): + """ + By default, the test client is CSRF exempt. + """ + User.objects.create_user('example', 'example@example.com', 'password') + self.client.login(username='example', password='password') + response = self.client.post('/view/') + self.assertEqual(response.status_code, 200) + + def test_explicitly_enforce_csrf_checks(self): + """ + The test client can enforce CSRF checks. + """ + client = APIClient(enforce_csrf_checks=True) + User.objects.create_user('example', 'example@example.com', 'password') + client.login(username='example', password='password') + response = client.post('/view/') + expected = {'detail': 'CSRF Failed: CSRF cookie not set.'} + self.assertEqual(response.status_code, 403) + self.assertEqual(response.data, expected) -- cgit v1.2.3 From ab799ccc3ee473de61ec35c6f745c6952752c522 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Sat, 29 Jun 2013 21:34:47 +0100 Subject: Simplify APIClient implementation --- rest_framework/authentication.py | 6 ++-- rest_framework/test.py | 68 +++++++++------------------------------- 2 files changed, 17 insertions(+), 57 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index b42162dd..cf001a24 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -109,14 +109,14 @@ class SessionAuthentication(BaseAuthentication): """ # Get the underlying HttpRequest object - http_request = request._request - user = getattr(http_request, 'user', None) + request = request._request + user = getattr(request, 'user', None) # Unauthenticated, CSRF validation not required if not user or not user.is_active: return None - self.enforce_csrf(http_request) + self.enforce_csrf(request) # CSRF passed with authenticated user return (user, None) diff --git a/rest_framework/test.py b/rest_framework/test.py index 08de2297..2e9cfe09 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -86,10 +86,6 @@ class ForceAuthClientHandler(ClientHandler): self._force_auth_token = None super(ForceAuthClientHandler, self).__init__(*args, **kwargs) - def force_authenticate(self, user=None, token=None): - self._force_auth_user = user - self._force_auth_token = token - def get_response(self, request): # This is the simplest place we can hook into to patch the # request object. @@ -108,56 +104,20 @@ class APIClient(APIRequestFactory, DjangoClient): self._credentials = {} def credentials(self, **kwargs): + """ + Sets headers that will be used on every outgoing request. + """ self._credentials = kwargs def authenticate(self, user=None, token=None): - self.handler.force_authenticate(user, token) - - def get(self, path, data={}, follow=False, **extra): - extra.update(self._credentials) - response = super(APIClient, self).get(path, data=data, **extra) - if follow: - response = self._handle_redirects(response, **extra) - return response - - def head(self, path, data={}, follow=False, **extra): - extra.update(self._credentials) - response = super(APIClient, self).head(path, data=data, **extra) - if follow: - response = self._handle_redirects(response, **extra) - return response - - def post(self, path, data=None, format=None, content_type=None, follow=False, **extra): - extra.update(self._credentials) - response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) - if follow: - response = self._handle_redirects(response, **extra) - return response - - def put(self, path, data=None, format=None, content_type=None, follow=False, **extra): - extra.update(self._credentials) - response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) - if follow: - response = self._handle_redirects(response, **extra) - return response - - def patch(self, path, data=None, format=None, content_type=None, follow=False, **extra): - extra.update(self._credentials) - response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) - if follow: - response = self._handle_redirects(response, **extra) - return response - - def delete(self, path, data=None, format=None, content_type=None, follow=False, **extra): - extra.update(self._credentials) - response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) - if follow: - response = self._handle_redirects(response, **extra) - return response - - def options(self, path, data=None, format=None, content_type=None, follow=False, **extra): - extra.update(self._credentials) - response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) - if follow: - response = self._handle_redirects(response, **extra) - return response + """ + Forcibly authenticates outgoing requests with the given + user and/or token. + """ + self.handler._force_auth_user = user + self.handler._force_auth_token = token + + def request(self, **request): + # Ensure that any credentials set get added to every request. + request.update(self._credentials) + return super(APIClient, self).request(**request) -- cgit v1.2.3 From c9485c783a555516e41068996258f4c5e383523b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Sat, 29 Jun 2013 22:53:15 +0100 Subject: Rename to force_authenticate --- rest_framework/test.py | 2 +- rest_framework/tests/test_testing.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/test.py b/rest_framework/test.py index 2e9cfe09..2f658a56 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -109,7 +109,7 @@ class APIClient(APIRequestFactory, DjangoClient): """ self._credentials = kwargs - def authenticate(self, user=None, token=None): + def force_authenticate(self, user=None, token=None): """ Forcibly authenticates outgoing requests with the given user and/or token. diff --git a/rest_framework/tests/test_testing.py b/rest_framework/tests/test_testing.py index a8398b9a..3706f38c 100644 --- a/rest_framework/tests/test_testing.py +++ b/rest_framework/tests/test_testing.py @@ -37,12 +37,12 @@ class CheckTestClient(TestCase): response = self.client.get('/view/') self.assertEqual(response.data['auth'], 'example') - def test_authenticate(self): + def test_force_authenticate(self): """ - Setting `.authenticate()` forcibly authenticates each request. + Setting `.force_authenticate()` forcibly authenticates each request. """ user = User.objects.create_user('example', 'example@example.com') - self.client.authenticate(user) + self.client.force_authenticate(user) response = self.client.get('/view/') self.assertEqual(response.data['user'], 'example') -- cgit v1.2.3 From 0a722de171b0e80ac26d8c77b8051a4170bdb4c6 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 1 Jul 2013 13:59:05 +0100 Subject: Complete testing docs --- rest_framework/response.py | 2 +- rest_framework/settings.py | 8 +++++ rest_framework/test.py | 70 ++++++++++++++++++++++-------------- rest_framework/tests/test_testing.py | 55 +++++++++++++++++++++++++--- 4 files changed, 103 insertions(+), 32 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/response.py b/rest_framework/response.py index c4b2aaa6..5877c8a3 100644 --- a/rest_framework/response.py +++ b/rest_framework/response.py @@ -50,7 +50,7 @@ class Response(SimpleTemplateResponse): charset = renderer.charset content_type = self.content_type - if content_type is None and charset is not None and ';' not in media_type: + if content_type is None and charset is not None: content_type = "{0}; charset={1}".format(media_type, charset) elif content_type is None: content_type = media_type diff --git a/rest_framework/settings.py b/rest_framework/settings.py index beb511ac..8fd177d5 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -73,6 +73,13 @@ DEFAULTS = { 'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser', 'UNAUTHENTICATED_TOKEN': None, + # Testing + 'TEST_REQUEST_RENDERER_CLASSES': ( + 'rest_framework.renderers.MultiPartRenderer', + 'rest_framework.renderers.JSONRenderer' + ), + 'TEST_REQUEST_DEFAULT_FORMAT': 'multipart', + # Browser enhancements 'FORM_METHOD_OVERRIDE': '_method', 'FORM_CONTENT_OVERRIDE': '_content', @@ -115,6 +122,7 @@ IMPORT_STRINGS = ( 'DEFAULT_PAGINATION_SERIALIZER_CLASS', 'DEFAULT_FILTER_BACKENDS', 'FILTER_BACKEND', + 'TEST_REQUEST_RENDERER_CLASSES', 'UNAUTHENTICATED_USER', 'UNAUTHENTICATED_TOKEN', ) diff --git a/rest_framework/test.py b/rest_framework/test.py index 2f658a56..29d017ee 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -1,22 +1,31 @@ # -- coding: utf-8 -- -# Note that we use `DjangoRequestFactory` and `DjangoClient` names in order +# Note that we import as `DjangoRequestFactory` and `DjangoClient` in order # to make it harder for the user to import the wrong thing without realizing. from __future__ import unicode_literals from django.conf import settings from django.test.client import Client as DjangoClient from django.test.client import ClientHandler +from rest_framework.settings import api_settings from rest_framework.compat import RequestFactory as DjangoRequestFactory from rest_framework.compat import force_bytes_or_smart_bytes, six -from rest_framework.renderers import JSONRenderer, MultiPartRenderer + + +def force_authenticate(request, user=None, token=None): + request._force_auth_user = user + request._force_auth_token = token class APIRequestFactory(DjangoRequestFactory): - renderer_classes = { - 'json': JSONRenderer, - 'multipart': MultiPartRenderer - } - default_format = 'multipart' + renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES + default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT + + def __init__(self, enforce_csrf_checks=False, **defaults): + self.enforce_csrf_checks = enforce_csrf_checks + self.renderer_classes = {} + for cls in self.renderer_classes_list: + self.renderer_classes[cls.format] = cls + super(APIRequestFactory, self).__init__(**defaults) def _encode_data(self, data, format=None, content_type=None): """ @@ -35,18 +44,24 @@ class APIRequestFactory(DjangoRequestFactory): ret = force_bytes_or_smart_bytes(data, settings.DEFAULT_CHARSET) else: - # Use format and render the data into a bytestring format = format or self.default_format + + assert format in self.renderer_classes, ("Invalid format '{0}'. " + "Available formats are {1}. Set TEST_REQUEST_RENDERER_CLASSES " + "to enable extra request formats.".format( + format, + ', '.join(["'" + fmt + "'" for fmt in self.renderer_classes.keys()]) + ) + ) + + # Use format and render the data into a bytestring renderer = self.renderer_classes[format]() ret = renderer.render(data) # Determine the content-type header from the renderer - if ';' in renderer.media_type: - content_type = renderer.media_type - else: - content_type = "{0}; charset={1}".format( - renderer.media_type, renderer.charset - ) + content_type = "{0}; charset={1}".format( + renderer.media_type, renderer.charset + ) # Coerce text to bytes if required. if isinstance(ret, six.text_type): @@ -74,6 +89,11 @@ class APIRequestFactory(DjangoRequestFactory): data, content_type = self._encode_data(data, format, content_type) return self.generic('OPTIONS', path, data, content_type, **extra) + def request(self, **kwargs): + request = super(APIRequestFactory, self).request(**kwargs) + request._dont_enforce_csrf_checks = not self.enforce_csrf_checks + return request + class ForceAuthClientHandler(ClientHandler): """ @@ -82,25 +102,21 @@ class ForceAuthClientHandler(ClientHandler): """ def __init__(self, *args, **kwargs): - self._force_auth_user = None - self._force_auth_token = None + self._force_user = None + self._force_token = None super(ForceAuthClientHandler, self).__init__(*args, **kwargs) def get_response(self, request): # This is the simplest place we can hook into to patch the # request object. - request._force_auth_user = self._force_auth_user - request._force_auth_token = self._force_auth_token + force_authenticate(request, self._force_user, self._force_token) return super(ForceAuthClientHandler, self).get_response(request) class APIClient(APIRequestFactory, DjangoClient): def __init__(self, enforce_csrf_checks=False, **defaults): - # Note that our super call skips Client.__init__ - # since we don't need to instantiate a regular ClientHandler - super(DjangoClient, self).__init__(**defaults) + super(APIClient, self).__init__(**defaults) self.handler = ForceAuthClientHandler(enforce_csrf_checks) - self.exc_info = None self._credentials = {} def credentials(self, **kwargs): @@ -114,10 +130,10 @@ class APIClient(APIRequestFactory, DjangoClient): Forcibly authenticates outgoing requests with the given user and/or token. """ - self.handler._force_auth_user = user - self.handler._force_auth_token = token + self.handler._force_user = user + self.handler._force_token = token - def request(self, **request): + def request(self, **kwargs): # Ensure that any credentials set get added to every request. - request.update(self._credentials) - return super(APIClient, self).request(**request) + kwargs.update(self._credentials) + return super(APIClient, self).request(**kwargs) diff --git a/rest_framework/tests/test_testing.py b/rest_framework/tests/test_testing.py index 3706f38c..49d45fc2 100644 --- a/rest_framework/tests/test_testing.py +++ b/rest_framework/tests/test_testing.py @@ -6,11 +6,11 @@ from django.test import TestCase from rest_framework.compat import patterns, url from rest_framework.decorators import api_view from rest_framework.response import Response -from rest_framework.test import APIClient +from rest_framework.test import APIClient, APIRequestFactory, force_authenticate @api_view(['GET', 'POST']) -def mirror(request): +def view(request): return Response({ 'auth': request.META.get('HTTP_AUTHORIZATION', b''), 'user': request.user.username @@ -18,11 +18,11 @@ def mirror(request): urlpatterns = patterns('', - url(r'^view/$', mirror), + url(r'^view/$', view), ) -class CheckTestClient(TestCase): +class TestAPITestClient(TestCase): urls = 'rest_framework.tests.test_testing' def setUp(self): @@ -66,3 +66,50 @@ class CheckTestClient(TestCase): expected = {'detail': 'CSRF Failed: CSRF cookie not set.'} self.assertEqual(response.status_code, 403) self.assertEqual(response.data, expected) + + +class TestAPIRequestFactory(TestCase): + def test_csrf_exempt_by_default(self): + """ + By default, the test client is CSRF exempt. + """ + user = User.objects.create_user('example', 'example@example.com', 'password') + factory = APIRequestFactory() + request = factory.post('/view/') + request.user = user + response = view(request) + self.assertEqual(response.status_code, 200) + + def test_explicitly_enforce_csrf_checks(self): + """ + The test client can enforce CSRF checks. + """ + user = User.objects.create_user('example', 'example@example.com', 'password') + factory = APIRequestFactory(enforce_csrf_checks=True) + request = factory.post('/view/') + request.user = user + response = view(request) + expected = {'detail': 'CSRF Failed: CSRF cookie not set.'} + self.assertEqual(response.status_code, 403) + self.assertEqual(response.data, expected) + + def test_invalid_format(self): + """ + Attempting to use a format that is not configured will raise an + assertion error. + """ + factory = APIRequestFactory() + self.assertRaises(AssertionError, factory.post, + path='/view/', data={'example': 1}, format='xml' + ) + + def test_force_authenticate(self): + """ + Setting `force_authenticate()` forcibly authenticates the request. + """ + user = User.objects.create_user('example', 'example@example.com') + factory = APIRequestFactory() + request = factory.get('/view') + force_authenticate(request, user=user) + response = view(request) + self.assertEqual(response.data['user'], 'example') -- cgit v1.2.3 From 7d43f41e4aa50c4258ec1d7b63dd62a01440fa9d Mon Sep 17 00:00:00 2001 From: Andy Freeland Date: Thu, 4 Jul 2013 01:51:24 -0400 Subject: Remove 'Hold down "Control" ...' message from help_text When getting the help_text from a field where `many=True`, Django appends 'Hold down "Control", or "Command" on a Mac, to select more than one.' to the help_text. This makes some sense in Django's ModelForms, but no sense in the API. --- rest_framework/fields.py | 15 ++++++++++++++- rest_framework/tests/models.py | 2 +- rest_framework/tests/test_serializer.py | 12 ++++++++++++ 3 files changed, 27 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 35848b4c..1a0ad3b9 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -100,6 +100,19 @@ def humanize_strptime(format_string): return format_string +def strip_multiple_choice_msg(help_text): + """ + Remove the 'Hold down "control" ...' message that is enforced in select + multiple fields. + + See https://code.djangoproject.com/ticket/9321 + """ + multiple_choice_msg = _(' Hold down "Control", or "Command" on a Mac, to select more than one.') + multiple_choice_msg = unicode(multiple_choice_msg) + + return help_text.replace(multiple_choice_msg, '') + + class Field(object): read_only = True creation_counter = 0 @@ -122,7 +135,7 @@ class Field(object): self.label = smart_text(label) if help_text is not None: - self.help_text = smart_text(help_text) + self.help_text = strip_multiple_choice_msg(smart_text(help_text)) def initialize(self, parent, field_name): """ diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py index e2d4eacd..1598ecd9 100644 --- a/rest_framework/tests/models.py +++ b/rest_framework/tests/models.py @@ -52,7 +52,7 @@ class CallableDefaultValueModel(RESTFrameworkModel): class ManyToManyModel(RESTFrameworkModel): - rel = models.ManyToManyField(Anchor) + rel = models.ManyToManyField(Anchor, help_text='Some help text.') class ReadOnlyManyToManyModel(RESTFrameworkModel): diff --git a/rest_framework/tests/test_serializer.py b/rest_framework/tests/test_serializer.py index 8b87a084..6c18f15c 100644 --- a/rest_framework/tests/test_serializer.py +++ b/rest_framework/tests/test_serializer.py @@ -1376,6 +1376,18 @@ class FieldLabelTest(TestCase): self.assertEqual('Label', relations.HyperlinkedRelatedField(view_name='fake', label='Label', help_text='Help', many=True).label) +# Test for issue #961 + +class ManyFieldHelpTextTest(TestCase): + def test_help_text_no_hold_down_control_msg(self): + """ + Validate that help_text doesn't contain the 'Hold down "Control" ...' + message that Django appends to choice fields. + """ + rel_field = fields.Field(help_text=ManyToManyModel._meta.get_field('rel').help_text) + self.assertEqual('Some help text.', unicode(rel_field.help_text)) + + class AttributeMappingOnAutogeneratedFieldsTests(TestCase): def setUp(self): -- cgit v1.2.3 From 8f79caf9d1bd4a3de8371c61f24dcf513454f06b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 5 Jul 2013 09:07:18 +0100 Subject: Use 'force_text', not 'unicode', for compat across python version --- rest_framework/fields.py | 6 +++--- rest_framework/tests/test_serializer.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 1a0ad3b9..6e5ee470 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -102,13 +102,13 @@ def humanize_strptime(format_string): def strip_multiple_choice_msg(help_text): """ - Remove the 'Hold down "control" ...' message that is enforced in select - multiple fields. + Remove the 'Hold down "control" ...' message that is Django enforces in + select multiple fields on ModelForms. (Required for 1.5 and earlier) See https://code.djangoproject.com/ticket/9321 """ multiple_choice_msg = _(' Hold down "Control", or "Command" on a Mac, to select more than one.') - multiple_choice_msg = unicode(multiple_choice_msg) + multiple_choice_msg = force_text(multiple_choice_msg) return help_text.replace(multiple_choice_msg, '') diff --git a/rest_framework/tests/test_serializer.py b/rest_framework/tests/test_serializer.py index 6c18f15c..38acc354 100644 --- a/rest_framework/tests/test_serializer.py +++ b/rest_framework/tests/test_serializer.py @@ -1385,7 +1385,7 @@ class ManyFieldHelpTextTest(TestCase): message that Django appends to choice fields. """ rel_field = fields.Field(help_text=ManyToManyModel._meta.get_field('rel').help_text) - self.assertEqual('Some help text.', unicode(rel_field.help_text)) + self.assertEqual('Some help text.', rel_field.help_text) class AttributeMappingOnAutogeneratedFieldsTests(TestCase): -- cgit v1.2.3 From db863be10cca5d43e37cace88fd2a500f6ee96f8 Mon Sep 17 00:00:00 2001 From: Gertjan Oude Lohuis Date: Tue, 9 Jul 2013 12:19:13 +0200 Subject: Add an ModelAdmin for easy management of Tokens --- rest_framework/authtoken/admin.py | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 rest_framework/authtoken/admin.py (limited to 'rest_framework') diff --git a/rest_framework/authtoken/admin.py b/rest_framework/authtoken/admin.py new file mode 100644 index 00000000..0d948160 --- /dev/null +++ b/rest_framework/authtoken/admin.py @@ -0,0 +1,11 @@ +from django.contrib import admin +from .models import Token + + +class TokenAdmin(admin.ModelAdmin): + list_display = ('key', 'user', 'created') + fields = ('user',) + ordering = ('-created',) + + +admin.site.register(Token, TokenAdmin) -- cgit v1.2.3 From 3032a06c9bd8cc98387b14f7feceb1f5d76041fd Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 9 Jul 2013 13:12:28 +0100 Subject: Use absolute import style --- rest_framework/authtoken/admin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/authtoken/admin.py b/rest_framework/authtoken/admin.py index 0d948160..ec28eb1c 100644 --- a/rest_framework/authtoken/admin.py +++ b/rest_framework/authtoken/admin.py @@ -1,5 +1,5 @@ from django.contrib import admin -from .models import Token +from rest_framework.authtoken.models import Token class TokenAdmin(admin.ModelAdmin): -- cgit v1.2.3 From ae63c49777f4d5b766b85a4b28f6328bd6f9516a Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 15 Jul 2013 11:38:38 +0100 Subject: Added test case classes --- rest_framework/test.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/test.py b/rest_framework/test.py index 29d017ee..ed436976 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -6,6 +6,7 @@ from __future__ import unicode_literals from django.conf import settings from django.test.client import Client as DjangoClient from django.test.client import ClientHandler +from django.test import testcases from rest_framework.settings import api_settings from rest_framework.compat import RequestFactory as DjangoRequestFactory from rest_framework.compat import force_bytes_or_smart_bytes, six @@ -137,3 +138,19 @@ class APIClient(APIRequestFactory, DjangoClient): # Ensure that any credentials set get added to every request. kwargs.update(self._credentials) return super(APIClient, self).request(**kwargs) + + +class APISimpleTestCase(testcases.SimpleTestCase): + client_class = APIClient + + +class APITransactionTestCase(testcases.TransactionTestCase): + client_class = APIClient + + +class APITestCase(testcases.TestCase): + client_class = APIClient + + +class APILiveServerTestCase(testcases.LiveServerTestCase): + client_class = APIClient -- cgit v1.2.3 From 82145e2b06e402c9740ee970c74456a59683667a Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 15 Jul 2013 21:54:13 +0100 Subject: Only include APISimpleTestCase and APILiveServerTestCase from django 1.4 onwards --- rest_framework/test.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/test.py b/rest_framework/test.py index ed436976..a18f5a29 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -3,6 +3,7 @@ # Note that we import as `DjangoRequestFactory` and `DjangoClient` in order # to make it harder for the user to import the wrong thing without realizing. from __future__ import unicode_literals +import django from django.conf import settings from django.test.client import Client as DjangoClient from django.test.client import ClientHandler @@ -140,10 +141,6 @@ class APIClient(APIRequestFactory, DjangoClient): return super(APIClient, self).request(**kwargs) -class APISimpleTestCase(testcases.SimpleTestCase): - client_class = APIClient - - class APITransactionTestCase(testcases.TransactionTestCase): client_class = APIClient @@ -152,5 +149,9 @@ class APITestCase(testcases.TestCase): client_class = APIClient -class APILiveServerTestCase(testcases.LiveServerTestCase): - client_class = APIClient +if django.VERSION >= (1, 4): + class APISimpleTestCase(testcases.SimpleTestCase): + client_class = APIClient + + class APILiveServerTestCase(testcases.LiveServerTestCase): + client_class = APIClient -- cgit v1.2.3 From 2e18fbe373b3216b451ac163ec822e04b1e76120 Mon Sep 17 00:00:00 2001 From: Pavel Zinovkin Date: Sun, 21 Jul 2013 17:03:58 +0400 Subject: Updated EmailField error message. This one already available in django translations. https://github.com/django/django/blob/master/django/conf/locale/ru/LC_MESSAGES/django.po#L343--- rest_framework/fields.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 6e5ee470..f9931887 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -512,7 +512,7 @@ class EmailField(CharField): form_field_class = forms.EmailField default_error_messages = { - 'invalid': _('Enter a valid e-mail address.'), + 'invalid': _('Enter a valid email address.'), } default_validators = [validators.validate_email] -- cgit v1.2.3 From b6d6feaa0210bef225f2dad0d5765b32ecfc8cd0 Mon Sep 17 00:00:00 2001 From: Pavel Zinovkin Date: Sun, 21 Jul 2013 22:43:19 +0400 Subject: Fixed test --- rest_framework/tests/test_serializer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/tests/test_serializer.py b/rest_framework/tests/test_serializer.py index 38acc354..c2497660 100644 --- a/rest_framework/tests/test_serializer.py +++ b/rest_framework/tests/test_serializer.py @@ -494,7 +494,7 @@ class CustomValidationTests(TestCase): } serializer = self.CommentSerializerWithFieldValidator(data=wrong_data) self.assertFalse(serializer.is_valid()) - self.assertEqual(serializer.errors, {'email': ['Enter a valid e-mail address.']}) + self.assertEqual(serializer.errors, {'email': ['Enter a valid email address.']}) class PositiveIntegerAsChoiceTests(TestCase): -- cgit v1.2.3 From 103fed966751680d4ac3dc8125b6807e34bb436a Mon Sep 17 00:00:00 2001 From: Kevin Brown Date: Fri, 26 Jul 2013 10:59:51 -0400 Subject: Fixed reversed arguments in assertion --- rest_framework/serializers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 023f7ccf..682a99a4 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -690,7 +690,7 @@ class ModelSerializer(Serializer): assert field_name in ret, \ "Noexistant field '%s' specified in `read_only_fields` " \ "on serializer '%s'." % \ - (self.__class__.__name__, field_name) + (field_name, self.__class__.__name__) ret[field_name].read_only = True return ret -- cgit v1.2.3 From edd696292ca50cb9c10b50b0124a79135d007ea4 Mon Sep 17 00:00:00 2001 From: Dan Stephenson Date: Sat, 10 Aug 2013 01:12:36 +0100 Subject: Spelling correction on read_only_fields err msg --- rest_framework/serializers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 682a99a4..25825df4 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -683,7 +683,7 @@ class ModelSerializer(Serializer): # in the `read_only_fields` option for field_name in self.opts.read_only_fields: assert field_name not in self.base_fields.keys(), \ - "field '%s' on serializer '%s' specfied in " \ + "field '%s' on serializer '%s' specified in " \ "`read_only_fields`, but also added " \ "as an explict field. Remove it from `read_only_fields`." % \ (field_name, self.__class__.__name__) -- cgit v1.2.3 From bbdcbe945248c5448540a1ab746b8a4b8bc8f95b Mon Sep 17 00:00:00 2001 From: Dan Stephenson Date: Sat, 10 Aug 2013 01:22:47 +0100 Subject: Spelling correction on read_only_fields err msg just spotted extras--- rest_framework/serializers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 682a99a4..9e85cb46 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -685,10 +685,10 @@ class ModelSerializer(Serializer): assert field_name not in self.base_fields.keys(), \ "field '%s' on serializer '%s' specfied in " \ "`read_only_fields`, but also added " \ - "as an explict field. Remove it from `read_only_fields`." % \ + "as an explicit field. Remove it from `read_only_fields`." % \ (field_name, self.__class__.__name__) assert field_name in ret, \ - "Noexistant field '%s' specified in `read_only_fields` " \ + "Non-existant field '%s' specified in `read_only_fields` " \ "on serializer '%s'." % \ (field_name, self.__class__.__name__) ret[field_name].read_only = True -- cgit v1.2.3 From cd5f1bb229f82cb1bf00c322fd5b688cf0638e97 Mon Sep 17 00:00:00 2001 From: Yuri Prezument Date: Mon, 12 Aug 2013 15:41:48 +0300 Subject: Fix PATCH button title in template --- rest_framework/templates/rest_framework/base.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html index 9d939e73..51f9c291 100644 --- a/rest_framework/templates/rest_framework/base.html +++ b/rest_framework/templates/rest_framework/base.html @@ -196,7 +196,7 @@ {% endif %} {% if raw_data_patch_form %} - + {% endif %} -- cgit v1.2.3 From 1d8a80f5cc41f157403771439f15c9b7d615a43b Mon Sep 17 00:00:00 2001 From: Jeremy Satterfield Date: Tue, 13 Aug 2013 15:31:58 -0500 Subject: don't set X-Throttle-Wait-Second header if throttle wait is None --- rest_framework/tests/test_throttling.py | 33 ++++++++++++++++++++++++++++++++- rest_framework/views.py | 2 +- 2 files changed, 33 insertions(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/tests/test_throttling.py b/rest_framework/tests/test_throttling.py index 19bc691a..41bff692 100644 --- a/rest_framework/tests/test_throttling.py +++ b/rest_framework/tests/test_throttling.py @@ -7,7 +7,7 @@ from django.contrib.auth.models import User from django.core.cache import cache from rest_framework.test import APIRequestFactory from rest_framework.views import APIView -from rest_framework.throttling import UserRateThrottle, ScopedRateThrottle +from rest_framework.throttling import BaseThrottle, UserRateThrottle, ScopedRateThrottle from rest_framework.response import Response @@ -21,6 +21,14 @@ class User3MinRateThrottle(UserRateThrottle): scope = 'minutes' +class NonTimeThrottle(BaseThrottle): + def allow_request(self, request, view): + if not hasattr(self.__class__, 'called'): + self.__class__.called = True + return True + return False + + class MockView(APIView): throttle_classes = (User3SecRateThrottle,) @@ -35,6 +43,13 @@ class MockView_MinuteThrottling(APIView): return Response('foo') +class MockView_NonTimeThrottling(APIView): + throttle_classes = (NonTimeThrottle,) + + def get(self, request): + return Response('foo') + + class ThrottlingTests(TestCase): def setUp(self): """ @@ -140,6 +155,22 @@ class ThrottlingTests(TestCase): (80, None) )) + def test_non_time_throttle(self): + """ + Ensure for second based throttles. + """ + request = self.factory.get('/') + + self.assertFalse(hasattr(MockView_NonTimeThrottling.throttle_classes[0], 'called')) + + response = MockView_NonTimeThrottling.as_view()(request) + self.assertFalse('X-Throttle-Wait-Seconds' in response) + + self.assertTrue(MockView_NonTimeThrottling.throttle_classes[0].called) + + response = MockView_NonTimeThrottling.as_view()(request) + self.assertFalse('X-Throttle-Wait-Seconds' in response) + class ScopedRateThrottleTests(TestCase): """ diff --git a/rest_framework/views.py b/rest_framework/views.py index 37bba7f0..d51233a9 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -269,7 +269,7 @@ class APIView(View): Handle any exception that occurs, by returning an appropriate response, or re-raising the error. """ - if isinstance(exc, exceptions.Throttled): + if isinstance(exc, exceptions.Throttled) and exc.wait is not None: # Throttle wait header self.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait -- cgit v1.2.3 From 2f03870ae12479cbfdce68db1a30bad6fa15b311 Mon Sep 17 00:00:00 2001 From: JT Date: Tue, 13 Aug 2013 18:48:49 -0500 Subject: Fix for "No module named compat" --- rest_framework/fields.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/fields.py b/rest_framework/fields.py index f9931887..add9d224 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -924,7 +924,7 @@ class ImageField(FileField): if f is None: return None - from compat import Image + from rest_framework.compat import Image assert Image is not None, 'PIL must be installed for ImageField support' # We need to get a file object for PIL. We might have a path or we might -- cgit v1.2.3 From 486f4c8047b18cc753e1969079b58ca3da6ab43f Mon Sep 17 00:00:00 2001 From: Chad Barrington Date: Wed, 14 Aug 2013 15:18:44 -0500 Subject: Update filters.py Here's a little cleanup for ya, brah!--- rest_framework/filters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/filters.py b/rest_framework/filters.py index c058bc71..fbfdd099 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -134,7 +134,7 @@ class OrderingFilter(BaseFilterBackend): ordering = self.remove_invalid_fields(queryset, ordering) if not ordering: - # Use 'ordering' attribtue by default + # Use 'ordering' attribute by default ordering = self.get_default_ordering(view) if ordering: -- cgit v1.2.3 From 99083baf25d3145d034336b2569b79488cf1662b Mon Sep 17 00:00:00 2001 From: Chad Barrington Date: Wed, 14 Aug 2013 16:01:25 -0500 Subject: Update filters.py Fixed cut n pasted get_ordering docstring for ya bro.--- rest_framework/filters.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/filters.py b/rest_framework/filters.py index fbfdd099..4079e1bd 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -109,8 +109,7 @@ class OrderingFilter(BaseFilterBackend): def get_ordering(self, request): """ - Search terms are set by a ?search=... query parameter, - and may be comma and/or whitespace delimited. + Ordering is set by a comma delimited ?ordering=... query parameter. """ params = request.QUERY_PARAMS.get(self.ordering_param) if params: -- cgit v1.2.3 From d07dae6e79d53fdcfc7d88e085fe7e6da97c9c74 Mon Sep 17 00:00:00 2001 From: Christopher Paolini Date: Thu, 15 Aug 2013 12:41:52 -0400 Subject: Ability to override name/description of view Added settings and additions to formatting.py --- rest_framework/settings.py | 4 ++++ rest_framework/utils/formatting.py | 15 +++++++++++++++ 2 files changed, 19 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 8fd177d5..b65e42cf 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -69,6 +69,10 @@ DEFAULTS = { 'PAGINATE_BY': None, 'PAGINATE_BY_PARAM': None, + # View configuration + 'VIEW_NAME_FUNCTION': None, + 'VIEW_DESCRIPTION_FUNCTION': None, + # Authentication 'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser', 'UNAUTHENTICATED_TOKEN': None, diff --git a/rest_framework/utils/formatting.py b/rest_framework/utils/formatting.py index 4bec8387..c908ce66 100644 --- a/rest_framework/utils/formatting.py +++ b/rest_framework/utils/formatting.py @@ -7,6 +7,7 @@ from django.utils.html import escape from django.utils.safestring import mark_safe from rest_framework.compat import apply_markdown, smart_text import re +from rest_framework.settings import api_settings def _remove_trailing_string(content, trailing): @@ -48,7 +49,14 @@ def _camelcase_to_spaces(content): def get_view_name(cls, suffix=None): """ Return a formatted name for an `APIView` class or `@api_view` function. + If a VIEW_NAME_FUNCTION is set in settings the value of that will be used + if that value is "falsy" then the normal formatting will be used. """ + if api_settings.VIEW_NAME_FUNCTION: + name = api_settings.VIEW_NAME_FUNCTION(cls, suffix) + if name: + return name + name = cls.__name__ name = _remove_trailing_string(name, 'View') name = _remove_trailing_string(name, 'ViewSet') @@ -61,7 +69,14 @@ def get_view_name(cls, suffix=None): def get_view_description(cls, html=False): """ Return a description for an `APIView` class or `@api_view` function. + If a VIEW_DESCRIPTION_FUNCTION is set in settings the value of that will be used + if that value is "falsy" then the normal formatting will be used. """ + if api_settings.VIEW_DESCRIPTION_FUNCTION: + description = api_settings.VIEW_DESCRIPTION_FUNCTION(cls) + if description: + return markup_description(description) + description = cls.__doc__ or '' description = _remove_leading_indent(smart_text(description)) if html: -- cgit v1.2.3 From f34b9ff0498ca054bfe65b4da99b01d48ff052b8 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 15 Aug 2013 21:36:45 +0100 Subject: AnonRateThrottle should always allow authenticated users. Closes #994 --- rest_framework/throttling.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'rest_framework') diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py index f6bb1cc8..65b45593 100644 --- a/rest_framework/throttling.py +++ b/rest_framework/throttling.py @@ -96,6 +96,9 @@ class SimpleRateThrottle(BaseThrottle): return True self.key = self.get_cache_key(request, view) + if self.key is None: + return True + self.history = cache.get(self.key, []) self.now = self.timer() -- cgit v1.2.3 From f6f69dc71d4db7492b5feecc69627dff0031e2b9 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 16 Aug 2013 14:03:20 +0100 Subject: Version 2.3.7 --- rest_framework/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'rest_framework') diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index 776618ac..087808e0 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -1,4 +1,4 @@ -__version__ = '2.3.6' +__version__ = '2.3.7' VERSION = __version__ # synonym -- cgit v1.2.3 From a95984e4d4b034c196d40f74fbdc6345e6d4124c Mon Sep 17 00:00:00 2001 From: Christopher Paolini Date: Fri, 16 Aug 2013 13:23:04 -0400 Subject: Settings now have default functions Updated the setting to have a default function. --- rest_framework/settings.py | 6 +++-- rest_framework/utils/formatting.py | 46 +++++++++++++++++--------------------- 2 files changed, 24 insertions(+), 28 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/settings.py b/rest_framework/settings.py index b65e42cf..0b2bdb62 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -70,8 +70,8 @@ DEFAULTS = { 'PAGINATE_BY_PARAM': None, # View configuration - 'VIEW_NAME_FUNCTION': None, - 'VIEW_DESCRIPTION_FUNCTION': None, + 'VIEW_NAME_FUNCTION': 'rest_framework.utils.formatting.view_name', + 'VIEW_DESCRIPTION_FUNCTION': 'rest_framework.utils.formatting.view_description', # Authentication 'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser', @@ -129,6 +129,8 @@ IMPORT_STRINGS = ( 'TEST_REQUEST_RENDERER_CLASSES', 'UNAUTHENTICATED_USER', 'UNAUTHENTICATED_TOKEN', + 'VIEW_NAME_FUNCTION', + 'VIEW_DESCRIPTION_FUNCTION' ) diff --git a/rest_framework/utils/formatting.py b/rest_framework/utils/formatting.py index c908ce66..4f59f659 100644 --- a/rest_framework/utils/formatting.py +++ b/rest_framework/utils/formatting.py @@ -49,39 +49,15 @@ def _camelcase_to_spaces(content): def get_view_name(cls, suffix=None): """ Return a formatted name for an `APIView` class or `@api_view` function. - If a VIEW_NAME_FUNCTION is set in settings the value of that will be used - if that value is "falsy" then the normal formatting will be used. """ - if api_settings.VIEW_NAME_FUNCTION: - name = api_settings.VIEW_NAME_FUNCTION(cls, suffix) - if name: - return name - - name = cls.__name__ - name = _remove_trailing_string(name, 'View') - name = _remove_trailing_string(name, 'ViewSet') - name = _camelcase_to_spaces(name) - if suffix: - name += ' ' + suffix - return name + return api_settings.VIEW_NAME_FUNCTION(cls, suffix) def get_view_description(cls, html=False): """ Return a description for an `APIView` class or `@api_view` function. - If a VIEW_DESCRIPTION_FUNCTION is set in settings the value of that will be used - if that value is "falsy" then the normal formatting will be used. """ - if api_settings.VIEW_DESCRIPTION_FUNCTION: - description = api_settings.VIEW_DESCRIPTION_FUNCTION(cls) - if description: - return markup_description(description) - - description = cls.__doc__ or '' - description = _remove_leading_indent(smart_text(description)) - if html: - return markup_description(description) - return description + return api_settings.VIEW_DESCRIPTION_FUNCTION(cls) def markup_description(description): @@ -93,3 +69,21 @@ def markup_description(description): else: description = escape(description).replace('\n', '
') return mark_safe(description) + + +def view_name(cls, suffix=None): + name = cls.__name__ + name = _remove_trailing_string(name, 'View') + name = _remove_trailing_string(name, 'ViewSet') + name = _camelcase_to_spaces(name) + if suffix: + name += ' ' + suffix + + return name + +def view_description(cls, html=False): + description = cls.__doc__ or '' + description = _remove_leading_indent(smart_text(description)) + if html: + return markup_description(description) + return description \ No newline at end of file -- cgit v1.2.3 From e6662d434f0214d21d38e4388a40fd63e1f9dcc6 Mon Sep 17 00:00:00 2001 From: Christopher Paolini Date: Sat, 17 Aug 2013 17:44:51 -0400 Subject: Improved view/description function setting Now supports each View having its own name and description function and overriding the global default. --- rest_framework/renderers.py | 5 ++--- rest_framework/utils/breadcrumbs.py | 5 ++--- rest_framework/utils/formatting.py | 15 --------------- rest_framework/views.py | 27 ++++++++++++++++++++------- 4 files changed, 24 insertions(+), 28 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 3a03ca33..1006e26c 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -24,7 +24,6 @@ from rest_framework.settings import api_settings from rest_framework.request import clone_request from rest_framework.utils import encoders from rest_framework.utils.breadcrumbs import get_breadcrumbs -from rest_framework.utils.formatting import get_view_name, get_view_description from rest_framework import exceptions, parsers, status, VERSION @@ -498,10 +497,10 @@ class BrowsableAPIRenderer(BaseRenderer): return GenericContentForm() def get_name(self, view): - return get_view_name(view.__class__, getattr(view, 'suffix', None)) + return view.get_view_name() def get_description(self, view): - return get_view_description(view.__class__, html=True) + return view.get_view_description(html=True) def get_breadcrumbs(self, request): return get_breadcrumbs(request.path) diff --git a/rest_framework/utils/breadcrumbs.py b/rest_framework/utils/breadcrumbs.py index d51374b0..0384faba 100644 --- a/rest_framework/utils/breadcrumbs.py +++ b/rest_framework/utils/breadcrumbs.py @@ -1,6 +1,5 @@ from __future__ import unicode_literals from django.core.urlresolvers import resolve, get_script_prefix -from rest_framework.utils.formatting import get_view_name def get_breadcrumbs(url): @@ -29,8 +28,8 @@ def get_breadcrumbs(url): # Don't list the same view twice in a row. # Probably an optional trailing slash. if not seen or seen[-1] != view: - suffix = getattr(view, 'suffix', None) - name = get_view_name(view.cls, suffix) + instance = view.cls() + name = instance.get_view_name() breadcrumbs_list.insert(0, (name, prefix + url)) seen.append(view) diff --git a/rest_framework/utils/formatting.py b/rest_framework/utils/formatting.py index 4f59f659..5780301a 100644 --- a/rest_framework/utils/formatting.py +++ b/rest_framework/utils/formatting.py @@ -45,21 +45,6 @@ def _camelcase_to_spaces(content): content = re.sub(camelcase_boundry, ' \\1', content).strip() return ' '.join(content.split('_')).title() - -def get_view_name(cls, suffix=None): - """ - Return a formatted name for an `APIView` class or `@api_view` function. - """ - return api_settings.VIEW_NAME_FUNCTION(cls, suffix) - - -def get_view_description(cls, html=False): - """ - Return a description for an `APIView` class or `@api_view` function. - """ - return api_settings.VIEW_DESCRIPTION_FUNCTION(cls) - - def markup_description(description): """ Apply HTML markup to the given description. diff --git a/rest_framework/views.py b/rest_framework/views.py index d51233a9..4553714a 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -12,7 +12,6 @@ from rest_framework.compat import View, HttpResponseBase from rest_framework.request import Request from rest_framework.response import Response from rest_framework.settings import api_settings -from rest_framework.utils.formatting import get_view_name, get_view_description class APIView(View): @@ -25,6 +24,9 @@ class APIView(View): permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS + view_name_function = api_settings.VIEW_NAME_FUNCTION + view_description_function = api_settings.VIEW_DESCRIPTION_FUNCTION + @classmethod def as_view(cls, **initkwargs): """ @@ -157,6 +159,21 @@ class APIView(View): self._negotiator = self.content_negotiation_class() return self._negotiator + def get_view_name(self): + """ + Get the view name + """ + # This is used by ViewSets to disambiguate instance vs list views + view_name_suffix = getattr(self, 'suffix', None) + + return self.view_name_function(self.__class__, view_name_suffix) + + def get_view_description(self, html=False): + """ + Get the view description + """ + return self.view_description_function(self.__class__, html) + # API policy implementation methods def perform_content_negotiation(self, request, force=False): @@ -342,16 +359,12 @@ class APIView(View): Return a dictionary of metadata about the view. Used to return responses for OPTIONS requests. """ - - # This is used by ViewSets to disambiguate instance vs list views - view_name_suffix = getattr(self, 'suffix', None) - # By default we can't provide any form-like information, however the # generic views override this implementation and add additional # information for POST and PUT methods, based on the serializer. ret = SortedDict() - ret['name'] = get_view_name(self.__class__, view_name_suffix) - ret['description'] = get_view_description(self.__class__) + ret['name'] = self.get_view_name() + ret['description'] = self.get_view_description() ret['renders'] = [renderer.media_type for renderer in self.renderer_classes] ret['parses'] = [parser.media_type for parser in self.parser_classes] return ret -- cgit v1.2.3 From 11d7c1838a1146728528d762cdf6bf329321c1d1 Mon Sep 17 00:00:00 2001 From: Christopher Paolini Date: Sat, 17 Aug 2013 17:52:08 -0400 Subject: Updated default view name/description functions Forgot to update the default view name/description functions to the new setup. --- rest_framework/utils/formatting.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/utils/formatting.py b/rest_framework/utils/formatting.py index 5780301a..89a89252 100644 --- a/rest_framework/utils/formatting.py +++ b/rest_framework/utils/formatting.py @@ -56,8 +56,8 @@ def markup_description(description): return mark_safe(description) -def view_name(cls, suffix=None): - name = cls.__name__ +def view_name(instance, view, suffix=None): + name = view.__name__ name = _remove_trailing_string(name, 'View') name = _remove_trailing_string(name, 'ViewSet') name = _camelcase_to_spaces(name) @@ -66,8 +66,8 @@ def view_name(cls, suffix=None): return name -def view_description(cls, html=False): - description = cls.__doc__ or '' +def view_description(instance, view, html=False): + description = view.__doc__ or '' description = _remove_leading_indent(smart_text(description)) if html: return markup_description(description) -- cgit v1.2.3 From 5a374955b14e1f1fdc85f0fa68284dbf6b83ab5f Mon Sep 17 00:00:00 2001 From: Christopher Paolini Date: Sun, 18 Aug 2013 00:29:05 -0400 Subject: Updated tests for view name and description Updated the tests to use the default view_name and view_description functions in the formatter through the default in settings. --- rest_framework/tests/test_description.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/tests/test_description.py b/rest_framework/tests/test_description.py index 8019f5ec..4c03c1de 100644 --- a/rest_framework/tests/test_description.py +++ b/rest_framework/tests/test_description.py @@ -6,7 +6,6 @@ from rest_framework.compat import apply_markdown, smart_text from rest_framework.views import APIView from rest_framework.tests.description import ViewWithNonASCIICharactersInDocstring from rest_framework.tests.description import UTF8_TEST_DOCSTRING -from rest_framework.utils.formatting import get_view_name, get_view_description # We check that docstrings get nicely un-indented. DESCRIPTION = """an example docstring @@ -58,7 +57,7 @@ class TestViewNamesAndDescriptions(TestCase): """ class MockView(APIView): pass - self.assertEqual(get_view_name(MockView), 'Mock') + self.assertEqual(MockView().get_view_name(), 'Mock') def test_view_description_uses_docstring(self): """Ensure view descriptions are based on the docstring.""" @@ -78,7 +77,7 @@ class TestViewNamesAndDescriptions(TestCase): # hash style header #""" - self.assertEqual(get_view_description(MockView), DESCRIPTION) + self.assertEqual(MockView().get_view_description(), DESCRIPTION) def test_view_description_supports_unicode(self): """ @@ -86,7 +85,7 @@ class TestViewNamesAndDescriptions(TestCase): """ self.assertEqual( - get_view_description(ViewWithNonASCIICharactersInDocstring), + ViewWithNonASCIICharactersInDocstring().get_view_description(), smart_text(UTF8_TEST_DOCSTRING) ) @@ -97,7 +96,7 @@ class TestViewNamesAndDescriptions(TestCase): """ class MockView(APIView): pass - self.assertEqual(get_view_description(MockView), '') + self.assertEqual(MockView().get_view_description(), '') def test_markdown(self): """ -- cgit v1.2.3 From 89b0a539c389477cfd7df7df461868b85f618d95 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 19 Aug 2013 08:24:27 +0100 Subject: Move view name/description functions into public space --- rest_framework/settings.py | 4 ++-- rest_framework/utils/formatting.py | 36 +++++++++++------------------------- rest_framework/views.py | 21 ++++++++++++++++++++- 3 files changed, 33 insertions(+), 28 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 0b2bdb62..7d25e513 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -70,8 +70,8 @@ DEFAULTS = { 'PAGINATE_BY_PARAM': None, # View configuration - 'VIEW_NAME_FUNCTION': 'rest_framework.utils.formatting.view_name', - 'VIEW_DESCRIPTION_FUNCTION': 'rest_framework.utils.formatting.view_description', + 'VIEW_NAME_FUNCTION': 'rest_framework.views.get_view_name', + 'VIEW_DESCRIPTION_FUNCTION': 'rest_framework.views.get_view_description', # Authentication 'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser', diff --git a/rest_framework/utils/formatting.py b/rest_framework/utils/formatting.py index 89a89252..4b59ba84 100644 --- a/rest_framework/utils/formatting.py +++ b/rest_framework/utils/formatting.py @@ -5,12 +5,13 @@ from __future__ import unicode_literals from django.utils.html import escape from django.utils.safestring import mark_safe -from rest_framework.compat import apply_markdown, smart_text -import re +from rest_framework.compat import apply_markdown from rest_framework.settings import api_settings +from textwrap import dedent +import re -def _remove_trailing_string(content, trailing): +def remove_trailing_string(content, trailing): """ Strip trailing component `trailing` from `content` if it exists. Used when generating names from view classes. @@ -20,10 +21,14 @@ def _remove_trailing_string(content, trailing): return content -def _remove_leading_indent(content): +def dedent(content): """ Remove leading indent from a block of text. Used when generating descriptions from docstrings. + + Note that python's `textwrap.dedent` doesn't quite cut it, + as it fails to dedent multiline docstrings that include + unindented text on the initial line. """ whitespace_counts = [len(line) - len(line.lstrip(' ')) for line in content.splitlines()[1:] if line.lstrip()] @@ -32,11 +37,10 @@ def _remove_leading_indent(content): if whitespace_counts: whitespace_pattern = '^' + (' ' * min(whitespace_counts)) content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content) - content = content.strip('\n') - return content + return content.strip() -def _camelcase_to_spaces(content): +def camelcase_to_spaces(content): """ Translate 'CamelCaseNames' to 'Camel Case Names'. Used when generating names from view classes. @@ -54,21 +58,3 @@ def markup_description(description): else: description = escape(description).replace('\n', '
') return mark_safe(description) - - -def view_name(instance, view, suffix=None): - name = view.__name__ - name = _remove_trailing_string(name, 'View') - name = _remove_trailing_string(name, 'ViewSet') - name = _camelcase_to_spaces(name) - if suffix: - name += ' ' + suffix - - return name - -def view_description(instance, view, html=False): - description = view.__doc__ or '' - description = _remove_leading_indent(smart_text(description)) - if html: - return markup_description(description) - return description \ No newline at end of file diff --git a/rest_framework/views.py b/rest_framework/views.py index 4553714a..431e21f9 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -8,10 +8,29 @@ from django.http import Http404 from django.utils.datastructures import SortedDict from django.views.decorators.csrf import csrf_exempt from rest_framework import status, exceptions -from rest_framework.compat import View, HttpResponseBase +from rest_framework.compat import smart_text, HttpResponseBase, View from rest_framework.request import Request from rest_framework.response import Response from rest_framework.settings import api_settings +from rest_framework.utils import formatting + + +def get_view_name(instance, view, suffix=None): + name = view.__name__ + name = formatting.remove_trailing_string(name, 'View') + name = formatting.remove_trailing_string(name, 'ViewSet') + name = formatting.camelcase_to_spaces(name) + if suffix: + name += ' ' + suffix + + return name + +def get_view_description(instance, view, html=False): + description = view.__doc__ or '' + description = formatting.dedent(smart_text(description)) + if html: + return formatting.markup_description(description) + return description class APIView(View): -- cgit v1.2.3 From 512067062419b736b65ca27bdb5663d863c775dd Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 19 Aug 2013 08:45:53 +0100 Subject: Document customizable view names/descriptions --- rest_framework/views.py | 42 ++++++++++++++++++++---------------------- 1 file changed, 20 insertions(+), 22 deletions(-) (limited to 'rest_framework') diff --git a/rest_framework/views.py b/rest_framework/views.py index 431e21f9..727a9f95 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -15,8 +15,8 @@ from rest_framework.settings import api_settings from rest_framework.utils import formatting -def get_view_name(instance, view, suffix=None): - name = view.__name__ +def get_view_name(cls, suffix=None): + name = cls.__name__ name = formatting.remove_trailing_string(name, 'View') name = formatting.remove_trailing_string(name, 'ViewSet') name = formatting.camelcase_to_spaces(name) @@ -25,8 +25,8 @@ def get_view_name(instance, view, suffix=None): return name -def get_view_description(instance, view, html=False): - description = view.__doc__ or '' +def get_view_description(cls, html=False): + description = cls.__doc__ or '' description = formatting.dedent(smart_text(description)) if html: return formatting.markup_description(description) @@ -43,9 +43,6 @@ class APIView(View): permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES content_negotiation_class = api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS - view_name_function = api_settings.VIEW_NAME_FUNCTION - view_description_function = api_settings.VIEW_DESCRIPTION_FUNCTION - @classmethod def as_view(cls, **initkwargs): """ @@ -131,6 +128,22 @@ class APIView(View): 'request': getattr(self, 'request', None) } + def get_view_name(self): + """ + Return the view name, as used in OPTIONS responses and in the + browsable API. + """ + func = api_settings.VIEW_NAME_FUNCTION + return func(self.__class__, getattr(self, 'suffix', None)) + + def get_view_description(self, html=False): + """ + Return some descriptive text for the view, as used in OPTIONS responses + and in the browsable API. + """ + func = api_settings.VIEW_DESCRIPTION_FUNCTION + return func(self.__class__, html) + # API policy instantiation methods def get_format_suffix(self, **kwargs): @@ -178,21 +191,6 @@ class APIView(View): self._negotiator = self.content_negotiation_class() return self._negotiator - def get_view_name(self): - """ - Get the view name - """ - # This is used by ViewSets to disambiguate instance vs list views - view_name_suffix = getattr(self, 'suffix', None) - - return self.view_name_function(self.__class__, view_name_suffix) - - def get_view_description(self, html=False): - """ - Get the view description - """ - return self.view_description_function(self.__class__, html) - # API policy implementation methods def perform_content_negotiation(self, request, force=False): -- cgit v1.2.3