aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework/tests
diff options
context:
space:
mode:
Diffstat (limited to 'rest_framework/tests')
-rw-r--r--rest_framework/tests/__init__.py0
-rw-r--r--rest_framework/tests/authentication.py154
-rw-r--r--rest_framework/tests/breadcrumbs.py72
-rw-r--r--rest_framework/tests/decorators.py128
-rw-r--r--rest_framework/tests/description.py113
-rw-r--r--rest_framework/tests/files.py34
-rw-r--r--rest_framework/tests/genericrelations.py33
-rw-r--r--rest_framework/tests/generics.py303
-rw-r--r--rest_framework/tests/htmlrenderer.py50
-rw-r--r--rest_framework/tests/hyperlinkedserializers.py168
-rw-r--r--rest_framework/tests/models.py128
-rw-r--r--rest_framework/tests/modelviews.py90
-rw-r--r--rest_framework/tests/negotiation.py37
-rw-r--r--rest_framework/tests/pagination.py87
-rw-r--r--rest_framework/tests/parsers.py212
-rw-r--r--rest_framework/tests/renderers.py418
-rw-r--r--rest_framework/tests/request.py278
-rw-r--r--rest_framework/tests/response.py182
-rw-r--r--rest_framework/tests/reverse.py26
-rw-r--r--rest_framework/tests/serializer.py500
-rw-r--r--rest_framework/tests/status.py12
-rw-r--r--rest_framework/tests/testcases.py63
-rw-r--r--rest_framework/tests/tests.py13
-rw-r--r--rest_framework/tests/throttling.py144
-rw-r--r--rest_framework/tests/validators.py329
-rw-r--r--rest_framework/tests/views.py97
26 files changed, 3671 insertions, 0 deletions
diff --git a/rest_framework/tests/__init__.py b/rest_framework/tests/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/rest_framework/tests/__init__.py
diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py
new file mode 100644
index 00000000..8ab4c4e4
--- /dev/null
+++ b/rest_framework/tests/authentication.py
@@ -0,0 +1,154 @@
+from django.conf.urls.defaults import patterns
+from django.contrib.auth.models import User
+from django.test import Client, TestCase
+
+from django.utils import simplejson as json
+from django.http import HttpResponse
+
+from rest_framework.views import APIView
+from rest_framework import permissions
+
+from rest_framework.authtoken.models import Token
+from rest_framework.authentication import TokenAuthentication
+
+import base64
+
+
+class MockView(APIView):
+ permission_classes = (permissions.IsAuthenticated,)
+
+ def post(self, request):
+ return HttpResponse({'a': 1, 'b': 2, 'c': 3})
+
+ def put(self, request):
+ return HttpResponse({'a': 1, 'b': 2, 'c': 3})
+
+MockView.authentication_classes += (TokenAuthentication,)
+
+urlpatterns = patterns('',
+ (r'^$', MockView.as_view()),
+)
+
+
+class BasicAuthTests(TestCase):
+ """Basic authentication"""
+ urls = 'rest_framework.tests.authentication'
+
+ def setUp(self):
+ self.csrf_client = Client(enforce_csrf_checks=True)
+ self.username = 'john'
+ self.email = 'lennon@thebeatles.com'
+ self.password = 'password'
+ self.user = User.objects.create_user(self.username, self.email, self.password)
+
+ def test_post_form_passing_basic_auth(self):
+ """Ensure POSTing json over basic auth with correct credentials passes and does not require CSRF"""
+ auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).strip()
+ response = self.csrf_client.post('/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ def test_post_json_passing_basic_auth(self):
+ """Ensure POSTing form over basic auth with correct credentials passes and does not require CSRF"""
+ auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).strip()
+ response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ def test_post_form_failing_basic_auth(self):
+ """Ensure POSTing form over basic auth without correct credentials fails"""
+ response = self.csrf_client.post('/', {'example': 'example'})
+ self.assertEqual(response.status_code, 403)
+
+ def test_post_json_failing_basic_auth(self):
+ """Ensure POSTing json over basic auth without correct credentials fails"""
+ response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json')
+ self.assertEqual(response.status_code, 403)
+
+
+class SessionAuthTests(TestCase):
+ """User session authentication"""
+ urls = 'rest_framework.tests.authentication'
+
+ def setUp(self):
+ self.csrf_client = Client(enforce_csrf_checks=True)
+ self.non_csrf_client = Client(enforce_csrf_checks=False)
+ self.username = 'john'
+ self.email = 'lennon@thebeatles.com'
+ self.password = 'password'
+ self.user = User.objects.create_user(self.username, self.email, self.password)
+
+ def tearDown(self):
+ self.csrf_client.logout()
+
+ def test_post_form_session_auth_failing_csrf(self):
+ """
+ Ensure POSTing form over session authentication without CSRF token fails.
+ """
+ self.csrf_client.login(username=self.username, password=self.password)
+ response = self.csrf_client.post('/', {'example': 'example'})
+ self.assertEqual(response.status_code, 403)
+
+ def test_post_form_session_auth_passing(self):
+ """
+ Ensure POSTing form over session authentication with logged in user and CSRF token passes.
+ """
+ self.non_csrf_client.login(username=self.username, password=self.password)
+ response = self.non_csrf_client.post('/', {'example': 'example'})
+ self.assertEqual(response.status_code, 200)
+
+ def test_put_form_session_auth_passing(self):
+ """
+ Ensure PUTting form over session authentication with logged in user and CSRF token passes.
+ """
+ self.non_csrf_client.login(username=self.username, password=self.password)
+ response = self.non_csrf_client.put('/', {'example': 'example'})
+ self.assertEqual(response.status_code, 200)
+
+ def test_post_form_session_auth_failing(self):
+ """
+ Ensure POSTing form over session authentication without logged in user fails.
+ """
+ response = self.csrf_client.post('/', {'example': 'example'})
+ self.assertEqual(response.status_code, 403)
+
+
+class TokenAuthTests(TestCase):
+ """Token authentication"""
+ urls = 'rest_framework.tests.authentication'
+
+ def setUp(self):
+ self.csrf_client = Client(enforce_csrf_checks=True)
+ self.username = 'john'
+ self.email = 'lennon@thebeatles.com'
+ self.password = 'password'
+ self.user = User.objects.create_user(self.username, self.email, self.password)
+
+ self.key = 'abcd1234'
+ self.token = Token.objects.create(key=self.key, user=self.user)
+
+ def test_post_form_passing_token_auth(self):
+ """Ensure POSTing json over token auth with correct credentials passes and does not require CSRF"""
+ auth = "Token " + self.key
+ response = self.csrf_client.post('/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ def test_post_json_passing_token_auth(self):
+ """Ensure POSTing form over token auth with correct credentials passes and does not require CSRF"""
+ auth = "Token " + self.key
+ response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, 200)
+
+ def test_post_form_failing_token_auth(self):
+ """Ensure POSTing form over token auth without correct credentials fails"""
+ response = self.csrf_client.post('/', {'example': 'example'})
+ self.assertEqual(response.status_code, 403)
+
+ def test_post_json_failing_token_auth(self):
+ """Ensure POSTing json over token auth without correct credentials fails"""
+ response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json')
+ self.assertEqual(response.status_code, 403)
+
+ def test_token_has_auto_assigned_key_if_none_provided(self):
+ """Ensure creating a token with no key will auto-assign a key"""
+ self.token.delete()
+ token = Token.objects.create(user=self.user)
+ self.assertTrue(bool(token.key))
diff --git a/rest_framework/tests/breadcrumbs.py b/rest_framework/tests/breadcrumbs.py
new file mode 100644
index 00000000..647ab96d
--- /dev/null
+++ b/rest_framework/tests/breadcrumbs.py
@@ -0,0 +1,72 @@
+from django.conf.urls.defaults import patterns, url
+from django.test import TestCase
+from rest_framework.utils.breadcrumbs import get_breadcrumbs
+from rest_framework.views import APIView
+
+
+class Root(APIView):
+ pass
+
+
+class ResourceRoot(APIView):
+ pass
+
+
+class ResourceInstance(APIView):
+ pass
+
+
+class NestedResourceRoot(APIView):
+ pass
+
+
+class NestedResourceInstance(APIView):
+ pass
+
+urlpatterns = patterns('',
+ url(r'^$', Root.as_view()),
+ url(r'^resource/$', ResourceRoot.as_view()),
+ url(r'^resource/(?P<key>[0-9]+)$', ResourceInstance.as_view()),
+ url(r'^resource/(?P<key>[0-9]+)/$', NestedResourceRoot.as_view()),
+ url(r'^resource/(?P<key>[0-9]+)/(?P<other>[A-Za-z]+)$', NestedResourceInstance.as_view()),
+)
+
+
+class BreadcrumbTests(TestCase):
+ """Tests the breadcrumb functionality used by the HTML renderer."""
+
+ urls = 'rest_framework.tests.breadcrumbs'
+
+ def test_root_breadcrumbs(self):
+ url = '/'
+ self.assertEqual(get_breadcrumbs(url), [('Root', '/')])
+
+ def test_resource_root_breadcrumbs(self):
+ url = '/resource/'
+ self.assertEqual(get_breadcrumbs(url), [('Root', '/'),
+ ('Resource Root', '/resource/')])
+
+ def test_resource_instance_breadcrumbs(self):
+ url = '/resource/123'
+ self.assertEqual(get_breadcrumbs(url), [('Root', '/'),
+ ('Resource Root', '/resource/'),
+ ('Resource Instance', '/resource/123')])
+
+ def test_nested_resource_breadcrumbs(self):
+ url = '/resource/123/'
+ self.assertEqual(get_breadcrumbs(url), [('Root', '/'),
+ ('Resource Root', '/resource/'),
+ ('Resource Instance', '/resource/123'),
+ ('Nested Resource Root', '/resource/123/')])
+
+ def test_nested_resource_instance_breadcrumbs(self):
+ url = '/resource/123/abc'
+ self.assertEqual(get_breadcrumbs(url), [('Root', '/'),
+ ('Resource Root', '/resource/'),
+ ('Resource Instance', '/resource/123'),
+ ('Nested Resource Root', '/resource/123/'),
+ ('Nested Resource Instance', '/resource/123/abc')])
+
+ def test_broken_url_breadcrumbs_handled_gracefully(self):
+ url = '/foobar'
+ self.assertEqual(get_breadcrumbs(url), [('Root', '/')])
diff --git a/rest_framework/tests/decorators.py b/rest_framework/tests/decorators.py
new file mode 100644
index 00000000..41864d71
--- /dev/null
+++ b/rest_framework/tests/decorators.py
@@ -0,0 +1,128 @@
+from django.test import TestCase
+from rest_framework import status
+from rest_framework.response import Response
+from django.test.client import RequestFactory
+from rest_framework.renderers import JSONRenderer
+from rest_framework.parsers import JSONParser
+from rest_framework.authentication import BasicAuthentication
+from rest_framework.throttling import UserRateThrottle
+from rest_framework.permissions import IsAuthenticated
+from rest_framework.views import APIView
+from rest_framework.decorators import (
+ api_view,
+ renderer_classes,
+ parser_classes,
+ authentication_classes,
+ throttle_classes,
+ permission_classes,
+)
+
+
+class DecoratorTestCase(TestCase):
+
+ def setUp(self):
+ self.factory = RequestFactory()
+
+ def _finalize_response(self, request, response, *args, **kwargs):
+ response.request = request
+ return APIView.finalize_response(self, request, response, *args, **kwargs)
+
+ def test_wrap_view(self):
+
+ @api_view(['GET'])
+ def view(request):
+ return Response({})
+
+ self.assertTrue(isinstance(view.cls_instance, APIView))
+
+ def test_calling_method(self):
+
+ @api_view(['GET'])
+ def view(request):
+ return Response({})
+
+ request = self.factory.get('/')
+ response = view(request)
+ self.assertEqual(response.status_code, 200)
+
+ request = self.factory.post('/')
+ response = view(request)
+ self.assertEqual(response.status_code, 405)
+
+ def test_calling_put_method(self):
+
+ @api_view(['GET', 'PUT'])
+ def view(request):
+ return Response({})
+
+ request = self.factory.put('/')
+ response = view(request)
+ self.assertEqual(response.status_code, 200)
+
+ request = self.factory.post('/')
+ response = view(request)
+ self.assertEqual(response.status_code, 405)
+
+ def test_renderer_classes(self):
+
+ @api_view(['GET'])
+ @renderer_classes([JSONRenderer])
+ def view(request):
+ return Response({})
+
+ request = self.factory.get('/')
+ response = view(request)
+ self.assertTrue(isinstance(response.accepted_renderer, JSONRenderer))
+
+ def test_parser_classes(self):
+
+ @api_view(['GET'])
+ @parser_classes([JSONParser])
+ def view(request):
+ self.assertEqual(len(request.parsers), 1)
+ self.assertTrue(isinstance(request.parsers[0],
+ JSONParser))
+ return Response({})
+
+ request = self.factory.get('/')
+ view(request)
+
+ def test_authentication_classes(self):
+
+ @api_view(['GET'])
+ @authentication_classes([BasicAuthentication])
+ def view(request):
+ self.assertEqual(len(request.authenticators), 1)
+ self.assertTrue(isinstance(request.authenticators[0],
+ BasicAuthentication))
+ return Response({})
+
+ request = self.factory.get('/')
+ view(request)
+
+ def test_permission_classes(self):
+
+ @api_view(['GET'])
+ @permission_classes([IsAuthenticated])
+ def view(request):
+ return Response({})
+
+ request = self.factory.get('/')
+ response = view(request)
+ self.assertEquals(response.status_code, status.HTTP_403_FORBIDDEN)
+
+ def test_throttle_classes(self):
+ class OncePerDayUserThrottle(UserRateThrottle):
+ rate = '1/day'
+
+ @api_view(['GET'])
+ @throttle_classes([OncePerDayUserThrottle])
+ def view(request):
+ return Response({})
+
+ request = self.factory.get('/')
+ response = view(request)
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+
+ response = view(request)
+ self.assertEquals(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS)
diff --git a/rest_framework/tests/description.py b/rest_framework/tests/description.py
new file mode 100644
index 00000000..d958b840
--- /dev/null
+++ b/rest_framework/tests/description.py
@@ -0,0 +1,113 @@
+from django.test import TestCase
+from rest_framework.views import APIView
+from rest_framework.compat import apply_markdown
+
+# We check that docstrings get nicely un-indented.
+DESCRIPTION = """an example docstring
+====================
+
+* list
+* list
+
+another header
+--------------
+
+ code block
+
+indented
+
+# hash style header #"""
+
+# If markdown is installed we also test it's working
+# (and that our wrapped forces '=' to h2 and '-' to h3)
+
+# We support markdown < 2.1 and markdown >= 2.1
+MARKED_DOWN_lt_21 = """<h2>an example docstring</h2>
+<ul>
+<li>list</li>
+<li>list</li>
+</ul>
+<h3>another header</h3>
+<pre><code>code block
+</code></pre>
+<p>indented</p>
+<h2 id="hash_style_header">hash style header</h2>"""
+
+MARKED_DOWN_gte_21 = """<h2 id="an-example-docstring">an example docstring</h2>
+<ul>
+<li>list</li>
+<li>list</li>
+</ul>
+<h3 id="another-header">another header</h3>
+<pre><code>code block
+</code></pre>
+<p>indented</p>
+<h2 id="hash-style-header">hash style header</h2>"""
+
+
+class TestViewNamesAndDescriptions(TestCase):
+ def test_resource_name_uses_classname_by_default(self):
+ """Ensure Resource names are based on the classname by default."""
+ class MockView(APIView):
+ pass
+ self.assertEquals(MockView().get_name(), 'Mock')
+
+ def test_resource_name_can_be_set_explicitly(self):
+ """Ensure Resource names can be set using the 'get_name' method."""
+ example = 'Some Other Name'
+ class MockView(APIView):
+ def get_name(self):
+ return example
+ self.assertEquals(MockView().get_name(), example)
+
+ def test_resource_description_uses_docstring_by_default(self):
+ """Ensure Resource names are based on the docstring by default."""
+ class MockView(APIView):
+ """an example docstring
+ ====================
+
+ * list
+ * list
+
+ another header
+ --------------
+
+ code block
+
+ indented
+
+ # hash style header #"""
+
+ self.assertEquals(MockView().get_description(), DESCRIPTION)
+
+ def test_resource_description_can_be_set_explicitly(self):
+ """Ensure Resource descriptions can be set using the 'get_description' method."""
+ example = 'Some other description'
+
+ class MockView(APIView):
+ """docstring"""
+ def get_description(self):
+ return example
+ self.assertEquals(MockView().get_description(), example)
+
+ def test_resource_description_does_not_require_docstring(self):
+ """Ensure that empty docstrings do not affect the Resource's description if it has been set using the 'get_description' method."""
+ example = 'Some other description'
+
+ class MockView(APIView):
+ def get_description(self):
+ return example
+ self.assertEquals(MockView().get_description(), example)
+
+ def test_resource_description_can_be_empty(self):
+ """Ensure that if a resource has no doctring or 'description' class attribute, then it's description is the empty string."""
+ class MockView(APIView):
+ pass
+ self.assertEquals(MockView().get_description(), '')
+
+ def test_markdown(self):
+ """Ensure markdown to HTML works as expected"""
+ if apply_markdown:
+ gte_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_gte_21
+ lt_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_lt_21
+ self.assertTrue(gte_21_match or lt_21_match)
diff --git a/rest_framework/tests/files.py b/rest_framework/tests/files.py
new file mode 100644
index 00000000..61d7f7b1
--- /dev/null
+++ b/rest_framework/tests/files.py
@@ -0,0 +1,34 @@
+# from django.test import TestCase
+# from django import forms
+
+# from django.test.client import RequestFactory
+# from rest_framework.views import View
+# from rest_framework.response import Response
+
+# import StringIO
+
+
+# class UploadFilesTests(TestCase):
+# """Check uploading of files"""
+# def setUp(self):
+# self.factory = RequestFactory()
+
+# def test_upload_file(self):
+
+# class FileForm(forms.Form):
+# file = forms.FileField()
+
+# class MockView(View):
+# permissions = ()
+# form = FileForm
+
+# def post(self, request, *args, **kwargs):
+# return Response({'FILE_NAME': self.CONTENT['file'].name,
+# 'FILE_CONTENT': self.CONTENT['file'].read()})
+
+# file = StringIO.StringIO('stuff')
+# file.name = 'stuff.txt'
+# request = self.factory.post('/', {'file': file})
+# view = MockView.as_view()
+# response = view(request)
+# self.assertEquals(response.raw_content, {"FILE_CONTENT": "stuff", "FILE_NAME": "stuff.txt"})
diff --git a/rest_framework/tests/genericrelations.py b/rest_framework/tests/genericrelations.py
new file mode 100644
index 00000000..1d7e33bc
--- /dev/null
+++ b/rest_framework/tests/genericrelations.py
@@ -0,0 +1,33 @@
+from django.test import TestCase
+from rest_framework import serializers
+from rest_framework.tests.models import *
+
+
+class TestGenericRelations(TestCase):
+ def setUp(self):
+ bookmark = Bookmark(url='https://www.djangoproject.com/')
+ bookmark.save()
+ django = Tag(tag_name='django')
+ django.save()
+ python = Tag(tag_name='python')
+ python.save()
+ t1 = TaggedItem(content_object=bookmark, tag=django)
+ t1.save()
+ t2 = TaggedItem(content_object=bookmark, tag=python)
+ t2.save()
+ self.bookmark = bookmark
+
+ def test_reverse_generic_relation(self):
+ class BookmarkSerializer(serializers.ModelSerializer):
+ tags = serializers.ManyRelatedField(source='tags')
+
+ class Meta:
+ model = Bookmark
+ exclude = ('id',)
+
+ serializer = BookmarkSerializer(instance=self.bookmark)
+ expected = {
+ 'tags': [u'django', u'python'],
+ 'url': u'https://www.djangoproject.com/'
+ }
+ self.assertEquals(serializer.data, expected)
diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py
new file mode 100644
index 00000000..d45ea976
--- /dev/null
+++ b/rest_framework/tests/generics.py
@@ -0,0 +1,303 @@
+from django.test import TestCase
+from django.test.client import RequestFactory
+from django.utils import simplejson as json
+from rest_framework import generics, serializers, status
+from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel
+
+
+factory = RequestFactory()
+
+
+class RootView(generics.ListCreateAPIView):
+ """
+ Example description for OPTIONS.
+ """
+ model = BasicModel
+
+
+class InstanceView(generics.RetrieveUpdateDestroyAPIView):
+ """
+ Example description for OPTIONS.
+ """
+ model = BasicModel
+
+
+class SlugSerializer(serializers.ModelSerializer):
+ slug = serializers.Field() # read only
+
+ class Meta:
+ model = SlugBasedModel
+ exclude = ('id',)
+
+
+class SlugBasedInstanceView(InstanceView):
+ """
+ A model with a slug-field.
+ """
+ model = SlugBasedModel
+ serializer_class = SlugSerializer
+
+
+class TestRootView(TestCase):
+ def setUp(self):
+ """
+ Create 3 BasicModel intances.
+ """
+ items = ['foo', 'bar', 'baz']
+ for item in items:
+ BasicModel(text=item).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+ self.view = RootView.as_view()
+
+ def test_get_root_view(self):
+ """
+ GET requests to ListCreateAPIView should return list of objects.
+ """
+ request = factory.get('/')
+ response = self.view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data, self.data)
+
+ def test_post_root_view(self):
+ """
+ POST requests to ListCreateAPIView should create a new object.
+ """
+ content = {'text': 'foobar'}
+ request = factory.post('/', json.dumps(content),
+ content_type='application/json')
+ response = self.view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_201_CREATED)
+ self.assertEquals(response.data, {'id': 4, 'text': u'foobar'})
+ created = self.objects.get(id=4)
+ self.assertEquals(created.text, 'foobar')
+
+ def test_put_root_view(self):
+ """
+ PUT requests to ListCreateAPIView should not be allowed
+ """
+ content = {'text': 'foobar'}
+ request = factory.put('/', json.dumps(content),
+ content_type='application/json')
+ response = self.view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
+ self.assertEquals(response.data, {"detail": "Method 'PUT' not allowed."})
+
+ def test_delete_root_view(self):
+ """
+ DELETE requests to ListCreateAPIView should not be allowed
+ """
+ request = factory.delete('/')
+ response = self.view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
+ self.assertEquals(response.data, {"detail": "Method 'DELETE' not allowed."})
+
+ def test_options_root_view(self):
+ """
+ OPTIONS requests to ListCreateAPIView should return metadata
+ """
+ request = factory.options('/')
+ response = self.view(request).render()
+ expected = {
+ 'parses': [
+ 'application/json',
+ 'application/x-www-form-urlencoded',
+ 'multipart/form-data'
+ ],
+ 'renders': [
+ 'application/json',
+ 'text/html'
+ ],
+ 'name': 'Root',
+ 'description': 'Example description for OPTIONS.'
+ }
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data, expected)
+
+ def test_post_cannot_set_id(self):
+ """
+ POST requests to create a new object should not be able to set the id.
+ """
+ content = {'id': 999, 'text': 'foobar'}
+ request = factory.post('/', json.dumps(content),
+ content_type='application/json')
+ response = self.view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_201_CREATED)
+ self.assertEquals(response.data, {'id': 4, 'text': u'foobar'})
+ created = self.objects.get(id=4)
+ self.assertEquals(created.text, 'foobar')
+
+
+class TestInstanceView(TestCase):
+ def setUp(self):
+ """
+ Create 3 BasicModel intances.
+ """
+ items = ['foo', 'bar', 'baz']
+ for item in items:
+ BasicModel(text=item).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+ self.view = InstanceView.as_view()
+ self.slug_based_view = SlugBasedInstanceView.as_view()
+
+ def test_get_instance_view(self):
+ """
+ GET requests to RetrieveUpdateDestroyAPIView should return a single object.
+ """
+ request = factory.get('/1')
+ response = self.view(request, pk=1).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data, self.data[0])
+
+ def test_post_instance_view(self):
+ """
+ POST requests to RetrieveUpdateDestroyAPIView should not be allowed
+ """
+ content = {'text': 'foobar'}
+ request = factory.post('/', json.dumps(content),
+ content_type='application/json')
+ response = self.view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
+ self.assertEquals(response.data, {"detail": "Method 'POST' not allowed."})
+
+ def test_put_instance_view(self):
+ """
+ PUT requests to RetrieveUpdateDestroyAPIView should update an object.
+ """
+ content = {'text': 'foobar'}
+ request = factory.put('/1', json.dumps(content),
+ content_type='application/json')
+ response = self.view(request, pk=1).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data, {'id': 1, 'text': 'foobar'})
+ updated = self.objects.get(id=1)
+ self.assertEquals(updated.text, 'foobar')
+
+ def test_delete_instance_view(self):
+ """
+ DELETE requests to RetrieveUpdateDestroyAPIView should delete an object.
+ """
+ request = factory.delete('/1')
+ response = self.view(request, pk=1).render()
+ self.assertEquals(response.status_code, status.HTTP_204_NO_CONTENT)
+ self.assertEquals(response.content, '')
+ ids = [obj.id for obj in self.objects.all()]
+ self.assertEquals(ids, [2, 3])
+
+ def test_options_instance_view(self):
+ """
+ OPTIONS requests to RetrieveUpdateDestroyAPIView should return metadata
+ """
+ request = factory.options('/')
+ response = self.view(request).render()
+ expected = {
+ 'parses': [
+ 'application/json',
+ 'application/x-www-form-urlencoded',
+ 'multipart/form-data'
+ ],
+ 'renders': [
+ 'application/json',
+ 'text/html'
+ ],
+ 'name': 'Instance',
+ 'description': 'Example description for OPTIONS.'
+ }
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data, expected)
+
+ def test_put_cannot_set_id(self):
+ """
+ PUT requests to create a new object should not be able to set the id.
+ """
+ content = {'id': 999, 'text': 'foobar'}
+ request = factory.put('/1', json.dumps(content),
+ content_type='application/json')
+ response = self.view(request, pk=1).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data, {'id': 1, 'text': 'foobar'})
+ updated = self.objects.get(id=1)
+ self.assertEquals(updated.text, 'foobar')
+
+ def test_put_to_deleted_instance(self):
+ """
+ PUT requests to RetrieveUpdateDestroyAPIView should create an object
+ if it does not currently exist.
+ """
+ self.objects.get(id=1).delete()
+ content = {'text': 'foobar'}
+ request = factory.put('/1', json.dumps(content),
+ content_type='application/json')
+ response = self.view(request, pk=1).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data, {'id': 1, 'text': 'foobar'})
+ updated = self.objects.get(id=1)
+ self.assertEquals(updated.text, 'foobar')
+
+ def test_put_as_create_on_id_based_url(self):
+ """
+ PUT requests to RetrieveUpdateDestroyAPIView should create an object
+ at the requested url if it doesn't exist.
+ """
+ content = {'text': 'foobar'}
+ # pk fields can not be created on demand, only the database can set th pk for a new object
+ request = factory.put('/5', json.dumps(content),
+ content_type='application/json')
+ response = self.view(request, pk=5).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ new_obj = self.objects.get(pk=5)
+ self.assertEquals(new_obj.text, 'foobar')
+
+ def test_put_as_create_on_slug_based_url(self):
+ """
+ PUT requests to RetrieveUpdateDestroyAPIView should create an object
+ at the requested url if possible, else return HTTP_403_FORBIDDEN error-response.
+ """
+ content = {'text': 'foobar'}
+ request = factory.put('/test_slug', json.dumps(content),
+ content_type='application/json')
+ response = self.slug_based_view(request, slug='test_slug').render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data, {'slug': 'test_slug', 'text': 'foobar'})
+ new_obj = SlugBasedModel.objects.get(slug='test_slug')
+ self.assertEquals(new_obj.text, 'foobar')
+
+
+# Regression test for #285
+
+class CommentSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Comment
+ exclude = ('created',)
+
+
+class CommentView(generics.ListCreateAPIView):
+ serializer_class = CommentSerializer
+ model = Comment
+
+
+class TestCreateModelWithAutoNowAddField(TestCase):
+ def setUp(self):
+ self.objects = Comment.objects
+ self.view = CommentView.as_view()
+
+ def test_create_model_with_auto_now_add_field(self):
+ """
+ Regression test for #285
+
+ https://github.com/tomchristie/django-rest-framework/issues/285
+ """
+ content = {'email': 'foobar@example.com', 'content': 'foobar'}
+ request = factory.post('/', json.dumps(content),
+ content_type='application/json')
+ response = self.view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_201_CREATED)
+ created = self.objects.get(id=1)
+ self.assertEquals(created.content, 'foobar')
diff --git a/rest_framework/tests/htmlrenderer.py b/rest_framework/tests/htmlrenderer.py
new file mode 100644
index 00000000..10d7e31d
--- /dev/null
+++ b/rest_framework/tests/htmlrenderer.py
@@ -0,0 +1,50 @@
+from django.conf.urls.defaults import patterns, url
+from django.test import TestCase
+from django.template import TemplateDoesNotExist, Template
+import django.template.loader
+from rest_framework.decorators import api_view, renderer_classes
+from rest_framework.renderers import TemplateHTMLRenderer
+from rest_framework.response import Response
+
+
+@api_view(('GET',))
+@renderer_classes((TemplateHTMLRenderer,))
+def example(request):
+ """
+ A view that can returns an HTML representation.
+ """
+ data = {'object': 'foobar'}
+ return Response(data, template_name='example.html')
+
+
+urlpatterns = patterns('',
+ url(r'^$', example),
+)
+
+
+class TemplateHTMLRendererTests(TestCase):
+ urls = 'rest_framework.tests.htmlrenderer'
+
+ def setUp(self):
+ """
+ Monkeypatch get_template
+ """
+ self.get_template = django.template.loader.get_template
+
+ def get_template(template_name):
+ if template_name == 'example.html':
+ return Template("example: {{ object }}")
+ raise TemplateDoesNotExist(template_name)
+
+ django.template.loader.get_template = get_template
+
+ def tearDown(self):
+ """
+ Revert monkeypatching
+ """
+ django.template.loader.get_template = self.get_template
+
+ def test_simple_html_view(self):
+ response = self.client.get('/')
+ self.assertContains(response, "example: foobar")
+ self.assertEquals(response['Content-Type'], 'text/html')
diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/hyperlinkedserializers.py
new file mode 100644
index 00000000..92c3691e
--- /dev/null
+++ b/rest_framework/tests/hyperlinkedserializers.py
@@ -0,0 +1,168 @@
+from django.conf.urls.defaults import patterns, url
+from django.test import TestCase
+from django.test.client import RequestFactory
+from rest_framework import generics, status, serializers
+from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment
+
+factory = RequestFactory()
+
+
+class BlogPostCommentSerializer(serializers.Serializer):
+ text = serializers.CharField()
+ blog_post_url = serializers.HyperlinkedRelatedField(source='blog_post', view_name='blogpost-detail', queryset=BlogPost.objects.all())
+
+ def restore_object(self, attrs, instance=None):
+ return BlogPostComment(**attrs)
+
+
+class BasicList(generics.ListCreateAPIView):
+ model = BasicModel
+ model_serializer_class = serializers.HyperlinkedModelSerializer
+
+
+class BasicDetail(generics.RetrieveUpdateDestroyAPIView):
+ model = BasicModel
+ model_serializer_class = serializers.HyperlinkedModelSerializer
+
+
+class AnchorDetail(generics.RetrieveAPIView):
+ model = Anchor
+ model_serializer_class = serializers.HyperlinkedModelSerializer
+
+
+class ManyToManyList(generics.ListAPIView):
+ model = ManyToManyModel
+ model_serializer_class = serializers.HyperlinkedModelSerializer
+
+
+class ManyToManyDetail(generics.RetrieveAPIView):
+ model = ManyToManyModel
+ model_serializer_class = serializers.HyperlinkedModelSerializer
+
+
+class BlogPostCommentListCreate(generics.ListCreateAPIView):
+ model = BlogPostComment
+ model_serializer_class = BlogPostCommentSerializer
+
+
+class BlogPostDetail(generics.RetrieveAPIView):
+ model = BlogPost
+
+urlpatterns = patterns('',
+ url(r'^basic/$', BasicList.as_view(), name='basicmodel-list'),
+ url(r'^basic/(?P<pk>\d+)/$', BasicDetail.as_view(), name='basicmodel-detail'),
+ url(r'^anchor/(?P<pk>\d+)/$', AnchorDetail.as_view(), name='anchor-detail'),
+ url(r'^manytomany/$', ManyToManyList.as_view(), name='manytomanymodel-list'),
+ url(r'^manytomany/(?P<pk>\d+)/$', ManyToManyDetail.as_view(), name='manytomanymodel-detail'),
+ url(r'^posts/(?P<pk>\d+)/$', BlogPostDetail.as_view(), name='blogpost-detail'),
+ url(r'^comments/$', BlogPostCommentListCreate.as_view(), name='blogpostcomment-list')
+)
+
+
+class TestBasicHyperlinkedView(TestCase):
+ urls = 'rest_framework.tests.hyperlinkedserializers'
+
+ def setUp(self):
+ """
+ Create 3 BasicModel intances.
+ """
+ items = ['foo', 'bar', 'baz']
+ for item in items:
+ BasicModel(text=item).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'url': 'http://testserver/basic/%d/' % obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+ self.list_view = BasicList.as_view()
+ self.detail_view = BasicDetail.as_view()
+
+ def test_get_list_view(self):
+ """
+ GET requests to ListCreateAPIView should return list of objects.
+ """
+ request = factory.get('/basic/')
+ response = self.list_view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data, self.data)
+
+ def test_get_detail_view(self):
+ """
+ GET requests to ListCreateAPIView should return list of objects.
+ """
+ request = factory.get('/basic/1')
+ response = self.detail_view(request, pk=1).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data, self.data[0])
+
+
+class TestManyToManyHyperlinkedView(TestCase):
+ urls = 'rest_framework.tests.hyperlinkedserializers'
+
+ def setUp(self):
+ """
+ Create 3 BasicModel intances.
+ """
+ items = ['foo', 'bar', 'baz']
+ anchors = []
+ for item in items:
+ anchor = Anchor(text=item)
+ anchor.save()
+ anchors.append(anchor)
+
+ manytomany = ManyToManyModel()
+ manytomany.save()
+ manytomany.rel.add(*anchors)
+
+ self.data = [{
+ 'url': 'http://testserver/manytomany/1/',
+ 'rel': [
+ 'http://testserver/anchor/1/',
+ 'http://testserver/anchor/2/',
+ 'http://testserver/anchor/3/',
+ ]
+ }]
+ self.list_view = ManyToManyList.as_view()
+ self.detail_view = ManyToManyDetail.as_view()
+
+ def test_get_list_view(self):
+ """
+ GET requests to ListCreateAPIView should return list of objects.
+ """
+ request = factory.get('/manytomany/')
+ response = self.list_view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data, self.data)
+
+ def test_get_detail_view(self):
+ """
+ GET requests to ListCreateAPIView should return list of objects.
+ """
+ request = factory.get('/manytomany/1/')
+ response = self.detail_view(request, pk=1).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data, self.data[0])
+
+
+class TestCreateWithForeignKeys(TestCase):
+ urls = 'rest_framework.tests.hyperlinkedserializers'
+
+ def setUp(self):
+ """
+ Create a blog post
+ """
+ self.post = BlogPost.objects.create(title="Test post")
+ self.create_view = BlogPostCommentListCreate.as_view()
+
+ def test_create_comment(self):
+
+ data = {
+ 'text': 'A test comment',
+ 'blog_post_url': 'http://testserver/posts/1/'
+ }
+
+ request = factory.post('/comments/', data=data)
+ response = self.create_view(request).render()
+ self.assertEqual(response.status_code, 201)
+ self.assertEqual(self.post.blogpostcomment_set.count(), 1)
+ self.assertEqual(self.post.blogpostcomment_set.all()[0].text, 'A test comment')
diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py
new file mode 100644
index 00000000..415e4d06
--- /dev/null
+++ b/rest_framework/tests/models.py
@@ -0,0 +1,128 @@
+from django.db import models
+from django.contrib.contenttypes.models import ContentType
+from django.contrib.contenttypes.generic import GenericForeignKey, GenericRelation
+
+# from django.contrib.auth.models import Group
+
+
+# class CustomUser(models.Model):
+# """
+# A custom user model, which uses a 'through' table for the foreign key
+# """
+# username = models.CharField(max_length=255, unique=True)
+# groups = models.ManyToManyField(
+# to=Group, blank=True, null=True, through='UserGroupMap'
+# )
+
+# @models.permalink
+# def get_absolute_url(self):
+# return ('custom_user', (), {
+# 'pk': self.id
+# })
+
+
+# class UserGroupMap(models.Model):
+# user = models.ForeignKey(to=CustomUser)
+# group = models.ForeignKey(to=Group)
+
+# @models.permalink
+# def get_absolute_url(self):
+# return ('user_group_map', (), {
+# 'pk': self.id
+# })
+
+def foobar():
+ return 'foobar'
+
+
+class RESTFrameworkModel(models.Model):
+ """
+ Base for test models that sets app_label, so they play nicely.
+ """
+ class Meta:
+ app_label = 'tests'
+ abstract = True
+
+
+class Anchor(RESTFrameworkModel):
+ text = models.CharField(max_length=100, default='anchor')
+
+
+class BasicModel(RESTFrameworkModel):
+ text = models.CharField(max_length=100)
+
+
+class SlugBasedModel(RESTFrameworkModel):
+ text = models.CharField(max_length=100)
+ slug = models.SlugField(max_length=32)
+
+
+class DefaultValueModel(RESTFrameworkModel):
+ text = models.CharField(default='foobar', max_length=100)
+
+
+class CallableDefaultValueModel(RESTFrameworkModel):
+ text = models.CharField(default=foobar, max_length=100)
+
+
+class ManyToManyModel(RESTFrameworkModel):
+ rel = models.ManyToManyField(Anchor)
+
+
+class ReadOnlyManyToManyModel(RESTFrameworkModel):
+ text = models.CharField(max_length=100, default='anchor')
+ rel = models.ManyToManyField(Anchor)
+
+# Models to test generic relations
+
+
+class Tag(RESTFrameworkModel):
+ tag_name = models.SlugField()
+
+
+class TaggedItem(RESTFrameworkModel):
+ tag = models.ForeignKey(Tag, related_name='items')
+ content_type = models.ForeignKey(ContentType)
+ object_id = models.PositiveIntegerField()
+ content_object = GenericForeignKey('content_type', 'object_id')
+
+ def __unicode__(self):
+ return self.tag.tag_name
+
+
+class Bookmark(RESTFrameworkModel):
+ url = models.URLField()
+ tags = GenericRelation(TaggedItem)
+
+
+# Model for regression test for #285
+
+class Comment(RESTFrameworkModel):
+ email = models.EmailField()
+ content = models.CharField(max_length=200)
+ created = models.DateTimeField(auto_now_add=True)
+
+
+class ActionItem(RESTFrameworkModel):
+ title = models.CharField(max_length=200)
+ done = models.BooleanField(default=False)
+
+
+# Models for reverse relations
+class BlogPost(RESTFrameworkModel):
+ title = models.CharField(max_length=100)
+
+
+class BlogPostComment(RESTFrameworkModel):
+ text = models.TextField()
+ blog_post = models.ForeignKey(BlogPost)
+
+
+class Person(RESTFrameworkModel):
+ name = models.CharField(max_length=10)
+ age = models.IntegerField(null=True, blank=True)
+
+
+# Model for issue #324
+class BlankFieldModel(RESTFrameworkModel):
+ title = models.CharField(max_length=100, blank=True)
diff --git a/rest_framework/tests/modelviews.py b/rest_framework/tests/modelviews.py
new file mode 100644
index 00000000..1f8468e8
--- /dev/null
+++ b/rest_framework/tests/modelviews.py
@@ -0,0 +1,90 @@
+# from django.conf.urls.defaults import patterns, url
+# from django.forms import ModelForm
+# from django.contrib.auth.models import Group, User
+# from rest_framework.resources import ModelResource
+# from rest_framework.views import ListOrCreateModelView, InstanceModelView
+# from rest_framework.tests.models import CustomUser
+# from rest_framework.tests.testcases import TestModelsTestCase
+
+
+# class GroupResource(ModelResource):
+# model = Group
+
+
+# class UserForm(ModelForm):
+# class Meta:
+# model = User
+# exclude = ('last_login', 'date_joined')
+
+
+# class UserResource(ModelResource):
+# model = User
+# form = UserForm
+
+
+# class CustomUserResource(ModelResource):
+# model = CustomUser
+
+# urlpatterns = patterns('',
+# url(r'^users/$', ListOrCreateModelView.as_view(resource=UserResource), name='users'),
+# url(r'^users/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=UserResource)),
+# url(r'^customusers/$', ListOrCreateModelView.as_view(resource=CustomUserResource), name='customusers'),
+# url(r'^customusers/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=CustomUserResource)),
+# url(r'^groups/$', ListOrCreateModelView.as_view(resource=GroupResource), name='groups'),
+# url(r'^groups/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=GroupResource)),
+# )
+
+
+# class ModelViewTests(TestModelsTestCase):
+# """Test the model views rest_framework provides"""
+# urls = 'rest_framework.tests.modelviews'
+
+# def test_creation(self):
+# """Ensure that a model object can be created"""
+# self.assertEqual(0, Group.objects.count())
+
+# response = self.client.post('/groups/', {'name': 'foo'})
+
+# self.assertEqual(response.status_code, 201)
+# self.assertEqual(1, Group.objects.count())
+# self.assertEqual('foo', Group.objects.all()[0].name)
+
+# def test_creation_with_m2m_relation(self):
+# """Ensure that a model object with a m2m relation can be created"""
+# group = Group(name='foo')
+# group.save()
+# self.assertEqual(0, User.objects.count())
+
+# response = self.client.post('/users/', {'username': 'bar', 'password': 'baz', 'groups': [group.id]})
+
+# self.assertEqual(response.status_code, 201)
+# self.assertEqual(1, User.objects.count())
+
+# user = User.objects.all()[0]
+# self.assertEqual('bar', user.username)
+# self.assertEqual('baz', user.password)
+# self.assertEqual(1, user.groups.count())
+
+# group = user.groups.all()[0]
+# self.assertEqual('foo', group.name)
+
+# def test_creation_with_m2m_relation_through(self):
+# """
+# Ensure that a model object with a m2m relation can be created where that
+# relation uses a through table
+# """
+# group = Group(name='foo')
+# group.save()
+# self.assertEqual(0, User.objects.count())
+
+# response = self.client.post('/customusers/', {'username': 'bar', 'groups': [group.id]})
+
+# self.assertEqual(response.status_code, 201)
+# self.assertEqual(1, CustomUser.objects.count())
+
+# user = CustomUser.objects.all()[0]
+# self.assertEqual('bar', user.username)
+# self.assertEqual(1, user.groups.count())
+
+# group = user.groups.all()[0]
+# self.assertEqual('foo', group.name)
diff --git a/rest_framework/tests/negotiation.py b/rest_framework/tests/negotiation.py
new file mode 100644
index 00000000..e06354ea
--- /dev/null
+++ b/rest_framework/tests/negotiation.py
@@ -0,0 +1,37 @@
+from django.test import TestCase
+from django.test.client import RequestFactory
+from rest_framework.negotiation import DefaultContentNegotiation
+
+factory = RequestFactory()
+
+
+class MockJSONRenderer(object):
+ media_type = 'application/json'
+
+
+class MockHTMLRenderer(object):
+ media_type = 'text/html'
+
+
+class TestAcceptedMediaType(TestCase):
+ def setUp(self):
+ self.renderers = [MockJSONRenderer(), MockHTMLRenderer()]
+ self.negotiator = DefaultContentNegotiation()
+
+ def select_renderer(self, request):
+ return self.negotiator.select_renderer(request, self.renderers)
+
+ def test_client_without_accept_use_renderer(self):
+ request = factory.get('/')
+ accepted_renderer, accepted_media_type = self.select_renderer(request)
+ self.assertEquals(accepted_media_type, 'application/json')
+
+ def test_client_underspecifies_accept_use_renderer(self):
+ request = factory.get('/', HTTP_ACCEPT='*/*')
+ accepted_renderer, accepted_media_type = self.select_renderer(request)
+ self.assertEquals(accepted_media_type, 'application/json')
+
+ def test_client_overspecifies_accept_use_client(self):
+ request = factory.get('/', HTTP_ACCEPT='application/json; indent=8')
+ accepted_renderer, accepted_media_type = self.select_renderer(request)
+ self.assertEquals(accepted_media_type, 'application/json; indent=8')
diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py
new file mode 100644
index 00000000..a939c9ef
--- /dev/null
+++ b/rest_framework/tests/pagination.py
@@ -0,0 +1,87 @@
+from django.core.paginator import Paginator
+from django.test import TestCase
+from django.test.client import RequestFactory
+from rest_framework import generics, status, pagination
+from rest_framework.tests.models import BasicModel
+
+factory = RequestFactory()
+
+
+class RootView(generics.ListCreateAPIView):
+ """
+ Example description for OPTIONS.
+ """
+ model = BasicModel
+ paginate_by = 10
+
+
+class IntegrationTestPagination(TestCase):
+ """
+ Integration tests for paginated list views.
+ """
+
+ def setUp(self):
+ """
+ Create 26 BasicModel intances.
+ """
+ for char in 'abcdefghijklmnopqrstuvwxyz':
+ BasicModel(text=char * 3).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+ self.view = RootView.as_view()
+
+ def test_get_paginated_root_view(self):
+ """
+ GET requests to paginated ListCreateAPIView should return paginated results.
+ """
+ request = factory.get('/')
+ response = self.view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data['count'], 26)
+ self.assertEquals(response.data['results'], self.data[:10])
+ self.assertNotEquals(response.data['next'], None)
+ self.assertEquals(response.data['previous'], None)
+
+ request = factory.get(response.data['next'])
+ response = self.view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data['count'], 26)
+ self.assertEquals(response.data['results'], self.data[10:20])
+ self.assertNotEquals(response.data['next'], None)
+ self.assertNotEquals(response.data['previous'], None)
+
+ request = factory.get(response.data['next'])
+ response = self.view(request).render()
+ self.assertEquals(response.status_code, status.HTTP_200_OK)
+ self.assertEquals(response.data['count'], 26)
+ self.assertEquals(response.data['results'], self.data[20:])
+ self.assertEquals(response.data['next'], None)
+ self.assertNotEquals(response.data['previous'], None)
+
+
+class UnitTestPagination(TestCase):
+ """
+ Unit tests for pagination of primative objects.
+ """
+
+ def setUp(self):
+ self.objects = [char * 3 for char in 'abcdefghijklmnopqrstuvwxyz']
+ paginator = Paginator(self.objects, 10)
+ self.first_page = paginator.page(1)
+ self.last_page = paginator.page(3)
+
+ def test_native_pagination(self):
+ serializer = pagination.PaginationSerializer(instance=self.first_page)
+ self.assertEquals(serializer.data['count'], 26)
+ self.assertEquals(serializer.data['next'], '?page=2')
+ self.assertEquals(serializer.data['previous'], None)
+ self.assertEquals(serializer.data['results'], self.objects[:10])
+
+ serializer = pagination.PaginationSerializer(instance=self.last_page)
+ self.assertEquals(serializer.data['count'], 26)
+ self.assertEquals(serializer.data['next'], None)
+ self.assertEquals(serializer.data['previous'], '?page=2')
+ self.assertEquals(serializer.data['results'], self.objects[20:])
diff --git a/rest_framework/tests/parsers.py b/rest_framework/tests/parsers.py
new file mode 100644
index 00000000..8ab8a52f
--- /dev/null
+++ b/rest_framework/tests/parsers.py
@@ -0,0 +1,212 @@
+# """
+# ..
+# >>> from rest_framework.parsers import FormParser
+# >>> from django.test.client import RequestFactory
+# >>> from rest_framework.views import View
+# >>> from StringIO import StringIO
+# >>> from urllib import urlencode
+# >>> req = RequestFactory().get('/')
+# >>> some_view = View()
+# >>> some_view.request = req # Make as if this request had been dispatched
+#
+# FormParser
+# ============
+#
+# Data flatening
+# ----------------
+#
+# Here is some example data, which would eventually be sent along with a post request :
+#
+# >>> inpt = urlencode([
+# ... ('key1', 'bla1'),
+# ... ('key2', 'blo1'), ('key2', 'blo2'),
+# ... ])
+#
+# Default behaviour for :class:`parsers.FormParser`, is to return a single value for each parameter :
+#
+# >>> (data, files) = FormParser(some_view).parse(StringIO(inpt))
+# >>> data == {'key1': 'bla1', 'key2': 'blo1'}
+# True
+#
+# However, you can customize this behaviour by subclassing :class:`parsers.FormParser`, and overriding :meth:`parsers.FormParser.is_a_list` :
+#
+# >>> class MyFormParser(FormParser):
+# ...
+# ... def is_a_list(self, key, val_list):
+# ... return len(val_list) > 1
+#
+# This new parser only flattens the lists of parameters that contain a single value.
+#
+# >>> (data, files) = MyFormParser(some_view).parse(StringIO(inpt))
+# >>> data == {'key1': 'bla1', 'key2': ['blo1', 'blo2']}
+# True
+#
+# .. note:: The same functionality is available for :class:`parsers.MultiPartParser`.
+#
+# Submitting an empty list
+# --------------------------
+#
+# When submitting an empty select multiple, like this one ::
+#
+# <select multiple="multiple" name="key2"></select>
+#
+# The browsers usually strip the parameter completely. A hack to avoid this, and therefore being able to submit an empty select multiple, is to submit a value that tells the server that the list is empty ::
+#
+# <select multiple="multiple" name="key2"><option value="_empty"></select>
+#
+# :class:`parsers.FormParser` provides the server-side implementation for this hack. Considering the following posted data :
+#
+# >>> inpt = urlencode([
+# ... ('key1', 'blo1'), ('key1', '_empty'),
+# ... ('key2', '_empty'),
+# ... ])
+#
+# :class:`parsers.FormParser` strips the values ``_empty`` from all the lists.
+#
+# >>> (data, files) = MyFormParser(some_view).parse(StringIO(inpt))
+# >>> data == {'key1': 'blo1'}
+# True
+#
+# Oh ... but wait a second, the parameter ``key2`` isn't even supposed to be a list, so the parser just stripped it.
+#
+# >>> class MyFormParser(FormParser):
+# ...
+# ... def is_a_list(self, key, val_list):
+# ... return key == 'key2'
+# ...
+# >>> (data, files) = MyFormParser(some_view).parse(StringIO(inpt))
+# >>> data == {'key1': 'blo1', 'key2': []}
+# True
+#
+# Better like that. Note that you can configure something else than ``_empty`` for the empty value by setting :attr:`parsers.FormParser.EMPTY_VALUE`.
+# """
+# import httplib, mimetypes
+# from tempfile import TemporaryFile
+# from django.test import TestCase
+# from django.test.client import RequestFactory
+# from rest_framework.parsers import MultiPartParser
+# from rest_framework.views import View
+# from StringIO import StringIO
+#
+# def encode_multipart_formdata(fields, files):
+# """For testing multipart parser.
+# fields is a sequence of (name, value) elements for regular form fields.
+# files is a sequence of (name, filename, value) elements for data to be uploaded as files
+# Return (content_type, body)."""
+# BOUNDARY = '----------ThIs_Is_tHe_bouNdaRY_$'
+# CRLF = '\r\n'
+# L = []
+# for (key, value) in fields:
+# L.append('--' + BOUNDARY)
+# L.append('Content-Disposition: form-data; name="%s"' % key)
+# L.append('')
+# L.append(value)
+# for (key, filename, value) in files:
+# L.append('--' + BOUNDARY)
+# L.append('Content-Disposition: form-data; name="%s"; filename="%s"' % (key, filename))
+# L.append('Content-Type: %s' % get_content_type(filename))
+# L.append('')
+# L.append(value)
+# L.append('--' + BOUNDARY + '--')
+# L.append('')
+# body = CRLF.join(L)
+# content_type = 'multipart/form-data; boundary=%s' % BOUNDARY
+# return content_type, body
+#
+# def get_content_type(filename):
+# return mimetypes.guess_type(filename)[0] or 'application/octet-stream'
+#
+#class TestMultiPartParser(TestCase):
+# def setUp(self):
+# self.req = RequestFactory()
+# self.content_type, self.body = encode_multipart_formdata([('key1', 'val1'), ('key1', 'val2')],
+# [('file1', 'pic.jpg', 'blablabla'), ('file1', 't.txt', 'blobloblo')])
+#
+# def test_multipartparser(self):
+# """Ensure that MultiPartParser can parse multipart/form-data that contains a mix of several files and parameters."""
+# post_req = RequestFactory().post('/', self.body, content_type=self.content_type)
+# view = View()
+# view.request = post_req
+# (data, files) = MultiPartParser(view).parse(StringIO(self.body))
+# self.assertEqual(data['key1'], 'val1')
+# self.assertEqual(files['file1'].read(), 'blablabla')
+
+from StringIO import StringIO
+from django import forms
+from django.test import TestCase
+from rest_framework.parsers import FormParser
+from rest_framework.parsers import XMLParser
+import datetime
+
+
+class Form(forms.Form):
+ field1 = forms.CharField(max_length=3)
+ field2 = forms.CharField()
+
+
+class TestFormParser(TestCase):
+ def setUp(self):
+ self.string = "field1=abc&field2=defghijk"
+
+ def test_parse(self):
+ """ Make sure the `QueryDict` works OK """
+ parser = FormParser()
+
+ stream = StringIO(self.string)
+ data = parser.parse(stream)
+
+ self.assertEqual(Form(data).is_valid(), True)
+
+
+class TestXMLParser(TestCase):
+ def setUp(self):
+ self._input = StringIO(
+ '<?xml version="1.0" encoding="utf-8"?>'
+ '<root>'
+ '<field_a>121.0</field_a>'
+ '<field_b>dasd</field_b>'
+ '<field_c></field_c>'
+ '<field_d>2011-12-25 12:45:00</field_d>'
+ '</root>'
+ )
+ self._data = {
+ 'field_a': 121,
+ 'field_b': 'dasd',
+ 'field_c': None,
+ 'field_d': datetime.datetime(2011, 12, 25, 12, 45, 00)
+ }
+ self._complex_data_input = StringIO(
+ '<?xml version="1.0" encoding="utf-8"?>'
+ '<root>'
+ '<creation_date>2011-12-25 12:45:00</creation_date>'
+ '<sub_data_list>'
+ '<list-item><sub_id>1</sub_id><sub_name>first</sub_name></list-item>'
+ '<list-item><sub_id>2</sub_id><sub_name>second</sub_name></list-item>'
+ '</sub_data_list>'
+ '<name>name</name>'
+ '</root>'
+ )
+ self._complex_data = {
+ "creation_date": datetime.datetime(2011, 12, 25, 12, 45, 00),
+ "name": "name",
+ "sub_data_list": [
+ {
+ "sub_id": 1,
+ "sub_name": "first"
+ },
+ {
+ "sub_id": 2,
+ "sub_name": "second"
+ }
+ ]
+ }
+
+ def test_parse(self):
+ parser = XMLParser()
+ data = parser.parse(self._input)
+ self.assertEqual(data, self._data)
+
+ def test_complex_data_parse(self):
+ parser = XMLParser()
+ data = parser.parse(self._complex_data_input)
+ self.assertEqual(data, self._complex_data)
diff --git a/rest_framework/tests/renderers.py b/rest_framework/tests/renderers.py
new file mode 100644
index 00000000..48d8d9bd
--- /dev/null
+++ b/rest_framework/tests/renderers.py
@@ -0,0 +1,418 @@
+import re
+
+from django.conf.urls.defaults import patterns, url, include
+from django.test import TestCase
+from django.test.client import RequestFactory
+
+from rest_framework import status, permissions
+from rest_framework.compat import yaml
+from rest_framework.response import Response
+from rest_framework.views import APIView
+from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \
+ XMLRenderer, JSONPRenderer, BrowsableAPIRenderer
+from rest_framework.parsers import YAMLParser, XMLParser
+from rest_framework.settings import api_settings
+
+from StringIO import StringIO
+import datetime
+from decimal import Decimal
+
+
+DUMMYSTATUS = status.HTTP_200_OK
+DUMMYCONTENT = 'dummycontent'
+
+RENDERER_A_SERIALIZER = lambda x: 'Renderer A: %s' % x
+RENDERER_B_SERIALIZER = lambda x: 'Renderer B: %s' % x
+
+
+expected_results = [
+ ((elem for elem in [1, 2, 3]), JSONRenderer, '[1, 2, 3]') # Generator
+]
+
+
+class BasicRendererTests(TestCase):
+ def test_expected_results(self):
+ for value, renderer_cls, expected in expected_results:
+ output = renderer_cls().render(value)
+ self.assertEquals(output, expected)
+
+
+class RendererA(BaseRenderer):
+ media_type = 'mock/renderera'
+ format = "formata"
+
+ def render(self, data, media_type=None, renderer_context=None):
+ return RENDERER_A_SERIALIZER(data)
+
+
+class RendererB(BaseRenderer):
+ media_type = 'mock/rendererb'
+ format = "formatb"
+
+ def render(self, data, media_type=None, renderer_context=None):
+ return RENDERER_B_SERIALIZER(data)
+
+
+class MockView(APIView):
+ renderer_classes = (RendererA, RendererB)
+
+ def get(self, request, **kwargs):
+ response = Response(DUMMYCONTENT, status=DUMMYSTATUS)
+ return response
+
+
+class MockGETView(APIView):
+
+ def get(self, request, **kwargs):
+ return Response({'foo': ['bar', 'baz']})
+
+
+class HTMLView(APIView):
+ renderer_classes = (BrowsableAPIRenderer, )
+
+ def get(self, request, **kwargs):
+ return Response('text')
+
+
+class HTMLView1(APIView):
+ renderer_classes = (BrowsableAPIRenderer, JSONRenderer)
+
+ def get(self, request, **kwargs):
+ return Response('text')
+
+urlpatterns = patterns('',
+ url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB])),
+ url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB])),
+ url(r'^jsonp/jsonrenderer$', MockGETView.as_view(renderer_classes=[JSONRenderer, JSONPRenderer])),
+ url(r'^jsonp/nojsonrenderer$', MockGETView.as_view(renderer_classes=[JSONPRenderer])),
+ url(r'^html$', HTMLView.as_view()),
+ url(r'^html1$', HTMLView1.as_view()),
+ url(r'^api', include('rest_framework.urls', namespace='rest_framework'))
+)
+
+
+class POSTDeniedPermission(permissions.BasePermission):
+ def has_permission(self, request, view, obj=None):
+ return request.method != 'POST'
+
+
+class POSTDeniedView(APIView):
+ renderer_classes = (BrowsableAPIRenderer,)
+ permission_classes = (POSTDeniedPermission,)
+
+ def get(self, request):
+ return Response()
+
+ def post(self, request):
+ return Response()
+
+ def put(self, request):
+ return Response()
+
+
+class DocumentingRendererTests(TestCase):
+ def test_only_permitted_forms_are_displayed(self):
+ view = POSTDeniedView.as_view()
+ request = RequestFactory().get('/')
+ response = view(request).render()
+ self.assertNotContains(response, '>POST<')
+ self.assertContains(response, '>PUT<')
+
+
+class RendererEndToEndTests(TestCase):
+ """
+ End-to-end testing of renderers using an RendererMixin on a generic view.
+ """
+
+ urls = 'rest_framework.tests.renderers'
+
+ def test_default_renderer_serializes_content(self):
+ """If the Accept header is not set the default renderer should serialize the response."""
+ resp = self.client.get('/')
+ self.assertEquals(resp['Content-Type'], RendererA.media_type)
+ self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
+ self.assertEquals(resp.status_code, DUMMYSTATUS)
+
+ def test_head_method_serializes_no_content(self):
+ """No response must be included in HEAD requests."""
+ resp = self.client.head('/')
+ self.assertEquals(resp.status_code, DUMMYSTATUS)
+ self.assertEquals(resp['Content-Type'], RendererA.media_type)
+ self.assertEquals(resp.content, '')
+
+ def test_default_renderer_serializes_content_on_accept_any(self):
+ """If the Accept header is set to */* the default renderer should serialize the response."""
+ resp = self.client.get('/', HTTP_ACCEPT='*/*')
+ self.assertEquals(resp['Content-Type'], RendererA.media_type)
+ self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
+ self.assertEquals(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_serializes_content_default_case(self):
+ """If the Accept header is set the specified renderer should serialize the response.
+ (In this case we check that works for the default renderer)"""
+ resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
+ self.assertEquals(resp['Content-Type'], RendererA.media_type)
+ self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
+ self.assertEquals(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_serializes_content_non_default_case(self):
+ """If the Accept header is set the specified renderer should serialize the response.
+ (In this case we check that works for a non-default renderer)"""
+ resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
+ self.assertEquals(resp['Content-Type'], RendererB.media_type)
+ self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEquals(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_serializes_content_on_accept_query(self):
+ """The '_accept' query string should behave in the same way as the Accept header."""
+ param = '?%s=%s' % (
+ api_settings.URL_ACCEPT_OVERRIDE,
+ RendererB.media_type
+ )
+ resp = self.client.get('/' + param)
+ self.assertEquals(resp['Content-Type'], RendererB.media_type)
+ self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEquals(resp.status_code, DUMMYSTATUS)
+
+ def test_unsatisfiable_accept_header_on_request_returns_406_status(self):
+ """If the Accept header is unsatisfiable we should return a 406 Not Acceptable response."""
+ resp = self.client.get('/', HTTP_ACCEPT='foo/bar')
+ self.assertEquals(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE)
+
+ def test_specified_renderer_serializes_content_on_format_query(self):
+ """If a 'format' query is specified, the renderer with the matching
+ format attribute should serialize the response."""
+ param = '?%s=%s' % (
+ api_settings.URL_FORMAT_OVERRIDE,
+ RendererB.format
+ )
+ resp = self.client.get('/' + param)
+ self.assertEquals(resp['Content-Type'], RendererB.media_type)
+ self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEquals(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_serializes_content_on_format_kwargs(self):
+ """If a 'format' keyword arg is specified, the renderer with the matching
+ format attribute should serialize the response."""
+ resp = self.client.get('/something.formatb')
+ self.assertEquals(resp['Content-Type'], RendererB.media_type)
+ self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEquals(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_is_used_on_format_query_with_matching_accept(self):
+ """If both a 'format' query and a matching Accept header specified,
+ the renderer with the matching format attribute should serialize the response."""
+ param = '?%s=%s' % (
+ api_settings.URL_FORMAT_OVERRIDE,
+ RendererB.format
+ )
+ resp = self.client.get('/' + param,
+ HTTP_ACCEPT=RendererB.media_type)
+ self.assertEquals(resp['Content-Type'], RendererB.media_type)
+ self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEquals(resp.status_code, DUMMYSTATUS)
+
+
+_flat_repr = '{"foo": ["bar", "baz"]}'
+_indented_repr = '{\n "foo": [\n "bar",\n "baz"\n ]\n}'
+
+
+def strip_trailing_whitespace(content):
+ """
+ Seems to be some inconsistencies re. trailing whitespace with
+ different versions of the json lib.
+ """
+ return re.sub(' +\n', '\n', content)
+
+
+class JSONRendererTests(TestCase):
+ """
+ Tests specific to the JSON Renderer
+ """
+
+ def test_without_content_type_args(self):
+ """
+ Test basic JSON rendering.
+ """
+ obj = {'foo': ['bar', 'baz']}
+ renderer = JSONRenderer()
+ content = renderer.render(obj, 'application/json')
+ # Fix failing test case which depends on version of JSON library.
+ self.assertEquals(content, _flat_repr)
+
+ def test_with_content_type_args(self):
+ """
+ Test JSON rendering with additional content type arguments supplied.
+ """
+ obj = {'foo': ['bar', 'baz']}
+ renderer = JSONRenderer()
+ content = renderer.render(obj, 'application/json; indent=2')
+ self.assertEquals(strip_trailing_whitespace(content), _indented_repr)
+
+
+class JSONPRendererTests(TestCase):
+ """
+ Tests specific to the JSONP Renderer
+ """
+
+ urls = 'rest_framework.tests.renderers'
+
+ def test_without_callback_with_json_renderer(self):
+ """
+ Test JSONP rendering with View JSON Renderer.
+ """
+ resp = self.client.get('/jsonp/jsonrenderer',
+ HTTP_ACCEPT='application/javascript')
+ self.assertEquals(resp.status_code, 200)
+ self.assertEquals(resp['Content-Type'], 'application/javascript')
+ self.assertEquals(resp.content, 'callback(%s);' % _flat_repr)
+
+ def test_without_callback_without_json_renderer(self):
+ """
+ Test JSONP rendering without View JSON Renderer.
+ """
+ resp = self.client.get('/jsonp/nojsonrenderer',
+ HTTP_ACCEPT='application/javascript')
+ self.assertEquals(resp.status_code, 200)
+ self.assertEquals(resp['Content-Type'], 'application/javascript')
+ self.assertEquals(resp.content, 'callback(%s);' % _flat_repr)
+
+ def test_with_callback(self):
+ """
+ Test JSONP rendering with callback function name.
+ """
+ callback_func = 'myjsonpcallback'
+ resp = self.client.get('/jsonp/nojsonrenderer?callback=' + callback_func,
+ HTTP_ACCEPT='application/javascript')
+ self.assertEquals(resp.status_code, 200)
+ self.assertEquals(resp['Content-Type'], 'application/javascript')
+ self.assertEquals(resp.content, '%s(%s);' % (callback_func, _flat_repr))
+
+
+if yaml:
+ _yaml_repr = 'foo: [bar, baz]\n'
+
+ class YAMLRendererTests(TestCase):
+ """
+ Tests specific to the JSON Renderer
+ """
+
+ def test_render(self):
+ """
+ Test basic YAML rendering.
+ """
+ obj = {'foo': ['bar', 'baz']}
+ renderer = YAMLRenderer()
+ content = renderer.render(obj, 'application/yaml')
+ self.assertEquals(content, _yaml_repr)
+
+ def test_render_and_parse(self):
+ """
+ Test rendering and then parsing returns the original object.
+ IE obj -> render -> parse -> obj.
+ """
+ obj = {'foo': ['bar', 'baz']}
+
+ renderer = YAMLRenderer()
+ parser = YAMLParser()
+
+ content = renderer.render(obj, 'application/yaml')
+ data = parser.parse(StringIO(content))
+ self.assertEquals(obj, data)
+
+
+class XMLRendererTestCase(TestCase):
+ """
+ Tests specific to the XML Renderer
+ """
+
+ _complex_data = {
+ "creation_date": datetime.datetime(2011, 12, 25, 12, 45, 00),
+ "name": "name",
+ "sub_data_list": [
+ {
+ "sub_id": 1,
+ "sub_name": "first"
+ },
+ {
+ "sub_id": 2,
+ "sub_name": "second"
+ }
+ ]
+ }
+
+ def test_render_string(self):
+ """
+ Test XML rendering.
+ """
+ renderer = XMLRenderer()
+ content = renderer.render({'field': 'astring'}, 'application/xml')
+ self.assertXMLContains(content, '<field>astring</field>')
+
+ def test_render_integer(self):
+ """
+ Test XML rendering.
+ """
+ renderer = XMLRenderer()
+ content = renderer.render({'field': 111}, 'application/xml')
+ self.assertXMLContains(content, '<field>111</field>')
+
+ def test_render_datetime(self):
+ """
+ Test XML rendering.
+ """
+ renderer = XMLRenderer()
+ content = renderer.render({
+ 'field': datetime.datetime(2011, 12, 25, 12, 45, 00)
+ }, 'application/xml')
+ self.assertXMLContains(content, '<field>2011-12-25 12:45:00</field>')
+
+ def test_render_float(self):
+ """
+ Test XML rendering.
+ """
+ renderer = XMLRenderer()
+ content = renderer.render({'field': 123.4}, 'application/xml')
+ self.assertXMLContains(content, '<field>123.4</field>')
+
+ def test_render_decimal(self):
+ """
+ Test XML rendering.
+ """
+ renderer = XMLRenderer()
+ content = renderer.render({'field': Decimal('111.2')}, 'application/xml')
+ self.assertXMLContains(content, '<field>111.2</field>')
+
+ def test_render_none(self):
+ """
+ Test XML rendering.
+ """
+ renderer = XMLRenderer()
+ content = renderer.render({'field': None}, 'application/xml')
+ self.assertXMLContains(content, '<field></field>')
+
+ def test_render_complex_data(self):
+ """
+ Test XML rendering.
+ """
+ renderer = XMLRenderer()
+ content = renderer.render(self._complex_data, 'application/xml')
+ self.assertXMLContains(content, '<sub_name>first</sub_name>')
+ self.assertXMLContains(content, '<sub_name>second</sub_name>')
+
+ def test_render_and_parse_complex_data(self):
+ """
+ Test XML rendering.
+ """
+ renderer = XMLRenderer()
+ content = StringIO(renderer.render(self._complex_data, 'application/xml'))
+
+ parser = XMLParser()
+ complex_data_out = parser.parse(content)
+ error_msg = "complex data differs!IN:\n %s \n\n OUT:\n %s" % (repr(self._complex_data), repr(complex_data_out))
+ self.assertEqual(self._complex_data, complex_data_out, error_msg)
+
+ def assertXMLContains(self, xml, string):
+ self.assertTrue(xml.startswith('<?xml version="1.0" encoding="utf-8"?>\n<root>'))
+ self.assertTrue(xml.endswith('</root>'))
+ self.assertTrue(string in xml, '%r not in %r' % (string, xml))
diff --git a/rest_framework/tests/request.py b/rest_framework/tests/request.py
new file mode 100644
index 00000000..ff48f3fa
--- /dev/null
+++ b/rest_framework/tests/request.py
@@ -0,0 +1,278 @@
+"""
+Tests for content parsing, and form-overloaded content parsing.
+"""
+from django.conf.urls.defaults import patterns
+from django.contrib.auth.models import User
+from django.test import TestCase, Client
+from django.utils import simplejson as json
+
+from rest_framework import status
+from rest_framework.authentication import SessionAuthentication
+from django.test.client import RequestFactory
+from rest_framework.parsers import (
+ BaseParser,
+ FormParser,
+ MultiPartParser,
+ JSONParser
+)
+from rest_framework.request import Request
+from rest_framework.response import Response
+from rest_framework.settings import api_settings
+from rest_framework.views import APIView
+
+
+factory = RequestFactory()
+
+
+class PlainTextParser(BaseParser):
+ media_type = 'text/plain'
+
+ def parse(self, stream, media_type=None, parser_context=None):
+ """
+ Returns a 2-tuple of `(data, files)`.
+
+ `data` will simply be a string representing the body of the request.
+ `files` will always be `None`.
+ """
+ return stream.read()
+
+
+class TestMethodOverloading(TestCase):
+ def test_method(self):
+ """
+ Request methods should be same as underlying request.
+ """
+ request = Request(factory.get('/'))
+ self.assertEqual(request.method, 'GET')
+ request = Request(factory.post('/'))
+ self.assertEqual(request.method, 'POST')
+
+ def test_overloaded_method(self):
+ """
+ POST requests can be overloaded to another method by setting a
+ reserved form field
+ """
+ request = Request(factory.post('/', {api_settings.FORM_METHOD_OVERRIDE: 'DELETE'}))
+ self.assertEqual(request.method, 'DELETE')
+
+
+class TestContentParsing(TestCase):
+ def test_standard_behaviour_determines_no_content_GET(self):
+ """
+ Ensure request.DATA returns None for GET request with no content.
+ """
+ request = Request(factory.get('/'))
+ self.assertEqual(request.DATA, None)
+
+ def test_standard_behaviour_determines_no_content_HEAD(self):
+ """
+ Ensure request.DATA returns None for HEAD request.
+ """
+ request = Request(factory.head('/'))
+ self.assertEqual(request.DATA, None)
+
+ def test_request_DATA_with_form_content(self):
+ """
+ Ensure request.DATA returns content for POST request with form content.
+ """
+ data = {'qwerty': 'uiop'}
+ request = Request(factory.post('/', data))
+ request.parsers = (FormParser(), MultiPartParser())
+ self.assertEqual(request.DATA.items(), data.items())
+
+ def test_request_DATA_with_text_content(self):
+ """
+ Ensure request.DATA returns content for POST request with
+ non-form content.
+ """
+ content = 'qwerty'
+ content_type = 'text/plain'
+ request = Request(factory.post('/', content, content_type=content_type))
+ request.parsers = (PlainTextParser(),)
+ self.assertEqual(request.DATA, content)
+
+ def test_request_POST_with_form_content(self):
+ """
+ Ensure request.POST returns content for POST request with form content.
+ """
+ data = {'qwerty': 'uiop'}
+ request = Request(factory.post('/', data))
+ request.parsers = (FormParser(), MultiPartParser())
+ self.assertEqual(request.POST.items(), data.items())
+
+ def test_standard_behaviour_determines_form_content_PUT(self):
+ """
+ Ensure request.DATA returns content for PUT request with form content.
+ """
+ data = {'qwerty': 'uiop'}
+
+ from django import VERSION
+
+ if VERSION >= (1, 5):
+ from django.test.client import MULTIPART_CONTENT, BOUNDARY, encode_multipart
+ request = Request(factory.put('/', encode_multipart(BOUNDARY, data),
+ content_type=MULTIPART_CONTENT))
+ else:
+ request = Request(factory.put('/', data))
+
+ request.parsers = (FormParser(), MultiPartParser())
+ self.assertEqual(request.DATA.items(), data.items())
+
+ def test_standard_behaviour_determines_non_form_content_PUT(self):
+ """
+ Ensure request.DATA returns content for PUT request with
+ non-form content.
+ """
+ content = 'qwerty'
+ content_type = 'text/plain'
+ request = Request(factory.put('/', content, content_type=content_type))
+ request.parsers = (PlainTextParser(), )
+ self.assertEqual(request.DATA, content)
+
+ def test_overloaded_behaviour_allows_content_tunnelling(self):
+ """
+ Ensure request.DATA returns content for overloaded POST request.
+ """
+ json_data = {'foobar': 'qwerty'}
+ content = json.dumps(json_data)
+ content_type = 'application/json'
+ form_data = {
+ api_settings.FORM_CONTENT_OVERRIDE: content,
+ api_settings.FORM_CONTENTTYPE_OVERRIDE: content_type
+ }
+ request = Request(factory.post('/', form_data))
+ request.parsers = (JSONParser(), )
+ self.assertEqual(request.DATA, json_data)
+
+ # def test_accessing_post_after_data_form(self):
+ # """
+ # Ensures request.POST can be accessed after request.DATA in
+ # form request.
+ # """
+ # data = {'qwerty': 'uiop'}
+ # request = factory.post('/', data=data)
+ # self.assertEqual(request.DATA.items(), data.items())
+ # self.assertEqual(request.POST.items(), data.items())
+
+ # def test_accessing_post_after_data_for_json(self):
+ # """
+ # Ensures request.POST can be accessed after request.DATA in
+ # json request.
+ # """
+ # data = {'qwerty': 'uiop'}
+ # content = json.dumps(data)
+ # content_type = 'application/json'
+ # parsers = (JSONParser, )
+
+ # request = factory.post('/', content, content_type=content_type,
+ # parsers=parsers)
+ # self.assertEqual(request.DATA.items(), data.items())
+ # self.assertEqual(request.POST.items(), [])
+
+ # def test_accessing_post_after_data_for_overloaded_json(self):
+ # """
+ # Ensures request.POST can be accessed after request.DATA in overloaded
+ # json request.
+ # """
+ # data = {'qwerty': 'uiop'}
+ # content = json.dumps(data)
+ # content_type = 'application/json'
+ # parsers = (JSONParser, )
+ # form_data = {Request._CONTENT_PARAM: content,
+ # Request._CONTENTTYPE_PARAM: content_type}
+
+ # request = factory.post('/', form_data, parsers=parsers)
+ # self.assertEqual(request.DATA.items(), data.items())
+ # self.assertEqual(request.POST.items(), form_data.items())
+
+ # def test_accessing_data_after_post_form(self):
+ # """
+ # Ensures request.DATA can be accessed after request.POST in
+ # form request.
+ # """
+ # data = {'qwerty': 'uiop'}
+ # parsers = (FormParser, MultiPartParser)
+ # request = factory.post('/', data, parsers=parsers)
+
+ # self.assertEqual(request.POST.items(), data.items())
+ # self.assertEqual(request.DATA.items(), data.items())
+
+ # def test_accessing_data_after_post_for_json(self):
+ # """
+ # Ensures request.DATA can be accessed after request.POST in
+ # json request.
+ # """
+ # data = {'qwerty': 'uiop'}
+ # content = json.dumps(data)
+ # content_type = 'application/json'
+ # parsers = (JSONParser, )
+ # request = factory.post('/', content, content_type=content_type,
+ # parsers=parsers)
+ # self.assertEqual(request.POST.items(), [])
+ # self.assertEqual(request.DATA.items(), data.items())
+
+ # def test_accessing_data_after_post_for_overloaded_json(self):
+ # """
+ # Ensures request.DATA can be accessed after request.POST in overloaded
+ # json request
+ # """
+ # data = {'qwerty': 'uiop'}
+ # content = json.dumps(data)
+ # content_type = 'application/json'
+ # parsers = (JSONParser, )
+ # form_data = {Request._CONTENT_PARAM: content,
+ # Request._CONTENTTYPE_PARAM: content_type}
+
+ # request = factory.post('/', form_data, parsers=parsers)
+ # self.assertEqual(request.POST.items(), form_data.items())
+ # self.assertEqual(request.DATA.items(), data.items())
+
+
+class MockView(APIView):
+ authentication_classes = (SessionAuthentication,)
+
+ def post(self, request):
+ if request.POST.get('example') is not None:
+ return Response(status=status.HTTP_200_OK)
+
+ return Response(status=status.INTERNAL_SERVER_ERROR)
+
+urlpatterns = patterns('',
+ (r'^$', MockView.as_view()),
+)
+
+
+class TestContentParsingWithAuthentication(TestCase):
+ urls = 'rest_framework.tests.request'
+
+ def setUp(self):
+ self.csrf_client = Client(enforce_csrf_checks=True)
+ self.username = 'john'
+ self.email = 'lennon@thebeatles.com'
+ self.password = 'password'
+ self.user = User.objects.create_user(self.username, self.email, self.password)
+
+ def test_user_logged_in_authentication_has_POST_when_not_logged_in(self):
+ """
+ Ensures request.POST exists after SessionAuthentication when user
+ doesn't log in.
+ """
+ content = {'example': 'example'}
+
+ response = self.client.post('/', content)
+ self.assertEqual(status.HTTP_200_OK, response.status_code)
+
+ response = self.csrf_client.post('/', content)
+ self.assertEqual(status.HTTP_200_OK, response.status_code)
+
+ # def test_user_logged_in_authentication_has_post_when_logged_in(self):
+ # """Ensures request.POST exists after UserLoggedInAuthentication when user does log in"""
+ # self.client.login(username='john', password='password')
+ # self.csrf_client.login(username='john', password='password')
+ # content = {'example': 'example'}
+
+ # response = self.client.post('/', content)
+ # self.assertEqual(status.OK, response.status_code, "POST data is malformed")
+
+ # response = self.csrf_client.post('/', content)
+ # self.assertEqual(status.OK, response.status_code, "POST data is malformed")
diff --git a/rest_framework/tests/response.py b/rest_framework/tests/response.py
new file mode 100644
index 00000000..18b6af39
--- /dev/null
+++ b/rest_framework/tests/response.py
@@ -0,0 +1,182 @@
+import unittest
+
+from django.conf.urls.defaults import patterns, url, include
+from django.test import TestCase
+
+from rest_framework.response import Response
+from rest_framework.views import APIView
+from rest_framework import status
+from rest_framework.renderers import (
+ BaseRenderer,
+ JSONRenderer,
+ BrowsableAPIRenderer
+)
+from rest_framework.settings import api_settings
+
+
+class MockPickleRenderer(BaseRenderer):
+ media_type = 'application/pickle'
+
+
+class MockJsonRenderer(BaseRenderer):
+ media_type = 'application/json'
+
+
+DUMMYSTATUS = status.HTTP_200_OK
+DUMMYCONTENT = 'dummycontent'
+
+RENDERER_A_SERIALIZER = lambda x: 'Renderer A: %s' % x
+RENDERER_B_SERIALIZER = lambda x: 'Renderer B: %s' % x
+
+
+class RendererA(BaseRenderer):
+ media_type = 'mock/renderera'
+ format = "formata"
+
+ def render(self, data, media_type=None, renderer_context=None):
+ return RENDERER_A_SERIALIZER(data)
+
+
+class RendererB(BaseRenderer):
+ media_type = 'mock/rendererb'
+ format = "formatb"
+
+ def render(self, data, media_type=None, renderer_context=None):
+ return RENDERER_B_SERIALIZER(data)
+
+
+class MockView(APIView):
+ renderer_classes = (RendererA, RendererB)
+
+ def get(self, request, **kwargs):
+ return Response(DUMMYCONTENT, status=DUMMYSTATUS)
+
+
+class HTMLView(APIView):
+ renderer_classes = (BrowsableAPIRenderer, )
+
+ def get(self, request, **kwargs):
+ return Response('text')
+
+
+class HTMLView1(APIView):
+ renderer_classes = (BrowsableAPIRenderer, JSONRenderer)
+
+ def get(self, request, **kwargs):
+ return Response('text')
+
+
+urlpatterns = patterns('',
+ url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB])),
+ url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB])),
+ url(r'^html$', HTMLView.as_view()),
+ url(r'^html1$', HTMLView1.as_view()),
+ url(r'^restframework', include('rest_framework.urls', namespace='rest_framework'))
+)
+
+
+# TODO: Clean tests bellow - remove duplicates with above, better unit testing, ...
+class RendererIntegrationTests(TestCase):
+ """
+ End-to-end testing of renderers using an ResponseMixin on a generic view.
+ """
+
+ urls = 'rest_framework.tests.response'
+
+ def test_default_renderer_serializes_content(self):
+ """If the Accept header is not set the default renderer should serialize the response."""
+ resp = self.client.get('/')
+ self.assertEquals(resp['Content-Type'], RendererA.media_type)
+ self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
+ self.assertEquals(resp.status_code, DUMMYSTATUS)
+
+ def test_head_method_serializes_no_content(self):
+ """No response must be included in HEAD requests."""
+ resp = self.client.head('/')
+ self.assertEquals(resp.status_code, DUMMYSTATUS)
+ self.assertEquals(resp['Content-Type'], RendererA.media_type)
+ self.assertEquals(resp.content, '')
+
+ def test_default_renderer_serializes_content_on_accept_any(self):
+ """If the Accept header is set to */* the default renderer should serialize the response."""
+ resp = self.client.get('/', HTTP_ACCEPT='*/*')
+ self.assertEquals(resp['Content-Type'], RendererA.media_type)
+ self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
+ self.assertEquals(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_serializes_content_default_case(self):
+ """If the Accept header is set the specified renderer should serialize the response.
+ (In this case we check that works for the default renderer)"""
+ resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
+ self.assertEquals(resp['Content-Type'], RendererA.media_type)
+ self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
+ self.assertEquals(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_serializes_content_non_default_case(self):
+ """If the Accept header is set the specified renderer should serialize the response.
+ (In this case we check that works for a non-default renderer)"""
+ resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
+ self.assertEquals(resp['Content-Type'], RendererB.media_type)
+ self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEquals(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_serializes_content_on_accept_query(self):
+ """The '_accept' query string should behave in the same way as the Accept header."""
+ param = '?%s=%s' % (
+ api_settings.URL_ACCEPT_OVERRIDE,
+ RendererB.media_type
+ )
+ resp = self.client.get('/' + param)
+ self.assertEquals(resp['Content-Type'], RendererB.media_type)
+ self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEquals(resp.status_code, DUMMYSTATUS)
+
+ @unittest.skip('can\'t pass because view is a simple Django view and response is an ImmediateResponse')
+ def test_unsatisfiable_accept_header_on_request_returns_406_status(self):
+ """If the Accept header is unsatisfiable we should return a 406 Not Acceptable response."""
+ resp = self.client.get('/', HTTP_ACCEPT='foo/bar')
+ self.assertEquals(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE)
+
+ def test_specified_renderer_serializes_content_on_format_query(self):
+ """If a 'format' query is specified, the renderer with the matching
+ format attribute should serialize the response."""
+ resp = self.client.get('/?format=%s' % RendererB.format)
+ self.assertEquals(resp['Content-Type'], RendererB.media_type)
+ self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEquals(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_serializes_content_on_format_kwargs(self):
+ """If a 'format' keyword arg is specified, the renderer with the matching
+ format attribute should serialize the response."""
+ resp = self.client.get('/something.formatb')
+ self.assertEquals(resp['Content-Type'], RendererB.media_type)
+ self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEquals(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_is_used_on_format_query_with_matching_accept(self):
+ """If both a 'format' query and a matching Accept header specified,
+ the renderer with the matching format attribute should serialize the response."""
+ resp = self.client.get('/?format=%s' % RendererB.format,
+ HTTP_ACCEPT=RendererB.media_type)
+ self.assertEquals(resp['Content-Type'], RendererB.media_type)
+ self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEquals(resp.status_code, DUMMYSTATUS)
+
+
+class Issue122Tests(TestCase):
+ """
+ Tests that covers #122.
+ """
+ urls = 'rest_framework.tests.response'
+
+ def test_only_html_renderer(self):
+ """
+ Test if no infinite recursion occurs.
+ """
+ self.client.get('/html')
+
+ def test_html_renderer_is_first(self):
+ """
+ Test if no infinite recursion occurs.
+ """
+ self.client.get('/html1')
diff --git a/rest_framework/tests/reverse.py b/rest_framework/tests/reverse.py
new file mode 100644
index 00000000..fd9a7d64
--- /dev/null
+++ b/rest_framework/tests/reverse.py
@@ -0,0 +1,26 @@
+from django.conf.urls.defaults import patterns, url
+from django.test import TestCase
+from django.test.client import RequestFactory
+from rest_framework.reverse import reverse
+
+factory = RequestFactory()
+
+
+def null_view(request):
+ pass
+
+urlpatterns = patterns('',
+ url(r'^view$', null_view, name='view'),
+)
+
+
+class ReverseTests(TestCase):
+ """
+ Tests for fully qualifed URLs when using `reverse`.
+ """
+ urls = 'rest_framework.tests.reverse'
+
+ def test_reversed_urls_are_fully_qualified(self):
+ request = factory.get('/view')
+ url = reverse('view', request=request)
+ self.assertEqual(url, 'http://testserver/view')
diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py
new file mode 100644
index 00000000..d4b43862
--- /dev/null
+++ b/rest_framework/tests/serializer.py
@@ -0,0 +1,500 @@
+import datetime
+from django.test import TestCase
+from rest_framework import serializers
+from rest_framework.tests.models import *
+
+
+class SubComment(object):
+ def __init__(self, sub_comment):
+ self.sub_comment = sub_comment
+
+
+class Comment(object):
+ def __init__(self, email, content, created):
+ self.email = email
+ self.content = content
+ self.created = created or datetime.datetime.now()
+
+ def __eq__(self, other):
+ return all([getattr(self, attr) == getattr(other, attr)
+ for attr in ('email', 'content', 'created')])
+
+ def get_sub_comment(self):
+ sub_comment = SubComment('And Merry Christmas!')
+ return sub_comment
+
+
+class CommentSerializer(serializers.Serializer):
+ email = serializers.EmailField()
+ content = serializers.CharField(max_length=1000)
+ created = serializers.DateTimeField()
+ sub_comment = serializers.Field(source='get_sub_comment.sub_comment')
+
+ def restore_object(self, data, instance=None):
+ if instance is None:
+ return Comment(**data)
+ for key, val in data.items():
+ setattr(instance, key, val)
+ return instance
+
+
+class ActionItemSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ActionItem
+
+
+class PersonSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Person
+
+
+class BasicTests(TestCase):
+ def setUp(self):
+ self.comment = Comment(
+ 'tom@example.com',
+ 'Happy new year!',
+ datetime.datetime(2012, 1, 1)
+ )
+ self.data = {
+ 'email': 'tom@example.com',
+ 'content': 'Happy new year!',
+ 'created': datetime.datetime(2012, 1, 1),
+ 'sub_comment': 'This wont change'
+ }
+ self.expected = {
+ 'email': 'tom@example.com',
+ 'content': 'Happy new year!',
+ 'created': datetime.datetime(2012, 1, 1),
+ 'sub_comment': 'And Merry Christmas!'
+ }
+
+ def test_empty(self):
+ serializer = CommentSerializer()
+ expected = {
+ 'email': '',
+ 'content': '',
+ 'created': None,
+ 'sub_comment': ''
+ }
+ self.assertEquals(serializer.data, expected)
+
+ def test_retrieve(self):
+ serializer = CommentSerializer(instance=self.comment)
+ self.assertEquals(serializer.data, self.expected)
+
+ def test_create(self):
+ serializer = CommentSerializer(self.data)
+ expected = self.comment
+ self.assertEquals(serializer.is_valid(), True)
+ self.assertEquals(serializer.object, expected)
+ self.assertFalse(serializer.object is expected)
+ self.assertEquals(serializer.data['sub_comment'], 'And Merry Christmas!')
+
+ def test_update(self):
+ serializer = CommentSerializer(self.data, instance=self.comment)
+ expected = self.comment
+ self.assertEquals(serializer.is_valid(), True)
+ self.assertEquals(serializer.object, expected)
+ self.assertTrue(serializer.object is expected)
+ self.assertEquals(serializer.data['sub_comment'], 'And Merry Christmas!')
+
+
+class ValidationTests(TestCase):
+ def setUp(self):
+ self.comment = Comment(
+ 'tom@example.com',
+ 'Happy new year!',
+ datetime.datetime(2012, 1, 1)
+ )
+ self.data = {
+ 'email': 'tom@example.com',
+ 'content': 'x' * 1001,
+ 'created': datetime.datetime(2012, 1, 1)
+ }
+ self.actionitem = ActionItem('Some to do item',
+ )
+
+ def test_create(self):
+ serializer = CommentSerializer(self.data)
+ self.assertEquals(serializer.is_valid(), False)
+ self.assertEquals(serializer.errors, {'content': [u'Ensure this value has at most 1000 characters (it has 1001).']})
+
+ def test_update(self):
+ serializer = CommentSerializer(self.data, instance=self.comment)
+ self.assertEquals(serializer.is_valid(), False)
+ self.assertEquals(serializer.errors, {'content': [u'Ensure this value has at most 1000 characters (it has 1001).']})
+
+ def test_update_missing_field(self):
+ data = {
+ 'content': 'xxx',
+ 'created': datetime.datetime(2012, 1, 1)
+ }
+ serializer = CommentSerializer(data, instance=self.comment)
+ self.assertEquals(serializer.is_valid(), False)
+ self.assertEquals(serializer.errors, {'email': [u'This field is required.']})
+
+ def test_missing_bool_with_default(self):
+ """Make sure that a boolean value with a 'False' value is not
+ mistaken for not having a default."""
+ data = {
+ 'title': 'Some action item',
+ #No 'done' value.
+ }
+ serializer = ActionItemSerializer(data, instance=self.actionitem)
+ self.assertEquals(serializer.is_valid(), True)
+ self.assertEquals(serializer.errors, {})
+
+ def test_field_validation(self):
+
+ class CommentSerializerWithFieldValidator(CommentSerializer):
+
+ def validate_content(self, attrs, source):
+ value = attrs[source]
+ if "test" not in value:
+ raise serializers.ValidationError("Test not in value")
+ return attrs
+
+ data = {
+ 'email': 'tom@example.com',
+ 'content': 'A test comment',
+ 'created': datetime.datetime(2012, 1, 1)
+ }
+
+ serializer = CommentSerializerWithFieldValidator(data)
+ self.assertTrue(serializer.is_valid())
+
+ data['content'] = 'This should not validate'
+
+ serializer = CommentSerializerWithFieldValidator(data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEquals(serializer.errors, {'content': [u'Test not in value']})
+
+ def test_cross_field_validation(self):
+
+ class CommentSerializerWithCrossFieldValidator(CommentSerializer):
+
+ def validate(self, attrs):
+ if attrs["email"] not in attrs["content"]:
+ raise serializers.ValidationError("Email address not in content")
+ return attrs
+
+ data = {
+ 'email': 'tom@example.com',
+ 'content': 'A comment from tom@example.com',
+ 'created': datetime.datetime(2012, 1, 1)
+ }
+
+ serializer = CommentSerializerWithCrossFieldValidator(data)
+ self.assertTrue(serializer.is_valid())
+
+ data['content'] = 'A comment from foo@bar.com'
+
+ serializer = CommentSerializerWithCrossFieldValidator(data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEquals(serializer.errors, {'non_field_errors': [u'Email address not in content']})
+
+ def test_null_is_true_fields(self):
+ """
+ Omitting a value for null-field should validate.
+ """
+ serializer = PersonSerializer({'name': 'marko'})
+ self.assertEquals(serializer.is_valid(), True)
+ self.assertEquals(serializer.errors, {})
+
+
+class MetadataTests(TestCase):
+ def test_empty(self):
+ serializer = CommentSerializer()
+ expected = {
+ 'email': serializers.CharField,
+ 'content': serializers.CharField,
+ 'created': serializers.DateTimeField
+ }
+ for field_name, field in expected.items():
+ self.assertTrue(isinstance(serializer.data.fields[field_name], field))
+
+
+class ManyToManyTests(TestCase):
+ def setUp(self):
+ class ManyToManySerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ManyToManyModel
+
+ self.serializer_class = ManyToManySerializer
+
+ # An anchor instance to use for the relationship
+ self.anchor = Anchor()
+ self.anchor.save()
+
+ # A model instance with a many to many relationship to the anchor
+ self.instance = ManyToManyModel()
+ self.instance.save()
+ self.instance.rel.add(self.anchor)
+
+ # A serialized representation of the model instance
+ self.data = {'id': 1, 'rel': [self.anchor.id]}
+
+ def test_retrieve(self):
+ """
+ Serialize an instance of a model with a ManyToMany relationship.
+ """
+ serializer = self.serializer_class(instance=self.instance)
+ expected = self.data
+ self.assertEquals(serializer.data, expected)
+
+ def test_create(self):
+ """
+ Create an instance of a model with a ManyToMany relationship.
+ """
+ data = {'rel': [self.anchor.id]}
+ serializer = self.serializer_class(data)
+ self.assertEquals(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEquals(len(ManyToManyModel.objects.all()), 2)
+ self.assertEquals(instance.pk, 2)
+ self.assertEquals(list(instance.rel.all()), [self.anchor])
+
+ def test_update(self):
+ """
+ Update an instance of a model with a ManyToMany relationship.
+ """
+ new_anchor = Anchor()
+ new_anchor.save()
+ data = {'rel': [self.anchor.id, new_anchor.id]}
+ serializer = self.serializer_class(data, instance=self.instance)
+ self.assertEquals(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEquals(len(ManyToManyModel.objects.all()), 1)
+ self.assertEquals(instance.pk, 1)
+ self.assertEquals(list(instance.rel.all()), [self.anchor, new_anchor])
+
+ def test_create_empty_relationship(self):
+ """
+ Create an instance of a model with a ManyToMany relationship,
+ containing no items.
+ """
+ data = {'rel': []}
+ serializer = self.serializer_class(data)
+ self.assertEquals(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEquals(len(ManyToManyModel.objects.all()), 2)
+ self.assertEquals(instance.pk, 2)
+ self.assertEquals(list(instance.rel.all()), [])
+
+ def test_update_empty_relationship(self):
+ """
+ Update an instance of a model with a ManyToMany relationship,
+ containing no items.
+ """
+ new_anchor = Anchor()
+ new_anchor.save()
+ data = {'rel': []}
+ serializer = self.serializer_class(data, instance=self.instance)
+ self.assertEquals(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEquals(len(ManyToManyModel.objects.all()), 1)
+ self.assertEquals(instance.pk, 1)
+ self.assertEquals(list(instance.rel.all()), [])
+
+ def test_create_empty_relationship_flat_data(self):
+ """
+ Create an instance of a model with a ManyToMany relationship,
+ containing no items, using a representation that does not support
+ lists (eg form data).
+ """
+ data = {'rel': ''}
+ serializer = self.serializer_class(data)
+ self.assertEquals(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEquals(len(ManyToManyModel.objects.all()), 2)
+ self.assertEquals(instance.pk, 2)
+ self.assertEquals(list(instance.rel.all()), [])
+
+
+class ReadOnlyManyToManyTests(TestCase):
+ def setUp(self):
+ class ReadOnlyManyToManySerializer(serializers.ModelSerializer):
+ rel = serializers.ManyRelatedField(read_only=True)
+
+ class Meta:
+ model = ReadOnlyManyToManyModel
+
+ self.serializer_class = ReadOnlyManyToManySerializer
+
+ # An anchor instance to use for the relationship
+ self.anchor = Anchor()
+ self.anchor.save()
+
+ # A model instance with a many to many relationship to the anchor
+ self.instance = ReadOnlyManyToManyModel()
+ self.instance.save()
+ self.instance.rel.add(self.anchor)
+
+ # A serialized representation of the model instance
+ self.data = {'rel': [self.anchor.id], 'id': 1, 'text': 'anchor'}
+
+ def test_update(self):
+ """
+ Attempt to update an instance of a model with a ManyToMany
+ relationship. Not updated due to read_only=True
+ """
+ new_anchor = Anchor()
+ new_anchor.save()
+ data = {'rel': [self.anchor.id, new_anchor.id]}
+ serializer = self.serializer_class(data, instance=self.instance)
+ self.assertEquals(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEquals(len(ReadOnlyManyToManyModel.objects.all()), 1)
+ self.assertEquals(instance.pk, 1)
+ # rel is still as original (1 entry)
+ self.assertEquals(list(instance.rel.all()), [self.anchor])
+
+ def test_update_without_relationship(self):
+ """
+ Attempt to update an instance of a model where many to ManyToMany
+ relationship is not supplied. Not updated due to read_only=True
+ """
+ new_anchor = Anchor()
+ new_anchor.save()
+ data = {}
+ serializer = self.serializer_class(data, instance=self.instance)
+ self.assertEquals(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEquals(len(ReadOnlyManyToManyModel.objects.all()), 1)
+ self.assertEquals(instance.pk, 1)
+ # rel is still as original (1 entry)
+ self.assertEquals(list(instance.rel.all()), [self.anchor])
+
+
+class DefaultValueTests(TestCase):
+ def setUp(self):
+ class DefaultValueSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = DefaultValueModel
+
+ self.serializer_class = DefaultValueSerializer
+ self.objects = DefaultValueModel.objects
+
+ def test_create_using_default(self):
+ data = {}
+ serializer = self.serializer_class(data)
+ self.assertEquals(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEquals(len(self.objects.all()), 1)
+ self.assertEquals(instance.pk, 1)
+ self.assertEquals(instance.text, 'foobar')
+
+ def test_create_overriding_default(self):
+ data = {'text': 'overridden'}
+ serializer = self.serializer_class(data)
+ self.assertEquals(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEquals(len(self.objects.all()), 1)
+ self.assertEquals(instance.pk, 1)
+ self.assertEquals(instance.text, 'overridden')
+
+
+class CallableDefaultValueTests(TestCase):
+ def setUp(self):
+ class CallableDefaultValueSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = CallableDefaultValueModel
+
+ self.serializer_class = CallableDefaultValueSerializer
+ self.objects = CallableDefaultValueModel.objects
+
+ def test_create_using_default(self):
+ data = {}
+ serializer = self.serializer_class(data)
+ self.assertEquals(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEquals(len(self.objects.all()), 1)
+ self.assertEquals(instance.pk, 1)
+ self.assertEquals(instance.text, 'foobar')
+
+ def test_create_overriding_default(self):
+ data = {'text': 'overridden'}
+ serializer = self.serializer_class(data)
+ self.assertEquals(serializer.is_valid(), True)
+ instance = serializer.save()
+ self.assertEquals(len(self.objects.all()), 1)
+ self.assertEquals(instance.pk, 1)
+ self.assertEquals(instance.text, 'overridden')
+
+
+class ManyRelatedTests(TestCase):
+ def setUp(self):
+
+ class BlogPostCommentSerializer(serializers.Serializer):
+ text = serializers.CharField()
+
+ class BlogPostSerializer(serializers.Serializer):
+ title = serializers.CharField()
+ comments = BlogPostCommentSerializer(source='blogpostcomment_set')
+
+ self.serializer_class = BlogPostSerializer
+
+ def test_reverse_relations(self):
+ post = BlogPost.objects.create(title="Test blog post")
+ post.blogpostcomment_set.create(text="I hate this blog post")
+ post.blogpostcomment_set.create(text="I love this blog post")
+
+ serializer = self.serializer_class(instance=post)
+ expected = {
+ 'title': 'Test blog post',
+ 'comments': [
+ {'text': 'I hate this blog post'},
+ {'text': 'I love this blog post'}
+ ]
+ }
+
+ self.assertEqual(serializer.data, expected)
+
+
+# Test for issue #324
+class BlankFieldTests(TestCase):
+ def setUp(self):
+
+ class BlankFieldModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BlankFieldModel
+
+ class BlankFieldSerializer(serializers.Serializer):
+ title = serializers.CharField(blank=True)
+
+ class NotBlankFieldModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BasicModel
+
+ class NotBlankFieldSerializer(serializers.Serializer):
+ title = serializers.CharField()
+
+ self.model_serializer_class = BlankFieldModelSerializer
+ self.serializer_class = BlankFieldSerializer
+ self.not_blank_model_serializer_class = NotBlankFieldModelSerializer
+ self.not_blank_serializer_class = NotBlankFieldSerializer
+ self.data = {'title': ''}
+
+ def test_create_blank_field(self):
+ serializer = self.serializer_class(self.data)
+ self.assertEquals(serializer.is_valid(), True)
+
+ def test_create_model_blank_field(self):
+ serializer = self.model_serializer_class(self.data)
+ self.assertEquals(serializer.is_valid(), True)
+
+ def test_create_not_blank_field(self):
+ """
+ Test to ensure blank data in a field not marked as blank=True
+ is considered invalid in a non-model serializer
+ """
+ serializer = self.not_blank_serializer_class(self.data)
+ self.assertEquals(serializer.is_valid(), False)
+
+ def test_create_model_not_blank_field(self):
+ """
+ Test to ensure blank data in a field not marked as blank=True
+ is considered invalid in a model serializer
+ """
+ serializer = self.not_blank_model_serializer_class(self.data)
+ self.assertEquals(serializer.is_valid(), False)
diff --git a/rest_framework/tests/status.py b/rest_framework/tests/status.py
new file mode 100644
index 00000000..30df5cef
--- /dev/null
+++ b/rest_framework/tests/status.py
@@ -0,0 +1,12 @@
+"""Tests for the status module"""
+from django.test import TestCase
+from rest_framework import status
+
+
+class TestStatus(TestCase):
+ """Simple sanity test to check the status module"""
+
+ def test_status(self):
+ """Ensure the status module is present and correct."""
+ self.assertEquals(200, status.HTTP_200_OK)
+ self.assertEquals(404, status.HTTP_404_NOT_FOUND)
diff --git a/rest_framework/tests/testcases.py b/rest_framework/tests/testcases.py
new file mode 100644
index 00000000..c90224aa
--- /dev/null
+++ b/rest_framework/tests/testcases.py
@@ -0,0 +1,63 @@
+# http://djangosnippets.org/snippets/1011/
+from django.conf import settings
+from django.core.management import call_command
+from django.db.models import loading
+from django.test import TestCase
+
+NO_SETTING = ('!', None)
+
+class TestSettingsManager(object):
+ """
+ A class which can modify some Django settings temporarily for a
+ test and then revert them to their original values later.
+
+ Automatically handles resyncing the DB if INSTALLED_APPS is
+ modified.
+
+ """
+ def __init__(self):
+ self._original_settings = {}
+
+ def set(self, **kwargs):
+ for k,v in kwargs.iteritems():
+ self._original_settings.setdefault(k, getattr(settings, k,
+ NO_SETTING))
+ setattr(settings, k, v)
+ if 'INSTALLED_APPS' in kwargs:
+ self.syncdb()
+
+ def syncdb(self):
+ loading.cache.loaded = False
+ call_command('syncdb', verbosity=0)
+
+ def revert(self):
+ for k,v in self._original_settings.iteritems():
+ if v == NO_SETTING:
+ delattr(settings, k)
+ else:
+ setattr(settings, k, v)
+ if 'INSTALLED_APPS' in self._original_settings:
+ self.syncdb()
+ self._original_settings = {}
+
+
+class SettingsTestCase(TestCase):
+ """
+ A subclass of the Django TestCase with a settings_manager
+ attribute which is an instance of TestSettingsManager.
+
+ Comes with a tearDown() method that calls
+ self.settings_manager.revert().
+
+ """
+ def __init__(self, *args, **kwargs):
+ super(SettingsTestCase, self).__init__(*args, **kwargs)
+ self.settings_manager = TestSettingsManager()
+
+ def tearDown(self):
+ self.settings_manager.revert()
+
+class TestModelsTestCase(SettingsTestCase):
+ def setUp(self, *args, **kwargs):
+ installed_apps = tuple(settings.INSTALLED_APPS) + ('rest_framework.tests',)
+ self.settings_manager.set(INSTALLED_APPS=installed_apps)
diff --git a/rest_framework/tests/tests.py b/rest_framework/tests/tests.py
new file mode 100644
index 00000000..adeaf6da
--- /dev/null
+++ b/rest_framework/tests/tests.py
@@ -0,0 +1,13 @@
+"""
+Force import of all modules in this package in order to get the standard test
+runner to pick up the tests. Yowzers.
+"""
+import os
+
+modules = [filename.rsplit('.', 1)[0]
+ for filename in os.listdir(os.path.dirname(__file__))
+ if filename.endswith('.py') and not filename.startswith('_')]
+__test__ = dict()
+
+for module in modules:
+ exec("from rest_framework.tests.%s import *" % module)
diff --git a/rest_framework/tests/throttling.py b/rest_framework/tests/throttling.py
new file mode 100644
index 00000000..0b94c25b
--- /dev/null
+++ b/rest_framework/tests/throttling.py
@@ -0,0 +1,144 @@
+"""
+Tests for the throttling implementations in the permissions module.
+"""
+
+from django.test import TestCase
+from django.contrib.auth.models import User
+from django.core.cache import cache
+
+from django.test.client import RequestFactory
+from rest_framework.views import APIView
+from rest_framework.throttling import UserRateThrottle
+from rest_framework.response import Response
+
+
+class User3SecRateThrottle(UserRateThrottle):
+ rate = '3/sec'
+ scope = 'seconds'
+
+
+class User3MinRateThrottle(UserRateThrottle):
+ rate = '3/min'
+ scope = 'minutes'
+
+
+class MockView(APIView):
+ throttle_classes = (User3SecRateThrottle,)
+
+ def get(self, request):
+ return Response('foo')
+
+
+class MockView_MinuteThrottling(APIView):
+ throttle_classes = (User3MinRateThrottle,)
+
+ def get(self, request):
+ return Response('foo')
+
+
+class ThrottlingTests(TestCase):
+ urls = 'rest_framework.tests.throttling'
+
+ def setUp(self):
+ """
+ Reset the cache so that no throttles will be active
+ """
+ cache.clear()
+ self.factory = RequestFactory()
+
+ def test_requests_are_throttled(self):
+ """
+ Ensure request rate is limited
+ """
+ request = self.factory.get('/')
+ for dummy in range(4):
+ response = MockView.as_view()(request)
+ self.assertEqual(429, response.status_code)
+
+ def set_throttle_timer(self, view, value):
+ """
+ Explicitly set the timer, overriding time.time()
+ """
+ view.throttle_classes[0].timer = lambda self: value
+
+ def test_request_throttling_expires(self):
+ """
+ Ensure request rate is limited for a limited duration only
+ """
+ self.set_throttle_timer(MockView, 0)
+
+ request = self.factory.get('/')
+ for dummy in range(4):
+ response = MockView.as_view()(request)
+ self.assertEqual(429, response.status_code)
+
+ # Advance the timer by one second
+ self.set_throttle_timer(MockView, 1)
+
+ response = MockView.as_view()(request)
+ self.assertEqual(200, response.status_code)
+
+ def ensure_is_throttled(self, view, expect):
+ request = self.factory.get('/')
+ request.user = User.objects.create(username='a')
+ for dummy in range(3):
+ view.as_view()(request)
+ request.user = User.objects.create(username='b')
+ response = view.as_view()(request)
+ self.assertEqual(expect, response.status_code)
+
+ def test_request_throttling_is_per_user(self):
+ """
+ Ensure request rate is only limited per user, not globally for
+ PerUserThrottles
+ """
+ self.ensure_is_throttled(MockView, 200)
+
+ def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers):
+ """
+ Ensure the response returns an X-Throttle field with status and next attributes
+ set properly.
+ """
+ request = self.factory.get('/')
+ for timer, expect in expected_headers:
+ self.set_throttle_timer(view, timer)
+ response = view.as_view()(request)
+ if expect is not None:
+ self.assertEquals(response['X-Throttle-Wait-Seconds'], expect)
+ else:
+ self.assertFalse('X-Throttle-Wait-Seconds' in response.headers)
+
+ def test_seconds_fields(self):
+ """
+ Ensure for second based throttles.
+ """
+ self.ensure_response_header_contains_proper_throttle_field(MockView,
+ ((0, None),
+ (0, None),
+ (0, None),
+ (0, '1')
+ ))
+
+ def test_minutes_fields(self):
+ """
+ Ensure for minute based throttles.
+ """
+ self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling,
+ ((0, None),
+ (0, None),
+ (0, None),
+ (0, '60')
+ ))
+
+ def test_next_rate_remains_constant_if_followed(self):
+ """
+ If a client follows the recommended next request rate,
+ the throttling rate should stay constant.
+ """
+ self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling,
+ ((0, None),
+ (20, None),
+ (40, None),
+ (60, None),
+ (80, None)
+ ))
diff --git a/rest_framework/tests/validators.py b/rest_framework/tests/validators.py
new file mode 100644
index 00000000..c032985e
--- /dev/null
+++ b/rest_framework/tests/validators.py
@@ -0,0 +1,329 @@
+# from django import forms
+# from django.db import models
+# from django.test import TestCase
+# from rest_framework.response import ImmediateResponse
+# from rest_framework.views import View
+
+
+# class TestDisabledValidations(TestCase):
+# """Tests on FormValidator with validation disabled by setting form to None"""
+
+# def test_disabled_form_validator_returns_content_unchanged(self):
+# """If the view's form attribute is None then FormValidator(view).validate_request(content, None)
+# should just return the content unmodified."""
+# class DisabledFormResource(FormResource):
+# form = None
+
+# class MockView(View):
+# resource = DisabledFormResource
+
+# view = MockView()
+# content = {'qwerty': 'uiop'}
+# self.assertEqual(FormResource(view).validate_request(content, None), content)
+
+# def test_disabled_form_validator_get_bound_form_returns_none(self):
+# """If the view's form attribute is None on then
+# FormValidator(view).get_bound_form(content) should just return None."""
+# class DisabledFormResource(FormResource):
+# form = None
+
+# class MockView(View):
+# resource = DisabledFormResource
+
+# view = MockView()
+# content = {'qwerty': 'uiop'}
+# self.assertEqual(FormResource(view).get_bound_form(content), None)
+
+# def test_disabled_model_form_validator_returns_content_unchanged(self):
+# """If the view's form is None and does not have a Resource with a model set then
+# ModelFormValidator(view).validate_request(content, None) should just return the content unmodified."""
+
+# class DisabledModelFormView(View):
+# resource = ModelResource
+
+# view = DisabledModelFormView()
+# content = {'qwerty': 'uiop'}
+# self.assertEqual(ModelResource(view).get_bound_form(content), None)
+
+# def test_disabled_model_form_validator_get_bound_form_returns_none(self):
+# """If the form attribute is None on FormValidatorMixin then get_bound_form(content) should just return None."""
+# class DisabledModelFormView(View):
+# resource = ModelResource
+
+# view = DisabledModelFormView()
+# content = {'qwerty': 'uiop'}
+# self.assertEqual(ModelResource(view).get_bound_form(content), None)
+
+
+# class TestNonFieldErrors(TestCase):
+# """Tests against form validation errors caused by non-field errors. (eg as might be caused by some custom form validation)"""
+
+# def test_validate_failed_due_to_non_field_error_returns_appropriate_message(self):
+# """If validation fails with a non-field error, ensure the response a non-field error"""
+# class MockForm(forms.Form):
+# field1 = forms.CharField(required=False)
+# field2 = forms.CharField(required=False)
+# ERROR_TEXT = 'You may not supply both field1 and field2'
+
+# def clean(self):
+# if 'field1' in self.cleaned_data and 'field2' in self.cleaned_data:
+# raise forms.ValidationError(self.ERROR_TEXT)
+# return self.cleaned_data
+
+# class MockResource(FormResource):
+# form = MockForm
+
+# class MockView(View):
+# pass
+
+# view = MockView()
+# content = {'field1': 'example1', 'field2': 'example2'}
+# try:
+# MockResource(view).validate_request(content, None)
+# except ImmediateResponse, exc:
+# response = exc.response
+# self.assertEqual(response.raw_content, {'errors': [MockForm.ERROR_TEXT]})
+# else:
+# self.fail('ImmediateResponse was not raised')
+
+
+# class TestFormValidation(TestCase):
+# """Tests which check basic form validation.
+# Also includes the same set of tests with a ModelFormValidator for which the form has been explicitly set.
+# (ModelFormValidator should behave as FormValidator if a form is set rather than relying on the default ModelForm)"""
+# def setUp(self):
+# class MockForm(forms.Form):
+# qwerty = forms.CharField(required=True)
+
+# class MockFormResource(FormResource):
+# form = MockForm
+
+# class MockModelResource(ModelResource):
+# form = MockForm
+
+# class MockFormView(View):
+# resource = MockFormResource
+
+# class MockModelFormView(View):
+# resource = MockModelResource
+
+# self.MockFormResource = MockFormResource
+# self.MockModelResource = MockModelResource
+# self.MockFormView = MockFormView
+# self.MockModelFormView = MockModelFormView
+
+# def validation_returns_content_unchanged_if_already_valid_and_clean(self, validator):
+# """If the content is already valid and clean then validate(content) should just return the content unmodified."""
+# content = {'qwerty': 'uiop'}
+# self.assertEqual(validator.validate_request(content, None), content)
+
+# def validation_failure_raises_response_exception(self, validator):
+# """If form validation fails a ResourceException 400 (Bad Request) should be raised."""
+# content = {}
+# self.assertRaises(ImmediateResponse, validator.validate_request, content, None)
+
+# def validation_does_not_allow_extra_fields_by_default(self, validator):
+# """If some (otherwise valid) content includes fields that are not in the form then validation should fail.
+# It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up
+# broken clients more easily (eg submitting content with a misnamed field)"""
+# content = {'qwerty': 'uiop', 'extra': 'extra'}
+# self.assertRaises(ImmediateResponse, validator.validate_request, content, None)
+
+# def validation_allows_extra_fields_if_explicitly_set(self, validator):
+# """If we include an allowed_extra_fields paramater on _validate, then allow fields with those names."""
+# content = {'qwerty': 'uiop', 'extra': 'extra'}
+# validator._validate(content, None, allowed_extra_fields=('extra',))
+
+# def validation_allows_unknown_fields_if_explicitly_allowed(self, validator):
+# """If we set ``unknown_form_fields`` on the form resource, then don't
+# raise errors on unexpected request data"""
+# content = {'qwerty': 'uiop', 'extra': 'extra'}
+# validator.allow_unknown_form_fields = True
+# self.assertEqual({'qwerty': u'uiop'},
+# validator.validate_request(content, None),
+# "Resource didn't accept unknown fields.")
+# validator.allow_unknown_form_fields = False
+
+# def validation_does_not_require_extra_fields_if_explicitly_set(self, validator):
+# """If we include an allowed_extra_fields paramater on _validate, then do not fail if we do not have fields with those names."""
+# content = {'qwerty': 'uiop'}
+# self.assertEqual(validator._validate(content, None, allowed_extra_fields=('extra',)), content)
+
+# def validation_failed_due_to_no_content_returns_appropriate_message(self, validator):
+# """If validation fails due to no content, ensure the response contains a single non-field error"""
+# content = {}
+# try:
+# validator.validate_request(content, None)
+# except ImmediateResponse, exc:
+# response = exc.response
+# self.assertEqual(response.raw_content, {'field_errors': {'qwerty': ['This field is required.']}})
+# else:
+# self.fail('ResourceException was not raised')
+
+# def validation_failed_due_to_field_error_returns_appropriate_message(self, validator):
+# """If validation fails due to a field error, ensure the response contains a single field error"""
+# content = {'qwerty': ''}
+# try:
+# validator.validate_request(content, None)
+# except ImmediateResponse, exc:
+# response = exc.response
+# self.assertEqual(response.raw_content, {'field_errors': {'qwerty': ['This field is required.']}})
+# else:
+# self.fail('ResourceException was not raised')
+
+# def validation_failed_due_to_invalid_field_returns_appropriate_message(self, validator):
+# """If validation fails due to an invalid field, ensure the response contains a single field error"""
+# content = {'qwerty': 'uiop', 'extra': 'extra'}
+# try:
+# validator.validate_request(content, None)
+# except ImmediateResponse, exc:
+# response = exc.response
+# self.assertEqual(response.raw_content, {'field_errors': {'extra': ['This field does not exist.']}})
+# else:
+# self.fail('ResourceException was not raised')
+
+# def validation_failed_due_to_multiple_errors_returns_appropriate_message(self, validator):
+# """If validation for multiple reasons, ensure the response contains each error"""
+# content = {'qwerty': '', 'extra': 'extra'}
+# try:
+# validator.validate_request(content, None)
+# except ImmediateResponse, exc:
+# response = exc.response
+# self.assertEqual(response.raw_content, {'field_errors': {'qwerty': ['This field is required.'],
+# 'extra': ['This field does not exist.']}})
+# else:
+# self.fail('ResourceException was not raised')
+
+# # Tests on FormResource
+
+# def test_form_validation_returns_content_unchanged_if_already_valid_and_clean(self):
+# validator = self.MockFormResource(self.MockFormView())
+# self.validation_returns_content_unchanged_if_already_valid_and_clean(validator)
+
+# def test_form_validation_failure_raises_response_exception(self):
+# validator = self.MockFormResource(self.MockFormView())
+# self.validation_failure_raises_response_exception(validator)
+
+# def test_validation_does_not_allow_extra_fields_by_default(self):
+# validator = self.MockFormResource(self.MockFormView())
+# self.validation_does_not_allow_extra_fields_by_default(validator)
+
+# def test_validation_allows_extra_fields_if_explicitly_set(self):
+# validator = self.MockFormResource(self.MockFormView())
+# self.validation_allows_extra_fields_if_explicitly_set(validator)
+
+# def test_validation_allows_unknown_fields_if_explicitly_allowed(self):
+# validator = self.MockFormResource(self.MockFormView())
+# self.validation_allows_unknown_fields_if_explicitly_allowed(validator)
+
+# def test_validation_does_not_require_extra_fields_if_explicitly_set(self):
+# validator = self.MockFormResource(self.MockFormView())
+# self.validation_does_not_require_extra_fields_if_explicitly_set(validator)
+
+# def test_validation_failed_due_to_no_content_returns_appropriate_message(self):
+# validator = self.MockFormResource(self.MockFormView())
+# self.validation_failed_due_to_no_content_returns_appropriate_message(validator)
+
+# def test_validation_failed_due_to_field_error_returns_appropriate_message(self):
+# validator = self.MockFormResource(self.MockFormView())
+# self.validation_failed_due_to_field_error_returns_appropriate_message(validator)
+
+# def test_validation_failed_due_to_invalid_field_returns_appropriate_message(self):
+# validator = self.MockFormResource(self.MockFormView())
+# self.validation_failed_due_to_invalid_field_returns_appropriate_message(validator)
+
+# def test_validation_failed_due_to_multiple_errors_returns_appropriate_message(self):
+# validator = self.MockFormResource(self.MockFormView())
+# self.validation_failed_due_to_multiple_errors_returns_appropriate_message(validator)
+
+# # Same tests on ModelResource
+
+# def test_modelform_validation_returns_content_unchanged_if_already_valid_and_clean(self):
+# validator = self.MockModelResource(self.MockModelFormView())
+# self.validation_returns_content_unchanged_if_already_valid_and_clean(validator)
+
+# def test_modelform_validation_failure_raises_response_exception(self):
+# validator = self.MockModelResource(self.MockModelFormView())
+# self.validation_failure_raises_response_exception(validator)
+
+# def test_modelform_validation_does_not_allow_extra_fields_by_default(self):
+# validator = self.MockModelResource(self.MockModelFormView())
+# self.validation_does_not_allow_extra_fields_by_default(validator)
+
+# def test_modelform_validation_allows_extra_fields_if_explicitly_set(self):
+# validator = self.MockModelResource(self.MockModelFormView())
+# self.validation_allows_extra_fields_if_explicitly_set(validator)
+
+# def test_modelform_validation_does_not_require_extra_fields_if_explicitly_set(self):
+# validator = self.MockModelResource(self.MockModelFormView())
+# self.validation_does_not_require_extra_fields_if_explicitly_set(validator)
+
+# def test_modelform_validation_failed_due_to_no_content_returns_appropriate_message(self):
+# validator = self.MockModelResource(self.MockModelFormView())
+# self.validation_failed_due_to_no_content_returns_appropriate_message(validator)
+
+# def test_modelform_validation_failed_due_to_field_error_returns_appropriate_message(self):
+# validator = self.MockModelResource(self.MockModelFormView())
+# self.validation_failed_due_to_field_error_returns_appropriate_message(validator)
+
+# def test_modelform_validation_failed_due_to_invalid_field_returns_appropriate_message(self):
+# validator = self.MockModelResource(self.MockModelFormView())
+# self.validation_failed_due_to_invalid_field_returns_appropriate_message(validator)
+
+# def test_modelform_validation_failed_due_to_multiple_errors_returns_appropriate_message(self):
+# validator = self.MockModelResource(self.MockModelFormView())
+# self.validation_failed_due_to_multiple_errors_returns_appropriate_message(validator)
+
+
+# class TestModelFormValidator(TestCase):
+# """Tests specific to ModelFormValidatorMixin"""
+
+# def setUp(self):
+# """Create a validator for a model with two fields and a property."""
+# class MockModel(models.Model):
+# qwerty = models.CharField(max_length=256)
+# uiop = models.CharField(max_length=256, blank=True)
+
+# @property
+# def read_only(self):
+# return 'read only'
+
+# class MockResource(ModelResource):
+# model = MockModel
+
+# class MockView(View):
+# resource = MockResource
+
+# self.validator = MockResource(MockView)
+
+# def test_property_fields_are_allowed_on_model_forms(self):
+# """Validation on ModelForms may include property fields that exist on the Model to be included in the input."""
+# content = {'qwerty': 'example', 'uiop': 'example', 'read_only': 'read only'}
+# self.assertEqual(self.validator.validate_request(content, None), content)
+
+# def test_property_fields_are_not_required_on_model_forms(self):
+# """Validation on ModelForms does not require property fields that exist on the Model to be included in the input."""
+# content = {'qwerty': 'example', 'uiop': 'example'}
+# self.assertEqual(self.validator.validate_request(content, None), content)
+
+# def test_extra_fields_not_allowed_on_model_forms(self):
+# """If some (otherwise valid) content includes fields that are not in the form then validation should fail.
+# It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up
+# broken clients more easily (eg submitting content with a misnamed field)"""
+# content = {'qwerty': 'example', 'uiop': 'example', 'read_only': 'read only', 'extra': 'extra'}
+# self.assertRaises(ImmediateResponse, self.validator.validate_request, content, None)
+
+# def test_validate_requires_fields_on_model_forms(self):
+# """If some (otherwise valid) content includes fields that are not in the form then validation should fail.
+# It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up
+# broken clients more easily (eg submitting content with a misnamed field)"""
+# content = {'read_only': 'read only'}
+# self.assertRaises(ImmediateResponse, self.validator.validate_request, content, None)
+
+# def test_validate_does_not_require_blankable_fields_on_model_forms(self):
+# """Test standard ModelForm validation behaviour - fields with blank=True are not required."""
+# content = {'qwerty': 'example', 'read_only': 'read only'}
+# self.validator.validate_request(content, None)
+
+# def test_model_form_validator_uses_model_forms(self):
+# self.assertTrue(isinstance(self.validator.get_bound_form(), forms.ModelForm))
diff --git a/rest_framework/tests/views.py b/rest_framework/tests/views.py
new file mode 100644
index 00000000..43365e07
--- /dev/null
+++ b/rest_framework/tests/views.py
@@ -0,0 +1,97 @@
+import copy
+from django.test import TestCase
+from django.test.client import RequestFactory
+from rest_framework import status
+from rest_framework.decorators import api_view
+from rest_framework.response import Response
+from rest_framework.settings import api_settings
+from rest_framework.views import APIView
+
+factory = RequestFactory()
+
+
+class BasicView(APIView):
+ def get(self, request, *args, **kwargs):
+ return Response({'method': 'GET'})
+
+ def post(self, request, *args, **kwargs):
+ return Response({'method': 'POST', 'data': request.DATA})
+
+
+@api_view(['GET', 'POST', 'PUT'])
+def basic_view(request):
+ if request.method == 'GET':
+ return {'method': 'GET'}
+ elif request.method == 'POST':
+ return {'method': 'POST', 'data': request.DATA}
+ elif request.method == 'PUT':
+ return {'method': 'PUT', 'data': request.DATA}
+
+
+def sanitise_json_error(error_dict):
+ """
+ Exact contents of JSON error messages depend on the installed version
+ of json.
+ """
+ ret = copy.copy(error_dict)
+ chop = len('JSON parse error - No JSON object could be decoded')
+ ret['detail'] = ret['detail'][:chop]
+ return ret
+
+
+class ClassBasedViewIntegrationTests(TestCase):
+ def setUp(self):
+ self.view = BasicView.as_view()
+
+ def test_400_parse_error(self):
+ request = factory.post('/', 'f00bar', content_type='application/json')
+ response = self.view(request)
+ expected = {
+ 'detail': u'JSON parse error - No JSON object could be decoded'
+ }
+ self.assertEquals(response.status_code, status.HTTP_400_BAD_REQUEST)
+ self.assertEquals(sanitise_json_error(response.data), expected)
+
+ def test_400_parse_error_tunneled_content(self):
+ content = 'f00bar'
+ content_type = 'application/json'
+ form_data = {
+ api_settings.FORM_CONTENT_OVERRIDE: content,
+ api_settings.FORM_CONTENTTYPE_OVERRIDE: content_type
+ }
+ request = factory.post('/', form_data)
+ response = self.view(request)
+ expected = {
+ 'detail': u'JSON parse error - No JSON object could be decoded'
+ }
+ self.assertEquals(response.status_code, status.HTTP_400_BAD_REQUEST)
+ self.assertEquals(sanitise_json_error(response.data), expected)
+
+
+class FunctionBasedViewIntegrationTests(TestCase):
+ def setUp(self):
+ self.view = basic_view
+
+ def test_400_parse_error(self):
+ request = factory.post('/', 'f00bar', content_type='application/json')
+ response = self.view(request)
+ expected = {
+ 'detail': u'JSON parse error - No JSON object could be decoded'
+ }
+ self.assertEquals(response.status_code, status.HTTP_400_BAD_REQUEST)
+ self.assertEquals(sanitise_json_error(response.data), expected)
+
+ def test_400_parse_error_tunneled_content(self):
+ content = 'f00bar'
+ content_type = 'application/json'
+ form_data = {
+ api_settings.FORM_CONTENT_OVERRIDE: content,
+ api_settings.FORM_CONTENTTYPE_OVERRIDE: content_type
+ }
+ request = factory.post('/', form_data)
+ response = self.view(request)
+ expected = {
+ 'detail': u'JSON parse error - No JSON object could be decoded'
+ }
+ self.assertEquals(response.status_code, status.HTTP_400_BAD_REQUEST)
+ self.assertEquals(sanitise_json_error(response.data), expected)