diff options
Diffstat (limited to 'rest_framework/tests')
| -rw-r--r-- | rest_framework/tests/authentication.py | 33 | ||||
| -rw-r--r-- | rest_framework/tests/files.py | 55 | ||||
| -rw-r--r-- | rest_framework/tests/hyperlinkedserializers.py | 9 | ||||
| -rw-r--r-- | rest_framework/tests/models.py | 13 | ||||
| -rw-r--r-- | rest_framework/tests/pagination.py | 82 | ||||
| -rw-r--r-- | rest_framework/tests/serializer.py | 85 | ||||
| -rw-r--r-- | rest_framework/tests/throttling.py | 2 |
7 files changed, 240 insertions, 39 deletions
diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py index 8ab4c4e4..96ca9f52 100644 --- a/rest_framework/tests/authentication.py +++ b/rest_framework/tests/authentication.py @@ -1,4 +1,4 @@ -from django.conf.urls.defaults import patterns +from django.conf.urls.defaults import patterns, include from django.contrib.auth.models import User from django.test import Client, TestCase @@ -27,6 +27,7 @@ MockView.authentication_classes += (TokenAuthentication,) urlpatterns = patterns('', (r'^$', MockView.as_view()), + (r'^auth-token/', 'rest_framework.authtoken.views.obtain_auth_token'), ) @@ -152,3 +153,33 @@ class TokenAuthTests(TestCase): self.token.delete() token = Token.objects.create(user=self.user) self.assertTrue(bool(token.key)) + + def test_token_login_json(self): + """Ensure token login view using JSON POST works.""" + client = Client(enforce_csrf_checks=True) + response = client.post('/auth-token/login/', + json.dumps({'username': self.username, 'password': self.password}), 'application/json') + self.assertEqual(response.status_code, 200) + self.assertEqual(json.loads(response.content)['token'], self.key) + + def test_token_login_json_bad_creds(self): + """Ensure token login view using JSON POST fails if bad credentials are used.""" + client = Client(enforce_csrf_checks=True) + response = client.post('/auth-token/login/', + json.dumps({'username': self.username, 'password': "badpass"}), 'application/json') + self.assertEqual(response.status_code, 400) + + def test_token_login_json_missing_fields(self): + """Ensure token login view using JSON POST fails if missing fields.""" + client = Client(enforce_csrf_checks=True) + response = client.post('/auth-token/login/', + json.dumps({'username': self.username}), 'application/json') + self.assertEqual(response.status_code, 400) + + def test_token_login_form(self): + """Ensure token login view using form POST works.""" + client = Client(enforce_csrf_checks=True) + response = client.post('/auth-token/login/', + {'username': self.username, 'password': self.password}) + self.assertEqual(response.status_code, 200) + self.assertEqual(json.loads(response.content)['token'], self.key) diff --git a/rest_framework/tests/files.py b/rest_framework/tests/files.py index 61d7f7b1..5dd57b7c 100644 --- a/rest_framework/tests/files.py +++ b/rest_framework/tests/files.py @@ -1,34 +1,39 @@ -# from django.test import TestCase -# from django import forms +import StringIO +import datetime -# from django.test.client import RequestFactory -# from rest_framework.views import View -# from rest_framework.response import Response +from django.test import TestCase -# import StringIO +from rest_framework import serializers -# class UploadFilesTests(TestCase): -# """Check uploading of files""" -# def setUp(self): -# self.factory = RequestFactory() +class UploadedFile(object): + def __init__(self, file, created=None): + self.file = file + self.created = created or datetime.datetime.now() -# def test_upload_file(self): -# class FileForm(forms.Form): -# file = forms.FileField() +class UploadedFileSerializer(serializers.Serializer): + file = serializers.FileField() + created = serializers.DateTimeField() -# class MockView(View): -# permissions = () -# form = FileForm + def restore_object(self, attrs, instance=None): + if instance: + instance.file = attrs['file'] + instance.created = attrs['created'] + return instance + return UploadedFile(**attrs) -# def post(self, request, *args, **kwargs): -# return Response({'FILE_NAME': self.CONTENT['file'].name, -# 'FILE_CONTENT': self.CONTENT['file'].read()}) -# file = StringIO.StringIO('stuff') -# file.name = 'stuff.txt' -# request = self.factory.post('/', {'file': file}) -# view = MockView.as_view() -# response = view(request) -# self.assertEquals(response.raw_content, {"FILE_CONTENT": "stuff", "FILE_NAME": "stuff.txt"}) +class FileSerializerTests(TestCase): + + def test_create(self): + now = datetime.datetime.now() + file = StringIO.StringIO('stuff') + file.name = 'stuff.txt' + file.size = file.len + serializer = UploadedFileSerializer(data={'created': now}, files={'file': file}) + uploaded_file = UploadedFile(file=file, created=now) + self.assertTrue(serializer.is_valid()) + self.assertEquals(serializer.object.created, uploaded_file.created) + self.assertEquals(serializer.object.file, uploaded_file.file) + self.assertFalse(serializer.object is uploaded_file) diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/hyperlinkedserializers.py index 5ab850af..d7effce7 100644 --- a/rest_framework/tests/hyperlinkedserializers.py +++ b/rest_framework/tests/hyperlinkedserializers.py @@ -8,12 +8,13 @@ factory = RequestFactory() class BlogPostCommentSerializer(serializers.ModelSerializer): + url = serializers.HyperlinkedIdentityField(view_name='blogpostcomment-detail') text = serializers.CharField() blog_post_url = serializers.HyperlinkedRelatedField(source='blog_post', view_name='blogpost-detail') class Meta: model = BlogPostComment - fields = ('text', 'blog_post_url') + fields = ('text', 'blog_post_url', 'url') class PhotoSerializer(serializers.Serializer): @@ -53,6 +54,9 @@ class BlogPostCommentListCreate(generics.ListCreateAPIView): model = BlogPostComment serializer_class = BlogPostCommentSerializer +class BlogPostCommentDetail(generics.RetrieveAPIView): + model = BlogPostComment + serializer_class = BlogPostCommentSerializer class BlogPostDetail(generics.RetrieveAPIView): model = BlogPost @@ -80,6 +84,7 @@ urlpatterns = patterns('', url(r'^manytomany/(?P<pk>\d+)/$', ManyToManyDetail.as_view(), name='manytomanymodel-detail'), url(r'^posts/(?P<pk>\d+)/$', BlogPostDetail.as_view(), name='blogpost-detail'), url(r'^comments/$', BlogPostCommentListCreate.as_view(), name='blogpostcomment-list'), + url(r'^comments/(?P<pk>\d+)/$', BlogPostCommentDetail.as_view(), name='blogpostcomment-detail'), url(r'^albums/(?P<title>\w[\w-]*)/$', AlbumDetail.as_view(), name='album-detail'), url(r'^photos/$', PhotoListCreate.as_view(), name='photo-list'), url(r'^optionalrelation/(?P<pk>\d+)/$', OptionalRelationDetail.as_view(), name='optionalrelationmodel-detail'), @@ -191,6 +196,7 @@ class TestCreateWithForeignKeys(TestCase): request = factory.post('/comments/', data=data) response = self.create_view(request).render() self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response['Location'], 'http://testserver/comments/1/') self.assertEqual(self.post.blogpostcomment_set.count(), 1) self.assertEqual(self.post.blogpostcomment_set.all()[0].text, 'A test comment') @@ -215,6 +221,7 @@ class TestCreateWithForeignKeysAndCustomSlug(TestCase): request = factory.post('/photos/', data=data) response = self.list_create_view(request).render() self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertNotIn('Location', response, msg='Location should only be included if there is a "url" field on the serializer') self.assertEqual(self.post.photo_set.count(), 1) self.assertEqual(self.post.photo_set.all()[0].description, 'A test photo') diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py index f6e5333b..70523fc0 100644 --- a/rest_framework/tests/models.py +++ b/rest_framework/tests/models.py @@ -35,6 +35,13 @@ def foobar(): return 'foobar' +class CustomField(models.CharField): + + def __init__(self, *args, **kwargs): + kwargs['max_length'] = 12 + super(CustomField, self).__init__(*args, **kwargs) + + class RESTFrameworkModel(models.Model): """ Base for test models that sets app_label, so they play nicely. @@ -113,12 +120,16 @@ class Comment(RESTFrameworkModel): class ActionItem(RESTFrameworkModel): title = models.CharField(max_length=200) done = models.BooleanField(default=False) + info = CustomField(default='---', max_length=12) # Models for reverse relations class BlogPost(RESTFrameworkModel): title = models.CharField(max_length=100) + def get_first_comment(self): + return self.blogpostcomment_set.all()[0] + class BlogPostComment(RESTFrameworkModel): text = models.TextField() @@ -157,4 +168,4 @@ class OptionalRelationModel(RESTFrameworkModel): # Model for RegexField class Book(RESTFrameworkModel): - isbn = models.CharField(max_length=13)
\ No newline at end of file + isbn = models.CharField(max_length=13) diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py index 713a7255..3062007d 100644 --- a/rest_framework/tests/pagination.py +++ b/rest_framework/tests/pagination.py @@ -34,6 +34,21 @@ if django_filters: filter_backend = filters.DjangoFilterBackend +class DefaultPageSizeKwargView(generics.ListAPIView): + """ + View for testing default paginate_by_param usage + """ + model = BasicModel + + +class PaginateByParamView(generics.ListAPIView): + """ + View for testing custom paginate_by_param usage + """ + model = BasicModel + paginate_by_param = 'page_size' + + class IntegrationTestPagination(TestCase): """ Integration tests for paginated list views. @@ -135,7 +150,7 @@ class IntegrationTestPaginationAndFiltering(TestCase): class UnitTestPagination(TestCase): """ - Unit tests for pagination of primative objects. + Unit tests for pagination of primitive objects. """ def setUp(self): @@ -156,3 +171,68 @@ class UnitTestPagination(TestCase): self.assertEquals(serializer.data['next'], None) self.assertEquals(serializer.data['previous'], '?page=2') self.assertEquals(serializer.data['results'], self.objects[20:]) + + +class TestUnpaginated(TestCase): + """ + Tests for list views without pagination. + """ + + def setUp(self): + """ + Create 13 BasicModel instances. + """ + for i in range(13): + BasicModel(text=i).save() + self.objects = BasicModel.objects + self.data = [ + {'id': obj.id, 'text': obj.text} + for obj in self.objects.all() + ] + self.view = DefaultPageSizeKwargView.as_view() + + def test_unpaginated(self): + """ + Tests the default page size for this view. + no page size --> no limit --> no meta data + """ + request = factory.get('/') + response = self.view(request) + self.assertEquals(response.data, self.data) + + +class TestCustomPaginateByParam(TestCase): + """ + Tests for list views with default page size kwarg + """ + + def setUp(self): + """ + Create 13 BasicModel instances. + """ + for i in range(13): + BasicModel(text=i).save() + self.objects = BasicModel.objects + self.data = [ + {'id': obj.id, 'text': obj.text} + for obj in self.objects.all() + ] + self.view = PaginateByParamView.as_view() + + def test_default_page_size(self): + """ + Tests the default page size for this view. + no page size --> no limit --> no meta data + """ + request = factory.get('/') + response = self.view(request).render() + self.assertEquals(response.data, self.data) + + def test_paginate_by_param(self): + """ + If paginate_by_param is set, the new kwarg should limit per view requests. + """ + request = factory.get('/?page_size=5') + response = self.view(request).render() + self.assertEquals(response.data['count'], 13) + self.assertEquals(response.data['results'], self.data[:5]) diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py index ad100e53..520029ec 100644 --- a/rest_framework/tests/serializer.py +++ b/rest_framework/tests/serializer.py @@ -48,6 +48,7 @@ class BookSerializer(serializers.ModelSerializer): class ActionItemSerializer(serializers.ModelSerializer): + class Meta: model = ActionItem @@ -246,6 +247,23 @@ class ValidationTests(TestCase): self.assertEquals(serializer.is_valid(), True) self.assertEquals(serializer.errors, {}) + def test_modelserializer_max_length_exceeded(self): + data = { + 'title': 'x' * 201, + } + serializer = ActionItemSerializer(data=data) + self.assertEquals(serializer.is_valid(), False) + self.assertEquals(serializer.errors, {'title': [u'Ensure this value has at most 200 characters (it has 201).']}) + + def test_default_modelfield_max_length_exceeded(self): + data = { + 'title': 'Testing "info" field...', + 'info': 'x' * 13, + } + serializer = ActionItemSerializer(data=data) + self.assertEquals(serializer.is_valid(), False) + self.assertEquals(serializer.errors, {'info': [u'Ensure this value has at most 12 characters (it has 13).']}) + class RegexValidationTest(TestCase): def test_create_failed(self): @@ -487,7 +505,10 @@ class CallableDefaultValueTests(TestCase): class ManyRelatedTests(TestCase): - def setUp(self): + def test_reverse_relations(self): + post = BlogPost.objects.create(title="Test blog post") + post.blogpostcomment_set.create(text="I hate this blog post") + post.blogpostcomment_set.create(text="I love this blog post") class BlogPostCommentSerializer(serializers.Serializer): text = serializers.CharField() @@ -496,14 +517,7 @@ class ManyRelatedTests(TestCase): title = serializers.CharField() comments = BlogPostCommentSerializer(source='blogpostcomment_set') - self.serializer_class = BlogPostSerializer - - def test_reverse_relations(self): - post = BlogPost.objects.create(title="Test blog post") - post.blogpostcomment_set.create(text="I hate this blog post") - post.blogpostcomment_set.create(text="I love this blog post") - - serializer = self.serializer_class(instance=post) + serializer = BlogPostSerializer(instance=post) expected = { 'title': 'Test blog post', 'comments': [ @@ -514,6 +528,59 @@ class ManyRelatedTests(TestCase): self.assertEqual(serializer.data, expected) + def test_callable_source(self): + post = BlogPost.objects.create(title="Test blog post") + post.blogpostcomment_set.create(text="I love this blog post") + + class BlogPostCommentSerializer(serializers.Serializer): + text = serializers.CharField() + + class BlogPostSerializer(serializers.Serializer): + title = serializers.CharField() + first_comment = BlogPostCommentSerializer(source='get_first_comment') + + serializer = BlogPostSerializer(post) + + expected = { + 'title': 'Test blog post', + 'first_comment': {'text': 'I love this blog post'} + } + self.assertEqual(serializer.data, expected) + + +class SerializerMethodFieldTests(TestCase): + def setUp(self): + + class BoopSerializer(serializers.Serializer): + beep = serializers.SerializerMethodField('get_beep') + boop = serializers.Field() + boop_count = serializers.SerializerMethodField('get_boop_count') + + def get_beep(self, obj): + return 'hello!' + + def get_boop_count(self, obj): + return len(obj.boop) + + self.serializer_class = BoopSerializer + + def test_serializer_method_field(self): + + class MyModel(object): + boop = ['a', 'b', 'c'] + + source_data = MyModel() + + serializer = self.serializer_class(source_data) + + expected = { + 'beep': u'hello!', + 'boop': [u'a', u'b', u'c'], + 'boop_count': 3, + } + + self.assertEqual(serializer.data, expected) + # Test for issue #324 class BlankFieldTests(TestCase): diff --git a/rest_framework/tests/throttling.py b/rest_framework/tests/throttling.py index 0b94c25b..4b98b941 100644 --- a/rest_framework/tests/throttling.py +++ b/rest_framework/tests/throttling.py @@ -106,7 +106,7 @@ class ThrottlingTests(TestCase): if expect is not None: self.assertEquals(response['X-Throttle-Wait-Seconds'], expect) else: - self.assertFalse('X-Throttle-Wait-Seconds' in response.headers) + self.assertFalse('X-Throttle-Wait-Seconds' in response) def test_seconds_fields(self): """ |
