aboutsummaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/models.py9
-rw-r--r--tests/test_authentication.py34
-rw-r--r--tests/test_filters.py99
-rw-r--r--tests/test_genericrelations.py28
-rw-r--r--tests/test_htmlrenderer.py8
-rw-r--r--tests/test_pagination.py33
-rw-r--r--tests/test_parsers.py4
-rw-r--r--tests/test_relations.py24
-rw-r--r--tests/test_relations_nested.py4
-rw-r--r--tests/test_renderers.py17
-rw-r--r--tests/test_serializer.py132
-rw-r--r--tests/test_testing.py10
-rw-r--r--tests/test_urlizer.py38
-rw-r--r--tests/test_validation.py44
-rw-r--r--tests/utils.py25
15 files changed, 430 insertions, 79 deletions
diff --git a/tests/models.py b/tests/models.py
index 32a726c0..6c8f2342 100644
--- a/tests/models.py
+++ b/tests/models.py
@@ -103,7 +103,7 @@ class BlogPostComment(RESTFrameworkModel):
class Album(RESTFrameworkModel):
title = models.CharField(max_length=100, unique=True)
-
+ ref = models.CharField(max_length=10, unique=True, null=True, blank=True)
class Photo(RESTFrameworkModel):
description = models.TextField()
@@ -168,3 +168,10 @@ class NullableOneToOneSource(RESTFrameworkModel):
class BasicModelSerializer(serializers.ModelSerializer):
class Meta:
model = BasicModel
+
+
+# Models to test filters
+class FilterableItem(models.Model):
+ text = models.CharField(max_length=100)
+ decimal = models.DecimalField(max_digits=4, decimal_places=2)
+ date = models.DateField()
diff --git a/tests/test_authentication.py b/tests/test_authentication.py
index 4ecfef44..1d90493e 100644
--- a/tests/test_authentication.py
+++ b/tests/test_authentication.py
@@ -3,6 +3,7 @@ from django.contrib.auth.models import User
from django.http import HttpResponse
from django.test import TestCase
from django.utils import unittest
+from django.utils.http import urlencode
from rest_framework import HTTP_HEADER_ENCODING
from rest_framework import exceptions
from rest_framework import permissions
@@ -19,7 +20,7 @@ from rest_framework.authentication import (
)
from rest_framework.authtoken.models import Token
from rest_framework.compat import patterns, url, include
-from rest_framework.compat import oauth2_provider, oauth2_provider_models, oauth2_provider_scope
+from rest_framework.compat import oauth2_provider, oauth2_provider_scope
from rest_framework.compat import oauth, oauth_provider
from rest_framework.test import APIRequestFactory, APIClient
from rest_framework.views import APIView
@@ -53,10 +54,14 @@ urlpatterns = patterns('',
permission_classes=[permissions.TokenHasReadWriteScope]))
)
+class OAuth2AuthenticationDebug(OAuth2Authentication):
+ allow_query_params_token = True
+
if oauth2_provider is not None:
urlpatterns += patterns('',
url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')),
url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])),
+ url(r'^oauth2-test-debug/$', MockView.as_view(authentication_classes=[OAuth2AuthenticationDebug])),
url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication],
permission_classes=[permissions.TokenHasReadWriteScope])),
)
@@ -488,7 +493,7 @@ class OAuth2Tests(TestCase):
self.ACCESS_TOKEN = "access_token"
self.REFRESH_TOKEN = "refresh_token"
- self.oauth2_client = oauth2_provider_models.Client.objects.create(
+ self.oauth2_client = oauth2_provider.oauth2.models.Client.objects.create(
client_id=self.CLIENT_ID,
client_secret=self.CLIENT_SECRET,
redirect_uri='',
@@ -497,12 +502,12 @@ class OAuth2Tests(TestCase):
user=None,
)
- self.access_token = oauth2_provider_models.AccessToken.objects.create(
+ self.access_token = oauth2_provider.oauth2.models.AccessToken.objects.create(
token=self.ACCESS_TOKEN,
client=self.oauth2_client,
user=self.user,
)
- self.refresh_token = oauth2_provider_models.RefreshToken.objects.create(
+ self.refresh_token = oauth2_provider.oauth2.models.RefreshToken.objects.create(
user=self.user,
access_token=self.access_token,
client=self.oauth2_client
@@ -546,6 +551,27 @@ class OAuth2Tests(TestCase):
self.assertEqual(response.status_code, 200)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_post_form_passing_auth_url_transport(self):
+ """Ensure GETing form over OAuth with correct client credentials in form data succeed"""
+ response = self.csrf_client.post('/oauth2-test/',
+ data={'access_token': self.access_token.token})
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_get_form_passing_auth_url_transport(self):
+ """Ensure GETing form over OAuth with correct client credentials in query succeed when DEBUG is True"""
+ query = urlencode({'access_token': self.access_token.token})
+ response = self.csrf_client.get('/oauth2-test-debug/?%s' % query)
+ self.assertEqual(response.status_code, 200)
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
+ def test_get_form_failing_auth_url_transport(self):
+ """Ensure GETing form over OAuth with correct client credentials in query fails when DEBUG is False"""
+ query = urlencode({'access_token': self.access_token.token})
+ response = self.csrf_client.get('/oauth2-test/?%s' % query)
+ self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
+
+ @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_post_form_passing_auth(self):
"""Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""
auth = self._create_authorization_header()
diff --git a/tests/test_filters.py b/tests/test_filters.py
index d9d8042e..3c6e8857 100644
--- a/tests/test_filters.py
+++ b/tests/test_filters.py
@@ -7,18 +7,14 @@ from django.test import TestCase
from django.utils import unittest
from rest_framework import generics, serializers, status, filters
from rest_framework.compat import django_filters, patterns, url
+from rest_framework.settings import api_settings
from rest_framework.test import APIRequestFactory
-from tests.models import BasicModel
+from .models import FilterableItem, BasicModel
+from .utils import temporary_setting
factory = APIRequestFactory()
-class FilterableItem(models.Model):
- text = models.CharField(max_length=100)
- decimal = models.DecimalField(max_digits=4, decimal_places=2)
- date = models.DateField()
-
-
if django_filters:
# Basic filter on a list view.
class FilterFieldsRootView(generics.ListCreateAPIView):
@@ -128,7 +124,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
# Tests that the decimal filter works.
search_decimal = Decimal('2.25')
- request = factory.get('/?decimal=%s' % search_decimal)
+ request = factory.get('/', {'decimal': '%s' % search_decimal})
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['decimal'] == search_decimal]
@@ -136,7 +132,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
# Tests that the date filter works.
search_date = datetime.date(2012, 9, 22)
- request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22'
+ request = factory.get('/', {'date': '%s' % search_date}) # search_date str: '2012-09-22'
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['date'] == search_date]
@@ -151,7 +147,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
# Tests that the decimal filter works.
search_decimal = Decimal('2.25')
- request = factory.get('/?decimal=%s' % search_decimal)
+ request = factory.get('/', {'decimal': '%s' % search_decimal})
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['decimal'] == search_decimal]
@@ -184,7 +180,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
# Tests that the decimal filter set with 'lt' in the filter class works.
search_decimal = Decimal('4.25')
- request = factory.get('/?decimal=%s' % search_decimal)
+ request = factory.get('/', {'decimal': '%s' % search_decimal})
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['decimal'] < search_decimal]
@@ -192,7 +188,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
# Tests that the date filter set with 'gt' in the filter class works.
search_date = datetime.date(2012, 10, 2)
- request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02'
+ request = factory.get('/', {'date': '%s' % search_date}) # search_date str: '2012-10-02'
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['date'] > search_date]
@@ -200,7 +196,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
# Tests that the text filter set with 'icontains' in the filter class works.
search_text = 'ff'
- request = factory.get('/?text=%s' % search_text)
+ request = factory.get('/', {'text': '%s' % search_text})
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if search_text in f['text'].lower()]
@@ -209,7 +205,10 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
# Tests that multiple filters works.
search_decimal = Decimal('5.25')
search_date = datetime.date(2012, 10, 2)
- request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date))
+ request = factory.get('/', {
+ 'decimal': '%s' % (search_decimal,),
+ 'date': '%s' % (search_date,)
+ })
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['date'] > search_date and
@@ -234,7 +233,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
view = FilterFieldsRootView.as_view()
search_integer = 10
- request = factory.get('/?integer=%s' % search_integer)
+ request = factory.get('/', {'integer': '%s' % search_integer})
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
@@ -265,14 +264,18 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase):
# Tests that the decimal filter set that should fail.
search_decimal = Decimal('4.25')
high_item = self.objects.filter(decimal__gt=search_decimal)[0]
- response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(high_item), param=search_decimal))
+ response = self.client.get(
+ '{url}'.format(url=self._get_url(high_item)),
+ {'decimal': '{param}'.format(param=search_decimal)})
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
# Tests that the decimal filter set that should succeed.
search_decimal = Decimal('4.25')
low_item = self.objects.filter(decimal__lt=search_decimal)[0]
low_item_data = self._serialize_object(low_item)
- response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(low_item), param=search_decimal))
+ response = self.client.get(
+ '{url}'.format(url=self._get_url(low_item)),
+ {'decimal': '{param}'.format(param=search_decimal)})
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, low_item_data)
@@ -281,7 +284,11 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase):
search_date = datetime.date(2012, 10, 2)
valid_item = self.objects.filter(decimal__lt=search_decimal, date__gt=search_date)[0]
valid_item_data = self._serialize_object(valid_item)
- response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date))
+ response = self.client.get(
+ '{url}'.format(url=self._get_url(valid_item)), {
+ 'decimal': '{decimal}'.format(decimal=search_decimal),
+ 'date': '{date}'.format(date=search_date)
+ })
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, valid_item_data)
@@ -315,7 +322,7 @@ class SearchFilterTests(TestCase):
search_fields = ('title', 'text')
view = SearchListView.as_view()
- request = factory.get('?search=b')
+ request = factory.get('/', {'search': 'b'})
response = view(request)
self.assertEqual(
response.data,
@@ -332,7 +339,7 @@ class SearchFilterTests(TestCase):
search_fields = ('=title', 'text')
view = SearchListView.as_view()
- request = factory.get('?search=zzz')
+ request = factory.get('/', {'search': 'zzz'})
response = view(request)
self.assertEqual(
response.data,
@@ -348,7 +355,7 @@ class SearchFilterTests(TestCase):
search_fields = ('title', '^text')
view = SearchListView.as_view()
- request = factory.get('?search=b')
+ request = factory.get('/', {'search': 'b'})
response = view(request)
self.assertEqual(
response.data,
@@ -357,6 +364,24 @@ class SearchFilterTests(TestCase):
]
)
+ def test_search_with_nonstandard_search_param(self):
+ with temporary_setting('SEARCH_PARAM', 'query', module=filters):
+ class SearchListView(generics.ListAPIView):
+ model = SearchFilterModel
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('title', 'text')
+
+ view = SearchListView.as_view()
+ request = factory.get('/', {'query': 'b'})
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 1, 'title': 'z', 'text': 'abc'},
+ {'id': 2, 'title': 'zz', 'text': 'bcd'}
+ ]
+ )
+
class OrdringFilterModel(models.Model):
title = models.CharField(max_length=20)
@@ -396,7 +421,7 @@ class OrderingFilterTests(TestCase):
ordering_fields = ('text',)
view = OrderingListView.as_view()
- request = factory.get('?ordering=text')
+ request = factory.get('/', {'ordering': 'text'})
response = view(request)
self.assertEqual(
response.data,
@@ -415,7 +440,7 @@ class OrderingFilterTests(TestCase):
ordering_fields = ('text',)
view = OrderingListView.as_view()
- request = factory.get('?ordering=-text')
+ request = factory.get('/', {'ordering': '-text'})
response = view(request)
self.assertEqual(
response.data,
@@ -434,7 +459,7 @@ class OrderingFilterTests(TestCase):
ordering_fields = ('text',)
view = OrderingListView.as_view()
- request = factory.get('?ordering=foobar')
+ request = factory.get('/', {'ordering': 'foobar'})
response = view(request)
self.assertEqual(
response.data,
@@ -503,7 +528,7 @@ class OrderingFilterTests(TestCase):
models.Count("relateds"))
view = OrderingListView.as_view()
- request = factory.get('?ordering=relateds__count')
+ request = factory.get('/', {'ordering': 'relateds__count'})
response = view(request)
self.assertEqual(
response.data,
@@ -514,6 +539,26 @@ class OrderingFilterTests(TestCase):
]
)
+ def test_ordering_with_nonstandard_ordering_param(self):
+ with temporary_setting('ORDERING_PARAM', 'order', filters):
+ class OrderingListView(generics.ListAPIView):
+ model = OrdringFilterModel
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+ ordering_fields = ('text',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('/', {'order': 'text'})
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ ]
+ )
+
class SensitiveOrderingFilterModel(models.Model):
username = models.CharField(max_length=20)
@@ -566,7 +611,7 @@ class SensitiveOrderingFilterTests(TestCase):
serializer_class = serializer_cls
view = OrderingListView.as_view()
- request = factory.get('?ordering=-username')
+ request = factory.get('/', {'ordering': '-username'})
response = view(request)
if serializer_cls == SensitiveDataSerializer3:
@@ -596,7 +641,7 @@ class SensitiveOrderingFilterTests(TestCase):
serializer_class = serializer_cls
view = OrderingListView.as_view()
- request = factory.get('?ordering=password')
+ request = factory.get('/', {'ordering': 'password'})
response = view(request)
if serializer_cls == SensitiveDataSerializer3:
diff --git a/tests/test_genericrelations.py b/tests/test_genericrelations.py
index 2d341344..46a2d863 100644
--- a/tests/test_genericrelations.py
+++ b/tests/test_genericrelations.py
@@ -4,8 +4,10 @@ from django.contrib.contenttypes.generic import GenericRelation, GenericForeignK
from django.db import models
from django.test import TestCase
from rest_framework import serializers
+from rest_framework.compat import python_2_unicode_compatible
+@python_2_unicode_compatible
class Tag(models.Model):
"""
Tags have a descriptive slug, and are attached to an arbitrary object.
@@ -15,10 +17,11 @@ class Tag(models.Model):
object_id = models.PositiveIntegerField()
tagged_item = GenericForeignKey('content_type', 'object_id')
- def __unicode__(self):
+ def __str__(self):
return self.tag
+@python_2_unicode_compatible
class Bookmark(models.Model):
"""
A URL bookmark that may have multiple tags attached.
@@ -26,10 +29,11 @@ class Bookmark(models.Model):
url = models.URLField()
tags = GenericRelation(Tag)
- def __unicode__(self):
+ def __str__(self):
return 'Bookmark: %s' % self.url
+@python_2_unicode_compatible
class Note(models.Model):
"""
A textual note that may have multiple tags attached.
@@ -37,7 +41,7 @@ class Note(models.Model):
text = models.TextField()
tags = GenericRelation(Tag)
- def __unicode__(self):
+ def __str__(self):
return 'Note: %s' % self.text
@@ -127,3 +131,21 @@ class TestGenericRelations(TestCase):
}
]
self.assertEqual(serializer.data, expected)
+
+ def test_restore_object_generic_fk(self):
+ """
+ Ensure an object with a generic foreign key can be restored.
+ """
+
+ class TagSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Tag
+ exclude = ('content_type', 'object_id')
+
+ serializer = TagSerializer()
+
+ bookmark = Bookmark(url='http://example.com')
+ attrs = {'tagged_item': bookmark, 'tag': 'example'}
+
+ tag = serializer.restore_object(attrs)
+ self.assertEqual(tag.tagged_item, bookmark)
diff --git a/tests/test_htmlrenderer.py b/tests/test_htmlrenderer.py
index c748fbdb..8af5bb50 100644
--- a/tests/test_htmlrenderer.py
+++ b/tests/test_htmlrenderer.py
@@ -50,7 +50,7 @@ class TemplateHTMLRendererTests(TestCase):
"""
self.get_template = django.template.loader.get_template
- def get_template(template_name):
+ def get_template(template_name, dirs=None):
if template_name == 'example.html':
return Template("example: {{ object }}")
raise TemplateDoesNotExist(template_name)
@@ -108,11 +108,13 @@ class TemplateHTMLRendererExceptionTests(TestCase):
def test_not_found_html_view_with_template(self):
response = self.client.get('/not_found')
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
- self.assertEqual(response.content, six.b("404: Not found"))
+ self.assertTrue(response.content in (
+ six.b("404: Not found"), six.b("404 Not Found")))
self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
def test_permission_denied_html_view_with_template(self):
response = self.client.get('/permission_denied')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
- self.assertEqual(response.content, six.b("403: Permission denied"))
+ self.assertTrue(response.content in (
+ six.b("403: Permission denied"), six.b("403 Forbidden")))
self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
diff --git a/tests/test_pagination.py b/tests/test_pagination.py
index 65fa9dcd..293146c0 100644
--- a/tests/test_pagination.py
+++ b/tests/test_pagination.py
@@ -8,15 +8,18 @@ from django.utils import unittest
from rest_framework import generics, status, pagination, filters, serializers
from rest_framework.compat import django_filters
from rest_framework.test import APIRequestFactory
-from tests.models import BasicModel
+from .models import BasicModel, FilterableItem
factory = APIRequestFactory()
+# Helper function to split arguments out of an url
+def split_arguments_from_url(url):
+ if '?' not in url:
+ return url
-class FilterableItem(models.Model):
- text = models.CharField(max_length=100)
- decimal = models.DecimalField(max_digits=4, decimal_places=2)
- date = models.DateField()
+ path, args = url.split('?')
+ args = dict(r.split('=') for r in args.split('&'))
+ return path, args
class RootView(generics.ListCreateAPIView):
@@ -84,7 +87,7 @@ class IntegrationTestPagination(TestCase):
self.assertNotEqual(response.data['next'], None)
self.assertEqual(response.data['previous'], None)
- request = factory.get(response.data['next'])
+ request = factory.get(*split_arguments_from_url(response.data['next']))
with self.assertNumQueries(2):
response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
@@ -93,7 +96,7 @@ class IntegrationTestPagination(TestCase):
self.assertNotEqual(response.data['next'], None)
self.assertNotEqual(response.data['previous'], None)
- request = factory.get(response.data['next'])
+ request = factory.get(*split_arguments_from_url(response.data['next']))
with self.assertNumQueries(2):
response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
@@ -146,7 +149,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
EXPECTED_NUM_QUERIES = 2
- request = factory.get('/?decimal=15.20')
+ request = factory.get('/', {'decimal': '15.20'})
with self.assertNumQueries(EXPECTED_NUM_QUERIES):
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
@@ -155,7 +158,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
self.assertNotEqual(response.data['next'], None)
self.assertEqual(response.data['previous'], None)
- request = factory.get(response.data['next'])
+ request = factory.get(*split_arguments_from_url(response.data['next']))
with self.assertNumQueries(EXPECTED_NUM_QUERIES):
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
@@ -164,7 +167,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
self.assertEqual(response.data['next'], None)
self.assertNotEqual(response.data['previous'], None)
- request = factory.get(response.data['previous'])
+ request = factory.get(*split_arguments_from_url(response.data['previous']))
with self.assertNumQueries(EXPECTED_NUM_QUERIES):
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
@@ -191,7 +194,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
view = BasicFilterFieldsRootView.as_view()
- request = factory.get('/?decimal=15.20')
+ request = factory.get('/', {'decimal': '15.20'})
with self.assertNumQueries(2):
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
@@ -200,7 +203,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
self.assertNotEqual(response.data['next'], None)
self.assertEqual(response.data['previous'], None)
- request = factory.get(response.data['next'])
+ request = factory.get(*split_arguments_from_url(response.data['next']))
with self.assertNumQueries(2):
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
@@ -209,7 +212,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
self.assertEqual(response.data['next'], None)
self.assertNotEqual(response.data['previous'], None)
- request = factory.get(response.data['previous'])
+ request = factory.get(*split_arguments_from_url(response.data['previous']))
with self.assertNumQueries(2):
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
@@ -317,7 +320,7 @@ class TestCustomPaginateByParam(TestCase):
"""
If paginate_by_param is set, the new kwarg should limit per view requests.
"""
- request = factory.get('/?page_size=5')
+ request = factory.get('/', {'page_size': 5})
response = self.view(request).render()
self.assertEqual(response.data['count'], 13)
self.assertEqual(response.data['results'], self.data[:5])
@@ -345,7 +348,7 @@ class TestMaxPaginateByParam(TestCase):
"""
If max_paginate_by is set, it should limit page size for the view.
"""
- request = factory.get('/?page_size=10')
+ request = factory.get('/', data={'page_size': 10})
response = self.view(request).render()
self.assertEqual(response.data['count'], 13)
self.assertEqual(response.data['results'], self.data[:5])
diff --git a/tests/test_parsers.py b/tests/test_parsers.py
index 7699e10c..8af90677 100644
--- a/tests/test_parsers.py
+++ b/tests/test_parsers.py
@@ -96,7 +96,7 @@ class TestFileUploadParser(TestCase):
request = MockRequest()
request.upload_handlers = (MemoryFileUploadHandler(),)
request.META = {
- 'HTTP_CONTENT_DISPOSITION': 'Content-Disposition: inline; filename=file.txt'.encode('utf-8'),
+ 'HTTP_CONTENT_DISPOSITION': 'Content-Disposition: inline; filename=file.txt',
'HTTP_CONTENT_LENGTH': 14,
}
self.parser_context = {'request': request, 'kwargs': {}}
@@ -112,4 +112,4 @@ class TestFileUploadParser(TestCase):
def test_get_filename(self):
parser = FileUploadParser()
filename = parser.get_filename(self.stream, None, self.parser_context)
- self.assertEqual(filename, 'file.txt'.encode('utf-8'))
+ self.assertEqual(filename, 'file.txt')
diff --git a/tests/test_relations.py b/tests/test_relations.py
index bfc8d487..cd276d30 100644
--- a/tests/test_relations.py
+++ b/tests/test_relations.py
@@ -2,8 +2,10 @@
General tests for relational fields.
"""
from __future__ import unicode_literals
+from django import get_version
from django.db import models
from django.test import TestCase
+from django.utils import unittest
from rest_framework import serializers
from tests.models import BlogPost
@@ -118,3 +120,25 @@ class RelatedFieldSourceTests(TestCase):
(serializers.ModelSerializer,), attrs)
with self.assertRaises(AttributeError):
TestSerializer(data={'name': 'foo'})
+
+@unittest.skipIf(get_version() < '1.6.0', 'Upstream behaviour changed in v1.6')
+class RelatedFieldChoicesTests(TestCase):
+ """
+ Tests for #1408 "Web browseable API doesn't have blank option on drop down list box"
+ https://github.com/tomchristie/django-rest-framework/issues/1408
+ """
+ def test_blank_option_is_added_to_choice_if_required_equals_false(self):
+ """
+
+ """
+ post = BlogPost(title="Checking blank option is added")
+ post.save()
+
+ queryset = BlogPost.objects.all()
+ field = serializers.RelatedField(required=False, queryset=queryset)
+
+ choice_count = BlogPost.objects.count()
+ widget_count = len(field.widget.choices)
+
+ self.assertEqual(widget_count, choice_count + 1, 'BLANK_CHOICE_DASH option should have been added')
+
diff --git a/tests/test_relations_nested.py b/tests/test_relations_nested.py
index d393b0c3..4d9da489 100644
--- a/tests/test_relations_nested.py
+++ b/tests/test_relations_nested.py
@@ -3,9 +3,7 @@ from django.db import models
from django.test import TestCase
from rest_framework import serializers
-
-class OneToOneTarget(models.Model):
- name = models.CharField(max_length=100)
+from .models import OneToOneTarget
class OneToOneSource(models.Model):
diff --git a/tests/test_renderers.py b/tests/test_renderers.py
index b41cff39..f733d6b6 100644
--- a/tests/test_renderers.py
+++ b/tests/test_renderers.py
@@ -12,7 +12,7 @@ from rest_framework.compat import yaml, etree, patterns, url, include, six, Stri
from rest_framework.response import Response
from rest_framework.views import APIView
from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \
- XMLRenderer, JSONPRenderer, BrowsableAPIRenderer, UnicodeJSONRenderer
+ XMLRenderer, JSONPRenderer, BrowsableAPIRenderer, UnicodeJSONRenderer, UnicodeYAMLRenderer
from rest_framework.parsers import YAMLParser, XMLParser
from rest_framework.settings import api_settings
from rest_framework.test import APIRequestFactory
@@ -467,6 +467,17 @@ if yaml:
self.assertTrue(string in content, '%r not in %r' % (string, content))
+ class UnicodeYAMLRendererTests(TestCase):
+ """
+ Tests specific for the Unicode YAML Renderer
+ """
+ def test_proper_encoding(self):
+ obj = {'countries': ['United Kingdom', 'France', 'España']}
+ renderer = UnicodeYAMLRenderer()
+ content = renderer.render(obj, 'application/yaml')
+ self.assertEqual(content.strip(), 'countries: [United Kingdom, France, España]'.encode('utf-8'))
+
+
class XMLRendererTestCase(TestCase):
"""
Tests specific to the XML Renderer
@@ -613,6 +624,10 @@ class CacheRenderTest(TestCase):
method = getattr(self.client, http_method)
resp = method(url)
del resp.client, resp.request
+ try:
+ del resp.wsgi_request
+ except AttributeError:
+ pass
return resp
def test_obj_pickling(self):
diff --git a/tests/test_serializer.py b/tests/test_serializer.py
index 18484afe..f8966886 100644
--- a/tests/test_serializer.py
+++ b/tests/test_serializer.py
@@ -3,6 +3,7 @@ from __future__ import unicode_literals
from django.db import models
from django.db.models.fields import BLANK_CHOICE_DASH
from django.test import TestCase
+from django.utils import unittest
from django.utils.datastructures import MultiValueDict
from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers, fields, relations
@@ -12,6 +13,31 @@ from tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor,
from tests.models import BasicModelSerializer
import datetime
import pickle
+try:
+ import PIL
+except:
+ PIL = None
+
+
+if PIL is not None:
+ class AMOAFModel(RESTFrameworkModel):
+ char_field = models.CharField(max_length=1024, blank=True)
+ comma_separated_integer_field = models.CommaSeparatedIntegerField(max_length=1024, blank=True)
+ decimal_field = models.DecimalField(max_digits=64, decimal_places=32, blank=True)
+ email_field = models.EmailField(max_length=1024, blank=True)
+ file_field = models.FileField(upload_to='test', max_length=1024, blank=True)
+ image_field = models.ImageField(upload_to='test', max_length=1024, blank=True)
+ slug_field = models.SlugField(max_length=1024, blank=True)
+ url_field = models.URLField(max_length=1024, blank=True)
+
+ class DVOAFModel(RESTFrameworkModel):
+ positive_integer_field = models.PositiveIntegerField(blank=True)
+ positive_small_integer_field = models.PositiveSmallIntegerField(blank=True)
+ email_field = models.EmailField(blank=True)
+ file_field = models.FileField(upload_to='test', blank=True)
+ image_field = models.ImageField(upload_to='test', blank=True)
+ slug_field = models.SlugField(blank=True)
+ url_field = models.URLField(blank=True)
class SubComment(object):
@@ -141,7 +167,7 @@ class AlbumsSerializer(serializers.ModelSerializer):
class Meta:
model = Album
- fields = ['title'] # lists are also valid options
+ fields = ['title', 'ref'] # lists are also valid options
class PositiveIntegerAsChoiceSerializer(serializers.ModelSerializer):
@@ -482,6 +508,32 @@ class ValidationTests(TestCase):
)
self.assertEqual(serializer.is_valid(), True)
+ def test_writable_star_source_on_nested_serializer_with_parent_object(self):
+ class TitleSerializer(serializers.Serializer):
+ title = serializers.WritableField(source='title')
+
+ class AlbumSerializer(serializers.ModelSerializer):
+ nested = TitleSerializer(source='*')
+
+ class Meta:
+ model = Album
+ fields = ('nested',)
+
+ class PhotoSerializer(serializers.ModelSerializer):
+ album = AlbumSerializer(source='album')
+
+ class Meta:
+ model = Photo
+ fields = ('album', )
+
+ photo = Photo(album=Album())
+
+ data = {'album': {'nested': {'title': 'test'}}}
+
+ serializer = PhotoSerializer(photo, data=data)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.data, data)
+
def test_writable_star_source_with_inner_source_fields(self):
"""
Tests that a serializer with source="*" correctly expands the
@@ -591,12 +643,15 @@ class ModelValidationTests(TestCase):
"""
Just check if serializers.ModelSerializer handles unique checks via .full_clean()
"""
- serializer = AlbumsSerializer(data={'title': 'a'})
+ serializer = AlbumsSerializer(data={'title': 'a', 'ref': '1'})
serializer.is_valid()
serializer.save()
second_serializer = AlbumsSerializer(data={'title': 'a'})
self.assertFalse(second_serializer.is_valid())
- self.assertEqual(second_serializer.errors, {'title': ['Album with this Title already exists.']})
+ self.assertEqual(second_serializer.errors, {'title': ['Album with this Title already exists.'],})
+ third_serializer = AlbumsSerializer(data=[{'title': 'b', 'ref': '1'}, {'title': 'c'}])
+ self.assertFalse(third_serializer.is_valid())
+ self.assertEqual(third_serializer.errors, [{'ref': ['Album with this Ref already exists.']}, {}])
def test_foreign_key_is_null_with_partial(self):
"""
@@ -880,6 +935,58 @@ class DefaultValueTests(TestCase):
self.assertEqual(instance.text, 'overridden')
+class WritableFieldDefaultValueTests(TestCase):
+
+ def setUp(self):
+ self.expected = {'default': 'value'}
+ self.create_field = fields.WritableField
+
+ def test_get_default_value_with_noncallable(self):
+ field = self.create_field(default=self.expected)
+ got = field.get_default_value()
+ self.assertEqual(got, self.expected)
+
+ def test_get_default_value_with_callable(self):
+ field = self.create_field(default=lambda : self.expected)
+ got = field.get_default_value()
+ self.assertEqual(got, self.expected)
+
+ def test_get_default_value_when_not_required(self):
+ field = self.create_field(default=self.expected, required=False)
+ got = field.get_default_value()
+ self.assertEqual(got, self.expected)
+
+ def test_get_default_value_returns_None(self):
+ field = self.create_field()
+ got = field.get_default_value()
+ self.assertIsNone(got)
+
+ def test_get_default_value_returns_non_True_values(self):
+ values = [None, '', False, 0, [], (), {}] # values that assumed as 'False' in the 'if' clause
+ for expected in values:
+ field = self.create_field(default=expected)
+ got = field.get_default_value()
+ self.assertEqual(got, expected)
+
+
+class RelatedFieldDefaultValueTests(WritableFieldDefaultValueTests):
+
+ def setUp(self):
+ self.expected = {'foo': 'bar'}
+ self.create_field = relations.RelatedField
+
+ def test_get_default_value_returns_empty_list(self):
+ field = self.create_field(many=True)
+ got = field.get_default_value()
+ self.assertListEqual(got, [])
+
+ def test_get_default_value_returns_expected(self):
+ expected = [1, 2, 3]
+ field = self.create_field(many=True, default=expected)
+ got = field.get_default_value()
+ self.assertListEqual(got, expected)
+
+
class CallableDefaultValueTests(TestCase):
def setUp(self):
class CallableDefaultValueSerializer(serializers.ModelSerializer):
@@ -1493,18 +1600,10 @@ class ManyFieldHelpTextTest(TestCase):
self.assertEqual('Some help text.', rel_field.help_text)
+@unittest.skipUnless(PIL is not None, 'PIL is not installed')
class AttributeMappingOnAutogeneratedFieldsTests(TestCase):
def setUp(self):
- class AMOAFModel(RESTFrameworkModel):
- char_field = models.CharField(max_length=1024, blank=True)
- comma_separated_integer_field = models.CommaSeparatedIntegerField(max_length=1024, blank=True)
- decimal_field = models.DecimalField(max_digits=64, decimal_places=32, blank=True)
- email_field = models.EmailField(max_length=1024, blank=True)
- file_field = models.FileField(max_length=1024, blank=True)
- image_field = models.ImageField(max_length=1024, blank=True)
- slug_field = models.SlugField(max_length=1024, blank=True)
- url_field = models.URLField(max_length=1024, blank=True)
class AMOAFSerializer(serializers.ModelSerializer):
class Meta:
@@ -1574,17 +1673,10 @@ class AttributeMappingOnAutogeneratedFieldsTests(TestCase):
self.field_test('url_field')
+@unittest.skipUnless(PIL is not None, 'PIL is not installed')
class DefaultValuesOnAutogeneratedFieldsTests(TestCase):
def setUp(self):
- class DVOAFModel(RESTFrameworkModel):
- positive_integer_field = models.PositiveIntegerField(blank=True)
- positive_small_integer_field = models.PositiveSmallIntegerField(blank=True)
- email_field = models.EmailField(blank=True)
- file_field = models.FileField(blank=True)
- image_field = models.ImageField(blank=True)
- slug_field = models.SlugField(blank=True)
- url_field = models.URLField(blank=True)
class DVOAFSerializer(serializers.ModelSerializer):
class Meta:
diff --git a/tests/test_testing.py b/tests/test_testing.py
index 8c6086a2..bd3e1329 100644
--- a/tests/test_testing.py
+++ b/tests/test_testing.py
@@ -152,3 +152,13 @@ class TestAPIRequestFactory(TestCase):
simple_png.name = 'test.png'
factory = APIRequestFactory()
factory.post('/', data={'image': simple_png})
+
+ def test_request_factory_url_arguments(self):
+ """
+ This is a non regression test against #1461
+ """
+ factory = APIRequestFactory()
+ request = factory.get('/view/?demo=test')
+ self.assertEqual(dict(request.GET), {'demo': ['test']})
+ request = factory.get('/view/', {'demo': 'test'})
+ self.assertEqual(dict(request.GET), {'demo': ['test']})
diff --git a/tests/test_urlizer.py b/tests/test_urlizer.py
new file mode 100644
index 00000000..3dc8e8fe
--- /dev/null
+++ b/tests/test_urlizer.py
@@ -0,0 +1,38 @@
+from __future__ import unicode_literals
+from django.test import TestCase
+from rest_framework.templatetags.rest_framework import urlize_quoted_links
+import sys
+
+
+class URLizerTests(TestCase):
+ """
+ Test if both JSON and YAML URLs are transformed into links well
+ """
+ def _urlize_dict_check(self, data):
+ """
+ For all items in dict test assert that the value is urlized key
+ """
+ for original, urlized in data.items():
+ assert urlize_quoted_links(original, nofollow=False) == urlized
+
+ def test_json_with_url(self):
+ """
+ Test if JSON URLs are transformed into links well
+ """
+ data = {}
+ data['"url": "http://api/users/1/", '] = \
+ '&quot;url&quot;: &quot;<a href="http://api/users/1/">http://api/users/1/</a>&quot;, '
+ data['"foo_set": [\n "http://api/foos/1/"\n], '] = \
+ '&quot;foo_set&quot;: [\n &quot;<a href="http://api/foos/1/">http://api/foos/1/</a>&quot;\n], '
+ self._urlize_dict_check(data)
+
+ def test_yaml_with_url(self):
+ """
+ Test if YAML URLs are transformed into links well
+ """
+ data = {}
+ data['''{users: 'http://api/users/'}'''] = \
+ '''{users: &#39;<a href="http://api/users/">http://api/users/</a>&#39;}'''
+ data['''foo_set: ['http://api/foos/1/']'''] = \
+ '''foo_set: [&#39;<a href="http://api/foos/1/">http://api/foos/1/</a>&#39;]'''
+ self._urlize_dict_check(data)
diff --git a/tests/test_validation.py b/tests/test_validation.py
index 124c874d..e13e4078 100644
--- a/tests/test_validation.py
+++ b/tests/test_validation.py
@@ -1,4 +1,5 @@
from __future__ import unicode_literals
+from django.core.validators import MaxValueValidator
from django.db import models
from django.test import TestCase
from rest_framework import generics, serializers, status
@@ -102,3 +103,46 @@ class TestAvoidValidation(TestCase):
self.assertFalse(serializer.is_valid())
self.assertDictEqual(serializer.errors,
{'non_field_errors': ['Invalid data']})
+
+
+# regression tests for issue: 1493
+
+class ValidationMaxValueValidatorModel(models.Model):
+ number_value = models.PositiveIntegerField(validators=[MaxValueValidator(100)])
+
+
+class ValidationMaxValueValidatorModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ValidationMaxValueValidatorModel
+
+
+class UpdateMaxValueValidationModel(generics.RetrieveUpdateDestroyAPIView):
+ model = ValidationMaxValueValidatorModel
+ serializer_class = ValidationMaxValueValidatorModelSerializer
+
+
+class TestMaxValueValidatorValidation(TestCase):
+
+ def test_max_value_validation_serializer_success(self):
+ serializer = ValidationMaxValueValidatorModelSerializer(data={'number_value': 99})
+ self.assertTrue(serializer.is_valid())
+
+ def test_max_value_validation_serializer_fails(self):
+ serializer = ValidationMaxValueValidatorModelSerializer(data={'number_value': 101})
+ self.assertFalse(serializer.is_valid())
+ self.assertDictEqual({'number_value': ['Ensure this value is less than or equal to 100.']}, serializer.errors)
+
+ def test_max_value_validation_success(self):
+ obj = ValidationMaxValueValidatorModel.objects.create(number_value=100)
+ request = factory.patch('/{0}'.format(obj.pk), {'number_value': 98}, format='json')
+ view = UpdateMaxValueValidationModel().as_view()
+ response = view(request, pk=obj.pk).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ def test_max_value_validation_fail(self):
+ obj = ValidationMaxValueValidatorModel.objects.create(number_value=100)
+ request = factory.patch('/{0}'.format(obj.pk), {'number_value': 101}, format='json')
+ view = UpdateMaxValueValidationModel().as_view()
+ response = view(request, pk=obj.pk).render()
+ self.assertEqual(response.content, b'{"number_value": ["Ensure this value is less than or equal to 100."]}')
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
diff --git a/tests/utils.py b/tests/utils.py
new file mode 100644
index 00000000..a8f2eb0b
--- /dev/null
+++ b/tests/utils.py
@@ -0,0 +1,25 @@
+from contextlib import contextmanager
+from rest_framework.compat import six
+from rest_framework.settings import api_settings
+
+
+@contextmanager
+def temporary_setting(setting, value, module=None):
+ """
+ Temporarily change value of setting for test.
+
+ Optionally reload given module, useful when module uses value of setting on
+ import.
+ """
+ original_value = getattr(api_settings, setting)
+ setattr(api_settings, setting, value)
+
+ if module is not None:
+ six.moves.reload_module(module)
+
+ yield
+
+ setattr(api_settings, setting, original_value)
+
+ if module is not None:
+ six.moves.reload_module(module)