aboutsummaryrefslogtreecommitdiffstats
path: root/rest_framework/tests
diff options
context:
space:
mode:
authorTom Christie2013-01-26 21:37:43 +0000
committerTom Christie2013-01-26 21:37:43 +0000
commitb5d8f50f9dcace3ad3c708ed518f23ff260f6bea (patch)
tree7a0fb931ad5841b863359719075dd55f8370b3a1 /rest_framework/tests
parent4eb5861f3676781493af29f8e9fd87ec22e591aa (diff)
parenta75db4cfb8ed756c451bfda7ea0c73a73859216f (diff)
downloaddjango-rest-framework-b5d8f50f9dcace3ad3c708ed518f23ff260f6bea.tar.bz2
Merge branch 'master' into many-fields
Diffstat (limited to 'rest_framework/tests')
-rw-r--r--rest_framework/tests/authentication.py41
-rw-r--r--rest_framework/tests/decorators.py22
-rw-r--r--rest_framework/tests/genericrelations.py96
-rw-r--r--rest_framework/tests/models.py21
-rw-r--r--rest_framework/tests/pagination.py43
-rw-r--r--rest_framework/tests/serializer.py88
-rw-r--r--rest_framework/tests/urlpatterns.py78
7 files changed, 300 insertions, 89 deletions
diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py
index e86041bc..1f17e8d2 100644
--- a/rest_framework/tests/authentication.py
+++ b/rest_framework/tests/authentication.py
@@ -4,7 +4,7 @@ from django.test import Client, TestCase
from rest_framework import permissions
from rest_framework.authtoken.models import Token
-from rest_framework.authentication import TokenAuthentication
+from rest_framework.authentication import TokenAuthentication, BasicAuthentication, SessionAuthentication
from rest_framework.compat import patterns
from rest_framework.views import APIView
@@ -21,10 +21,10 @@ class MockView(APIView):
def put(self, request):
return HttpResponse({'a': 1, 'b': 2, 'c': 3})
-MockView.authentication_classes += (TokenAuthentication,)
-
urlpatterns = patterns('',
- (r'^$', MockView.as_view()),
+ (r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])),
+ (r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])),
+ (r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])),
(r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'),
)
@@ -43,24 +43,25 @@ class BasicAuthTests(TestCase):
def test_post_form_passing_basic_auth(self):
"""Ensure POSTing json over basic auth with correct credentials passes and does not require CSRF"""
auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).strip()
- response = self.csrf_client.post('/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.post('/basic/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
def test_post_json_passing_basic_auth(self):
"""Ensure POSTing form over basic auth with correct credentials passes and does not require CSRF"""
auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).strip()
- response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
def test_post_form_failing_basic_auth(self):
"""Ensure POSTing form over basic auth without correct credentials fails"""
- response = self.csrf_client.post('/', {'example': 'example'})
- self.assertEqual(response.status_code, 403)
+ response = self.csrf_client.post('/basic/', {'example': 'example'})
+ self.assertEqual(response.status_code, 401)
def test_post_json_failing_basic_auth(self):
"""Ensure POSTing json over basic auth without correct credentials fails"""
- response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json')
- self.assertEqual(response.status_code, 403)
+ response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json')
+ self.assertEqual(response.status_code, 401)
+ self.assertEqual(response['WWW-Authenticate'], 'Basic realm="api"')
class SessionAuthTests(TestCase):
@@ -83,7 +84,7 @@ class SessionAuthTests(TestCase):
Ensure POSTing form over session authentication without CSRF token fails.
"""
self.csrf_client.login(username=self.username, password=self.password)
- response = self.csrf_client.post('/', {'example': 'example'})
+ response = self.csrf_client.post('/session/', {'example': 'example'})
self.assertEqual(response.status_code, 403)
def test_post_form_session_auth_passing(self):
@@ -91,7 +92,7 @@ class SessionAuthTests(TestCase):
Ensure POSTing form over session authentication with logged in user and CSRF token passes.
"""
self.non_csrf_client.login(username=self.username, password=self.password)
- response = self.non_csrf_client.post('/', {'example': 'example'})
+ response = self.non_csrf_client.post('/session/', {'example': 'example'})
self.assertEqual(response.status_code, 200)
def test_put_form_session_auth_passing(self):
@@ -99,14 +100,14 @@ class SessionAuthTests(TestCase):
Ensure PUTting form over session authentication with logged in user and CSRF token passes.
"""
self.non_csrf_client.login(username=self.username, password=self.password)
- response = self.non_csrf_client.put('/', {'example': 'example'})
+ response = self.non_csrf_client.put('/session/', {'example': 'example'})
self.assertEqual(response.status_code, 200)
def test_post_form_session_auth_failing(self):
"""
Ensure POSTing form over session authentication without logged in user fails.
"""
- response = self.csrf_client.post('/', {'example': 'example'})
+ response = self.csrf_client.post('/session/', {'example': 'example'})
self.assertEqual(response.status_code, 403)
@@ -127,24 +128,24 @@ class TokenAuthTests(TestCase):
def test_post_form_passing_token_auth(self):
"""Ensure POSTing json over token auth with correct credentials passes and does not require CSRF"""
auth = "Token " + self.key
- response = self.csrf_client.post('/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
def test_post_json_passing_token_auth(self):
"""Ensure POSTing form over token auth with correct credentials passes and does not require CSRF"""
auth = "Token " + self.key
- response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth)
+ response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
def test_post_form_failing_token_auth(self):
"""Ensure POSTing form over token auth without correct credentials fails"""
- response = self.csrf_client.post('/', {'example': 'example'})
- self.assertEqual(response.status_code, 403)
+ response = self.csrf_client.post('/token/', {'example': 'example'})
+ self.assertEqual(response.status_code, 401)
def test_post_json_failing_token_auth(self):
"""Ensure POSTing json over token auth without correct credentials fails"""
- response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json')
- self.assertEqual(response.status_code, 403)
+ response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json')
+ self.assertEqual(response.status_code, 401)
def test_token_has_auto_assigned_key_if_none_provided(self):
"""Ensure creating a token with no key will auto-assign a key"""
diff --git a/rest_framework/tests/decorators.py b/rest_framework/tests/decorators.py
index 5e6bce4e..82f912e9 100644
--- a/rest_framework/tests/decorators.py
+++ b/rest_framework/tests/decorators.py
@@ -28,13 +28,27 @@ class DecoratorTestCase(TestCase):
response.request = request
return APIView.finalize_response(self, request, response, *args, **kwargs)
- def test_wrap_view(self):
+ def test_api_view_incorrect(self):
+ """
+ If @api_view is not applied correct, we should raise an assertion.
+ """
- @api_view(['GET'])
+ @api_view
def view(request):
- return Response({})
+ return Response()
+
+ request = self.factory.get('/')
+ self.assertRaises(AssertionError, view, request)
+
+ def test_api_view_incorrect_arguments(self):
+ """
+ If @api_view is missing arguments, we should raise an assertion.
+ """
- self.assertTrue(isinstance(view.cls_instance, APIView))
+ with self.assertRaises(AssertionError):
+ @api_view('GET')
+ def view(request):
+ return Response()
def test_calling_method(self):
diff --git a/rest_framework/tests/genericrelations.py b/rest_framework/tests/genericrelations.py
index bc7378e1..146ad1e4 100644
--- a/rest_framework/tests/genericrelations.py
+++ b/rest_framework/tests/genericrelations.py
@@ -1,25 +1,61 @@
+from django.contrib.contenttypes.models import ContentType
+from django.contrib.contenttypes.generic import GenericRelation, GenericForeignKey
+from django.db import models
from django.test import TestCase
from rest_framework import serializers
-from rest_framework.tests.models import *
+
+
+class Tag(models.Model):
+ """
+ Tags have a descriptive slug, and are attached to an arbitrary object.
+ """
+ tag = models.SlugField()
+ content_type = models.ForeignKey(ContentType)
+ object_id = models.PositiveIntegerField()
+ tagged_item = GenericForeignKey('content_type', 'object_id')
+
+ def __unicode__(self):
+ return self.tag
+
+
+class Bookmark(models.Model):
+ """
+ A URL bookmark that may have multiple tags attached.
+ """
+ url = models.URLField()
+ tags = GenericRelation(Tag)
+
+ def __unicode__(self):
+ return 'Bookmark: %s' % self.url
+
+
+class Note(models.Model):
+ """
+ A textual note that may have multiple tags attached.
+ """
+ text = models.TextField()
+ tags = GenericRelation(Tag)
+
+ def __unicode__(self):
+ return 'Note: %s' % self.text
class TestGenericRelations(TestCase):
def setUp(self):
- bookmark = Bookmark(url='https://www.djangoproject.com/')
- bookmark.save()
- django = Tag(tag_name='django')
- django.save()
- python = Tag(tag_name='python')
- python.save()
- t1 = TaggedItem(content_object=bookmark, tag=django)
- t1.save()
- t2 = TaggedItem(content_object=bookmark, tag=python)
- t2.save()
- self.bookmark = bookmark
-
- def test_reverse_generic_relation(self):
+ self.bookmark = Bookmark.objects.create(url='https://www.djangoproject.com/')
+ Tag.objects.create(tagged_item=self.bookmark, tag='django')
+ Tag.objects.create(tagged_item=self.bookmark, tag='python')
+ self.note = Note.objects.create(text='Remember the milk')
+ Tag.objects.create(tagged_item=self.note, tag='reminder')
+
+ def test_generic_relation(self):
+ """
+ Test a relationship that spans a GenericRelation field.
+ IE. A reverse generic relationship.
+ """
+
class BookmarkSerializer(serializers.ModelSerializer):
- tags = serializers.ManyRelatedField(source='tags')
+ tags = serializers.ManyRelatedField()
class Meta:
model = Bookmark
@@ -31,3 +67,33 @@ class TestGenericRelations(TestCase):
'url': u'https://www.djangoproject.com/'
}
self.assertEquals(serializer.data, expected)
+
+ def test_generic_fk(self):
+ """
+ Test a relationship that spans a GenericForeignKey field.
+ IE. A forward generic relationship.
+ """
+
+ class TagSerializer(serializers.ModelSerializer):
+ tagged_item = serializers.RelatedField()
+
+ class Meta:
+ model = Tag
+ exclude = ('id', 'content_type', 'object_id')
+
+ serializer = TagSerializer(Tag.objects.all())
+ expected = [
+ {
+ 'tag': u'django',
+ 'tagged_item': u'Bookmark: https://www.djangoproject.com/'
+ },
+ {
+ 'tag': u'python',
+ 'tagged_item': u'Bookmark: https://www.djangoproject.com/'
+ },
+ {
+ 'tag': u'reminder',
+ 'tagged_item': u'Note: Remember the milk'
+ }
+ ]
+ self.assertEquals(serializer.data, expected)
diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py
index 93f09761..9ab15328 100644
--- a/rest_framework/tests/models.py
+++ b/rest_framework/tests/models.py
@@ -86,27 +86,6 @@ class ReadOnlyManyToManyModel(RESTFrameworkModel):
text = models.CharField(max_length=100, default='anchor')
rel = models.ManyToManyField(Anchor)
-# Models to test generic relations
-
-
-class Tag(RESTFrameworkModel):
- tag_name = models.SlugField()
-
-
-class TaggedItem(RESTFrameworkModel):
- tag = models.ForeignKey(Tag, related_name='items')
- content_type = models.ForeignKey(ContentType)
- object_id = models.PositiveIntegerField()
- content_object = GenericForeignKey('content_type', 'object_id')
-
- def __unicode__(self):
- return self.tag.tag_name
-
-
-class Bookmark(RESTFrameworkModel):
- url = models.URLField()
- tags = GenericRelation(TaggedItem)
-
# Model to test filtering.
class FilterableItem(RESTFrameworkModel):
diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py
index 3b550877..697dfb5b 100644
--- a/rest_framework/tests/pagination.py
+++ b/rest_framework/tests/pagination.py
@@ -252,6 +252,8 @@ class TestCustomPaginateByParam(TestCase):
self.assertEquals(response.data['results'], self.data[:5])
+### Tests for context in pagination serializers
+
class CustomField(serializers.Field):
def to_native(self, value):
if not 'view' in self.context:
@@ -262,6 +264,11 @@ class CustomField(serializers.Field):
class BasicModelSerializer(serializers.Serializer):
text = CustomField()
+ def __init__(self, *args, **kwargs):
+ super(BasicModelSerializer, self).__init__(*args, **kwargs)
+ if not 'view' in self.context:
+ raise RuntimeError("context isn't getting passed into serializer init")
+
class TestContextPassedToCustomField(TestCase):
def setUp(self):
@@ -279,3 +286,39 @@ class TestContextPassedToCustomField(TestCase):
self.assertEquals(response.status_code, status.HTTP_200_OK)
+
+### Tests for custom pagination serializers
+
+class LinksSerializer(serializers.Serializer):
+ next = pagination.NextPageField(source='*')
+ prev = pagination.PreviousPageField(source='*')
+
+
+class CustomPaginationSerializer(pagination.BasePaginationSerializer):
+ links = LinksSerializer(source='*') # Takes the page object as the source
+ total_results = serializers.Field(source='paginator.count')
+
+ results_field = 'objects'
+
+
+class TestCustomPaginationSerializer(TestCase):
+ def setUp(self):
+ objects = ['john', 'paul', 'george', 'ringo']
+ paginator = Paginator(objects, 2)
+ self.page = paginator.page(1)
+
+ def test_custom_pagination_serializer(self):
+ request = RequestFactory().get('/foobar')
+ serializer = CustomPaginationSerializer(
+ instance=self.page,
+ context={'request': request}
+ )
+ expected = {
+ 'links': {
+ 'next': 'http://testserver/foobar?page=2',
+ 'prev': None
+ },
+ 'total_results': 4,
+ 'objects': ['john', 'paul']
+ }
+ self.assertEquals(serializer.data, expected)
diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py
index b2d62ade..4724348e 100644
--- a/rest_framework/tests/serializer.py
+++ b/rest_framework/tests/serializer.py
@@ -163,7 +163,6 @@ class BasicTests(TestCase):
"""
Attempting to update fields set as read_only should have no effect.
"""
-
serializer = PersonSerializer(self.person, data={'name': 'dwight', 'age': 99})
self.assertEquals(serializer.is_valid(), True)
instance = serializer.save()
@@ -184,8 +183,7 @@ class ValidationTests(TestCase):
'content': 'x' * 1001,
'created': datetime.datetime(2012, 1, 1)
}
- self.actionitem = ActionItem(title='Some to do item',
- )
+ self.actionitem = ActionItem(title='Some to do item',)
def test_create(self):
serializer = CommentSerializer(data=self.data)
@@ -217,31 +215,6 @@ class ValidationTests(TestCase):
self.assertEquals(serializer.is_valid(), True)
self.assertEquals(serializer.errors, {})
- def test_field_validation(self):
-
- class CommentSerializerWithFieldValidator(CommentSerializer):
-
- def validate_content(self, attrs, source):
- value = attrs[source]
- if "test" not in value:
- raise serializers.ValidationError("Test not in value")
- return attrs
-
- data = {
- 'email': 'tom@example.com',
- 'content': 'A test comment',
- 'created': datetime.datetime(2012, 1, 1)
- }
-
- serializer = CommentSerializerWithFieldValidator(data=data)
- self.assertTrue(serializer.is_valid())
-
- data['content'] = 'This should not validate'
-
- serializer = CommentSerializerWithFieldValidator(data=data)
- self.assertFalse(serializer.is_valid())
- self.assertEquals(serializer.errors, {'content': [u'Test not in value']})
-
def test_bad_type_data_is_false(self):
"""
Data of the wrong type is not valid.
@@ -311,12 +284,69 @@ class ValidationTests(TestCase):
self.assertEquals(serializer.errors, {'info': [u'Ensure this value has at most 12 characters (it has 13).']})
+class CustomValidationTests(TestCase):
+ class CommentSerializerWithFieldValidator(CommentSerializer):
+
+ def validate_email(self, attrs, source):
+ value = attrs[source]
+
+ return attrs
+
+ def validate_content(self, attrs, source):
+ value = attrs[source]
+ if "test" not in value:
+ raise serializers.ValidationError("Test not in value")
+ return attrs
+
+ def test_field_validation(self):
+ data = {
+ 'email': 'tom@example.com',
+ 'content': 'A test comment',
+ 'created': datetime.datetime(2012, 1, 1)
+ }
+
+ serializer = self.CommentSerializerWithFieldValidator(data=data)
+ self.assertTrue(serializer.is_valid())
+
+ data['content'] = 'This should not validate'
+
+ serializer = self.CommentSerializerWithFieldValidator(data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEquals(serializer.errors, {'content': [u'Test not in value']})
+
+ def test_missing_data(self):
+ """
+ Make sure that validate_content isn't called if the field is missing
+ """
+ incomplete_data = {
+ 'email': 'tom@example.com',
+ 'created': datetime.datetime(2012, 1, 1)
+ }
+ serializer = self.CommentSerializerWithFieldValidator(data=incomplete_data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEquals(serializer.errors, {'content': [u'This field is required.']})
+
+ def test_wrong_data(self):
+ """
+ Make sure that validate_content isn't called if the field input is wrong
+ """
+ wrong_data = {
+ 'email': 'not an email',
+ 'content': 'A test comment',
+ 'created': datetime.datetime(2012, 1, 1)
+ }
+ serializer = self.CommentSerializerWithFieldValidator(data=wrong_data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEquals(serializer.errors, {'email': [u'Enter a valid e-mail address.']})
+
+
class PositiveIntegerAsChoiceTests(TestCase):
def test_positive_integer_in_json_is_correctly_parsed(self):
- data = {'some_integer':1}
+ data = {'some_integer': 1}
serializer = PositiveIntegerAsChoiceSerializer(data=data)
self.assertEquals(serializer.is_valid(), True)
+
class ModelValidationTests(TestCase):
def test_validate_unique(self):
"""
diff --git a/rest_framework/tests/urlpatterns.py b/rest_framework/tests/urlpatterns.py
new file mode 100644
index 00000000..43e8ef69
--- /dev/null
+++ b/rest_framework/tests/urlpatterns.py
@@ -0,0 +1,78 @@
+from collections import namedtuple
+
+from django.core import urlresolvers
+
+from django.test import TestCase
+from django.test.client import RequestFactory
+
+from rest_framework.compat import patterns, url, include
+from rest_framework.urlpatterns import format_suffix_patterns
+
+
+# A container class for test paths for the test case
+URLTestPath = namedtuple('URLTestPath', ['path', 'args', 'kwargs'])
+
+
+def dummy_view(request, *args, **kwargs):
+ pass
+
+
+class FormatSuffixTests(TestCase):
+ """
+ Tests `format_suffix_patterns` against different URLPatterns to ensure the URLs still resolve properly, including any captured parameters.
+ """
+ def _resolve_urlpatterns(self, urlpatterns, test_paths):
+ factory = RequestFactory()
+ try:
+ urlpatterns = format_suffix_patterns(urlpatterns)
+ except:
+ self.fail("Failed to apply `format_suffix_patterns` on the supplied urlpatterns")
+ resolver = urlresolvers.RegexURLResolver(r'^/', urlpatterns)
+ for test_path in test_paths:
+ request = factory.get(test_path.path)
+ try:
+ callback, callback_args, callback_kwargs = resolver.resolve(request.path_info)
+ except:
+ self.fail("Failed to resolve URL: %s" % request.path_info)
+ self.assertEquals(callback_args, test_path.args)
+ self.assertEquals(callback_kwargs, test_path.kwargs)
+
+ def test_format_suffix(self):
+ urlpatterns = patterns(
+ '',
+ url(r'^test$', dummy_view),
+ )
+ test_paths = [
+ URLTestPath('/test', (), {}),
+ URLTestPath('/test.api', (), {'format': 'api'}),
+ URLTestPath('/test.asdf', (), {'format': 'asdf'}),
+ ]
+ self._resolve_urlpatterns(urlpatterns, test_paths)
+
+ def test_default_args(self):
+ urlpatterns = patterns(
+ '',
+ url(r'^test$', dummy_view, {'foo': 'bar'}),
+ )
+ test_paths = [
+ URLTestPath('/test', (), {'foo': 'bar', }),
+ URLTestPath('/test.api', (), {'foo': 'bar', 'format': 'api'}),
+ URLTestPath('/test.asdf', (), {'foo': 'bar', 'format': 'asdf'}),
+ ]
+ self._resolve_urlpatterns(urlpatterns, test_paths)
+
+ def test_included_urls(self):
+ nested_patterns = patterns(
+ '',
+ url(r'^path$', dummy_view)
+ )
+ urlpatterns = patterns(
+ '',
+ url(r'^test/', include(nested_patterns), {'foo': 'bar'}),
+ )
+ test_paths = [
+ URLTestPath('/test/path', (), {'foo': 'bar', }),
+ URLTestPath('/test/path.api', (), {'foo': 'bar', 'format': 'api'}),
+ URLTestPath('/test/path.asdf', (), {'foo': 'bar', 'format': 'asdf'}),
+ ]
+ self._resolve_urlpatterns(urlpatterns, test_paths)