diff options
Diffstat (limited to 'rest_framework/tests')
22 files changed, 2853 insertions, 0 deletions
diff --git a/rest_framework/tests/__init__.py b/rest_framework/tests/__init__.py new file mode 100644 index 00000000..ba3a27c0 --- /dev/null +++ b/rest_framework/tests/__init__.py @@ -0,0 +1,12 @@ +"""Force import of all modules in this package in order to get the standard test runner to pick up the tests. Yowzers.""" +import os + +modules = [filename.rsplit('.', 1)[0] + for filename in os.listdir(os.path.dirname(__file__)) + if filename.endswith('.py') and not filename.startswith('_')] +__test__ = dict() + +for module in modules: + exec("from rest_framework.tests.%s import __doc__ as module_doc" % module) + exec("from rest_framework.tests.%s import *" % module) + __test__[module] = module_doc or "" diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py new file mode 100644 index 00000000..0a3b2e02 --- /dev/null +++ b/rest_framework/tests/authentication.py @@ -0,0 +1,153 @@ +from django.conf.urls.defaults import patterns +from django.contrib.auth.models import User +from django.test import Client, TestCase + +from django.utils import simplejson as json +from django.http import HttpResponse + +from rest_framework.views import APIView +from rest_framework import permissions + +from rest_framework.authtoken.models import Token +from rest_framework.authentication import TokenAuthentication + +import base64 + + +class MockView(APIView): + permission_classes = (permissions.IsAuthenticated,) + + def post(self, request): + return HttpResponse({'a': 1, 'b': 2, 'c': 3}) + + def put(self, request): + return HttpResponse({'a': 1, 'b': 2, 'c': 3}) + +MockView.authentication_classes += (TokenAuthentication,) + +urlpatterns = patterns('', + (r'^$', MockView.as_view()), +) + + +class BasicAuthTests(TestCase): + """Basic authentication""" + urls = 'rest_framework.tests.authentication' + + def setUp(self): + self.csrf_client = Client(enforce_csrf_checks=True) + self.username = 'john' + self.email = 'lennon@thebeatles.com' + self.password = 'password' + self.user = User.objects.create_user(self.username, self.email, self.password) + + def test_post_form_passing_basic_auth(self): + """Ensure POSTing json over basic auth with correct credentials passes and does not require CSRF""" + auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).strip() + response = self.csrf_client.post('/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 200) + + def test_post_json_passing_basic_auth(self): + """Ensure POSTing form over basic auth with correct credentials passes and does not require CSRF""" + auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).strip() + response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 200) + + def test_post_form_failing_basic_auth(self): + """Ensure POSTing form over basic auth without correct credentials fails""" + response = self.csrf_client.post('/', {'example': 'example'}) + self.assertEqual(response.status_code, 403) + + def test_post_json_failing_basic_auth(self): + """Ensure POSTing json over basic auth without correct credentials fails""" + response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json') + self.assertEqual(response.status_code, 403) + + +class SessionAuthTests(TestCase): + """User session authentication""" + urls = 'rest_framework.tests.authentication' + + def setUp(self): + self.csrf_client = Client(enforce_csrf_checks=True) + self.non_csrf_client = Client(enforce_csrf_checks=False) + self.username = 'john' + self.email = 'lennon@thebeatles.com' + self.password = 'password' + self.user = User.objects.create_user(self.username, self.email, self.password) + + def tearDown(self): + self.csrf_client.logout() + + def test_post_form_session_auth_failing_csrf(self): + """ + Ensure POSTing form over session authentication without CSRF token fails. + """ + self.csrf_client.login(username=self.username, password=self.password) + response = self.csrf_client.post('/', {'example': 'example'}) + self.assertEqual(response.status_code, 403) + + def test_post_form_session_auth_passing(self): + """ + Ensure POSTing form over session authentication with logged in user and CSRF token passes. + """ + self.non_csrf_client.login(username=self.username, password=self.password) + response = self.non_csrf_client.post('/', {'example': 'example'}) + self.assertEqual(response.status_code, 200) + + def test_put_form_session_auth_passing(self): + """ + Ensure PUTting form over session authentication with logged in user and CSRF token passes. + """ + self.non_csrf_client.login(username=self.username, password=self.password) + response = self.non_csrf_client.put('/', {'example': 'example'}) + self.assertEqual(response.status_code, 200) + + def test_post_form_session_auth_failing(self): + """ + Ensure POSTing form over session authentication without logged in user fails. + """ + response = self.csrf_client.post('/', {'example': 'example'}) + self.assertEqual(response.status_code, 403) + + +class TokenAuthTests(TestCase): + """Token authentication""" + urls = 'rest_framework.tests.authentication' + + def setUp(self): + self.csrf_client = Client(enforce_csrf_checks=True) + self.username = 'john' + self.email = 'lennon@thebeatles.com' + self.password = 'password' + self.user = User.objects.create_user(self.username, self.email, self.password) + + self.key = 'abcd1234' + self.token = Token.objects.create(key=self.key, user=self.user) + + def test_post_form_passing_token_auth(self): + """Ensure POSTing json over token auth with correct credentials passes and does not require CSRF""" + auth = "Token " + self.key + response = self.csrf_client.post('/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 200) + + def test_post_json_passing_token_auth(self): + """Ensure POSTing form over token auth with correct credentials passes and does not require CSRF""" + auth = "Token " + self.key + response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 200) + + def test_post_form_failing_token_auth(self): + """Ensure POSTing form over token auth without correct credentials fails""" + response = self.csrf_client.post('/', {'example': 'example'}) + self.assertEqual(response.status_code, 403) + + def test_post_json_failing_token_auth(self): + """Ensure POSTing json over token auth without correct credentials fails""" + response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json') + self.assertEqual(response.status_code, 403) + + def test_token_has_auto_assigned_key_if_none_provided(self): + """Ensure creating a token with no key will auto-assign a key""" + token = Token.objects.create(user=self.user) + self.assertTrue(bool(token.key)) diff --git a/rest_framework/tests/breadcrumbs.py b/rest_framework/tests/breadcrumbs.py new file mode 100644 index 00000000..647ab96d --- /dev/null +++ b/rest_framework/tests/breadcrumbs.py @@ -0,0 +1,72 @@ +from django.conf.urls.defaults import patterns, url +from django.test import TestCase +from rest_framework.utils.breadcrumbs import get_breadcrumbs +from rest_framework.views import APIView + + +class Root(APIView): + pass + + +class ResourceRoot(APIView): + pass + + +class ResourceInstance(APIView): + pass + + +class NestedResourceRoot(APIView): + pass + + +class NestedResourceInstance(APIView): + pass + +urlpatterns = patterns('', + url(r'^$', Root.as_view()), + url(r'^resource/$', ResourceRoot.as_view()), + url(r'^resource/(?P<key>[0-9]+)$', ResourceInstance.as_view()), + url(r'^resource/(?P<key>[0-9]+)/$', NestedResourceRoot.as_view()), + url(r'^resource/(?P<key>[0-9]+)/(?P<other>[A-Za-z]+)$', NestedResourceInstance.as_view()), +) + + +class BreadcrumbTests(TestCase): + """Tests the breadcrumb functionality used by the HTML renderer.""" + + urls = 'rest_framework.tests.breadcrumbs' + + def test_root_breadcrumbs(self): + url = '/' + self.assertEqual(get_breadcrumbs(url), [('Root', '/')]) + + def test_resource_root_breadcrumbs(self): + url = '/resource/' + self.assertEqual(get_breadcrumbs(url), [('Root', '/'), + ('Resource Root', '/resource/')]) + + def test_resource_instance_breadcrumbs(self): + url = '/resource/123' + self.assertEqual(get_breadcrumbs(url), [('Root', '/'), + ('Resource Root', '/resource/'), + ('Resource Instance', '/resource/123')]) + + def test_nested_resource_breadcrumbs(self): + url = '/resource/123/' + self.assertEqual(get_breadcrumbs(url), [('Root', '/'), + ('Resource Root', '/resource/'), + ('Resource Instance', '/resource/123'), + ('Nested Resource Root', '/resource/123/')]) + + def test_nested_resource_instance_breadcrumbs(self): + url = '/resource/123/abc' + self.assertEqual(get_breadcrumbs(url), [('Root', '/'), + ('Resource Root', '/resource/'), + ('Resource Instance', '/resource/123'), + ('Nested Resource Root', '/resource/123/'), + ('Nested Resource Instance', '/resource/123/abc')]) + + def test_broken_url_breadcrumbs_handled_gracefully(self): + url = '/foobar' + self.assertEqual(get_breadcrumbs(url), [('Root', '/')]) diff --git a/rest_framework/tests/description.py b/rest_framework/tests/description.py new file mode 100644 index 00000000..d958b840 --- /dev/null +++ b/rest_framework/tests/description.py @@ -0,0 +1,113 @@ +from django.test import TestCase +from rest_framework.views import APIView +from rest_framework.compat import apply_markdown + +# We check that docstrings get nicely un-indented. +DESCRIPTION = """an example docstring +==================== + +* list +* list + +another header +-------------- + + code block + +indented + +# hash style header #""" + +# If markdown is installed we also test it's working +# (and that our wrapped forces '=' to h2 and '-' to h3) + +# We support markdown < 2.1 and markdown >= 2.1 +MARKED_DOWN_lt_21 = """<h2>an example docstring</h2> +<ul> +<li>list</li> +<li>list</li> +</ul> +<h3>another header</h3> +<pre><code>code block +</code></pre> +<p>indented</p> +<h2 id="hash_style_header">hash style header</h2>""" + +MARKED_DOWN_gte_21 = """<h2 id="an-example-docstring">an example docstring</h2> +<ul> +<li>list</li> +<li>list</li> +</ul> +<h3 id="another-header">another header</h3> +<pre><code>code block +</code></pre> +<p>indented</p> +<h2 id="hash-style-header">hash style header</h2>""" + + +class TestViewNamesAndDescriptions(TestCase): + def test_resource_name_uses_classname_by_default(self): + """Ensure Resource names are based on the classname by default.""" + class MockView(APIView): + pass + self.assertEquals(MockView().get_name(), 'Mock') + + def test_resource_name_can_be_set_explicitly(self): + """Ensure Resource names can be set using the 'get_name' method.""" + example = 'Some Other Name' + class MockView(APIView): + def get_name(self): + return example + self.assertEquals(MockView().get_name(), example) + + def test_resource_description_uses_docstring_by_default(self): + """Ensure Resource names are based on the docstring by default.""" + class MockView(APIView): + """an example docstring + ==================== + + * list + * list + + another header + -------------- + + code block + + indented + + # hash style header #""" + + self.assertEquals(MockView().get_description(), DESCRIPTION) + + def test_resource_description_can_be_set_explicitly(self): + """Ensure Resource descriptions can be set using the 'get_description' method.""" + example = 'Some other description' + + class MockView(APIView): + """docstring""" + def get_description(self): + return example + self.assertEquals(MockView().get_description(), example) + + def test_resource_description_does_not_require_docstring(self): + """Ensure that empty docstrings do not affect the Resource's description if it has been set using the 'get_description' method.""" + example = 'Some other description' + + class MockView(APIView): + def get_description(self): + return example + self.assertEquals(MockView().get_description(), example) + + def test_resource_description_can_be_empty(self): + """Ensure that if a resource has no doctring or 'description' class attribute, then it's description is the empty string.""" + class MockView(APIView): + pass + self.assertEquals(MockView().get_description(), '') + + def test_markdown(self): + """Ensure markdown to HTML works as expected""" + if apply_markdown: + gte_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_gte_21 + lt_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_lt_21 + self.assertTrue(gte_21_match or lt_21_match) diff --git a/rest_framework/tests/files.py b/rest_framework/tests/files.py new file mode 100644 index 00000000..eb5c7741 --- /dev/null +++ b/rest_framework/tests/files.py @@ -0,0 +1,34 @@ +# from django.test import TestCase +# from django import forms + +# from rest_framework.compat import RequestFactory +# from rest_framework.views import View +# from rest_framework.response import Response + +# import StringIO + + +# class UploadFilesTests(TestCase): +# """Check uploading of files""" +# def setUp(self): +# self.factory = RequestFactory() + +# def test_upload_file(self): + +# class FileForm(forms.Form): +# file = forms.FileField() + +# class MockView(View): +# permissions = () +# form = FileForm + +# def post(self, request, *args, **kwargs): +# return Response({'FILE_NAME': self.CONTENT['file'].name, +# 'FILE_CONTENT': self.CONTENT['file'].read()}) + +# file = StringIO.StringIO('stuff') +# file.name = 'stuff.txt' +# request = self.factory.post('/', {'file': file}) +# view = MockView.as_view() +# response = view(request) +# self.assertEquals(response.raw_content, {"FILE_CONTENT": "stuff", "FILE_NAME": "stuff.txt"}) diff --git a/rest_framework/tests/methods.py b/rest_framework/tests/methods.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/rest_framework/tests/methods.py diff --git a/rest_framework/tests/mixins.py b/rest_framework/tests/mixins.py new file mode 100644 index 00000000..def06464 --- /dev/null +++ b/rest_framework/tests/mixins.py @@ -0,0 +1,285 @@ +# """Tests for the mixin module""" +# from django.test import TestCase +# from rest_framework import status +# from rest_framework.compat import RequestFactory +# from django.contrib.auth.models import Group, User +# from rest_framework.mixins import CreateModelMixin, PaginatorMixin, ReadModelMixin +# from rest_framework.resources import ModelResource +# from rest_framework.response import Response, ImmediateResponse +# from rest_framework.tests.models import CustomUser +# from rest_framework.tests.testcases import TestModelsTestCase +# from rest_framework.views import View + + +# class TestModelRead(TestModelsTestCase): +# """Tests on ReadModelMixin""" + +# def setUp(self): +# super(TestModelRead, self).setUp() +# self.req = RequestFactory() + +# def test_read(self): +# Group.objects.create(name='other group') +# group = Group.objects.create(name='my group') + +# class GroupResource(ModelResource): +# model = Group + +# request = self.req.get('/groups') +# mixin = ReadModelMixin() +# mixin.resource = GroupResource + +# response = mixin.get(request, id=group.id) +# self.assertEquals(group.name, response.raw_content.name) + +# def test_read_404(self): +# class GroupResource(ModelResource): +# model = Group + +# request = self.req.get('/groups') +# mixin = ReadModelMixin() +# mixin.resource = GroupResource + +# self.assertRaises(ImmediateResponse, mixin.get, request, id=12345) + + +# class TestModelCreation(TestModelsTestCase): +# """Tests on CreateModelMixin""" + +# def setUp(self): +# super(TestModelsTestCase, self).setUp() +# self.req = RequestFactory() + +# def test_creation(self): +# self.assertEquals(0, Group.objects.count()) + +# class GroupResource(ModelResource): +# model = Group + +# form_data = {'name': 'foo'} +# request = self.req.post('/groups', data=form_data) +# mixin = CreateModelMixin() +# mixin.resource = GroupResource +# mixin.CONTENT = form_data + +# response = mixin.post(request) +# self.assertEquals(1, Group.objects.count()) +# self.assertEquals('foo', response.raw_content.name) + +# def test_creation_with_m2m_relation(self): +# class UserResource(ModelResource): +# model = User + +# def url(self, instance): +# return "/users/%i" % instance.id + +# group = Group(name='foo') +# group.save() + +# form_data = { +# 'username': 'bar', +# 'password': 'baz', +# 'groups': [group.id] +# } +# request = self.req.post('/groups', data=form_data) +# cleaned_data = dict(form_data) +# cleaned_data['groups'] = [group] +# mixin = CreateModelMixin() +# mixin.resource = UserResource +# mixin.CONTENT = cleaned_data + +# response = mixin.post(request) +# self.assertEquals(1, User.objects.count()) +# self.assertEquals(1, response.raw_content.groups.count()) +# self.assertEquals('foo', response.raw_content.groups.all()[0].name) + +# def test_creation_with_m2m_relation_through(self): +# """ +# Tests creation where the m2m relation uses a through table +# """ +# class UserResource(ModelResource): +# model = CustomUser + +# def url(self, instance): +# return "/customusers/%i" % instance.id + +# form_data = {'username': 'bar0', 'groups': []} +# request = self.req.post('/groups', data=form_data) +# cleaned_data = dict(form_data) +# cleaned_data['groups'] = [] +# mixin = CreateModelMixin() +# mixin.resource = UserResource +# mixin.CONTENT = cleaned_data + +# response = mixin.post(request) +# self.assertEquals(1, CustomUser.objects.count()) +# self.assertEquals(0, response.raw_content.groups.count()) + +# group = Group(name='foo1') +# group.save() + +# form_data = {'username': 'bar1', 'groups': [group.id]} +# request = self.req.post('/groups', data=form_data) +# cleaned_data = dict(form_data) +# cleaned_data['groups'] = [group] +# mixin = CreateModelMixin() +# mixin.resource = UserResource +# mixin.CONTENT = cleaned_data + +# response = mixin.post(request) +# self.assertEquals(2, CustomUser.objects.count()) +# self.assertEquals(1, response.raw_content.groups.count()) +# self.assertEquals('foo1', response.raw_content.groups.all()[0].name) + +# group2 = Group(name='foo2') +# group2.save() + +# form_data = {'username': 'bar2', 'groups': [group.id, group2.id]} +# request = self.req.post('/groups', data=form_data) +# cleaned_data = dict(form_data) +# cleaned_data['groups'] = [group, group2] +# mixin = CreateModelMixin() +# mixin.resource = UserResource +# mixin.CONTENT = cleaned_data + +# response = mixin.post(request) +# self.assertEquals(3, CustomUser.objects.count()) +# self.assertEquals(2, response.raw_content.groups.count()) +# self.assertEquals('foo1', response.raw_content.groups.all()[0].name) +# self.assertEquals('foo2', response.raw_content.groups.all()[1].name) + + +# class MockPaginatorView(PaginatorMixin, View): +# total = 60 + +# def get(self, request): +# return Response(range(0, self.total)) + +# def post(self, request): +# return Response({'status': 'OK'}, status=status.HTTP_201_CREATED) + + +# class TestPagination(TestCase): +# def setUp(self): +# self.req = RequestFactory() + +# def test_default_limit(self): +# """ Tests if pagination works without overwriting the limit """ +# request = self.req.get('/paginator') +# response = MockPaginatorView.as_view()(request) +# content = response.raw_content + +# self.assertEqual(response.status_code, status.HTTP_200_OK) +# self.assertEqual(MockPaginatorView.total, content['total']) +# self.assertEqual(MockPaginatorView.limit, content['per_page']) + +# self.assertEqual(range(0, MockPaginatorView.limit), content['results']) + +# def test_overwriting_limit(self): +# """ Tests if the limit can be overwritten """ +# limit = 10 + +# request = self.req.get('/paginator') +# response = MockPaginatorView.as_view(limit=limit)(request) +# content = response.raw_content + +# self.assertEqual(response.status_code, status.HTTP_200_OK) +# self.assertEqual(content['per_page'], limit) + +# self.assertEqual(range(0, limit), content['results']) + +# def test_limit_param(self): +# """ Tests if the client can set the limit """ +# from math import ceil + +# limit = 5 +# num_pages = int(ceil(MockPaginatorView.total / float(limit))) + +# request = self.req.get('/paginator/?limit=%d' % limit) +# response = MockPaginatorView.as_view()(request) +# content = response.raw_content + +# self.assertEqual(response.status_code, status.HTTP_200_OK) +# self.assertEqual(MockPaginatorView.total, content['total']) +# self.assertEqual(limit, content['per_page']) +# self.assertEqual(num_pages, content['pages']) + +# def test_exceeding_limit(self): +# """ Makes sure the client cannot exceed the default limit """ +# from math import ceil + +# limit = MockPaginatorView.limit + 10 +# num_pages = int(ceil(MockPaginatorView.total / float(limit))) + +# request = self.req.get('/paginator/?limit=%d' % limit) +# response = MockPaginatorView.as_view()(request) +# content = response.raw_content + +# self.assertEqual(response.status_code, status.HTTP_200_OK) +# self.assertEqual(MockPaginatorView.total, content['total']) +# self.assertNotEqual(limit, content['per_page']) +# self.assertNotEqual(num_pages, content['pages']) +# self.assertEqual(MockPaginatorView.limit, content['per_page']) + +# def test_only_works_for_get(self): +# """ Pagination should only work for GET requests """ +# request = self.req.post('/paginator', data={'content': 'spam'}) +# response = MockPaginatorView.as_view()(request) +# content = response.raw_content + +# self.assertEqual(response.status_code, status.HTTP_201_CREATED) +# self.assertEqual(None, content.get('per_page')) +# self.assertEqual('OK', content['status']) + +# def test_non_int_page(self): +# """ Tests that it can handle invalid values """ +# request = self.req.get('/paginator/?page=spam') +# response = MockPaginatorView.as_view()(request) + +# self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + +# def test_page_range(self): +# """ Tests that the page range is handle correctly """ +# request = self.req.get('/paginator/?page=0') +# response = MockPaginatorView.as_view()(request) +# content = response.raw_content +# self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + +# request = self.req.get('/paginator/') +# response = MockPaginatorView.as_view()(request) +# content = response.raw_content +# self.assertEqual(response.status_code, status.HTTP_200_OK) +# self.assertEqual(range(0, MockPaginatorView.limit), content['results']) + +# num_pages = content['pages'] + +# request = self.req.get('/paginator/?page=%d' % num_pages) +# response = MockPaginatorView.as_view()(request) +# content = response.raw_content +# self.assertEqual(response.status_code, status.HTTP_200_OK) +# self.assertEqual(range(MockPaginatorView.limit*(num_pages-1), MockPaginatorView.total), content['results']) + +# request = self.req.get('/paginator/?page=%d' % (num_pages + 1,)) +# response = MockPaginatorView.as_view()(request) +# content = response.raw_content +# self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + +# def test_existing_query_parameters_are_preserved(self): +# """ Tests that existing query parameters are preserved when +# generating next/previous page links """ +# request = self.req.get('/paginator/?foo=bar&another=something') +# response = MockPaginatorView.as_view()(request) +# content = response.raw_content +# self.assertEqual(response.status_code, status.HTTP_200_OK) +# self.assertTrue('foo=bar' in content['next']) +# self.assertTrue('another=something' in content['next']) +# self.assertTrue('page=2' in content['next']) + +# def test_duplicate_parameters_are_not_created(self): +# """ Regression: ensure duplicate "page" parameters are not added to +# paginated URLs. So page 1 should contain ?page=2, not ?page=1&page=2 """ +# request = self.req.get('/paginator/?page=1') +# response = MockPaginatorView.as_view()(request) +# content = response.raw_content +# self.assertTrue('page=2' in content['next']) +# self.assertFalse('page=1' in content['next']) diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py new file mode 100644 index 00000000..4cae68b6 --- /dev/null +++ b/rest_framework/tests/models.py @@ -0,0 +1,28 @@ +from django.db import models +from django.contrib.auth.models import Group + +class CustomUser(models.Model): + """ + A custom user model, which uses a 'through' table for the foreign key + """ + username = models.CharField(max_length=255, unique=True) + groups = models.ManyToManyField( + to=Group, blank=True, null=True, through='UserGroupMap' + ) + + @models.permalink + def get_absolute_url(self): + return ('custom_user', (), { + 'pk': self.id + }) + + +class UserGroupMap(models.Model): + user = models.ForeignKey(to=CustomUser) + group = models.ForeignKey(to=Group) + + @models.permalink + def get_absolute_url(self): + return ('user_group_map', (), { + 'pk': self.id + }) diff --git a/rest_framework/tests/modelviews.py b/rest_framework/tests/modelviews.py new file mode 100644 index 00000000..1f8468e8 --- /dev/null +++ b/rest_framework/tests/modelviews.py @@ -0,0 +1,90 @@ +# from django.conf.urls.defaults import patterns, url +# from django.forms import ModelForm +# from django.contrib.auth.models import Group, User +# from rest_framework.resources import ModelResource +# from rest_framework.views import ListOrCreateModelView, InstanceModelView +# from rest_framework.tests.models import CustomUser +# from rest_framework.tests.testcases import TestModelsTestCase + + +# class GroupResource(ModelResource): +# model = Group + + +# class UserForm(ModelForm): +# class Meta: +# model = User +# exclude = ('last_login', 'date_joined') + + +# class UserResource(ModelResource): +# model = User +# form = UserForm + + +# class CustomUserResource(ModelResource): +# model = CustomUser + +# urlpatterns = patterns('', +# url(r'^users/$', ListOrCreateModelView.as_view(resource=UserResource), name='users'), +# url(r'^users/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=UserResource)), +# url(r'^customusers/$', ListOrCreateModelView.as_view(resource=CustomUserResource), name='customusers'), +# url(r'^customusers/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=CustomUserResource)), +# url(r'^groups/$', ListOrCreateModelView.as_view(resource=GroupResource), name='groups'), +# url(r'^groups/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=GroupResource)), +# ) + + +# class ModelViewTests(TestModelsTestCase): +# """Test the model views rest_framework provides""" +# urls = 'rest_framework.tests.modelviews' + +# def test_creation(self): +# """Ensure that a model object can be created""" +# self.assertEqual(0, Group.objects.count()) + +# response = self.client.post('/groups/', {'name': 'foo'}) + +# self.assertEqual(response.status_code, 201) +# self.assertEqual(1, Group.objects.count()) +# self.assertEqual('foo', Group.objects.all()[0].name) + +# def test_creation_with_m2m_relation(self): +# """Ensure that a model object with a m2m relation can be created""" +# group = Group(name='foo') +# group.save() +# self.assertEqual(0, User.objects.count()) + +# response = self.client.post('/users/', {'username': 'bar', 'password': 'baz', 'groups': [group.id]}) + +# self.assertEqual(response.status_code, 201) +# self.assertEqual(1, User.objects.count()) + +# user = User.objects.all()[0] +# self.assertEqual('bar', user.username) +# self.assertEqual('baz', user.password) +# self.assertEqual(1, user.groups.count()) + +# group = user.groups.all()[0] +# self.assertEqual('foo', group.name) + +# def test_creation_with_m2m_relation_through(self): +# """ +# Ensure that a model object with a m2m relation can be created where that +# relation uses a through table +# """ +# group = Group(name='foo') +# group.save() +# self.assertEqual(0, User.objects.count()) + +# response = self.client.post('/customusers/', {'username': 'bar', 'groups': [group.id]}) + +# self.assertEqual(response.status_code, 201) +# self.assertEqual(1, CustomUser.objects.count()) + +# user = CustomUser.objects.all()[0] +# self.assertEqual('bar', user.username) +# self.assertEqual(1, user.groups.count()) + +# group = user.groups.all()[0] +# self.assertEqual('foo', group.name) diff --git a/rest_framework/tests/oauthentication.py b/rest_framework/tests/oauthentication.py new file mode 100644 index 00000000..6e7af52d --- /dev/null +++ b/rest_framework/tests/oauthentication.py @@ -0,0 +1,211 @@ +import time + +from django.conf.urls.defaults import patterns, url, include +from django.contrib.auth.models import User +from django.test import Client, TestCase + +from rest_framework.views import APIView + +# Since oauth2 / django-oauth-plus are optional dependancies, we don't want to +# always run these tests. + +# Unfortunatly we can't skip tests easily until 2.7, se we'll just do this for now. +try: + import oauth2 as oauth + from oauth_provider.decorators import oauth_required + from oauth_provider.models import Resource, Consumer, Token + +except ImportError: + pass + +else: + # Alrighty, we're good to go here. + class ClientView(APIView): + def get(self, request): + return {'resource': 'Protected!'} + + urlpatterns = patterns('', + url(r'^$', oauth_required(ClientView.as_view())), + url(r'^oauth/', include('oauth_provider.urls')), + url(r'^restframework/', include('rest_framework.urls', namespace='rest_framework')), + ) + + class OAuthTests(TestCase): + """ + OAuth authentication: + * the user would like to access his API data from a third-party website + * the third-party website proposes a link to get that API data + * the user is redirected to the API and must log in if not authenticated + * the API displays a webpage to confirm that the user trusts the third-party website + * if confirmed, the user is redirected to the third-party website through the callback view + * the third-party website is able to retrieve data from the API + """ + urls = 'rest_framework.tests.oauthentication' + + def setUp(self): + self.client = Client() + self.username = 'john' + self.email = 'lennon@thebeatles.com' + self.password = 'password' + self.user = User.objects.create_user(self.username, self.email, self.password) + + # OAuth requirements + self.resource = Resource(name='data', url='/') + self.resource.save() + self.CONSUMER_KEY = 'dpf43f3p2l4k3l03' + self.CONSUMER_SECRET = 'kd94hf93k423kf44' + self.consumer = Consumer(key=self.CONSUMER_KEY, secret=self.CONSUMER_SECRET, + name='api.example.com', user=self.user) + self.consumer.save() + + def test_oauth_invalid_and_anonymous_access(self): + """ + Verify that the resource is protected and the OAuth authorization view + require the user to be logged in. + """ + response = self.client.get('/') + self.assertEqual(response.content, 'Invalid request parameters.') + self.assertEqual(response.status_code, 401) + response = self.client.get('/oauth/authorize/', follow=True) + self.assertRedirects(response, '/accounts/login/?next=/oauth/authorize/') + + def test_oauth_authorize_access(self): + """ + Verify that once logged in, the user can access the authorization page + but can't display the page because the request token is not specified. + """ + self.client.login(username=self.username, password=self.password) + response = self.client.get('/oauth/authorize/', follow=True) + self.assertEqual(response.content, 'No request token specified.') + + def _create_request_token_parameters(self): + """ + A shortcut to create request's token parameters. + """ + return { + 'oauth_consumer_key': self.CONSUMER_KEY, + 'oauth_signature_method': 'PLAINTEXT', + 'oauth_signature': '%s&' % self.CONSUMER_SECRET, + 'oauth_timestamp': str(int(time.time())), + 'oauth_nonce': 'requestnonce', + 'oauth_version': '1.0', + 'oauth_callback': 'http://api.example.com/request_token_ready', + 'scope': 'data', + } + + def test_oauth_request_token_retrieval(self): + """ + Verify that the request token can be retrieved by the server. + """ + response = self.client.get("/oauth/request_token/", + self._create_request_token_parameters()) + self.assertEqual(response.status_code, 200) + token = list(Token.objects.all())[-1] + self.failIf(token.key not in response.content) + self.failIf(token.secret not in response.content) + + def test_oauth_user_request_authorization(self): + """ + Verify that the user can access the authorization page once logged in + and the request token has been retrieved. + """ + # Setup + response = self.client.get("/oauth/request_token/", + self._create_request_token_parameters()) + token = list(Token.objects.all())[-1] + + # Starting the test here + self.client.login(username=self.username, password=self.password) + parameters = {'oauth_token': token.key} + response = self.client.get("/oauth/authorize/", parameters) + self.assertEqual(response.status_code, 200) + self.failIf(not response.content.startswith('Fake authorize view for api.example.com with params: oauth_token=')) + self.assertEqual(token.is_approved, 0) + parameters['authorize_access'] = 1 # fake authorization by the user + response = self.client.post("/oauth/authorize/", parameters) + self.assertEqual(response.status_code, 302) + self.failIf(not response['Location'].startswith('http://api.example.com/request_token_ready?oauth_verifier=')) + token = Token.objects.get(key=token.key) + self.failIf(token.key not in response['Location']) + self.assertEqual(token.is_approved, 1) + + def _create_access_token_parameters(self, token): + """ + A shortcut to create access' token parameters. + """ + return { + 'oauth_consumer_key': self.CONSUMER_KEY, + 'oauth_token': token.key, + 'oauth_signature_method': 'PLAINTEXT', + 'oauth_signature': '%s&%s' % (self.CONSUMER_SECRET, token.secret), + 'oauth_timestamp': str(int(time.time())), + 'oauth_nonce': 'accessnonce', + 'oauth_version': '1.0', + 'oauth_verifier': token.verifier, + 'scope': 'data', + } + + def test_oauth_access_token_retrieval(self): + """ + Verify that the request token can be retrieved by the server. + """ + # Setup + response = self.client.get("/oauth/request_token/", + self._create_request_token_parameters()) + token = list(Token.objects.all())[-1] + self.client.login(username=self.username, password=self.password) + parameters = {'oauth_token': token.key,} + response = self.client.get("/oauth/authorize/", parameters) + parameters['authorize_access'] = 1 # fake authorization by the user + response = self.client.post("/oauth/authorize/", parameters) + token = Token.objects.get(key=token.key) + + # Starting the test here + response = self.client.get("/oauth/access_token/", self._create_access_token_parameters(token)) + self.assertEqual(response.status_code, 200) + self.failIf(not response.content.startswith('oauth_token_secret=')) + access_token = list(Token.objects.filter(token_type=Token.ACCESS))[-1] + self.failIf(access_token.key not in response.content) + self.failIf(access_token.secret not in response.content) + self.assertEqual(access_token.user.username, 'john') + + def _create_access_parameters(self, access_token): + """ + A shortcut to create access' parameters. + """ + parameters = { + 'oauth_consumer_key': self.CONSUMER_KEY, + 'oauth_token': access_token.key, + 'oauth_signature_method': 'HMAC-SHA1', + 'oauth_timestamp': str(int(time.time())), + 'oauth_nonce': 'accessresourcenonce', + 'oauth_version': '1.0', + } + oauth_request = oauth.Request.from_token_and_callback(access_token, + http_url='http://testserver/', parameters=parameters) + signature_method = oauth.SignatureMethod_HMAC_SHA1() + signature = signature_method.sign(oauth_request, self.consumer, access_token) + parameters['oauth_signature'] = signature + return parameters + + def test_oauth_protected_resource_access(self): + """ + Verify that the request token can be retrieved by the server. + """ + # Setup + response = self.client.get("/oauth/request_token/", + self._create_request_token_parameters()) + token = list(Token.objects.all())[-1] + self.client.login(username=self.username, password=self.password) + parameters = {'oauth_token': token.key,} + response = self.client.get("/oauth/authorize/", parameters) + parameters['authorize_access'] = 1 # fake authorization by the user + response = self.client.post("/oauth/authorize/", parameters) + token = Token.objects.get(key=token.key) + response = self.client.get("/oauth/access_token/", self._create_access_token_parameters(token)) + access_token = list(Token.objects.filter(token_type=Token.ACCESS))[-1] + + # Starting the test here + response = self.client.get("/", self._create_access_token_parameters(access_token)) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, '{"resource": "Protected!"}') diff --git a/rest_framework/tests/package.py b/rest_framework/tests/package.py new file mode 100644 index 00000000..db133b09 --- /dev/null +++ b/rest_framework/tests/package.py @@ -0,0 +1,11 @@ +"""Tests for the rest_framework package setup.""" +from django.test import TestCase +import rest_framework + +class TestVersion(TestCase): + """Simple sanity test to check the VERSION exists""" + + def test_version(self): + """Ensure the VERSION exists.""" + rest_framework.VERSION + diff --git a/rest_framework/tests/parsers.py b/rest_framework/tests/parsers.py new file mode 100644 index 00000000..4cafd660 --- /dev/null +++ b/rest_framework/tests/parsers.py @@ -0,0 +1,212 @@ +# """ +# .. +# >>> from rest_framework.parsers import FormParser +# >>> from rest_framework.compat import RequestFactory +# >>> from rest_framework.views import View +# >>> from StringIO import StringIO +# >>> from urllib import urlencode +# >>> req = RequestFactory().get('/') +# >>> some_view = View() +# >>> some_view.request = req # Make as if this request had been dispatched +# +# FormParser +# ============ +# +# Data flatening +# ---------------- +# +# Here is some example data, which would eventually be sent along with a post request : +# +# >>> inpt = urlencode([ +# ... ('key1', 'bla1'), +# ... ('key2', 'blo1'), ('key2', 'blo2'), +# ... ]) +# +# Default behaviour for :class:`parsers.FormParser`, is to return a single value for each parameter : +# +# >>> (data, files) = FormParser(some_view).parse(StringIO(inpt)) +# >>> data == {'key1': 'bla1', 'key2': 'blo1'} +# True +# +# However, you can customize this behaviour by subclassing :class:`parsers.FormParser`, and overriding :meth:`parsers.FormParser.is_a_list` : +# +# >>> class MyFormParser(FormParser): +# ... +# ... def is_a_list(self, key, val_list): +# ... return len(val_list) > 1 +# +# This new parser only flattens the lists of parameters that contain a single value. +# +# >>> (data, files) = MyFormParser(some_view).parse(StringIO(inpt)) +# >>> data == {'key1': 'bla1', 'key2': ['blo1', 'blo2']} +# True +# +# .. note:: The same functionality is available for :class:`parsers.MultiPartParser`. +# +# Submitting an empty list +# -------------------------- +# +# When submitting an empty select multiple, like this one :: +# +# <select multiple="multiple" name="key2"></select> +# +# The browsers usually strip the parameter completely. A hack to avoid this, and therefore being able to submit an empty select multiple, is to submit a value that tells the server that the list is empty :: +# +# <select multiple="multiple" name="key2"><option value="_empty"></select> +# +# :class:`parsers.FormParser` provides the server-side implementation for this hack. Considering the following posted data : +# +# >>> inpt = urlencode([ +# ... ('key1', 'blo1'), ('key1', '_empty'), +# ... ('key2', '_empty'), +# ... ]) +# +# :class:`parsers.FormParser` strips the values ``_empty`` from all the lists. +# +# >>> (data, files) = MyFormParser(some_view).parse(StringIO(inpt)) +# >>> data == {'key1': 'blo1'} +# True +# +# Oh ... but wait a second, the parameter ``key2`` isn't even supposed to be a list, so the parser just stripped it. +# +# >>> class MyFormParser(FormParser): +# ... +# ... def is_a_list(self, key, val_list): +# ... return key == 'key2' +# ... +# >>> (data, files) = MyFormParser(some_view).parse(StringIO(inpt)) +# >>> data == {'key1': 'blo1', 'key2': []} +# True +# +# Better like that. Note that you can configure something else than ``_empty`` for the empty value by setting :attr:`parsers.FormParser.EMPTY_VALUE`. +# """ +# import httplib, mimetypes +# from tempfile import TemporaryFile +# from django.test import TestCase +# from rest_framework.compat import RequestFactory +# from rest_framework.parsers import MultiPartParser +# from rest_framework.views import View +# from StringIO import StringIO +# +# def encode_multipart_formdata(fields, files): +# """For testing multipart parser. +# fields is a sequence of (name, value) elements for regular form fields. +# files is a sequence of (name, filename, value) elements for data to be uploaded as files +# Return (content_type, body).""" +# BOUNDARY = '----------ThIs_Is_tHe_bouNdaRY_$' +# CRLF = '\r\n' +# L = [] +# for (key, value) in fields: +# L.append('--' + BOUNDARY) +# L.append('Content-Disposition: form-data; name="%s"' % key) +# L.append('') +# L.append(value) +# for (key, filename, value) in files: +# L.append('--' + BOUNDARY) +# L.append('Content-Disposition: form-data; name="%s"; filename="%s"' % (key, filename)) +# L.append('Content-Type: %s' % get_content_type(filename)) +# L.append('') +# L.append(value) +# L.append('--' + BOUNDARY + '--') +# L.append('') +# body = CRLF.join(L) +# content_type = 'multipart/form-data; boundary=%s' % BOUNDARY +# return content_type, body +# +# def get_content_type(filename): +# return mimetypes.guess_type(filename)[0] or 'application/octet-stream' +# +#class TestMultiPartParser(TestCase): +# def setUp(self): +# self.req = RequestFactory() +# self.content_type, self.body = encode_multipart_formdata([('key1', 'val1'), ('key1', 'val2')], +# [('file1', 'pic.jpg', 'blablabla'), ('file1', 't.txt', 'blobloblo')]) +# +# def test_multipartparser(self): +# """Ensure that MultiPartParser can parse multipart/form-data that contains a mix of several files and parameters.""" +# post_req = RequestFactory().post('/', self.body, content_type=self.content_type) +# view = View() +# view.request = post_req +# (data, files) = MultiPartParser(view).parse(StringIO(self.body)) +# self.assertEqual(data['key1'], 'val1') +# self.assertEqual(files['file1'].read(), 'blablabla') + +from StringIO import StringIO +from django import forms +from django.test import TestCase +from rest_framework.parsers import FormParser +from rest_framework.parsers import XMLParser +import datetime + + +class Form(forms.Form): + field1 = forms.CharField(max_length=3) + field2 = forms.CharField() + + +class TestFormParser(TestCase): + def setUp(self): + self.string = "field1=abc&field2=defghijk" + + def test_parse(self): + """ Make sure the `QueryDict` works OK """ + parser = FormParser() + + stream = StringIO(self.string) + data = parser.parse(stream) + + self.assertEqual(Form(data).is_valid(), True) + + +class TestXMLParser(TestCase): + def setUp(self): + self._input = StringIO( + '<?xml version="1.0" encoding="utf-8"?>' + '<root>' + '<field_a>121.0</field_a>' + '<field_b>dasd</field_b>' + '<field_c></field_c>' + '<field_d>2011-12-25 12:45:00</field_d>' + '</root>' + ) + self._data = { + 'field_a': 121, + 'field_b': 'dasd', + 'field_c': None, + 'field_d': datetime.datetime(2011, 12, 25, 12, 45, 00) + } + self._complex_data_input = StringIO( + '<?xml version="1.0" encoding="utf-8"?>' + '<root>' + '<creation_date>2011-12-25 12:45:00</creation_date>' + '<sub_data_list>' + '<list-item><sub_id>1</sub_id><sub_name>first</sub_name></list-item>' + '<list-item><sub_id>2</sub_id><sub_name>second</sub_name></list-item>' + '</sub_data_list>' + '<name>name</name>' + '</root>' + ) + self._complex_data = { + "creation_date": datetime.datetime(2011, 12, 25, 12, 45, 00), + "name": "name", + "sub_data_list": [ + { + "sub_id": 1, + "sub_name": "first" + }, + { + "sub_id": 2, + "sub_name": "second" + } + ] + } + + def test_parse(self): + parser = XMLParser() + data = parser.parse(self._input) + self.assertEqual(data, self._data) + + def test_complex_data_parse(self): + parser = XMLParser() + data = parser.parse(self._complex_data_input) + self.assertEqual(data, self._complex_data) diff --git a/rest_framework/tests/renderers.py b/rest_framework/tests/renderers.py new file mode 100644 index 00000000..06954412 --- /dev/null +++ b/rest_framework/tests/renderers.py @@ -0,0 +1,375 @@ +import re + +from django.conf.urls.defaults import patterns, url, include +from django.test import TestCase + +from rest_framework import status +from rest_framework.response import Response +from rest_framework.views import APIView +from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \ + XMLRenderer, JSONPRenderer, DocumentingHTMLRenderer +from rest_framework.parsers import YAMLParser, XMLParser + +from StringIO import StringIO +import datetime +from decimal import Decimal + + +DUMMYSTATUS = status.HTTP_200_OK +DUMMYCONTENT = 'dummycontent' + +RENDERER_A_SERIALIZER = lambda x: 'Renderer A: %s' % x +RENDERER_B_SERIALIZER = lambda x: 'Renderer B: %s' % x + + +expected_results = [ + ((elem for elem in [1, 2, 3]), JSONRenderer, '[1, 2, 3]') # Generator +] + + +class BasicRendererTests(TestCase): + def test_expected_results(self): + for value, renderer_cls, expected in expected_results: + output = renderer_cls().render(value) + self.assertEquals(output, expected) + + +class RendererA(BaseRenderer): + media_type = 'mock/renderera' + format = "formata" + + def render(self, obj=None, media_type=None): + return RENDERER_A_SERIALIZER(obj) + + +class RendererB(BaseRenderer): + media_type = 'mock/rendererb' + format = "formatb" + + def render(self, obj=None, media_type=None): + return RENDERER_B_SERIALIZER(obj) + + +class MockView(APIView): + renderer_classes = (RendererA, RendererB) + + def get(self, request, **kwargs): + response = Response(DUMMYCONTENT, status=DUMMYSTATUS) + return response + + +class MockGETView(APIView): + + def get(self, request, **kwargs): + return Response({'foo': ['bar', 'baz']}) + + +class HTMLView(APIView): + renderer_classes = (DocumentingHTMLRenderer, ) + + def get(self, request, **kwargs): + return Response('text') + + +class HTMLView1(APIView): + renderer_classes = (DocumentingHTMLRenderer, JSONRenderer) + + def get(self, request, **kwargs): + return Response('text') + +urlpatterns = patterns('', + url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB])), + url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB])), + url(r'^jsonp/jsonrenderer$', MockGETView.as_view(renderer_classes=[JSONRenderer, JSONPRenderer])), + url(r'^jsonp/nojsonrenderer$', MockGETView.as_view(renderer_classes=[JSONPRenderer])), + url(r'^html$', HTMLView.as_view()), + url(r'^html1$', HTMLView1.as_view()), + url(r'^api', include('rest_framework.urls', namespace='rest_framework')) +) + + +class RendererEndToEndTests(TestCase): + """ + End-to-end testing of renderers using an RendererMixin on a generic view. + """ + + urls = 'rest_framework.tests.renderers' + + def test_default_renderer_serializes_content(self): + """If the Accept header is not set the default renderer should serialize the response.""" + resp = self.client.get('/') + self.assertEquals(resp['Content-Type'], RendererA.media_type) + self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) + self.assertEquals(resp.status_code, DUMMYSTATUS) + + def test_head_method_serializes_no_content(self): + """No response must be included in HEAD requests.""" + resp = self.client.head('/') + self.assertEquals(resp.status_code, DUMMYSTATUS) + self.assertEquals(resp['Content-Type'], RendererA.media_type) + self.assertEquals(resp.content, '') + + def test_default_renderer_serializes_content_on_accept_any(self): + """If the Accept header is set to */* the default renderer should serialize the response.""" + resp = self.client.get('/', HTTP_ACCEPT='*/*') + self.assertEquals(resp['Content-Type'], RendererA.media_type) + self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) + self.assertEquals(resp.status_code, DUMMYSTATUS) + + def test_specified_renderer_serializes_content_default_case(self): + """If the Accept header is set the specified renderer should serialize the response. + (In this case we check that works for the default renderer)""" + resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type) + self.assertEquals(resp['Content-Type'], RendererA.media_type) + self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) + self.assertEquals(resp.status_code, DUMMYSTATUS) + + def test_specified_renderer_serializes_content_non_default_case(self): + """If the Accept header is set the specified renderer should serialize the response. + (In this case we check that works for a non-default renderer)""" + resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type) + self.assertEquals(resp['Content-Type'], RendererB.media_type) + self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) + self.assertEquals(resp.status_code, DUMMYSTATUS) + + def test_specified_renderer_serializes_content_on_accept_query(self): + """The '_accept' query string should behave in the same way as the Accept header.""" + resp = self.client.get('/?_accept=%s' % RendererB.media_type) + self.assertEquals(resp['Content-Type'], RendererB.media_type) + self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) + self.assertEquals(resp.status_code, DUMMYSTATUS) + + def test_unsatisfiable_accept_header_on_request_returns_406_status(self): + """If the Accept header is unsatisfiable we should return a 406 Not Acceptable response.""" + resp = self.client.get('/', HTTP_ACCEPT='foo/bar') + self.assertEquals(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE) + + def test_specified_renderer_serializes_content_on_format_query(self): + """If a 'format' query is specified, the renderer with the matching + format attribute should serialize the response.""" + resp = self.client.get('/?format=%s' % RendererB.format) + self.assertEquals(resp['Content-Type'], RendererB.media_type) + self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) + self.assertEquals(resp.status_code, DUMMYSTATUS) + + def test_specified_renderer_serializes_content_on_format_kwargs(self): + """If a 'format' keyword arg is specified, the renderer with the matching + format attribute should serialize the response.""" + resp = self.client.get('/something.formatb') + self.assertEquals(resp['Content-Type'], RendererB.media_type) + self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) + self.assertEquals(resp.status_code, DUMMYSTATUS) + + def test_specified_renderer_is_used_on_format_query_with_matching_accept(self): + """If both a 'format' query and a matching Accept header specified, + the renderer with the matching format attribute should serialize the response.""" + resp = self.client.get('/?format=%s' % RendererB.format, + HTTP_ACCEPT=RendererB.media_type) + self.assertEquals(resp['Content-Type'], RendererB.media_type) + self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) + self.assertEquals(resp.status_code, DUMMYSTATUS) + + +_flat_repr = '{"foo": ["bar", "baz"]}' +_indented_repr = '{\n "foo": [\n "bar",\n "baz"\n ]\n}' + + +def strip_trailing_whitespace(content): + """ + Seems to be some inconsistencies re. trailing whitespace with + different versions of the json lib. + """ + return re.sub(' +\n', '\n', content) + + +class JSONRendererTests(TestCase): + """ + Tests specific to the JSON Renderer + """ + + def test_without_content_type_args(self): + """ + Test basic JSON rendering. + """ + obj = {'foo': ['bar', 'baz']} + renderer = JSONRenderer(None) + content = renderer.render(obj, 'application/json') + # Fix failing test case which depends on version of JSON library. + self.assertEquals(content, _flat_repr) + + def test_with_content_type_args(self): + """ + Test JSON rendering with additional content type arguments supplied. + """ + obj = {'foo': ['bar', 'baz']} + renderer = JSONRenderer(None) + content = renderer.render(obj, 'application/json; indent=2') + self.assertEquals(strip_trailing_whitespace(content), _indented_repr) + + +class JSONPRendererTests(TestCase): + """ + Tests specific to the JSONP Renderer + """ + + urls = 'rest_framework.tests.renderers' + + def test_without_callback_with_json_renderer(self): + """ + Test JSONP rendering with View JSON Renderer. + """ + resp = self.client.get('/jsonp/jsonrenderer', + HTTP_ACCEPT='application/javascript') + self.assertEquals(resp.status_code, 200) + self.assertEquals(resp['Content-Type'], 'application/javascript') + self.assertEquals(resp.content, 'callback(%s);' % _flat_repr) + + def test_without_callback_without_json_renderer(self): + """ + Test JSONP rendering without View JSON Renderer. + """ + resp = self.client.get('/jsonp/nojsonrenderer', + HTTP_ACCEPT='application/javascript') + self.assertEquals(resp.status_code, 200) + self.assertEquals(resp['Content-Type'], 'application/javascript') + self.assertEquals(resp.content, 'callback(%s);' % _flat_repr) + + def test_with_callback(self): + """ + Test JSONP rendering with callback function name. + """ + callback_func = 'myjsonpcallback' + resp = self.client.get('/jsonp/nojsonrenderer?callback=' + callback_func, + HTTP_ACCEPT='application/javascript') + self.assertEquals(resp.status_code, 200) + self.assertEquals(resp['Content-Type'], 'application/javascript') + self.assertEquals(resp.content, '%s(%s);' % (callback_func, _flat_repr)) + + +if YAMLRenderer: + _yaml_repr = 'foo: [bar, baz]\n' + + class YAMLRendererTests(TestCase): + """ + Tests specific to the JSON Renderer + """ + + def test_render(self): + """ + Test basic YAML rendering. + """ + obj = {'foo': ['bar', 'baz']} + renderer = YAMLRenderer(None) + content = renderer.render(obj, 'application/yaml') + self.assertEquals(content, _yaml_repr) + + def test_render_and_parse(self): + """ + Test rendering and then parsing returns the original object. + IE obj -> render -> parse -> obj. + """ + obj = {'foo': ['bar', 'baz']} + + renderer = YAMLRenderer(None) + parser = YAMLParser() + + content = renderer.render(obj, 'application/yaml') + data = parser.parse(StringIO(content)) + self.assertEquals(obj, data) + + +class XMLRendererTestCase(TestCase): + """ + Tests specific to the XML Renderer + """ + + _complex_data = { + "creation_date": datetime.datetime(2011, 12, 25, 12, 45, 00), + "name": "name", + "sub_data_list": [ + { + "sub_id": 1, + "sub_name": "first" + }, + { + "sub_id": 2, + "sub_name": "second" + } + ] + } + + def test_render_string(self): + """ + Test XML rendering. + """ + renderer = XMLRenderer(None) + content = renderer.render({'field': 'astring'}, 'application/xml') + self.assertXMLContains(content, '<field>astring</field>') + + def test_render_integer(self): + """ + Test XML rendering. + """ + renderer = XMLRenderer(None) + content = renderer.render({'field': 111}, 'application/xml') + self.assertXMLContains(content, '<field>111</field>') + + def test_render_datetime(self): + """ + Test XML rendering. + """ + renderer = XMLRenderer(None) + content = renderer.render({ + 'field': datetime.datetime(2011, 12, 25, 12, 45, 00) + }, 'application/xml') + self.assertXMLContains(content, '<field>2011-12-25 12:45:00</field>') + + def test_render_float(self): + """ + Test XML rendering. + """ + renderer = XMLRenderer(None) + content = renderer.render({'field': 123.4}, 'application/xml') + self.assertXMLContains(content, '<field>123.4</field>') + + def test_render_decimal(self): + """ + Test XML rendering. + """ + renderer = XMLRenderer(None) + content = renderer.render({'field': Decimal('111.2')}, 'application/xml') + self.assertXMLContains(content, '<field>111.2</field>') + + def test_render_none(self): + """ + Test XML rendering. + """ + renderer = XMLRenderer(None) + content = renderer.render({'field': None}, 'application/xml') + self.assertXMLContains(content, '<field></field>') + + def test_render_complex_data(self): + """ + Test XML rendering. + """ + renderer = XMLRenderer(None) + content = renderer.render(self._complex_data, 'application/xml') + self.assertXMLContains(content, '<sub_name>first</sub_name>') + self.assertXMLContains(content, '<sub_name>second</sub_name>') + + def test_render_and_parse_complex_data(self): + """ + Test XML rendering. + """ + renderer = XMLRenderer(None) + content = StringIO(renderer.render(self._complex_data, 'application/xml')) + + parser = XMLParser() + complex_data_out = parser.parse(content) + error_msg = "complex data differs!IN:\n %s \n\n OUT:\n %s" % (repr(self._complex_data), repr(complex_data_out)) + self.assertEqual(self._complex_data, complex_data_out, error_msg) + + def assertXMLContains(self, xml, string): + self.assertTrue(xml.startswith('<?xml version="1.0" encoding="utf-8"?>\n<root>')) + self.assertTrue(xml.endswith('</root>')) + self.assertTrue(string in xml, '%r not in %r' % (string, xml)) diff --git a/rest_framework/tests/request.py b/rest_framework/tests/request.py new file mode 100644 index 00000000..805f6efc --- /dev/null +++ b/rest_framework/tests/request.py @@ -0,0 +1,252 @@ +""" +Tests for content parsing, and form-overloaded content parsing. +""" +from django.conf.urls.defaults import patterns +from django.contrib.auth.models import User +from django.test import TestCase, Client + +from rest_framework import status +from rest_framework.authentication import SessionAuthentication +from rest_framework.compat import RequestFactory +from rest_framework.parsers import ( + FormParser, + MultiPartParser, + PlainTextParser, +) +from rest_framework.request import Request +from rest_framework.response import Response +from rest_framework.views import APIView + + +factory = RequestFactory() + + +class TestMethodOverloading(TestCase): + def test_method(self): + """ + Request methods should be same as underlying request. + """ + request = Request(factory.get('/')) + self.assertEqual(request.method, 'GET') + request = Request(factory.post('/')) + self.assertEqual(request.method, 'POST') + + def test_overloaded_method(self): + """ + POST requests can be overloaded to another method by setting a + reserved form field + """ + request = Request(factory.post('/', {Request._METHOD_PARAM: 'DELETE'})) + self.assertEqual(request.method, 'DELETE') + + +class TestContentParsing(TestCase): + def test_standard_behaviour_determines_no_content_GET(self): + """ + Ensure request.DATA returns None for GET request with no content. + """ + request = Request(factory.get('/')) + self.assertEqual(request.DATA, None) + + def test_standard_behaviour_determines_no_content_HEAD(self): + """ + Ensure request.DATA returns None for HEAD request. + """ + request = Request(factory.head('/')) + self.assertEqual(request.DATA, None) + + def test_standard_behaviour_determines_form_content_POST(self): + """ + Ensure request.DATA returns content for POST request with form content. + """ + data = {'qwerty': 'uiop'} + request = Request(factory.post('/', data)) + request.parser_classes = (FormParser, MultiPartParser) + self.assertEqual(request.DATA.items(), data.items()) + + def test_standard_behaviour_determines_non_form_content_POST(self): + """ + Ensure request.DATA returns content for POST request with + non-form content. + """ + content = 'qwerty' + content_type = 'text/plain' + request = Request(factory.post('/', content, content_type=content_type)) + request.parser_classes = (PlainTextParser,) + self.assertEqual(request.DATA, content) + + def test_standard_behaviour_determines_form_content_PUT(self): + """ + Ensure request.DATA returns content for PUT request with form content. + """ + data = {'qwerty': 'uiop'} + + from django import VERSION + + if VERSION >= (1, 5): + from django.test.client import MULTIPART_CONTENT, BOUNDARY, encode_multipart + request = Request(factory.put('/', encode_multipart(BOUNDARY, data), + content_type=MULTIPART_CONTENT)) + else: + request = Request(factory.put('/', data)) + + request.parser_classes = (FormParser, MultiPartParser) + self.assertEqual(request.DATA.items(), data.items()) + + def test_standard_behaviour_determines_non_form_content_PUT(self): + """ + Ensure request.DATA returns content for PUT request with + non-form content. + """ + content = 'qwerty' + content_type = 'text/plain' + request = Request(factory.put('/', content, content_type=content_type)) + request.parser_classes = (PlainTextParser, ) + self.assertEqual(request.DATA, content) + + def test_overloaded_behaviour_allows_content_tunnelling(self): + """ + Ensure request.DATA returns content for overloaded POST request. + """ + content = 'qwerty' + content_type = 'text/plain' + data = { + Request._CONTENT_PARAM: content, + Request._CONTENTTYPE_PARAM: content_type + } + request = Request(factory.post('/', data)) + request.parser_classes = (PlainTextParser, ) + self.assertEqual(request.DATA, content) + + # def test_accessing_post_after_data_form(self): + # """ + # Ensures request.POST can be accessed after request.DATA in + # form request. + # """ + # data = {'qwerty': 'uiop'} + # request = factory.post('/', data=data) + # self.assertEqual(request.DATA.items(), data.items()) + # self.assertEqual(request.POST.items(), data.items()) + + # def test_accessing_post_after_data_for_json(self): + # """ + # Ensures request.POST can be accessed after request.DATA in + # json request. + # """ + # data = {'qwerty': 'uiop'} + # content = json.dumps(data) + # content_type = 'application/json' + # parsers = (JSONParser, ) + + # request = factory.post('/', content, content_type=content_type, + # parsers=parsers) + # self.assertEqual(request.DATA.items(), data.items()) + # self.assertEqual(request.POST.items(), []) + + # def test_accessing_post_after_data_for_overloaded_json(self): + # """ + # Ensures request.POST can be accessed after request.DATA in overloaded + # json request. + # """ + # data = {'qwerty': 'uiop'} + # content = json.dumps(data) + # content_type = 'application/json' + # parsers = (JSONParser, ) + # form_data = {Request._CONTENT_PARAM: content, + # Request._CONTENTTYPE_PARAM: content_type} + + # request = factory.post('/', form_data, parsers=parsers) + # self.assertEqual(request.DATA.items(), data.items()) + # self.assertEqual(request.POST.items(), form_data.items()) + + # def test_accessing_data_after_post_form(self): + # """ + # Ensures request.DATA can be accessed after request.POST in + # form request. + # """ + # data = {'qwerty': 'uiop'} + # parsers = (FormParser, MultiPartParser) + # request = factory.post('/', data, parsers=parsers) + + # self.assertEqual(request.POST.items(), data.items()) + # self.assertEqual(request.DATA.items(), data.items()) + + # def test_accessing_data_after_post_for_json(self): + # """ + # Ensures request.DATA can be accessed after request.POST in + # json request. + # """ + # data = {'qwerty': 'uiop'} + # content = json.dumps(data) + # content_type = 'application/json' + # parsers = (JSONParser, ) + # request = factory.post('/', content, content_type=content_type, + # parsers=parsers) + # self.assertEqual(request.POST.items(), []) + # self.assertEqual(request.DATA.items(), data.items()) + + # def test_accessing_data_after_post_for_overloaded_json(self): + # """ + # Ensures request.DATA can be accessed after request.POST in overloaded + # json request + # """ + # data = {'qwerty': 'uiop'} + # content = json.dumps(data) + # content_type = 'application/json' + # parsers = (JSONParser, ) + # form_data = {Request._CONTENT_PARAM: content, + # Request._CONTENTTYPE_PARAM: content_type} + + # request = factory.post('/', form_data, parsers=parsers) + # self.assertEqual(request.POST.items(), form_data.items()) + # self.assertEqual(request.DATA.items(), data.items()) + + +class MockView(APIView): + authentication_classes = (SessionAuthentication,) + + def post(self, request): + if request.POST.get('example') is not None: + return Response(status=status.HTTP_200_OK) + + return Response(status=status.INTERNAL_SERVER_ERROR) + +urlpatterns = patterns('', + (r'^$', MockView.as_view()), +) + + +class TestContentParsingWithAuthentication(TestCase): + urls = 'rest_framework.tests.request' + + def setUp(self): + self.csrf_client = Client(enforce_csrf_checks=True) + self.username = 'john' + self.email = 'lennon@thebeatles.com' + self.password = 'password' + self.user = User.objects.create_user(self.username, self.email, self.password) + + def test_user_logged_in_authentication_has_POST_when_not_logged_in(self): + """ + Ensures request.POST exists after SessionAuthentication when user + doesn't log in. + """ + content = {'example': 'example'} + + response = self.client.post('/', content) + self.assertEqual(status.HTTP_200_OK, response.status_code) + + response = self.csrf_client.post('/', content) + self.assertEqual(status.HTTP_200_OK, response.status_code) + + # def test_user_logged_in_authentication_has_post_when_logged_in(self): + # """Ensures request.POST exists after UserLoggedInAuthentication when user does log in""" + # self.client.login(username='john', password='password') + # self.csrf_client.login(username='john', password='password') + # content = {'example': 'example'} + + # response = self.client.post('/', content) + # self.assertEqual(status.OK, response.status_code, "POST data is malformed") + + # response = self.csrf_client.post('/', content) + # self.assertEqual(status.OK, response.status_code, "POST data is malformed") diff --git a/rest_framework/tests/response.py b/rest_framework/tests/response.py new file mode 100644 index 00000000..af70a387 --- /dev/null +++ b/rest_framework/tests/response.py @@ -0,0 +1,177 @@ +import unittest + +from django.conf.urls.defaults import patterns, url, include +from django.test import TestCase + +from rest_framework.response import Response +from rest_framework.views import APIView +from rest_framework import status +from rest_framework.renderers import ( + BaseRenderer, + JSONRenderer, + DocumentingHTMLRenderer +) + + +class MockPickleRenderer(BaseRenderer): + media_type = 'application/pickle' + + +class MockJsonRenderer(BaseRenderer): + media_type = 'application/json' + + +DUMMYSTATUS = status.HTTP_200_OK +DUMMYCONTENT = 'dummycontent' + +RENDERER_A_SERIALIZER = lambda x: 'Renderer A: %s' % x +RENDERER_B_SERIALIZER = lambda x: 'Renderer B: %s' % x + + +class RendererA(BaseRenderer): + media_type = 'mock/renderera' + format = "formata" + + def render(self, obj=None, media_type=None): + return RENDERER_A_SERIALIZER(obj) + + +class RendererB(BaseRenderer): + media_type = 'mock/rendererb' + format = "formatb" + + def render(self, obj=None, media_type=None): + return RENDERER_B_SERIALIZER(obj) + + +class MockView(APIView): + renderer_classes = (RendererA, RendererB) + + def get(self, request, **kwargs): + return Response(DUMMYCONTENT, status=DUMMYSTATUS) + + +class HTMLView(APIView): + renderer_classes = (DocumentingHTMLRenderer, ) + + def get(self, request, **kwargs): + return Response('text') + + +class HTMLView1(APIView): + renderer_classes = (DocumentingHTMLRenderer, JSONRenderer) + + def get(self, request, **kwargs): + return Response('text') + + +urlpatterns = patterns('', + url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB])), + url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB])), + url(r'^html$', HTMLView.as_view()), + url(r'^html1$', HTMLView1.as_view()), + url(r'^restframework', include('rest_framework.urls', namespace='rest_framework')) +) + + +# TODO: Clean tests bellow - remove duplicates with above, better unit testing, ... +class RendererIntegrationTests(TestCase): + """ + End-to-end testing of renderers using an ResponseMixin on a generic view. + """ + + urls = 'rest_framework.tests.response' + + def test_default_renderer_serializes_content(self): + """If the Accept header is not set the default renderer should serialize the response.""" + resp = self.client.get('/') + self.assertEquals(resp['Content-Type'], RendererA.media_type) + self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) + self.assertEquals(resp.status_code, DUMMYSTATUS) + + def test_head_method_serializes_no_content(self): + """No response must be included in HEAD requests.""" + resp = self.client.head('/') + self.assertEquals(resp.status_code, DUMMYSTATUS) + self.assertEquals(resp['Content-Type'], RendererA.media_type) + self.assertEquals(resp.content, '') + + def test_default_renderer_serializes_content_on_accept_any(self): + """If the Accept header is set to */* the default renderer should serialize the response.""" + resp = self.client.get('/', HTTP_ACCEPT='*/*') + self.assertEquals(resp['Content-Type'], RendererA.media_type) + self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) + self.assertEquals(resp.status_code, DUMMYSTATUS) + + def test_specified_renderer_serializes_content_default_case(self): + """If the Accept header is set the specified renderer should serialize the response. + (In this case we check that works for the default renderer)""" + resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type) + self.assertEquals(resp['Content-Type'], RendererA.media_type) + self.assertEquals(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) + self.assertEquals(resp.status_code, DUMMYSTATUS) + + def test_specified_renderer_serializes_content_non_default_case(self): + """If the Accept header is set the specified renderer should serialize the response. + (In this case we check that works for a non-default renderer)""" + resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type) + self.assertEquals(resp['Content-Type'], RendererB.media_type) + self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) + self.assertEquals(resp.status_code, DUMMYSTATUS) + + def test_specified_renderer_serializes_content_on_accept_query(self): + """The '_accept' query string should behave in the same way as the Accept header.""" + resp = self.client.get('/?_accept=%s' % RendererB.media_type) + self.assertEquals(resp['Content-Type'], RendererB.media_type) + self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) + self.assertEquals(resp.status_code, DUMMYSTATUS) + + @unittest.skip('can\'t pass because view is a simple Django view and response is an ImmediateResponse') + def test_unsatisfiable_accept_header_on_request_returns_406_status(self): + """If the Accept header is unsatisfiable we should return a 406 Not Acceptable response.""" + resp = self.client.get('/', HTTP_ACCEPT='foo/bar') + self.assertEquals(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE) + + def test_specified_renderer_serializes_content_on_format_query(self): + """If a 'format' query is specified, the renderer with the matching + format attribute should serialize the response.""" + resp = self.client.get('/?format=%s' % RendererB.format) + self.assertEquals(resp['Content-Type'], RendererB.media_type) + self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) + self.assertEquals(resp.status_code, DUMMYSTATUS) + + def test_specified_renderer_serializes_content_on_format_kwargs(self): + """If a 'format' keyword arg is specified, the renderer with the matching + format attribute should serialize the response.""" + resp = self.client.get('/something.formatb') + self.assertEquals(resp['Content-Type'], RendererB.media_type) + self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) + self.assertEquals(resp.status_code, DUMMYSTATUS) + + def test_specified_renderer_is_used_on_format_query_with_matching_accept(self): + """If both a 'format' query and a matching Accept header specified, + the renderer with the matching format attribute should serialize the response.""" + resp = self.client.get('/?format=%s' % RendererB.format, + HTTP_ACCEPT=RendererB.media_type) + self.assertEquals(resp['Content-Type'], RendererB.media_type) + self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) + self.assertEquals(resp.status_code, DUMMYSTATUS) + + +class Issue122Tests(TestCase): + """ + Tests that covers #122. + """ + urls = 'rest_framework.tests.response' + + def test_only_html_renderer(self): + """ + Test if no infinite recursion occurs. + """ + self.client.get('/html') + + def test_html_renderer_is_first(self): + """ + Test if no infinite recursion occurs. + """ + self.client.get('/html1') diff --git a/rest_framework/tests/reverse.py b/rest_framework/tests/reverse.py new file mode 100644 index 00000000..4027e42f --- /dev/null +++ b/rest_framework/tests/reverse.py @@ -0,0 +1,35 @@ +from django.conf.urls.defaults import patterns, url +from django.test import TestCase +from django.utils import simplejson as json + +from rest_framework.renderers import JSONRenderer +from rest_framework.reverse import reverse +from rest_framework.views import APIView +from rest_framework.response import Response + + +class MyView(APIView): + """ + Mock resource which simply returns a URL, so that we can ensure + that reversed URLs are fully qualified. + """ + renderers = (JSONRenderer, ) + + def get(self, request): + return Response(reverse('myview', request=request)) + + +urlpatterns = patterns('', + url(r'^myview$', MyView.as_view(), name='myview'), +) + + +class ReverseTests(TestCase): + """ + Tests for fully qualifed URLs when using `reverse`. + """ + urls = 'rest_framework.tests.reverse' + + def test_reversed_urls_are_fully_qualified(self): + response = self.client.get('/myview') + self.assertEqual(json.loads(response.content), 'http://testserver/myview') diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py new file mode 100644 index 00000000..16de2c66 --- /dev/null +++ b/rest_framework/tests/serializer.py @@ -0,0 +1,117 @@ +import datetime +from django.test import TestCase +from rest_framework import serializers + + +class Comment(object): + def __init__(self, email, content, created): + self.email = email + self.content = content + self.created = created or datetime.datetime.now() + + def __eq__(self, other): + return all([getattr(self, attr) == getattr(other, attr) + for attr in ('email', 'content', 'created')]) + + +class CommentSerializer(serializers.Serializer): + email = serializers.EmailField() + content = serializers.CharField(max_length=1000) + created = serializers.DateTimeField() + + def restore_object(self, data, instance=None): + if instance is None: + return Comment(**data) + for key, val in data.items(): + setattr(instance, key, val) + return instance + + +class BasicTests(TestCase): + def setUp(self): + self.comment = Comment( + 'tom@example.com', + 'Happy new year!', + datetime.datetime(2012, 1, 1) + ) + self.data = { + 'email': 'tom@example.com', + 'content': 'Happy new year!', + 'created': datetime.datetime(2012, 1, 1) + } + + def test_empty(self): + serializer = CommentSerializer() + expected = { + 'email': '', + 'content': '', + 'created': None + } + self.assertEquals(serializer.data, expected) + + def test_serialization(self): + serializer = CommentSerializer(instance=self.comment) + expected = self.data + self.assertEquals(serializer.data, expected) + + def test_deserialization_for_create(self): + serializer = CommentSerializer(self.data) + expected = self.comment + self.assertEquals(serializer.is_valid(), True) + self.assertEquals(serializer.object, expected) + self.assertFalse(serializer.object is expected) + + def test_deserialization_for_update(self): + serializer = CommentSerializer(self.data, instance=self.comment) + expected = self.comment + self.assertEquals(serializer.is_valid(), True) + self.assertEquals(serializer.object, expected) + self.assertTrue(serializer.object is expected) + + +class ValidationTests(TestCase): + def setUp(self): + self.comment = Comment( + 'tom@example.com', + 'Happy new year!', + datetime.datetime(2012, 1, 1) + ) + self.data = { + 'email': 'tom@example.com', + 'content': 'x' * 1001, + 'created': datetime.datetime(2012, 1, 1) + } + + def test_deserialization_for_create(self): + serializer = CommentSerializer(self.data) + self.assertEquals(serializer.is_valid(), False) + self.assertEquals(serializer.errors, {'content': [u'Ensure this value has at most 1000 characters (it has 1001).']}) + + def test_deserialization_for_update(self): + serializer = CommentSerializer(self.data, instance=self.comment) + self.assertEquals(serializer.is_valid(), False) + self.assertEquals(serializer.errors, {'content': [u'Ensure this value has at most 1000 characters (it has 1001).']}) + + +class MetadataTests(TestCase): + # def setUp(self): + # self.comment = Comment( + # 'tomchristie', + # 'Happy new year!', + # datetime.datetime(2012, 1, 1) + # ) + # self.data = { + # 'email': 'tomchristie', + # 'content': 'Happy new year!', + # 'created': datetime.datetime(2012, 1, 1) + # } + + def test_empty(self): + serializer = CommentSerializer() + expected = { + 'email': serializers.CharField, + 'content': serializers.CharField, + 'created': serializers.DateTimeField + } + for field_name, field in expected.items(): + self.assertTrue(isinstance(serializer.data.fields[field_name], field)) diff --git a/rest_framework/tests/status.py b/rest_framework/tests/status.py new file mode 100644 index 00000000..30df5cef --- /dev/null +++ b/rest_framework/tests/status.py @@ -0,0 +1,12 @@ +"""Tests for the status module""" +from django.test import TestCase +from rest_framework import status + + +class TestStatus(TestCase): + """Simple sanity test to check the status module""" + + def test_status(self): + """Ensure the status module is present and correct.""" + self.assertEquals(200, status.HTTP_200_OK) + self.assertEquals(404, status.HTTP_404_NOT_FOUND) diff --git a/rest_framework/tests/testcases.py b/rest_framework/tests/testcases.py new file mode 100644 index 00000000..c90224aa --- /dev/null +++ b/rest_framework/tests/testcases.py @@ -0,0 +1,63 @@ +# http://djangosnippets.org/snippets/1011/ +from django.conf import settings +from django.core.management import call_command +from django.db.models import loading +from django.test import TestCase + +NO_SETTING = ('!', None) + +class TestSettingsManager(object): + """ + A class which can modify some Django settings temporarily for a + test and then revert them to their original values later. + + Automatically handles resyncing the DB if INSTALLED_APPS is + modified. + + """ + def __init__(self): + self._original_settings = {} + + def set(self, **kwargs): + for k,v in kwargs.iteritems(): + self._original_settings.setdefault(k, getattr(settings, k, + NO_SETTING)) + setattr(settings, k, v) + if 'INSTALLED_APPS' in kwargs: + self.syncdb() + + def syncdb(self): + loading.cache.loaded = False + call_command('syncdb', verbosity=0) + + def revert(self): + for k,v in self._original_settings.iteritems(): + if v == NO_SETTING: + delattr(settings, k) + else: + setattr(settings, k, v) + if 'INSTALLED_APPS' in self._original_settings: + self.syncdb() + self._original_settings = {} + + +class SettingsTestCase(TestCase): + """ + A subclass of the Django TestCase with a settings_manager + attribute which is an instance of TestSettingsManager. + + Comes with a tearDown() method that calls + self.settings_manager.revert(). + + """ + def __init__(self, *args, **kwargs): + super(SettingsTestCase, self).__init__(*args, **kwargs) + self.settings_manager = TestSettingsManager() + + def tearDown(self): + self.settings_manager.revert() + +class TestModelsTestCase(SettingsTestCase): + def setUp(self, *args, **kwargs): + installed_apps = tuple(settings.INSTALLED_APPS) + ('rest_framework.tests',) + self.settings_manager.set(INSTALLED_APPS=installed_apps) diff --git a/rest_framework/tests/throttling.py b/rest_framework/tests/throttling.py new file mode 100644 index 00000000..0058a28e --- /dev/null +++ b/rest_framework/tests/throttling.py @@ -0,0 +1,144 @@ +""" +Tests for the throttling implementations in the permissions module. +""" + +from django.test import TestCase +from django.contrib.auth.models import User +from django.core.cache import cache + +from rest_framework.compat import RequestFactory +from rest_framework.views import APIView +from rest_framework.throttling import UserRateThrottle +from rest_framework.response import Response + + +class User3SecRateThrottle(UserRateThrottle): + rate = '3/sec' + scope = 'seconds' + + +class User3MinRateThrottle(UserRateThrottle): + rate = '3/min' + scope = 'minutes' + + +class MockView(APIView): + throttle_classes = (User3SecRateThrottle,) + + def get(self, request): + return Response('foo') + + +class MockView_MinuteThrottling(APIView): + throttle_classes = (User3MinRateThrottle,) + + def get(self, request): + return Response('foo') + + +class ThrottlingTests(TestCase): + urls = 'rest_framework.tests.throttling' + + def setUp(self): + """ + Reset the cache so that no throttles will be active + """ + cache.clear() + self.factory = RequestFactory() + + def test_requests_are_throttled(self): + """ + Ensure request rate is limited + """ + request = self.factory.get('/') + for dummy in range(4): + response = MockView.as_view()(request) + self.assertEqual(429, response.status_code) + + def set_throttle_timer(self, view, value): + """ + Explicitly set the timer, overriding time.time() + """ + view.throttle_classes[0].timer = lambda self: value + + def test_request_throttling_expires(self): + """ + Ensure request rate is limited for a limited duration only + """ + self.set_throttle_timer(MockView, 0) + + request = self.factory.get('/') + for dummy in range(4): + response = MockView.as_view()(request) + self.assertEqual(429, response.status_code) + + # Advance the timer by one second + self.set_throttle_timer(MockView, 1) + + response = MockView.as_view()(request) + self.assertEqual(200, response.status_code) + + def ensure_is_throttled(self, view, expect): + request = self.factory.get('/') + request.user = User.objects.create(username='a') + for dummy in range(3): + view.as_view()(request) + request.user = User.objects.create(username='b') + response = view.as_view()(request) + self.assertEqual(expect, response.status_code) + + def test_request_throttling_is_per_user(self): + """ + Ensure request rate is only limited per user, not globally for + PerUserThrottles + """ + self.ensure_is_throttled(MockView, 200) + + def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers): + """ + Ensure the response returns an X-Throttle field with status and next attributes + set properly. + """ + request = self.factory.get('/') + for timer, expect in expected_headers: + self.set_throttle_timer(view, timer) + response = view.as_view()(request) + if expect is not None: + self.assertEquals(response['X-Throttle-Wait-Seconds'], expect) + else: + self.assertFalse('X-Throttle-Wait-Seconds' in response.headers) + + def test_seconds_fields(self): + """ + Ensure for second based throttles. + """ + self.ensure_response_header_contains_proper_throttle_field(MockView, + ((0, None), + (0, None), + (0, None), + (0, '1') + )) + + def test_minutes_fields(self): + """ + Ensure for minute based throttles. + """ + self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling, + ((0, None), + (0, None), + (0, None), + (0, '60') + )) + + def test_next_rate_remains_constant_if_followed(self): + """ + If a client follows the recommended next request rate, + the throttling rate should stay constant. + """ + self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling, + ((0, None), + (20, None), + (40, None), + (60, None), + (80, None) + )) diff --git a/rest_framework/tests/validators.py b/rest_framework/tests/validators.py new file mode 100644 index 00000000..b390c42f --- /dev/null +++ b/rest_framework/tests/validators.py @@ -0,0 +1,329 @@ +# from django import forms +# from django.db import models +# from django.test import TestCase +# from rest_framework.response import ImmediateResponse +# from rest_framework.views import View + + +# class TestDisabledValidations(TestCase): +# """Tests on FormValidator with validation disabled by setting form to None""" + +# def test_disabled_form_validator_returns_content_unchanged(self): +# """If the view's form attribute is None then FormValidator(view).validate_request(content, None) +# should just return the content unmodified.""" +# class DisabledFormResource(FormResource): +# form = None + +# class MockView(View): +# resource = DisabledFormResource + +# view = MockView() +# content = {'qwerty': 'uiop'} +# self.assertEqual(FormResource(view).validate_request(content, None), content) + +# def test_disabled_form_validator_get_bound_form_returns_none(self): +# """If the view's form attribute is None on then +# FormValidator(view).get_bound_form(content) should just return None.""" +# class DisabledFormResource(FormResource): +# form = None + +# class MockView(View): +# resource = DisabledFormResource + +# view = MockView() +# content = {'qwerty': 'uiop'} +# self.assertEqual(FormResource(view).get_bound_form(content), None) + +# def test_disabled_model_form_validator_returns_content_unchanged(self): +# """If the view's form is None and does not have a Resource with a model set then +# ModelFormValidator(view).validate_request(content, None) should just return the content unmodified.""" + +# class DisabledModelFormView(View): +# resource = ModelResource + +# view = DisabledModelFormView() +# content = {'qwerty': 'uiop'} +# self.assertEqual(ModelResource(view).get_bound_form(content), None) + +# def test_disabled_model_form_validator_get_bound_form_returns_none(self): +# """If the form attribute is None on FormValidatorMixin then get_bound_form(content) should just return None.""" +# class DisabledModelFormView(View): +# resource = ModelResource + +# view = DisabledModelFormView() +# content = {'qwerty': 'uiop'} +# self.assertEqual(ModelResource(view).get_bound_form(content), None) + + +# class TestNonFieldErrors(TestCase): +# """Tests against form validation errors caused by non-field errors. (eg as might be caused by some custom form validation)""" + +# def test_validate_failed_due_to_non_field_error_returns_appropriate_message(self): +# """If validation fails with a non-field error, ensure the response a non-field error""" +# class MockForm(forms.Form): +# field1 = forms.CharField(required=False) +# field2 = forms.CharField(required=False) +# ERROR_TEXT = 'You may not supply both field1 and field2' + +# def clean(self): +# if 'field1' in self.cleaned_data and 'field2' in self.cleaned_data: +# raise forms.ValidationError(self.ERROR_TEXT) +# return self.cleaned_data + +# class MockResource(FormResource): +# form = MockForm + +# class MockView(View): +# pass + +# view = MockView() +# content = {'field1': 'example1', 'field2': 'example2'} +# try: +# MockResource(view).validate_request(content, None) +# except ImmediateResponse, exc: +# response = exc.response +# self.assertEqual(response.raw_content, {'errors': [MockForm.ERROR_TEXT]}) +# else: +# self.fail('ImmediateResponse was not raised') + + +# class TestFormValidation(TestCase): +# """Tests which check basic form validation. +# Also includes the same set of tests with a ModelFormValidator for which the form has been explicitly set. +# (ModelFormValidator should behave as FormValidator if a form is set rather than relying on the default ModelForm)""" +# def setUp(self): +# class MockForm(forms.Form): +# qwerty = forms.CharField(required=True) + +# class MockFormResource(FormResource): +# form = MockForm + +# class MockModelResource(ModelResource): +# form = MockForm + +# class MockFormView(View): +# resource = MockFormResource + +# class MockModelFormView(View): +# resource = MockModelResource + +# self.MockFormResource = MockFormResource +# self.MockModelResource = MockModelResource +# self.MockFormView = MockFormView +# self.MockModelFormView = MockModelFormView + +# def validation_returns_content_unchanged_if_already_valid_and_clean(self, validator): +# """If the content is already valid and clean then validate(content) should just return the content unmodified.""" +# content = {'qwerty': 'uiop'} +# self.assertEqual(validator.validate_request(content, None), content) + +# def validation_failure_raises_response_exception(self, validator): +# """If form validation fails a ResourceException 400 (Bad Request) should be raised.""" +# content = {} +# self.assertRaises(ImmediateResponse, validator.validate_request, content, None) + +# def validation_does_not_allow_extra_fields_by_default(self, validator): +# """If some (otherwise valid) content includes fields that are not in the form then validation should fail. +# It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up +# broken clients more easily (eg submitting content with a misnamed field)""" +# content = {'qwerty': 'uiop', 'extra': 'extra'} +# self.assertRaises(ImmediateResponse, validator.validate_request, content, None) + +# def validation_allows_extra_fields_if_explicitly_set(self, validator): +# """If we include an allowed_extra_fields paramater on _validate, then allow fields with those names.""" +# content = {'qwerty': 'uiop', 'extra': 'extra'} +# validator._validate(content, None, allowed_extra_fields=('extra',)) + +# def validation_allows_unknown_fields_if_explicitly_allowed(self, validator): +# """If we set ``unknown_form_fields`` on the form resource, then don't +# raise errors on unexpected request data""" +# content = {'qwerty': 'uiop', 'extra': 'extra'} +# validator.allow_unknown_form_fields = True +# self.assertEqual({'qwerty': u'uiop'}, +# validator.validate_request(content, None), +# "Resource didn't accept unknown fields.") +# validator.allow_unknown_form_fields = False + +# def validation_does_not_require_extra_fields_if_explicitly_set(self, validator): +# """If we include an allowed_extra_fields paramater on _validate, then do not fail if we do not have fields with those names.""" +# content = {'qwerty': 'uiop'} +# self.assertEqual(validator._validate(content, None, allowed_extra_fields=('extra',)), content) + +# def validation_failed_due_to_no_content_returns_appropriate_message(self, validator): +# """If validation fails due to no content, ensure the response contains a single non-field error""" +# content = {} +# try: +# validator.validate_request(content, None) +# except ImmediateResponse, exc: +# response = exc.response +# self.assertEqual(response.raw_content, {'field_errors': {'qwerty': ['This field is required.']}}) +# else: +# self.fail('ResourceException was not raised') + +# def validation_failed_due_to_field_error_returns_appropriate_message(self, validator): +# """If validation fails due to a field error, ensure the response contains a single field error""" +# content = {'qwerty': ''} +# try: +# validator.validate_request(content, None) +# except ImmediateResponse, exc: +# response = exc.response +# self.assertEqual(response.raw_content, {'field_errors': {'qwerty': ['This field is required.']}}) +# else: +# self.fail('ResourceException was not raised') + +# def validation_failed_due_to_invalid_field_returns_appropriate_message(self, validator): +# """If validation fails due to an invalid field, ensure the response contains a single field error""" +# content = {'qwerty': 'uiop', 'extra': 'extra'} +# try: +# validator.validate_request(content, None) +# except ImmediateResponse, exc: +# response = exc.response +# self.assertEqual(response.raw_content, {'field_errors': {'extra': ['This field does not exist.']}}) +# else: +# self.fail('ResourceException was not raised') + +# def validation_failed_due_to_multiple_errors_returns_appropriate_message(self, validator): +# """If validation for multiple reasons, ensure the response contains each error""" +# content = {'qwerty': '', 'extra': 'extra'} +# try: +# validator.validate_request(content, None) +# except ImmediateResponse, exc: +# response = exc.response +# self.assertEqual(response.raw_content, {'field_errors': {'qwerty': ['This field is required.'], +# 'extra': ['This field does not exist.']}}) +# else: +# self.fail('ResourceException was not raised') + +# # Tests on FormResource + +# def test_form_validation_returns_content_unchanged_if_already_valid_and_clean(self): +# validator = self.MockFormResource(self.MockFormView()) +# self.validation_returns_content_unchanged_if_already_valid_and_clean(validator) + +# def test_form_validation_failure_raises_response_exception(self): +# validator = self.MockFormResource(self.MockFormView()) +# self.validation_failure_raises_response_exception(validator) + +# def test_validation_does_not_allow_extra_fields_by_default(self): +# validator = self.MockFormResource(self.MockFormView()) +# self.validation_does_not_allow_extra_fields_by_default(validator) + +# def test_validation_allows_extra_fields_if_explicitly_set(self): +# validator = self.MockFormResource(self.MockFormView()) +# self.validation_allows_extra_fields_if_explicitly_set(validator) + +# def test_validation_allows_unknown_fields_if_explicitly_allowed(self): +# validator = self.MockFormResource(self.MockFormView()) +# self.validation_allows_unknown_fields_if_explicitly_allowed(validator) + +# def test_validation_does_not_require_extra_fields_if_explicitly_set(self): +# validator = self.MockFormResource(self.MockFormView()) +# self.validation_does_not_require_extra_fields_if_explicitly_set(validator) + +# def test_validation_failed_due_to_no_content_returns_appropriate_message(self): +# validator = self.MockFormResource(self.MockFormView()) +# self.validation_failed_due_to_no_content_returns_appropriate_message(validator) + +# def test_validation_failed_due_to_field_error_returns_appropriate_message(self): +# validator = self.MockFormResource(self.MockFormView()) +# self.validation_failed_due_to_field_error_returns_appropriate_message(validator) + +# def test_validation_failed_due_to_invalid_field_returns_appropriate_message(self): +# validator = self.MockFormResource(self.MockFormView()) +# self.validation_failed_due_to_invalid_field_returns_appropriate_message(validator) + +# def test_validation_failed_due_to_multiple_errors_returns_appropriate_message(self): +# validator = self.MockFormResource(self.MockFormView()) +# self.validation_failed_due_to_multiple_errors_returns_appropriate_message(validator) + +# # Same tests on ModelResource + +# def test_modelform_validation_returns_content_unchanged_if_already_valid_and_clean(self): +# validator = self.MockModelResource(self.MockModelFormView()) +# self.validation_returns_content_unchanged_if_already_valid_and_clean(validator) + +# def test_modelform_validation_failure_raises_response_exception(self): +# validator = self.MockModelResource(self.MockModelFormView()) +# self.validation_failure_raises_response_exception(validator) + +# def test_modelform_validation_does_not_allow_extra_fields_by_default(self): +# validator = self.MockModelResource(self.MockModelFormView()) +# self.validation_does_not_allow_extra_fields_by_default(validator) + +# def test_modelform_validation_allows_extra_fields_if_explicitly_set(self): +# validator = self.MockModelResource(self.MockModelFormView()) +# self.validation_allows_extra_fields_if_explicitly_set(validator) + +# def test_modelform_validation_does_not_require_extra_fields_if_explicitly_set(self): +# validator = self.MockModelResource(self.MockModelFormView()) +# self.validation_does_not_require_extra_fields_if_explicitly_set(validator) + +# def test_modelform_validation_failed_due_to_no_content_returns_appropriate_message(self): +# validator = self.MockModelResource(self.MockModelFormView()) +# self.validation_failed_due_to_no_content_returns_appropriate_message(validator) + +# def test_modelform_validation_failed_due_to_field_error_returns_appropriate_message(self): +# validator = self.MockModelResource(self.MockModelFormView()) +# self.validation_failed_due_to_field_error_returns_appropriate_message(validator) + +# def test_modelform_validation_failed_due_to_invalid_field_returns_appropriate_message(self): +# validator = self.MockModelResource(self.MockModelFormView()) +# self.validation_failed_due_to_invalid_field_returns_appropriate_message(validator) + +# def test_modelform_validation_failed_due_to_multiple_errors_returns_appropriate_message(self): +# validator = self.MockModelResource(self.MockModelFormView()) +# self.validation_failed_due_to_multiple_errors_returns_appropriate_message(validator) + + +# class TestModelFormValidator(TestCase): +# """Tests specific to ModelFormValidatorMixin""" + +# def setUp(self): +# """Create a validator for a model with two fields and a property.""" +# class MockModel(models.Model): +# qwerty = models.CharField(max_length=256) +# uiop = models.CharField(max_length=256, blank=True) + +# @property +# def readonly(self): +# return 'read only' + +# class MockResource(ModelResource): +# model = MockModel + +# class MockView(View): +# resource = MockResource + +# self.validator = MockResource(MockView) + +# def test_property_fields_are_allowed_on_model_forms(self): +# """Validation on ModelForms may include property fields that exist on the Model to be included in the input.""" +# content = {'qwerty': 'example', 'uiop': 'example', 'readonly': 'read only'} +# self.assertEqual(self.validator.validate_request(content, None), content) + +# def test_property_fields_are_not_required_on_model_forms(self): +# """Validation on ModelForms does not require property fields that exist on the Model to be included in the input.""" +# content = {'qwerty': 'example', 'uiop': 'example'} +# self.assertEqual(self.validator.validate_request(content, None), content) + +# def test_extra_fields_not_allowed_on_model_forms(self): +# """If some (otherwise valid) content includes fields that are not in the form then validation should fail. +# It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up +# broken clients more easily (eg submitting content with a misnamed field)""" +# content = {'qwerty': 'example', 'uiop': 'example', 'readonly': 'read only', 'extra': 'extra'} +# self.assertRaises(ImmediateResponse, self.validator.validate_request, content, None) + +# def test_validate_requires_fields_on_model_forms(self): +# """If some (otherwise valid) content includes fields that are not in the form then validation should fail. +# It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up +# broken clients more easily (eg submitting content with a misnamed field)""" +# content = {'readonly': 'read only'} +# self.assertRaises(ImmediateResponse, self.validator.validate_request, content, None) + +# def test_validate_does_not_require_blankable_fields_on_model_forms(self): +# """Test standard ModelForm validation behaviour - fields with blank=True are not required.""" +# content = {'qwerty': 'example', 'readonly': 'read only'} +# self.validator.validate_request(content, None) + +# def test_model_form_validator_uses_model_forms(self): +# self.assertTrue(isinstance(self.validator.get_bound_form(), forms.ModelForm)) diff --git a/rest_framework/tests/views.py b/rest_framework/tests/views.py new file mode 100644 index 00000000..cd1f73c3 --- /dev/null +++ b/rest_framework/tests/views.py @@ -0,0 +1,128 @@ +# from django.core.urlresolvers import reverse +# from django.conf.urls.defaults import patterns, url, include +# from django.http import HttpResponse +# from django.test import TestCase +# from django.utils import simplejson as json + +# from rest_framework.views import View + + +# class MockView(View): +# """This is a basic mock view""" +# pass + + +# class MockViewFinal(View): +# """View with final() override""" + +# def final(self, request, response, *args, **kwargs): +# return HttpResponse('{"test": "passed"}', content_type="application/json") + + +# # class ResourceMockView(View): +# # """This is a resource-based mock view""" + +# # class MockForm(forms.Form): +# # foo = forms.BooleanField(required=False) +# # bar = forms.IntegerField(help_text='Must be an integer.') +# # baz = forms.CharField(max_length=32) + +# # form = MockForm + + +# # class MockResource(ModelResource): +# # """This is a mock model-based resource""" + +# # class MockResourceModel(models.Model): +# # foo = models.BooleanField() +# # bar = models.IntegerField(help_text='Must be an integer.') +# # baz = models.CharField(max_length=32, help_text='Free text. Max length 32 chars.') + +# # model = MockResourceModel +# # fields = ('foo', 'bar', 'baz') + +# urlpatterns = patterns('', +# url(r'^mock/$', MockView.as_view()), +# url(r'^mock/final/$', MockViewFinal.as_view()), +# # url(r'^resourcemock/$', ResourceMockView.as_view()), +# # url(r'^model/$', ListOrCreateModelView.as_view(resource=MockResource)), +# # url(r'^model/(?P<pk>[^/]+)/$', InstanceModelView.as_view(resource=MockResource)), +# url(r'^restframework/', include('rest_framework.urls', namespace='rest_framework')), +# ) + + +# class BaseViewTests(TestCase): +# """Test the base view class of rest_framework""" +# urls = 'rest_framework.tests.views' + +# def test_view_call_final(self): +# response = self.client.options('/mock/final/') +# self.assertEqual(response['Content-Type'].split(';')[0], "application/json") +# data = json.loads(response.content) +# self.assertEqual(data['test'], 'passed') + +# def test_options_method_simple_view(self): +# response = self.client.options('/mock/') +# self._verify_options_response(response, +# name='Mock', +# description='This is a basic mock view') + +# def test_options_method_resource_view(self): +# response = self.client.options('/resourcemock/') +# self._verify_options_response(response, +# name='Resource Mock', +# description='This is a resource-based mock view', +# fields={'foo': 'BooleanField', +# 'bar': 'IntegerField', +# 'baz': 'CharField', +# }) + +# def test_options_method_model_resource_list_view(self): +# response = self.client.options('/model/') +# self._verify_options_response(response, +# name='Mock List', +# description='This is a mock model-based resource', +# fields={'foo': 'BooleanField', +# 'bar': 'IntegerField', +# 'baz': 'CharField', +# }) + +# def test_options_method_model_resource_detail_view(self): +# response = self.client.options('/model/0/') +# self._verify_options_response(response, +# name='Mock Instance', +# description='This is a mock model-based resource', +# fields={'foo': 'BooleanField', +# 'bar': 'IntegerField', +# 'baz': 'CharField', +# }) + +# def _verify_options_response(self, response, name, description, fields=None, status=200, +# mime_type='application/json'): +# self.assertEqual(response.status_code, status) +# self.assertEqual(response['Content-Type'].split(';')[0], mime_type) +# data = json.loads(response.content) +# self.assertTrue('application/json' in data['renders']) +# self.assertEqual(name, data['name']) +# self.assertEqual(description, data['description']) +# if fields is None: +# self.assertFalse(hasattr(data, 'fields')) +# else: +# self.assertEqual(data['fields'], fields) + + +# class ExtraViewsTests(TestCase): +# """Test the extra views rest_framework provides""" +# urls = 'rest_framework.tests.views' + +# def test_login_view(self): +# """Ensure the login view exists""" +# response = self.client.get(reverse('rest_framework:login')) +# self.assertEqual(response.status_code, 200) +# self.assertEqual(response['Content-Type'].split(';')[0], 'text/html') + +# def test_logout_view(self): +# """Ensure the logout view exists""" +# response = self.client.get(reverse('rest_framework:logout')) +# self.assertEqual(response.status_code, 200) +# self.assertEqual(response['Content-Type'].split(';')[0], 'text/html') |
