aboutsummaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/__init__.py0
-rw-r--r--tests/browsable_api/__init__.py0
-rw-r--r--tests/browsable_api/auth_urls.py11
-rw-r--r--tests/browsable_api/no_auth_urls.py9
-rw-r--r--tests/browsable_api/test_browsable_api.py65
-rw-r--r--tests/browsable_api/views.py15
-rw-r--r--tests/conftest.py66
-rw-r--r--tests/description.py26
-rw-r--r--tests/models.py70
-rw-r--r--tests/test_authentication.py288
-rw-r--r--tests/test_bound_fields.py69
-rw-r--r--tests/test_decorators.py157
-rw-r--r--tests/test_description.py131
-rw-r--r--tests/test_fields.py1212
-rw-r--r--tests/test_filters.py823
-rw-r--r--tests/test_generics.py506
-rw-r--r--tests/test_htmlrenderer.py127
-rw-r--r--tests/test_metadata.py209
-rw-r--r--tests/test_middleware.py37
-rw-r--r--tests/test_model_serializer.py641
-rw-r--r--tests/test_multitable_inheritance.py67
-rw-r--r--tests/test_negotiation.py45
-rw-r--r--tests/test_pagination.py671
-rw-r--r--tests/test_parsers.py103
-rw-r--r--tests/test_permissions.py312
-rw-r--r--tests/test_relations.py169
-rw-r--r--tests/test_relations_generic.py104
-rw-r--r--tests/test_relations_hyperlink.py444
-rw-r--r--tests/test_relations_pk.py450
-rw-r--r--tests/test_relations_slug.py281
-rw-r--r--tests/test_renderers.py473
-rw-r--r--tests/test_request.py278
-rw-r--r--tests/test_response.py292
-rw-r--r--tests/test_reverse.py28
-rw-r--r--tests/test_routers.py348
-rw-r--r--tests/test_serializer.py297
-rw-r--r--tests/test_serializer_bulk_update.py123
-rw-r--r--tests/test_serializer_lists.py290
-rw-r--r--tests/test_serializer_nested.py40
-rw-r--r--tests/test_settings.py17
-rw-r--r--tests/test_status.py33
-rw-r--r--tests/test_templatetags.py75
-rw-r--r--tests/test_testing.py234
-rw-r--r--tests/test_throttling.py353
-rw-r--r--tests/test_urlpatterns.py76
-rw-r--r--tests/test_utils.py166
-rw-r--r--tests/test_validation.py183
-rw-r--r--tests/test_validators.py347
-rw-r--r--tests/test_versioning.py264
-rw-r--r--tests/test_views.py148
-rw-r--r--tests/test_viewsets.py35
-rw-r--r--tests/test_write_only_fields.py31
-rw-r--r--tests/urls.py6
-rw-r--r--tests/utils.py77
54 files changed, 11322 insertions, 0 deletions
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/tests/__init__.py
diff --git a/tests/browsable_api/__init__.py b/tests/browsable_api/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/tests/browsable_api/__init__.py
diff --git a/tests/browsable_api/auth_urls.py b/tests/browsable_api/auth_urls.py
new file mode 100644
index 00000000..97bc1036
--- /dev/null
+++ b/tests/browsable_api/auth_urls.py
@@ -0,0 +1,11 @@
+from __future__ import unicode_literals
+from django.conf.urls import patterns, url, include
+
+from .views import MockView
+
+
+urlpatterns = patterns(
+ '',
+ (r'^$', MockView.as_view()),
+ url(r'^auth/', include('rest_framework.urls', namespace='rest_framework')),
+)
diff --git a/tests/browsable_api/no_auth_urls.py b/tests/browsable_api/no_auth_urls.py
new file mode 100644
index 00000000..5e3604a6
--- /dev/null
+++ b/tests/browsable_api/no_auth_urls.py
@@ -0,0 +1,9 @@
+from __future__ import unicode_literals
+from django.conf.urls import patterns
+
+from .views import MockView
+
+urlpatterns = patterns(
+ '',
+ (r'^$', MockView.as_view()),
+)
diff --git a/tests/browsable_api/test_browsable_api.py b/tests/browsable_api/test_browsable_api.py
new file mode 100644
index 00000000..5f264783
--- /dev/null
+++ b/tests/browsable_api/test_browsable_api.py
@@ -0,0 +1,65 @@
+from __future__ import unicode_literals
+from django.contrib.auth.models import User
+from django.test import TestCase
+
+from rest_framework.test import APIClient
+
+
+class DropdownWithAuthTests(TestCase):
+ """Tests correct dropdown behaviour with Auth views enabled."""
+
+ urls = 'tests.browsable_api.auth_urls'
+
+ def setUp(self):
+ self.client = APIClient(enforce_csrf_checks=True)
+ self.username = 'john'
+ self.email = 'lennon@thebeatles.com'
+ self.password = 'password'
+ self.user = User.objects.create_user(self.username, self.email, self.password)
+
+ def tearDown(self):
+ self.client.logout()
+
+ def test_name_shown_when_logged_in(self):
+ self.client.login(username=self.username, password=self.password)
+ response = self.client.get('/')
+ self.assertContains(response, 'john')
+
+ def test_logout_shown_when_logged_in(self):
+ self.client.login(username=self.username, password=self.password)
+ response = self.client.get('/')
+ self.assertContains(response, '>Log out<')
+
+ def test_login_shown_when_logged_out(self):
+ response = self.client.get('/')
+ self.assertContains(response, '>Log in<')
+
+
+class NoDropdownWithoutAuthTests(TestCase):
+ """Tests correct dropdown behaviour with Auth views NOT enabled."""
+
+ urls = 'tests.browsable_api.no_auth_urls'
+
+ def setUp(self):
+ self.client = APIClient(enforce_csrf_checks=True)
+ self.username = 'john'
+ self.email = 'lennon@thebeatles.com'
+ self.password = 'password'
+ self.user = User.objects.create_user(self.username, self.email, self.password)
+
+ def tearDown(self):
+ self.client.logout()
+
+ def test_name_shown_when_logged_in(self):
+ self.client.login(username=self.username, password=self.password)
+ response = self.client.get('/')
+ self.assertContains(response, 'john')
+
+ def test_dropdown_not_shown_when_logged_in(self):
+ self.client.login(username=self.username, password=self.password)
+ response = self.client.get('/')
+ self.assertNotContains(response, '<li class="dropdown">')
+
+ def test_dropdown_not_shown_when_logged_out(self):
+ response = self.client.get('/')
+ self.assertNotContains(response, '<li class="dropdown">')
diff --git a/tests/browsable_api/views.py b/tests/browsable_api/views.py
new file mode 100644
index 00000000..000f4e80
--- /dev/null
+++ b/tests/browsable_api/views.py
@@ -0,0 +1,15 @@
+from __future__ import unicode_literals
+
+from rest_framework.views import APIView
+from rest_framework import authentication
+from rest_framework import renderers
+from rest_framework.response import Response
+
+
+class MockView(APIView):
+
+ authentication_classes = (authentication.SessionAuthentication,)
+ renderer_classes = (renderers.BrowsableAPIRenderer,)
+
+ def get(self, request):
+ return Response({'a': 1, 'b': 2, 'c': 3})
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 00000000..44ed070b
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,66 @@
+def pytest_configure():
+ from django.conf import settings
+
+ settings.configure(
+ DEBUG_PROPAGATE_EXCEPTIONS=True,
+ DATABASES={'default': {'ENGINE': 'django.db.backends.sqlite3',
+ 'NAME': ':memory:'}},
+ SITE_ID=1,
+ SECRET_KEY='not very secret in tests',
+ USE_I18N=True,
+ USE_L10N=True,
+ STATIC_URL='/static/',
+ ROOT_URLCONF='tests.urls',
+ TEMPLATE_LOADERS=(
+ 'django.template.loaders.filesystem.Loader',
+ 'django.template.loaders.app_directories.Loader',
+ ),
+ MIDDLEWARE_CLASSES=(
+ 'django.middleware.common.CommonMiddleware',
+ 'django.contrib.sessions.middleware.SessionMiddleware',
+ 'django.middleware.csrf.CsrfViewMiddleware',
+ 'django.contrib.auth.middleware.AuthenticationMiddleware',
+ 'django.contrib.messages.middleware.MessageMiddleware',
+ ),
+ INSTALLED_APPS=(
+ 'django.contrib.auth',
+ 'django.contrib.contenttypes',
+ 'django.contrib.sessions',
+ 'django.contrib.sites',
+ 'django.contrib.messages',
+ 'django.contrib.staticfiles',
+
+ 'rest_framework',
+ 'rest_framework.authtoken',
+ 'tests',
+ ),
+ PASSWORD_HASHERS=(
+ 'django.contrib.auth.hashers.SHA1PasswordHasher',
+ 'django.contrib.auth.hashers.PBKDF2PasswordHasher',
+ 'django.contrib.auth.hashers.PBKDF2SHA1PasswordHasher',
+ 'django.contrib.auth.hashers.BCryptPasswordHasher',
+ 'django.contrib.auth.hashers.MD5PasswordHasher',
+ 'django.contrib.auth.hashers.CryptPasswordHasher',
+ ),
+ )
+
+ # guardian is optional
+ try:
+ import guardian # NOQA
+ except ImportError:
+ pass
+ else:
+ settings.ANONYMOUS_USER_ID = -1
+ settings.AUTHENTICATION_BACKENDS = (
+ 'django.contrib.auth.backends.ModelBackend',
+ 'guardian.backends.ObjectPermissionBackend',
+ )
+ settings.INSTALLED_APPS += (
+ 'guardian',
+ )
+
+ try:
+ import django
+ django.setup()
+ except AttributeError:
+ pass
diff --git a/tests/description.py b/tests/description.py
new file mode 100644
index 00000000..b46d7f54
--- /dev/null
+++ b/tests/description.py
@@ -0,0 +1,26 @@
+# -- coding: utf-8 --
+
+# Apparently there is a python 2.6 issue where docstrings of imported view classes
+# do not retain their encoding information even if a module has a proper
+# encoding declaration at the top of its source file. Therefore for tests
+# to catch unicode related errors, a mock view has to be declared in a separate
+# module.
+
+from rest_framework.views import APIView
+
+
+# test strings snatched from http://www.columbia.edu/~fdc/utf8/,
+# http://winrus.com/utf8-jap.htm and memory
+UTF8_TEST_DOCSTRING = (
+ 'zażółć gęślą jaźń'
+ 'Sîne klâwen durh die wolken sint geslagen'
+ 'Τη γλώσσα μου έδωσαν ελληνική'
+ 'யாமறிந்த மொழிகளிலே தமிழ்மொழி'
+ 'На берегу пустынных волн'
+ 'てすと'
+ 'アイウエオカキクケコサシスセソタチツテ'
+)
+
+
+class ViewWithNonASCIICharactersInDocstring(APIView):
+ __doc__ = UTF8_TEST_DOCSTRING
diff --git a/tests/models.py b/tests/models.py
new file mode 100644
index 00000000..456b0a0b
--- /dev/null
+++ b/tests/models.py
@@ -0,0 +1,70 @@
+from __future__ import unicode_literals
+from django.db import models
+from django.utils.translation import ugettext_lazy as _
+
+
+class RESTFrameworkModel(models.Model):
+ """
+ Base for test models that sets app_label, so they play nicely.
+ """
+
+ class Meta:
+ app_label = 'tests'
+ abstract = True
+
+
+class BasicModel(RESTFrameworkModel):
+ text = models.CharField(max_length=100, verbose_name=_("Text comes here"), help_text=_("Text description."))
+
+
+class BaseFilterableItem(RESTFrameworkModel):
+ text = models.CharField(max_length=100)
+
+ class Meta:
+ abstract = True
+
+
+class FilterableItem(BaseFilterableItem):
+ decimal = models.DecimalField(max_digits=4, decimal_places=2)
+ date = models.DateField()
+
+
+# Models for relations tests
+# ManyToMany
+class ManyToManyTarget(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+
+
+class ManyToManySource(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+ targets = models.ManyToManyField(ManyToManyTarget, related_name='sources')
+
+
+# ForeignKey
+class ForeignKeyTarget(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+
+
+class ForeignKeySource(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+ target = models.ForeignKey(ForeignKeyTarget, related_name='sources',
+ help_text='Target', verbose_name='Target')
+
+
+# Nullable ForeignKey
+class NullableForeignKeySource(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+ target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True,
+ related_name='nullable_sources',
+ verbose_name='Optional target object')
+
+
+# OneToOne
+class OneToOneTarget(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+
+
+class NullableOneToOneSource(RESTFrameworkModel):
+ name = models.CharField(max_length=100)
+ target = models.OneToOneField(OneToOneTarget, null=True, blank=True,
+ related_name='nullable_source')
diff --git a/tests/test_authentication.py b/tests/test_authentication.py
new file mode 100644
index 00000000..91e49f9d
--- /dev/null
+++ b/tests/test_authentication.py
@@ -0,0 +1,288 @@
+from __future__ import unicode_literals
+from django.conf.urls import patterns, url, include
+from django.contrib.auth.models import User
+from django.http import HttpResponse
+from django.test import TestCase
+from django.utils import six
+from rest_framework import HTTP_HEADER_ENCODING
+from rest_framework import exceptions
+from rest_framework import permissions
+from rest_framework import renderers
+from rest_framework.response import Response
+from rest_framework import status
+from rest_framework.authentication import (
+ BaseAuthentication,
+ TokenAuthentication,
+ BasicAuthentication,
+ SessionAuthentication,
+)
+from rest_framework.authtoken.models import Token
+from rest_framework.test import APIRequestFactory, APIClient
+from rest_framework.views import APIView
+import base64
+
+factory = APIRequestFactory()
+
+
+class MockView(APIView):
+ permission_classes = (permissions.IsAuthenticated,)
+
+ def get(self, request):
+ return HttpResponse({'a': 1, 'b': 2, 'c': 3})
+
+ def post(self, request):
+ return HttpResponse({'a': 1, 'b': 2, 'c': 3})
+
+ def put(self, request):
+ return HttpResponse({'a': 1, 'b': 2, 'c': 3})
+
+
+urlpatterns = patterns(
+ '',
+ (r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])),
+ (r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])),
+ (r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])),
+ (r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'),
+ url(r'^auth/', include('rest_framework.urls', namespace='rest_framework'))
+)
+
+
+class BasicAuthTests(TestCase):
+ """Basic authentication"""
+ urls = 'tests.test_authentication'
+
+ def setUp(self):
+ self.csrf_client = APIClient(enforce_csrf_checks=True)
+ self.username = 'john'
+ self.email = 'lennon@thebeatles.com'
+ self.password = 'password'
+ self.user = User.objects.create_user(self.username, self.email, self.password)
+
+ def test_post_form_passing_basic_auth(self):
+ """Ensure POSTing json over basic auth with correct credentials passes and does not require CSRF"""
+ credentials = ('%s:%s' % (self.username, self.password))
+ base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING)
+ auth = 'Basic %s' % base64_credentials
+ response = self.csrf_client.post('/basic/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ def test_post_json_passing_basic_auth(self):
+ """Ensure POSTing form over basic auth with correct credentials passes and does not require CSRF"""
+ credentials = ('%s:%s' % (self.username, self.password))
+ base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING)
+ auth = 'Basic %s' % base64_credentials
+ response = self.csrf_client.post('/basic/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ def test_post_form_failing_basic_auth(self):
+ """Ensure POSTing form over basic auth without correct credentials fails"""
+ response = self.csrf_client.post('/basic/', {'example': 'example'})
+ self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
+
+ def test_post_json_failing_basic_auth(self):
+ """Ensure POSTing json over basic auth without correct credentials fails"""
+ response = self.csrf_client.post('/basic/', {'example': 'example'}, format='json')
+ self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
+ self.assertEqual(response['WWW-Authenticate'], 'Basic realm="api"')
+
+
+class SessionAuthTests(TestCase):
+ """User session authentication"""
+ urls = 'tests.test_authentication'
+
+ def setUp(self):
+ self.csrf_client = APIClient(enforce_csrf_checks=True)
+ self.non_csrf_client = APIClient(enforce_csrf_checks=False)
+ self.username = 'john'
+ self.email = 'lennon@thebeatles.com'
+ self.password = 'password'
+ self.user = User.objects.create_user(self.username, self.email, self.password)
+
+ def tearDown(self):
+ self.csrf_client.logout()
+
+ def test_login_view_renders_on_get(self):
+ """
+ Ensure the login template renders for a basic GET.
+
+ cf. [#1810](https://github.com/tomchristie/django-rest-framework/pull/1810)
+ """
+ response = self.csrf_client.get('/auth/login/')
+ self.assertContains(response, '<label for="id_username">Username:</label>')
+
+ def test_post_form_session_auth_failing_csrf(self):
+ """
+ Ensure POSTing form over session authentication without CSRF token fails.
+ """
+ self.csrf_client.login(username=self.username, password=self.password)
+ response = self.csrf_client.post('/session/', {'example': 'example'})
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+ def test_post_form_session_auth_passing(self):
+ """
+ Ensure POSTing form over session authentication with logged in user and CSRF token passes.
+ """
+ self.non_csrf_client.login(username=self.username, password=self.password)
+ response = self.non_csrf_client.post('/session/', {'example': 'example'})
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ def test_put_form_session_auth_passing(self):
+ """
+ Ensure PUTting form over session authentication with logged in user and CSRF token passes.
+ """
+ self.non_csrf_client.login(username=self.username, password=self.password)
+ response = self.non_csrf_client.put('/session/', {'example': 'example'})
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ def test_post_form_session_auth_failing(self):
+ """
+ Ensure POSTing form over session authentication without logged in user fails.
+ """
+ response = self.csrf_client.post('/session/', {'example': 'example'})
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+
+class TokenAuthTests(TestCase):
+ """Token authentication"""
+ urls = 'tests.test_authentication'
+
+ def setUp(self):
+ self.csrf_client = APIClient(enforce_csrf_checks=True)
+ self.username = 'john'
+ self.email = 'lennon@thebeatles.com'
+ self.password = 'password'
+ self.user = User.objects.create_user(self.username, self.email, self.password)
+
+ self.key = 'abcd1234'
+ self.token = Token.objects.create(key=self.key, user=self.user)
+
+ def test_post_form_passing_token_auth(self):
+ """Ensure POSTing json over token auth with correct credentials passes and does not require CSRF"""
+ auth = 'Token ' + self.key
+ response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ def test_post_json_passing_token_auth(self):
+ """Ensure POSTing form over token auth with correct credentials passes and does not require CSRF"""
+ auth = "Token " + self.key
+ response = self.csrf_client.post('/token/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ def test_post_json_makes_one_db_query(self):
+ """Ensure that authenticating a user using a token performs only one DB query"""
+ auth = "Token " + self.key
+
+ def func_to_test():
+ return self.csrf_client.post('/token/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)
+
+ self.assertNumQueries(1, func_to_test)
+
+ def test_post_form_failing_token_auth(self):
+ """Ensure POSTing form over token auth without correct credentials fails"""
+ response = self.csrf_client.post('/token/', {'example': 'example'})
+ self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
+
+ def test_post_json_failing_token_auth(self):
+ """Ensure POSTing json over token auth without correct credentials fails"""
+ response = self.csrf_client.post('/token/', {'example': 'example'}, format='json')
+ self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
+
+ def test_token_has_auto_assigned_key_if_none_provided(self):
+ """Ensure creating a token with no key will auto-assign a key"""
+ self.token.delete()
+ token = Token.objects.create(user=self.user)
+ self.assertTrue(bool(token.key))
+
+ def test_generate_key_returns_string(self):
+ """Ensure generate_key returns a string"""
+ token = Token()
+ key = token.generate_key()
+ self.assertTrue(isinstance(key, six.string_types))
+
+ def test_token_login_json(self):
+ """Ensure token login view using JSON POST works."""
+ client = APIClient(enforce_csrf_checks=True)
+ response = client.post('/auth-token/',
+ {'username': self.username, 'password': self.password}, format='json')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['token'], self.key)
+
+ def test_token_login_json_bad_creds(self):
+ """Ensure token login view using JSON POST fails if bad credentials are used."""
+ client = APIClient(enforce_csrf_checks=True)
+ response = client.post('/auth-token/',
+ {'username': self.username, 'password': "badpass"}, format='json')
+ self.assertEqual(response.status_code, 400)
+
+ def test_token_login_json_missing_fields(self):
+ """Ensure token login view using JSON POST fails if missing fields."""
+ client = APIClient(enforce_csrf_checks=True)
+ response = client.post('/auth-token/',
+ {'username': self.username}, format='json')
+ self.assertEqual(response.status_code, 400)
+
+ def test_token_login_form(self):
+ """Ensure token login view using form POST works."""
+ client = APIClient(enforce_csrf_checks=True)
+ response = client.post('/auth-token/',
+ {'username': self.username, 'password': self.password})
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data['token'], self.key)
+
+
+class IncorrectCredentialsTests(TestCase):
+ def test_incorrect_credentials(self):
+ """
+ If a request contains bad authentication credentials, then
+ authentication should run and error, even if no permissions
+ are set on the view.
+ """
+ class IncorrectCredentialsAuth(BaseAuthentication):
+ def authenticate(self, request):
+ raise exceptions.AuthenticationFailed('Bad credentials')
+
+ request = factory.get('/')
+ view = MockView.as_view(
+ authentication_classes=(IncorrectCredentialsAuth,),
+ permission_classes=()
+ )
+ response = view(request)
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+ self.assertEqual(response.data, {'detail': 'Bad credentials'})
+
+
+class FailingAuthAccessedInRenderer(TestCase):
+ def setUp(self):
+ class AuthAccessingRenderer(renderers.BaseRenderer):
+ media_type = 'text/plain'
+ format = 'txt'
+
+ def render(self, data, media_type=None, renderer_context=None):
+ request = renderer_context['request']
+ if request.user.is_authenticated():
+ return b'authenticated'
+ return b'not authenticated'
+
+ class FailingAuth(BaseAuthentication):
+ def authenticate(self, request):
+ raise exceptions.AuthenticationFailed('authentication failed')
+
+ class ExampleView(APIView):
+ authentication_classes = (FailingAuth,)
+ renderer_classes = (AuthAccessingRenderer,)
+
+ def get(self, request):
+ return Response({'foo': 'bar'})
+
+ self.view = ExampleView.as_view()
+
+ def test_failing_auth_accessed_in_renderer(self):
+ """
+ When authentication fails the renderer should still be able to access
+ `request.user` without raising an exception. Particularly relevant
+ to HTML responses that might reasonably access `request.user`.
+ """
+ request = factory.get('/')
+ response = self.view(request)
+ content = response.render().content
+ self.assertEqual(content, b'not authenticated')
diff --git a/tests/test_bound_fields.py b/tests/test_bound_fields.py
new file mode 100644
index 00000000..bfc54b23
--- /dev/null
+++ b/tests/test_bound_fields.py
@@ -0,0 +1,69 @@
+from rest_framework import serializers
+
+
+class TestSimpleBoundField:
+ def test_empty_bound_field(self):
+ class ExampleSerializer(serializers.Serializer):
+ text = serializers.CharField(max_length=100)
+ amount = serializers.IntegerField()
+
+ serializer = ExampleSerializer()
+
+ assert serializer['text'].value == ''
+ assert serializer['text'].errors is None
+ assert serializer['text'].name == 'text'
+ assert serializer['amount'].value is None
+ assert serializer['amount'].errors is None
+ assert serializer['amount'].name == 'amount'
+
+ def test_populated_bound_field(self):
+ class ExampleSerializer(serializers.Serializer):
+ text = serializers.CharField(max_length=100)
+ amount = serializers.IntegerField()
+
+ serializer = ExampleSerializer(data={'text': 'abc', 'amount': 123})
+ assert serializer.is_valid()
+ assert serializer['text'].value == 'abc'
+ assert serializer['text'].errors is None
+ assert serializer['text'].name == 'text'
+ assert serializer['amount'].value is 123
+ assert serializer['amount'].errors is None
+ assert serializer['amount'].name == 'amount'
+
+ def test_error_bound_field(self):
+ class ExampleSerializer(serializers.Serializer):
+ text = serializers.CharField(max_length=100)
+ amount = serializers.IntegerField()
+
+ serializer = ExampleSerializer(data={'text': 'x' * 1000, 'amount': 123})
+ serializer.is_valid()
+
+ assert serializer['text'].value == 'x' * 1000
+ assert serializer['text'].errors == ['Ensure this field has no more than 100 characters.']
+ assert serializer['text'].name == 'text'
+ assert serializer['amount'].value is 123
+ assert serializer['amount'].errors is None
+ assert serializer['amount'].name == 'amount'
+
+
+class TestNestedBoundField:
+ def test_nested_empty_bound_field(self):
+ class Nested(serializers.Serializer):
+ more_text = serializers.CharField(max_length=100)
+ amount = serializers.IntegerField()
+
+ class ExampleSerializer(serializers.Serializer):
+ text = serializers.CharField(max_length=100)
+ nested = Nested()
+
+ serializer = ExampleSerializer()
+
+ assert serializer['text'].value == ''
+ assert serializer['text'].errors is None
+ assert serializer['text'].name == 'text'
+ assert serializer['nested']['more_text'].value == ''
+ assert serializer['nested']['more_text'].errors is None
+ assert serializer['nested']['more_text'].name == 'nested.more_text'
+ assert serializer['nested']['amount'].value is None
+ assert serializer['nested']['amount'].errors is None
+ assert serializer['nested']['amount'].name == 'nested.amount'
diff --git a/tests/test_decorators.py b/tests/test_decorators.py
new file mode 100644
index 00000000..195f0ba3
--- /dev/null
+++ b/tests/test_decorators.py
@@ -0,0 +1,157 @@
+from __future__ import unicode_literals
+from django.test import TestCase
+from rest_framework import status
+from rest_framework.authentication import BasicAuthentication
+from rest_framework.parsers import JSONParser
+from rest_framework.permissions import IsAuthenticated
+from rest_framework.response import Response
+from rest_framework.renderers import JSONRenderer
+from rest_framework.test import APIRequestFactory
+from rest_framework.throttling import UserRateThrottle
+from rest_framework.views import APIView
+from rest_framework.decorators import (
+ api_view,
+ renderer_classes,
+ parser_classes,
+ authentication_classes,
+ throttle_classes,
+ permission_classes,
+)
+
+
+class DecoratorTestCase(TestCase):
+
+ def setUp(self):
+ self.factory = APIRequestFactory()
+
+ def _finalize_response(self, request, response, *args, **kwargs):
+ response.request = request
+ return APIView.finalize_response(self, request, response, *args, **kwargs)
+
+ def test_api_view_incorrect(self):
+ """
+ If @api_view is not applied correct, we should raise an assertion.
+ """
+
+ @api_view
+ def view(request):
+ return Response()
+
+ request = self.factory.get('/')
+ self.assertRaises(AssertionError, view, request)
+
+ def test_api_view_incorrect_arguments(self):
+ """
+ If @api_view is missing arguments, we should raise an assertion.
+ """
+
+ with self.assertRaises(AssertionError):
+ @api_view('GET')
+ def view(request):
+ return Response()
+
+ def test_calling_method(self):
+
+ @api_view(['GET'])
+ def view(request):
+ return Response({})
+
+ request = self.factory.get('/')
+ response = view(request)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ request = self.factory.post('/')
+ response = view(request)
+ self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
+
+ def test_calling_put_method(self):
+
+ @api_view(['GET', 'PUT'])
+ def view(request):
+ return Response({})
+
+ request = self.factory.put('/')
+ response = view(request)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ request = self.factory.post('/')
+ response = view(request)
+ self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
+
+ def test_calling_patch_method(self):
+
+ @api_view(['GET', 'PATCH'])
+ def view(request):
+ return Response({})
+
+ request = self.factory.patch('/')
+ response = view(request)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ request = self.factory.post('/')
+ response = view(request)
+ self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
+
+ def test_renderer_classes(self):
+
+ @api_view(['GET'])
+ @renderer_classes([JSONRenderer])
+ def view(request):
+ return Response({})
+
+ request = self.factory.get('/')
+ response = view(request)
+ self.assertTrue(isinstance(response.accepted_renderer, JSONRenderer))
+
+ def test_parser_classes(self):
+
+ @api_view(['GET'])
+ @parser_classes([JSONParser])
+ def view(request):
+ self.assertEqual(len(request.parsers), 1)
+ self.assertTrue(isinstance(request.parsers[0],
+ JSONParser))
+ return Response({})
+
+ request = self.factory.get('/')
+ view(request)
+
+ def test_authentication_classes(self):
+
+ @api_view(['GET'])
+ @authentication_classes([BasicAuthentication])
+ def view(request):
+ self.assertEqual(len(request.authenticators), 1)
+ self.assertTrue(isinstance(request.authenticators[0],
+ BasicAuthentication))
+ return Response({})
+
+ request = self.factory.get('/')
+ view(request)
+
+ def test_permission_classes(self):
+
+ @api_view(['GET'])
+ @permission_classes([IsAuthenticated])
+ def view(request):
+ return Response({})
+
+ request = self.factory.get('/')
+ response = view(request)
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+ def test_throttle_classes(self):
+ class OncePerDayUserThrottle(UserRateThrottle):
+ rate = '1/day'
+
+ @api_view(['GET'])
+ @throttle_classes([OncePerDayUserThrottle])
+ def view(request):
+ return Response({})
+
+ request = self.factory.get('/')
+ response = view(request)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ response = view(request)
+ self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS)
diff --git a/tests/test_description.py b/tests/test_description.py
new file mode 100644
index 00000000..78ce2350
--- /dev/null
+++ b/tests/test_description.py
@@ -0,0 +1,131 @@
+# -- coding: utf-8 --
+
+from __future__ import unicode_literals
+from django.test import TestCase
+from django.utils.encoding import python_2_unicode_compatible, smart_text
+from rest_framework.compat import apply_markdown
+from rest_framework.views import APIView
+from .description import ViewWithNonASCIICharactersInDocstring
+from .description import UTF8_TEST_DOCSTRING
+
+# We check that docstrings get nicely un-indented.
+DESCRIPTION = """an example docstring
+====================
+
+* list
+* list
+
+another header
+--------------
+
+ code block
+
+indented
+
+# hash style header #"""
+
+# If markdown is installed we also test it's working
+# (and that our wrapped forces '=' to h2 and '-' to h3)
+
+# We support markdown < 2.1 and markdown >= 2.1
+MARKED_DOWN_lt_21 = """<h2>an example docstring</h2>
+<ul>
+<li>list</li>
+<li>list</li>
+</ul>
+<h3>another header</h3>
+<pre><code>code block
+</code></pre>
+<p>indented</p>
+<h2 id="hash_style_header">hash style header</h2>"""
+
+MARKED_DOWN_gte_21 = """<h2 id="an-example-docstring">an example docstring</h2>
+<ul>
+<li>list</li>
+<li>list</li>
+</ul>
+<h3 id="another-header">another header</h3>
+<pre><code>code block
+</code></pre>
+<p>indented</p>
+<h2 id="hash-style-header">hash style header</h2>"""
+
+
+class TestViewNamesAndDescriptions(TestCase):
+ def test_view_name_uses_class_name(self):
+ """
+ Ensure view names are based on the class name.
+ """
+ class MockView(APIView):
+ pass
+ self.assertEqual(MockView().get_view_name(), 'Mock')
+
+ def test_view_description_uses_docstring(self):
+ """Ensure view descriptions are based on the docstring."""
+ class MockView(APIView):
+ """an example docstring
+ ====================
+
+ * list
+ * list
+
+ another header
+ --------------
+
+ code block
+
+ indented
+
+ # hash style header #"""
+
+ self.assertEqual(MockView().get_view_description(), DESCRIPTION)
+
+ def test_view_description_supports_unicode(self):
+ """
+ Unicode in docstrings should be respected.
+ """
+
+ self.assertEqual(
+ ViewWithNonASCIICharactersInDocstring().get_view_description(),
+ smart_text(UTF8_TEST_DOCSTRING)
+ )
+
+ def test_view_description_can_be_empty(self):
+ """
+ Ensure that if a view has no docstring,
+ then it's description is the empty string.
+ """
+ class MockView(APIView):
+ pass
+ self.assertEqual(MockView().get_view_description(), '')
+
+ def test_view_description_can_be_promise(self):
+ """
+ Ensure a view may have a docstring that is actually a lazily evaluated
+ class that can be converted to a string.
+
+ See: https://github.com/tomchristie/django-rest-framework/issues/1708
+ """
+ # use a mock object instead of gettext_lazy to ensure that we can't end
+ # up with a test case string in our l10n catalog
+ @python_2_unicode_compatible
+ class MockLazyStr(object):
+ def __init__(self, string):
+ self.s = string
+
+ def __str__(self):
+ return self.s
+
+ class MockView(APIView):
+ __doc__ = MockLazyStr("a gettext string")
+
+ self.assertEqual(MockView().get_view_description(), 'a gettext string')
+
+ def test_markdown(self):
+ """
+ Ensure markdown to HTML works as expected.
+ """
+ if apply_markdown:
+ gte_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_gte_21
+ lt_21_match = apply_markdown(DESCRIPTION) == MARKED_DOWN_lt_21
+ self.assertTrue(gte_21_match or lt_21_match)
diff --git a/tests/test_fields.py b/tests/test_fields.py
new file mode 100644
index 00000000..1aa528da
--- /dev/null
+++ b/tests/test_fields.py
@@ -0,0 +1,1212 @@
+from decimal import Decimal
+from django.utils import timezone
+from rest_framework import serializers
+import datetime
+import django
+import pytest
+import uuid
+
+
+# Tests for field keyword arguments and core functionality.
+# ---------------------------------------------------------
+
+class TestEmpty:
+ """
+ Tests for `required`, `allow_null`, `allow_blank`, `default`.
+ """
+ def test_required(self):
+ """
+ By default a field must be included in the input.
+ """
+ field = serializers.IntegerField()
+ with pytest.raises(serializers.ValidationError) as exc_info:
+ field.run_validation()
+ assert exc_info.value.detail == ['This field is required.']
+
+ def test_not_required(self):
+ """
+ If `required=False` then a field may be omitted from the input.
+ """
+ field = serializers.IntegerField(required=False)
+ with pytest.raises(serializers.SkipField):
+ field.run_validation()
+
+ def test_disallow_null(self):
+ """
+ By default `None` is not a valid input.
+ """
+ field = serializers.IntegerField()
+ with pytest.raises(serializers.ValidationError) as exc_info:
+ field.run_validation(None)
+ assert exc_info.value.detail == ['This field may not be null.']
+
+ def test_allow_null(self):
+ """
+ If `allow_null=True` then `None` is a valid input.
+ """
+ field = serializers.IntegerField(allow_null=True)
+ output = field.run_validation(None)
+ assert output is None
+
+ def test_disallow_blank(self):
+ """
+ By default '' is not a valid input.
+ """
+ field = serializers.CharField()
+ with pytest.raises(serializers.ValidationError) as exc_info:
+ field.run_validation('')
+ assert exc_info.value.detail == ['This field may not be blank.']
+
+ def test_allow_blank(self):
+ """
+ If `allow_blank=True` then '' is a valid input.
+ """
+ field = serializers.CharField(allow_blank=True)
+ output = field.run_validation('')
+ assert output == ''
+
+ def test_default(self):
+ """
+ If `default` is set, then omitted values get the default input.
+ """
+ field = serializers.IntegerField(default=123)
+ output = field.run_validation()
+ assert output is 123
+
+
+class TestSource:
+ def test_source(self):
+ class ExampleSerializer(serializers.Serializer):
+ example_field = serializers.CharField(source='other')
+ serializer = ExampleSerializer(data={'example_field': 'abc'})
+ assert serializer.is_valid()
+ assert serializer.validated_data == {'other': 'abc'}
+
+ def test_redundant_source(self):
+ class ExampleSerializer(serializers.Serializer):
+ example_field = serializers.CharField(source='example_field')
+ with pytest.raises(AssertionError) as exc_info:
+ ExampleSerializer().fields
+ assert str(exc_info.value) == (
+ "It is redundant to specify `source='example_field'` on field "
+ "'CharField' in serializer 'ExampleSerializer', because it is the "
+ "same as the field name. Remove the `source` keyword argument."
+ )
+
+ def test_callable_source(self):
+ class ExampleSerializer(serializers.Serializer):
+ example_field = serializers.CharField(source='example_callable')
+
+ class ExampleInstance(object):
+ def example_callable(self):
+ return 'example callable value'
+
+ serializer = ExampleSerializer(ExampleInstance())
+ assert serializer.data['example_field'] == 'example callable value'
+
+ def test_callable_source_raises(self):
+ class ExampleSerializer(serializers.Serializer):
+ example_field = serializers.CharField(source='example_callable', read_only=True)
+
+ class ExampleInstance(object):
+ def example_callable(self):
+ raise AttributeError('method call failed')
+
+ with pytest.raises(ValueError) as exc_info:
+ serializer = ExampleSerializer(ExampleInstance())
+ serializer.data.items()
+
+ assert 'method call failed' in str(exc_info.value)
+
+
+class TestReadOnly:
+ def setup(self):
+ class TestSerializer(serializers.Serializer):
+ read_only = serializers.ReadOnlyField()
+ writable = serializers.IntegerField()
+ self.Serializer = TestSerializer
+
+ def test_validate_read_only(self):
+ """
+ Read-only serializers.should not be included in validation.
+ """
+ data = {'read_only': 123, 'writable': 456}
+ serializer = self.Serializer(data=data)
+ assert serializer.is_valid()
+ assert serializer.validated_data == {'writable': 456}
+
+ def test_serialize_read_only(self):
+ """
+ Read-only serializers.should be serialized.
+ """
+ instance = {'read_only': 123, 'writable': 456}
+ serializer = self.Serializer(instance)
+ assert serializer.data == {'read_only': 123, 'writable': 456}
+
+
+class TestWriteOnly:
+ def setup(self):
+ class TestSerializer(serializers.Serializer):
+ write_only = serializers.IntegerField(write_only=True)
+ readable = serializers.IntegerField()
+ self.Serializer = TestSerializer
+
+ def test_validate_write_only(self):
+ """
+ Write-only serializers.should be included in validation.
+ """
+ data = {'write_only': 123, 'readable': 456}
+ serializer = self.Serializer(data=data)
+ assert serializer.is_valid()
+ assert serializer.validated_data == {'write_only': 123, 'readable': 456}
+
+ def test_serialize_write_only(self):
+ """
+ Write-only serializers.should not be serialized.
+ """
+ instance = {'write_only': 123, 'readable': 456}
+ serializer = self.Serializer(instance)
+ assert serializer.data == {'readable': 456}
+
+
+class TestInitial:
+ def setup(self):
+ class TestSerializer(serializers.Serializer):
+ initial_field = serializers.IntegerField(initial=123)
+ blank_field = serializers.IntegerField()
+ self.serializer = TestSerializer()
+
+ def test_initial(self):
+ """
+ Initial values should be included when serializing a new representation.
+ """
+ assert self.serializer.data == {
+ 'initial_field': 123,
+ 'blank_field': None
+ }
+
+
+class TestLabel:
+ def setup(self):
+ class TestSerializer(serializers.Serializer):
+ labeled = serializers.IntegerField(label='My label')
+ self.serializer = TestSerializer()
+
+ def test_label(self):
+ """
+ A field's label may be set with the `label` argument.
+ """
+ fields = self.serializer.fields
+ assert fields['labeled'].label == 'My label'
+
+
+class TestInvalidErrorKey:
+ def setup(self):
+ class ExampleField(serializers.Field):
+ def to_native(self, data):
+ self.fail('incorrect')
+ self.field = ExampleField()
+
+ def test_invalid_error_key(self):
+ """
+ If a field raises a validation error, but does not have a corresponding
+ error message, then raise an appropriate assertion error.
+ """
+ with pytest.raises(AssertionError) as exc_info:
+ self.field.to_native(123)
+ expected = (
+ 'ValidationError raised by `ExampleField`, but error key '
+ '`incorrect` does not exist in the `error_messages` dictionary.'
+ )
+ assert str(exc_info.value) == expected
+
+
+class TestBooleanHTMLInput:
+ def setup(self):
+ class TestSerializer(serializers.Serializer):
+ archived = serializers.BooleanField()
+ self.Serializer = TestSerializer
+
+ def test_empty_html_checkbox(self):
+ """
+ HTML checkboxes do not send any value, but should be treated
+ as `False` by BooleanField.
+ """
+ # This class mocks up a dictionary like object, that behaves
+ # as if it was returned for multipart or urlencoded data.
+ class MockHTMLDict(dict):
+ getlist = None
+ serializer = self.Serializer(data=MockHTMLDict())
+ assert serializer.is_valid()
+ assert serializer.validated_data == {'archived': False}
+
+
+class MockHTMLDict(dict):
+ """
+ This class mocks up a dictionary like object, that behaves
+ as if it was returned for multipart or urlencoded data.
+ """
+ getlist = None
+
+
+class TestHTMLInput:
+ def test_empty_html_charfield(self):
+ class TestSerializer(serializers.Serializer):
+ message = serializers.CharField(default='happy')
+
+ serializer = TestSerializer(data=MockHTMLDict())
+ assert serializer.is_valid()
+ assert serializer.validated_data == {'message': 'happy'}
+
+ def test_empty_html_charfield_allow_null(self):
+ class TestSerializer(serializers.Serializer):
+ message = serializers.CharField(allow_null=True)
+
+ serializer = TestSerializer(data=MockHTMLDict({'message': ''}))
+ assert serializer.is_valid()
+ assert serializer.validated_data == {'message': None}
+
+ def test_empty_html_datefield_allow_null(self):
+ class TestSerializer(serializers.Serializer):
+ expiry = serializers.DateField(allow_null=True)
+
+ serializer = TestSerializer(data=MockHTMLDict({'expiry': ''}))
+ assert serializer.is_valid()
+ assert serializer.validated_data == {'expiry': None}
+
+ def test_empty_html_charfield_allow_null_allow_blank(self):
+ class TestSerializer(serializers.Serializer):
+ message = serializers.CharField(allow_null=True, allow_blank=True)
+
+ serializer = TestSerializer(data=MockHTMLDict({'message': ''}))
+ assert serializer.is_valid()
+ assert serializer.validated_data == {'message': ''}
+
+ def test_empty_html_charfield_required_false(self):
+ class TestSerializer(serializers.Serializer):
+ message = serializers.CharField(required=False)
+
+ serializer = TestSerializer(data=MockHTMLDict())
+ assert serializer.is_valid()
+ assert serializer.validated_data == {}
+
+
+class TestCreateOnlyDefault:
+ def setup(self):
+ default = serializers.CreateOnlyDefault('2001-01-01')
+
+ class TestSerializer(serializers.Serializer):
+ published = serializers.HiddenField(default=default)
+ text = serializers.CharField()
+ self.Serializer = TestSerializer
+
+ def test_create_only_default_is_provided(self):
+ serializer = self.Serializer(data={'text': 'example'})
+ assert serializer.is_valid()
+ assert serializer.validated_data == {
+ 'text': 'example', 'published': '2001-01-01'
+ }
+
+ def test_create_only_default_is_not_provided_on_update(self):
+ instance = {
+ 'text': 'example', 'published': '2001-01-01'
+ }
+ serializer = self.Serializer(instance, data={'text': 'example'})
+ assert serializer.is_valid()
+ assert serializer.validated_data == {
+ 'text': 'example',
+ }
+
+ def test_create_only_default_callable_sets_context(self):
+ """
+ CreateOnlyDefault instances with a callable default should set_context
+ on the callable if possible
+ """
+ class TestCallableDefault:
+ def set_context(self, serializer_field):
+ self.field = serializer_field
+
+ def __call__(self):
+ return "success" if hasattr(self, 'field') else "failure"
+
+ class TestSerializer(serializers.Serializer):
+ context_set = serializers.CharField(default=serializers.CreateOnlyDefault(TestCallableDefault()))
+
+ serializer = TestSerializer(data={})
+ assert serializer.is_valid()
+ assert serializer.validated_data['context_set'] == 'success'
+
+
+# Tests for field input and output values.
+# ----------------------------------------
+
+def get_items(mapping_or_list_of_two_tuples):
+ # Tests accept either lists of two tuples, or dictionaries.
+ if isinstance(mapping_or_list_of_two_tuples, dict):
+ # {value: expected}
+ return mapping_or_list_of_two_tuples.items()
+ # [(value, expected), ...]
+ return mapping_or_list_of_two_tuples
+
+
+class FieldValues:
+ """
+ Base class for testing valid and invalid input values.
+ """
+ def test_valid_inputs(self):
+ """
+ Ensure that valid values return the expected validated data.
+ """
+ for input_value, expected_output in get_items(self.valid_inputs):
+ assert self.field.run_validation(input_value) == expected_output
+
+ def test_invalid_inputs(self):
+ """
+ Ensure that invalid values raise the expected validation error.
+ """
+ for input_value, expected_failure in get_items(self.invalid_inputs):
+ with pytest.raises(serializers.ValidationError) as exc_info:
+ self.field.run_validation(input_value)
+ assert exc_info.value.detail == expected_failure
+
+ def test_outputs(self):
+ for output_value, expected_output in get_items(self.outputs):
+ assert self.field.to_representation(output_value) == expected_output
+
+
+# Boolean types...
+
+class TestBooleanField(FieldValues):
+ """
+ Valid and invalid values for `BooleanField`.
+ """
+ valid_inputs = {
+ 'true': True,
+ 'false': False,
+ '1': True,
+ '0': False,
+ 1: True,
+ 0: False,
+ True: True,
+ False: False,
+ }
+ invalid_inputs = {
+ 'foo': ['"foo" is not a valid boolean.'],
+ None: ['This field may not be null.']
+ }
+ outputs = {
+ 'true': True,
+ 'false': False,
+ '1': True,
+ '0': False,
+ 1: True,
+ 0: False,
+ True: True,
+ False: False,
+ 'other': True
+ }
+ field = serializers.BooleanField()
+
+
+class TestNullBooleanField(FieldValues):
+ """
+ Valid and invalid values for `BooleanField`.
+ """
+ valid_inputs = {
+ 'true': True,
+ 'false': False,
+ 'null': None,
+ True: True,
+ False: False,
+ None: None
+ }
+ invalid_inputs = {
+ 'foo': ['"foo" is not a valid boolean.'],
+ }
+ outputs = {
+ 'true': True,
+ 'false': False,
+ 'null': None,
+ True: True,
+ False: False,
+ None: None,
+ 'other': True
+ }
+ field = serializers.NullBooleanField()
+
+
+# String types...
+
+class TestCharField(FieldValues):
+ """
+ Valid and invalid values for `CharField`.
+ """
+ valid_inputs = {
+ 1: '1',
+ 'abc': 'abc'
+ }
+ invalid_inputs = {
+ '': ['This field may not be blank.']
+ }
+ outputs = {
+ 1: '1',
+ 'abc': 'abc'
+ }
+ field = serializers.CharField()
+
+ def test_trim_whitespace_default(self):
+ field = serializers.CharField()
+ assert field.to_internal_value(' abc ') == 'abc'
+
+ def test_trim_whitespace_disabled(self):
+ field = serializers.CharField(trim_whitespace=False)
+ assert field.to_internal_value(' abc ') == ' abc '
+
+
+class TestEmailField(FieldValues):
+ """
+ Valid and invalid values for `EmailField`.
+ """
+ valid_inputs = {
+ 'example@example.com': 'example@example.com',
+ ' example@example.com ': 'example@example.com',
+ }
+ invalid_inputs = {
+ 'examplecom': ['Enter a valid email address.']
+ }
+ outputs = {}
+ field = serializers.EmailField()
+
+
+class TestRegexField(FieldValues):
+ """
+ Valid and invalid values for `RegexField`.
+ """
+ valid_inputs = {
+ 'a9': 'a9',
+ }
+ invalid_inputs = {
+ 'A9': ["This value does not match the required pattern."]
+ }
+ outputs = {}
+ field = serializers.RegexField(regex='[a-z][0-9]')
+
+
+class TestSlugField(FieldValues):
+ """
+ Valid and invalid values for `SlugField`.
+ """
+ valid_inputs = {
+ 'slug-99': 'slug-99',
+ }
+ invalid_inputs = {
+ 'slug 99': ['Enter a valid "slug" consisting of letters, numbers, underscores or hyphens.']
+ }
+ outputs = {}
+ field = serializers.SlugField()
+
+
+class TestURLField(FieldValues):
+ """
+ Valid and invalid values for `URLField`.
+ """
+ valid_inputs = {
+ 'http://example.com': 'http://example.com',
+ }
+ invalid_inputs = {
+ 'example.com': ['Enter a valid URL.']
+ }
+ outputs = {}
+ field = serializers.URLField()
+
+
+class TestUUIDField(FieldValues):
+ """
+ Valid and invalid values for `UUIDField`.
+ """
+ valid_inputs = {
+ '825d7aeb-05a9-45b5-a5b7-05df87923cda': uuid.UUID('825d7aeb-05a9-45b5-a5b7-05df87923cda'),
+ '825d7aeb05a945b5a5b705df87923cda': uuid.UUID('825d7aeb-05a9-45b5-a5b7-05df87923cda')
+ }
+ invalid_inputs = {
+ '825d7aeb-05a9-45b5-a5b7': ['"825d7aeb-05a9-45b5-a5b7" is not a valid UUID.']
+ }
+ outputs = {
+ uuid.UUID('825d7aeb-05a9-45b5-a5b7-05df87923cda'): '825d7aeb-05a9-45b5-a5b7-05df87923cda'
+ }
+ field = serializers.UUIDField()
+
+
+# Number types...
+
+class TestIntegerField(FieldValues):
+ """
+ Valid and invalid values for `IntegerField`.
+ """
+ valid_inputs = {
+ '1': 1,
+ '0': 0,
+ 1: 1,
+ 0: 0,
+ 1.0: 1,
+ 0.0: 0
+ }
+ invalid_inputs = {
+ 'abc': ['A valid integer is required.']
+ }
+ outputs = {
+ '1': 1,
+ '0': 0,
+ 1: 1,
+ 0: 0,
+ 1.0: 1,
+ 0.0: 0
+ }
+ field = serializers.IntegerField()
+
+
+class TestMinMaxIntegerField(FieldValues):
+ """
+ Valid and invalid values for `IntegerField` with min and max limits.
+ """
+ valid_inputs = {
+ '1': 1,
+ '3': 3,
+ 1: 1,
+ 3: 3,
+ }
+ invalid_inputs = {
+ 0: ['Ensure this value is greater than or equal to 1.'],
+ 4: ['Ensure this value is less than or equal to 3.'],
+ '0': ['Ensure this value is greater than or equal to 1.'],
+ '4': ['Ensure this value is less than or equal to 3.'],
+ }
+ outputs = {}
+ field = serializers.IntegerField(min_value=1, max_value=3)
+
+
+class TestFloatField(FieldValues):
+ """
+ Valid and invalid values for `FloatField`.
+ """
+ valid_inputs = {
+ '1': 1.0,
+ '0': 0.0,
+ 1: 1.0,
+ 0: 0.0,
+ 1.0: 1.0,
+ 0.0: 0.0,
+ }
+ invalid_inputs = {
+ 'abc': ["A valid number is required."]
+ }
+ outputs = {
+ '1': 1.0,
+ '0': 0.0,
+ 1: 1.0,
+ 0: 0.0,
+ 1.0: 1.0,
+ 0.0: 0.0,
+ }
+ field = serializers.FloatField()
+
+
+class TestMinMaxFloatField(FieldValues):
+ """
+ Valid and invalid values for `FloatField` with min and max limits.
+ """
+ valid_inputs = {
+ '1': 1,
+ '3': 3,
+ 1: 1,
+ 3: 3,
+ 1.0: 1.0,
+ 3.0: 3.0,
+ }
+ invalid_inputs = {
+ 0.9: ['Ensure this value is greater than or equal to 1.'],
+ 3.1: ['Ensure this value is less than or equal to 3.'],
+ '0.0': ['Ensure this value is greater than or equal to 1.'],
+ '3.1': ['Ensure this value is less than or equal to 3.'],
+ }
+ outputs = {}
+ field = serializers.FloatField(min_value=1, max_value=3)
+
+
+class TestDecimalField(FieldValues):
+ """
+ Valid and invalid values for `DecimalField`.
+ """
+ valid_inputs = {
+ '12.3': Decimal('12.3'),
+ '0.1': Decimal('0.1'),
+ 10: Decimal('10'),
+ 0: Decimal('0'),
+ 12.3: Decimal('12.3'),
+ 0.1: Decimal('0.1'),
+ }
+ invalid_inputs = (
+ ('abc', ["A valid number is required."]),
+ (Decimal('Nan'), ["A valid number is required."]),
+ (Decimal('Inf'), ["A valid number is required."]),
+ ('12.345', ["Ensure that there are no more than 3 digits in total."]),
+ ('0.01', ["Ensure that there are no more than 1 decimal places."]),
+ (123, ["Ensure that there are no more than 2 digits before the decimal point."])
+ )
+ outputs = {
+ '1': '1.0',
+ '0': '0.0',
+ '1.09': '1.1',
+ '0.04': '0.0',
+ 1: '1.0',
+ 0: '0.0',
+ Decimal('1.0'): '1.0',
+ Decimal('0.0'): '0.0',
+ Decimal('1.09'): '1.1',
+ Decimal('0.04'): '0.0'
+ }
+ field = serializers.DecimalField(max_digits=3, decimal_places=1)
+
+
+class TestMinMaxDecimalField(FieldValues):
+ """
+ Valid and invalid values for `DecimalField` with min and max limits.
+ """
+ valid_inputs = {
+ '10.0': Decimal('10.0'),
+ '20.0': Decimal('20.0'),
+ }
+ invalid_inputs = {
+ '9.9': ['Ensure this value is greater than or equal to 10.'],
+ '20.1': ['Ensure this value is less than or equal to 20.'],
+ }
+ outputs = {}
+ field = serializers.DecimalField(
+ max_digits=3, decimal_places=1,
+ min_value=10, max_value=20
+ )
+
+
+class TestNoStringCoercionDecimalField(FieldValues):
+ """
+ Output values for `DecimalField` with `coerce_to_string=False`.
+ """
+ valid_inputs = {}
+ invalid_inputs = {}
+ outputs = {
+ 1.09: Decimal('1.1'),
+ 0.04: Decimal('0.0'),
+ '1.09': Decimal('1.1'),
+ '0.04': Decimal('0.0'),
+ Decimal('1.09'): Decimal('1.1'),
+ Decimal('0.04'): Decimal('0.0'),
+ }
+ field = serializers.DecimalField(
+ max_digits=3, decimal_places=1,
+ coerce_to_string=False
+ )
+
+
+# Date & time serializers...
+
+class TestDateField(FieldValues):
+ """
+ Valid and invalid values for `DateField`.
+ """
+ valid_inputs = {
+ '2001-01-01': datetime.date(2001, 1, 1),
+ datetime.date(2001, 1, 1): datetime.date(2001, 1, 1),
+ }
+ invalid_inputs = {
+ 'abc': ['Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]].'],
+ '2001-99-99': ['Date has wrong format. Use one of these formats instead: YYYY[-MM[-DD]].'],
+ datetime.datetime(2001, 1, 1, 12, 00): ['Expected a date but got a datetime.'],
+ }
+ outputs = {
+ datetime.date(2001, 1, 1): '2001-01-01'
+ }
+ field = serializers.DateField()
+
+
+class TestCustomInputFormatDateField(FieldValues):
+ """
+ Valid and invalid values for `DateField` with a cutom input format.
+ """
+ valid_inputs = {
+ '1 Jan 2001': datetime.date(2001, 1, 1),
+ }
+ invalid_inputs = {
+ '2001-01-01': ['Date has wrong format. Use one of these formats instead: DD [Jan-Dec] YYYY.']
+ }
+ outputs = {}
+ field = serializers.DateField(input_formats=['%d %b %Y'])
+
+
+class TestCustomOutputFormatDateField(FieldValues):
+ """
+ Values for `DateField` with a custom output format.
+ """
+ valid_inputs = {}
+ invalid_inputs = {}
+ outputs = {
+ datetime.date(2001, 1, 1): '01 Jan 2001'
+ }
+ field = serializers.DateField(format='%d %b %Y')
+
+
+class TestNoOutputFormatDateField(FieldValues):
+ """
+ Values for `DateField` with no output format.
+ """
+ valid_inputs = {}
+ invalid_inputs = {}
+ outputs = {
+ datetime.date(2001, 1, 1): datetime.date(2001, 1, 1)
+ }
+ field = serializers.DateField(format=None)
+
+
+class TestDateTimeField(FieldValues):
+ """
+ Valid and invalid values for `DateTimeField`.
+ """
+ valid_inputs = {
+ '2001-01-01 13:00': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()),
+ '2001-01-01T13:00': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()),
+ '2001-01-01T13:00Z': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()),
+ datetime.datetime(2001, 1, 1, 13, 00): datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()),
+ datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()): datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()),
+ # Django 1.4 does not support timezone string parsing.
+ '2001-01-01T14:00+01:00' if (django.VERSION > (1, 4)) else '2001-01-01T13:00Z': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC())
+ }
+ invalid_inputs = {
+ 'abc': ['Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z].'],
+ '2001-99-99T99:00': ['Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z].'],
+ datetime.date(2001, 1, 1): ['Expected a datetime but got a date.'],
+ }
+ outputs = {
+ datetime.datetime(2001, 1, 1, 13, 00): '2001-01-01T13:00:00',
+ datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()): '2001-01-01T13:00:00Z'
+ }
+ field = serializers.DateTimeField(default_timezone=timezone.UTC())
+
+
+class TestCustomInputFormatDateTimeField(FieldValues):
+ """
+ Valid and invalid values for `DateTimeField` with a cutom input format.
+ """
+ valid_inputs = {
+ '1:35pm, 1 Jan 2001': datetime.datetime(2001, 1, 1, 13, 35, tzinfo=timezone.UTC()),
+ }
+ invalid_inputs = {
+ '2001-01-01T20:50': ['Datetime has wrong format. Use one of these formats instead: hh:mm[AM|PM], DD [Jan-Dec] YYYY.']
+ }
+ outputs = {}
+ field = serializers.DateTimeField(default_timezone=timezone.UTC(), input_formats=['%I:%M%p, %d %b %Y'])
+
+
+class TestCustomOutputFormatDateTimeField(FieldValues):
+ """
+ Values for `DateTimeField` with a custom output format.
+ """
+ valid_inputs = {}
+ invalid_inputs = {}
+ outputs = {
+ datetime.datetime(2001, 1, 1, 13, 00): '01:00PM, 01 Jan 2001',
+ }
+ field = serializers.DateTimeField(format='%I:%M%p, %d %b %Y')
+
+
+class TestNoOutputFormatDateTimeField(FieldValues):
+ """
+ Values for `DateTimeField` with no output format.
+ """
+ valid_inputs = {}
+ invalid_inputs = {}
+ outputs = {
+ datetime.datetime(2001, 1, 1, 13, 00): datetime.datetime(2001, 1, 1, 13, 00),
+ }
+ field = serializers.DateTimeField(format=None)
+
+
+class TestNaiveDateTimeField(FieldValues):
+ """
+ Valid and invalid values for `DateTimeField` with naive datetimes.
+ """
+ valid_inputs = {
+ datetime.datetime(2001, 1, 1, 13, 00, tzinfo=timezone.UTC()): datetime.datetime(2001, 1, 1, 13, 00),
+ '2001-01-01 13:00': datetime.datetime(2001, 1, 1, 13, 00),
+ }
+ invalid_inputs = {}
+ outputs = {}
+ field = serializers.DateTimeField(default_timezone=None)
+
+
+class TestTimeField(FieldValues):
+ """
+ Valid and invalid values for `TimeField`.
+ """
+ valid_inputs = {
+ '13:00': datetime.time(13, 00),
+ datetime.time(13, 00): datetime.time(13, 00),
+ }
+ invalid_inputs = {
+ 'abc': ['Time has wrong format. Use one of these formats instead: hh:mm[:ss[.uuuuuu]].'],
+ '99:99': ['Time has wrong format. Use one of these formats instead: hh:mm[:ss[.uuuuuu]].'],
+ }
+ outputs = {
+ datetime.time(13, 00): '13:00:00'
+ }
+ field = serializers.TimeField()
+
+
+class TestCustomInputFormatTimeField(FieldValues):
+ """
+ Valid and invalid values for `TimeField` with a custom input format.
+ """
+ valid_inputs = {
+ '1:00pm': datetime.time(13, 00),
+ }
+ invalid_inputs = {
+ '13:00': ['Time has wrong format. Use one of these formats instead: hh:mm[AM|PM].'],
+ }
+ outputs = {}
+ field = serializers.TimeField(input_formats=['%I:%M%p'])
+
+
+class TestCustomOutputFormatTimeField(FieldValues):
+ """
+ Values for `TimeField` with a custom output format.
+ """
+ valid_inputs = {}
+ invalid_inputs = {}
+ outputs = {
+ datetime.time(13, 00): '01:00PM'
+ }
+ field = serializers.TimeField(format='%I:%M%p')
+
+
+class TestNoOutputFormatTimeField(FieldValues):
+ """
+ Values for `TimeField` with a no output format.
+ """
+ valid_inputs = {}
+ invalid_inputs = {}
+ outputs = {
+ datetime.time(13, 00): datetime.time(13, 00)
+ }
+ field = serializers.TimeField(format=None)
+
+
+# Choice types...
+
+class TestChoiceField(FieldValues):
+ """
+ Valid and invalid values for `ChoiceField`.
+ """
+ valid_inputs = {
+ 'poor': 'poor',
+ 'medium': 'medium',
+ 'good': 'good',
+ }
+ invalid_inputs = {
+ 'amazing': ['"amazing" is not a valid choice.']
+ }
+ outputs = {
+ 'good': 'good',
+ '': ''
+ }
+ field = serializers.ChoiceField(
+ choices=[
+ ('poor', 'Poor quality'),
+ ('medium', 'Medium quality'),
+ ('good', 'Good quality'),
+ ]
+ )
+
+ def test_allow_blank(self):
+ """
+ If `allow_blank=True` then '' is a valid input.
+ """
+ field = serializers.ChoiceField(
+ allow_blank=True,
+ choices=[
+ ('poor', 'Poor quality'),
+ ('medium', 'Medium quality'),
+ ('good', 'Good quality'),
+ ]
+ )
+ output = field.run_validation('')
+ assert output == ''
+
+
+class TestChoiceFieldWithType(FieldValues):
+ """
+ Valid and invalid values for a `Choice` field that uses an integer type,
+ instead of a char type.
+ """
+ valid_inputs = {
+ '1': 1,
+ 3: 3,
+ }
+ invalid_inputs = {
+ 5: ['"5" is not a valid choice.'],
+ 'abc': ['"abc" is not a valid choice.']
+ }
+ outputs = {
+ '1': 1,
+ 1: 1
+ }
+ field = serializers.ChoiceField(
+ choices=[
+ (1, 'Poor quality'),
+ (2, 'Medium quality'),
+ (3, 'Good quality'),
+ ]
+ )
+
+
+class TestChoiceFieldWithListChoices(FieldValues):
+ """
+ Valid and invalid values for a `Choice` field that uses a flat list for the
+ choices, rather than a list of pairs of (`value`, `description`).
+ """
+ valid_inputs = {
+ 'poor': 'poor',
+ 'medium': 'medium',
+ 'good': 'good',
+ }
+ invalid_inputs = {
+ 'awful': ['"awful" is not a valid choice.']
+ }
+ outputs = {
+ 'good': 'good'
+ }
+ field = serializers.ChoiceField(choices=('poor', 'medium', 'good'))
+
+
+class TestMultipleChoiceField(FieldValues):
+ """
+ Valid and invalid values for `MultipleChoiceField`.
+ """
+ valid_inputs = {
+ (): set(),
+ ('aircon',): set(['aircon']),
+ ('aircon', 'manual'): set(['aircon', 'manual']),
+ }
+ invalid_inputs = {
+ 'abc': ['Expected a list of items but got type "str".'],
+ ('aircon', 'incorrect'): ['"incorrect" is not a valid choice.']
+ }
+ outputs = [
+ (['aircon', 'manual'], set(['aircon', 'manual']))
+ ]
+ field = serializers.MultipleChoiceField(
+ choices=[
+ ('aircon', 'AirCon'),
+ ('manual', 'Manual drive'),
+ ('diesel', 'Diesel'),
+ ]
+ )
+
+
+# File serializers...
+
+class MockFile:
+ def __init__(self, name='', size=0, url=''):
+ self.name = name
+ self.size = size
+ self.url = url
+
+ def __eq__(self, other):
+ return (
+ isinstance(other, MockFile) and
+ self.name == other.name and
+ self.size == other.size and
+ self.url == other.url
+ )
+
+
+class TestFileField(FieldValues):
+ """
+ Values for `FileField`.
+ """
+ valid_inputs = [
+ (MockFile(name='example', size=10), MockFile(name='example', size=10))
+ ]
+ invalid_inputs = [
+ ('invalid', ['The submitted data was not a file. Check the encoding type on the form.']),
+ (MockFile(name='example.txt', size=0), ['The submitted file is empty.']),
+ (MockFile(name='', size=10), ['No filename could be determined.']),
+ (MockFile(name='x' * 100, size=10), ['Ensure this filename has at most 10 characters (it has 100).'])
+ ]
+ outputs = [
+ (MockFile(name='example.txt', url='/example.txt'), '/example.txt'),
+ ('', None)
+ ]
+ field = serializers.FileField(max_length=10)
+
+
+class TestFieldFieldWithName(FieldValues):
+ """
+ Values for `FileField` with a filename output instead of URLs.
+ """
+ valid_inputs = {}
+ invalid_inputs = {}
+ outputs = [
+ (MockFile(name='example.txt', url='/example.txt'), 'example.txt')
+ ]
+ field = serializers.FileField(use_url=False)
+
+
+# Stub out mock Django `forms.ImageField` class so we don't *actually*
+# call into it's regular validation, or require PIL for testing.
+class FailImageValidation(object):
+ def to_python(self, value):
+ raise serializers.ValidationError(self.error_messages['invalid_image'])
+
+
+class PassImageValidation(object):
+ def to_python(self, value):
+ return value
+
+
+class TestInvalidImageField(FieldValues):
+ """
+ Values for an invalid `ImageField`.
+ """
+ valid_inputs = {}
+ invalid_inputs = [
+ (MockFile(name='example.txt', size=10), ['Upload a valid image. The file you uploaded was either not an image or a corrupted image.'])
+ ]
+ outputs = {}
+ field = serializers.ImageField(_DjangoImageField=FailImageValidation)
+
+
+class TestValidImageField(FieldValues):
+ """
+ Values for an valid `ImageField`.
+ """
+ valid_inputs = [
+ (MockFile(name='example.txt', size=10), MockFile(name='example.txt', size=10))
+ ]
+ invalid_inputs = {}
+ outputs = {}
+ field = serializers.ImageField(_DjangoImageField=PassImageValidation)
+
+
+# Composite serializers...
+
+class TestListField(FieldValues):
+ """
+ Values for `ListField` with IntegerField as child.
+ """
+ valid_inputs = [
+ ([1, 2, 3], [1, 2, 3]),
+ (['1', '2', '3'], [1, 2, 3])
+ ]
+ invalid_inputs = [
+ ('not a list', ['Expected a list of items but got type "str".']),
+ ([1, 2, 'error'], ['A valid integer is required.'])
+ ]
+ outputs = [
+ ([1, 2, 3], [1, 2, 3]),
+ (['1', '2', '3'], [1, 2, 3])
+ ]
+ field = serializers.ListField(child=serializers.IntegerField())
+
+
+class TestUnvalidatedListField(FieldValues):
+ """
+ Values for `ListField` with no `child` argument.
+ """
+ valid_inputs = [
+ ([1, '2', True, [4, 5, 6]], [1, '2', True, [4, 5, 6]]),
+ ]
+ invalid_inputs = [
+ ('not a list', ['Expected a list of items but got type "str".']),
+ ]
+ outputs = [
+ ([1, '2', True, [4, 5, 6]], [1, '2', True, [4, 5, 6]]),
+ ]
+ field = serializers.ListField()
+
+
+class TestDictField(FieldValues):
+ """
+ Values for `ListField` with CharField as child.
+ """
+ valid_inputs = [
+ ({'a': 1, 'b': '2', 3: 3}, {'a': '1', 'b': '2', '3': '3'}),
+ ]
+ invalid_inputs = [
+ ({'a': 1, 'b': None}, ['This field may not be null.']),
+ ('not a dict', ['Expected a dictionary of items but got type "str".']),
+ ]
+ outputs = [
+ ({'a': 1, 'b': '2', 3: 3}, {'a': '1', 'b': '2', '3': '3'}),
+ ]
+ field = serializers.DictField(child=serializers.CharField())
+
+
+class TestUnvalidatedDictField(FieldValues):
+ """
+ Values for `ListField` with no `child` argument.
+ """
+ valid_inputs = [
+ ({'a': 1, 'b': [4, 5, 6], 1: 123}, {'a': 1, 'b': [4, 5, 6], '1': 123}),
+ ]
+ invalid_inputs = [
+ ('not a dict', ['Expected a dictionary of items but got type "str".']),
+ ]
+ outputs = [
+ ({'a': 1, 'b': [4, 5, 6]}, {'a': 1, 'b': [4, 5, 6]}),
+ ]
+ field = serializers.DictField()
+
+
+# Tests for FieldField.
+# ---------------------
+
+class MockRequest:
+ def build_absolute_uri(self, value):
+ return 'http://example.com' + value
+
+
+class TestFileFieldContext:
+ def test_fully_qualified_when_request_in_context(self):
+ field = serializers.FileField(max_length=10)
+ field._context = {'request': MockRequest()}
+ obj = MockFile(name='example.txt', url='/example.txt')
+ value = field.to_representation(obj)
+ assert value == 'http://example.com/example.txt'
+
+
+# Tests for SerializerMethodField.
+# --------------------------------
+
+class TestSerializerMethodField:
+ def test_serializer_method_field(self):
+ class ExampleSerializer(serializers.Serializer):
+ example_field = serializers.SerializerMethodField()
+
+ def get_example_field(self, obj):
+ return 'ran get_example_field(%d)' % obj['example_field']
+
+ serializer = ExampleSerializer({'example_field': 123})
+ assert serializer.data == {
+ 'example_field': 'ran get_example_field(123)'
+ }
+
+ def test_redundant_method_name(self):
+ class ExampleSerializer(serializers.Serializer):
+ example_field = serializers.SerializerMethodField('get_example_field')
+
+ with pytest.raises(AssertionError) as exc_info:
+ ExampleSerializer().fields
+ assert str(exc_info.value) == (
+ "It is redundant to specify `get_example_field` on "
+ "SerializerMethodField 'example_field' in serializer "
+ "'ExampleSerializer', because it is the same as the default "
+ "method name. Remove the `method_name` argument."
+ )
diff --git a/tests/test_filters.py b/tests/test_filters.py
new file mode 100644
index 00000000..e7cb0c79
--- /dev/null
+++ b/tests/test_filters.py
@@ -0,0 +1,823 @@
+from __future__ import unicode_literals
+import datetime
+from decimal import Decimal
+from django.db import models
+from django.conf.urls import patterns, url
+from django.core.urlresolvers import reverse
+from django.test import TestCase
+from django.test.utils import override_settings
+from django.utils import unittest
+from django.utils.dateparse import parse_date
+from django.utils.six.moves import reload_module
+from rest_framework import generics, serializers, status, filters
+from rest_framework.compat import django_filters
+from rest_framework.test import APIRequestFactory
+from .models import BaseFilterableItem, FilterableItem, BasicModel
+
+
+factory = APIRequestFactory()
+
+
+if django_filters:
+ class FilterableItemSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = FilterableItem
+
+ # Basic filter on a list view.
+ class FilterFieldsRootView(generics.ListCreateAPIView):
+ queryset = FilterableItem.objects.all()
+ serializer_class = FilterableItemSerializer
+ filter_fields = ['decimal', 'date']
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ # These class are used to test a filter class.
+ class SeveralFieldsFilter(django_filters.FilterSet):
+ text = django_filters.CharFilter(lookup_type='icontains')
+ decimal = django_filters.NumberFilter(lookup_type='lt')
+ date = django_filters.DateFilter(lookup_type='gt')
+
+ class Meta:
+ model = FilterableItem
+ fields = ['text', 'decimal', 'date']
+
+ class FilterClassRootView(generics.ListCreateAPIView):
+ queryset = FilterableItem.objects.all()
+ serializer_class = FilterableItemSerializer
+ filter_class = SeveralFieldsFilter
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ # These classes are used to test a misconfigured filter class.
+ class MisconfiguredFilter(django_filters.FilterSet):
+ text = django_filters.CharFilter(lookup_type='icontains')
+
+ class Meta:
+ model = BasicModel
+ fields = ['text']
+
+ class IncorrectlyConfiguredRootView(generics.ListCreateAPIView):
+ queryset = FilterableItem.objects.all()
+ serializer_class = FilterableItemSerializer
+ filter_class = MisconfiguredFilter
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ class FilterClassDetailView(generics.RetrieveAPIView):
+ queryset = FilterableItem.objects.all()
+ serializer_class = FilterableItemSerializer
+ filter_class = SeveralFieldsFilter
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ # These classes are used to test base model filter support
+ class BaseFilterableItemFilter(django_filters.FilterSet):
+ text = django_filters.CharFilter()
+
+ class Meta:
+ model = BaseFilterableItem
+
+ class BaseFilterableItemFilterRootView(generics.ListCreateAPIView):
+ queryset = FilterableItem.objects.all()
+ serializer_class = FilterableItemSerializer
+ filter_class = BaseFilterableItemFilter
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ # Regression test for #814
+ class FilterFieldsQuerysetView(generics.ListCreateAPIView):
+ queryset = FilterableItem.objects.all()
+ serializer_class = FilterableItemSerializer
+ filter_fields = ['decimal', 'date']
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ class GetQuerysetView(generics.ListCreateAPIView):
+ serializer_class = FilterableItemSerializer
+ filter_class = SeveralFieldsFilter
+ filter_backends = (filters.DjangoFilterBackend,)
+
+ def get_queryset(self):
+ return FilterableItem.objects.all()
+
+ urlpatterns = patterns(
+ '',
+ url(r'^(?P<pk>\d+)/$', FilterClassDetailView.as_view(), name='detail-view'),
+ url(r'^$', FilterClassRootView.as_view(), name='root-view'),
+ url(r'^get-queryset/$', GetQuerysetView.as_view(),
+ name='get-queryset-view'),
+ )
+
+
+class CommonFilteringTestCase(TestCase):
+ def _serialize_object(self, obj):
+ return {'id': obj.id, 'text': obj.text, 'decimal': str(obj.decimal), 'date': obj.date.isoformat()}
+
+ def setUp(self):
+ """
+ Create 10 FilterableItem instances.
+ """
+ base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8))
+ for i in range(10):
+ text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc.
+ decimal = base_data[1] + i
+ date = base_data[2] - datetime.timedelta(days=i * 2)
+ FilterableItem(text=text, decimal=decimal, date=date).save()
+
+ self.objects = FilterableItem.objects
+ self.data = [
+ self._serialize_object(obj)
+ for obj in self.objects.all()
+ ]
+
+
+class IntegrationTestFiltering(CommonFilteringTestCase):
+ """
+ Integration tests for filtered list views.
+ """
+
+ @unittest.skipUnless(django_filters, 'django-filter not installed')
+ def test_get_filtered_fields_root_view(self):
+ """
+ GET requests to paginated ListCreateAPIView should return paginated results.
+ """
+ view = FilterFieldsRootView.as_view()
+
+ # Basic test with no filter.
+ request = factory.get('/')
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data)
+
+ # Tests that the decimal filter works.
+ search_decimal = Decimal('2.25')
+ request = factory.get('/', {'decimal': '%s' % search_decimal})
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if Decimal(f['decimal']) == search_decimal]
+ self.assertEqual(response.data, expected_data)
+
+ # Tests that the date filter works.
+ search_date = datetime.date(2012, 9, 22)
+ request = factory.get('/', {'date': '%s' % search_date}) # search_date str: '2012-09-22'
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if parse_date(f['date']) == search_date]
+ self.assertEqual(response.data, expected_data)
+
+ @unittest.skipUnless(django_filters, 'django-filter not installed')
+ def test_filter_with_queryset(self):
+ """
+ Regression test for #814.
+ """
+ view = FilterFieldsQuerysetView.as_view()
+
+ # Tests that the decimal filter works.
+ search_decimal = Decimal('2.25')
+ request = factory.get('/', {'decimal': '%s' % search_decimal})
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if Decimal(f['decimal']) == search_decimal]
+ self.assertEqual(response.data, expected_data)
+
+ @unittest.skipUnless(django_filters, 'django-filter not installed')
+ def test_filter_with_get_queryset_only(self):
+ """
+ Regression test for #834.
+ """
+ view = GetQuerysetView.as_view()
+ request = factory.get('/get-queryset/')
+ view(request).render()
+ # Used to raise "issubclass() arg 2 must be a class or tuple of classes"
+ # here when neither `model' nor `queryset' was specified.
+
+ @unittest.skipUnless(django_filters, 'django-filter not installed')
+ def test_get_filtered_class_root_view(self):
+ """
+ GET requests to filtered ListCreateAPIView that have a filter_class set
+ should return filtered results.
+ """
+ view = FilterClassRootView.as_view()
+
+ # Basic test with no filter.
+ request = factory.get('/')
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data)
+
+ # Tests that the decimal filter set with 'lt' in the filter class works.
+ search_decimal = Decimal('4.25')
+ request = factory.get('/', {'decimal': '%s' % search_decimal})
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if Decimal(f['decimal']) < search_decimal]
+ self.assertEqual(response.data, expected_data)
+
+ # Tests that the date filter set with 'gt' in the filter class works.
+ search_date = datetime.date(2012, 10, 2)
+ request = factory.get('/', {'date': '%s' % search_date}) # search_date str: '2012-10-02'
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if parse_date(f['date']) > search_date]
+ self.assertEqual(response.data, expected_data)
+
+ # Tests that the text filter set with 'icontains' in the filter class works.
+ search_text = 'ff'
+ request = factory.get('/', {'text': '%s' % search_text})
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if search_text in f['text'].lower()]
+ self.assertEqual(response.data, expected_data)
+
+ # Tests that multiple filters works.
+ search_decimal = Decimal('5.25')
+ search_date = datetime.date(2012, 10, 2)
+ request = factory.get('/', {
+ 'decimal': '%s' % (search_decimal,),
+ 'date': '%s' % (search_date,)
+ })
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ expected_data = [f for f in self.data if parse_date(f['date']) > search_date and
+ Decimal(f['decimal']) < search_decimal]
+ self.assertEqual(response.data, expected_data)
+
+ @unittest.skipUnless(django_filters, 'django-filter not installed')
+ def test_incorrectly_configured_filter(self):
+ """
+ An error should be displayed when the filter class is misconfigured.
+ """
+ view = IncorrectlyConfiguredRootView.as_view()
+
+ request = factory.get('/')
+ self.assertRaises(AssertionError, view, request)
+
+ @unittest.skipUnless(django_filters, 'django-filter not installed')
+ def test_base_model_filter(self):
+ """
+ The `get_filter_class` model checks should allow base model filters.
+ """
+ view = BaseFilterableItemFilterRootView.as_view()
+
+ request = factory.get('/?text=aaa')
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(len(response.data), 1)
+
+ @unittest.skipUnless(django_filters, 'django-filter not installed')
+ def test_unknown_filter(self):
+ """
+ GET requests with filters that aren't configured should return 200.
+ """
+ view = FilterFieldsRootView.as_view()
+
+ search_integer = 10
+ request = factory.get('/', {'integer': '%s' % search_integer})
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+
+class IntegrationTestDetailFiltering(CommonFilteringTestCase):
+ """
+ Integration tests for filtered detail views.
+ """
+ urls = 'tests.test_filters'
+
+ def _get_url(self, item):
+ return reverse('detail-view', kwargs=dict(pk=item.pk))
+
+ @unittest.skipUnless(django_filters, 'django-filter not installed')
+ def test_get_filtered_detail_view(self):
+ """
+ GET requests to filtered RetrieveAPIView that have a filter_class set
+ should return filtered results.
+ """
+ item = self.objects.all()[0]
+ data = self._serialize_object(item)
+
+ # Basic test with no filter.
+ response = self.client.get(self._get_url(item))
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, data)
+
+ # Tests that the decimal filter set that should fail.
+ search_decimal = Decimal('4.25')
+ high_item = self.objects.filter(decimal__gt=search_decimal)[0]
+ response = self.client.get(
+ '{url}'.format(url=self._get_url(high_item)),
+ {'decimal': '{param}'.format(param=search_decimal)})
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+
+ # Tests that the decimal filter set that should succeed.
+ search_decimal = Decimal('4.25')
+ low_item = self.objects.filter(decimal__lt=search_decimal)[0]
+ low_item_data = self._serialize_object(low_item)
+ response = self.client.get(
+ '{url}'.format(url=self._get_url(low_item)),
+ {'decimal': '{param}'.format(param=search_decimal)})
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, low_item_data)
+
+ # Tests that multiple filters works.
+ search_decimal = Decimal('5.25')
+ search_date = datetime.date(2012, 10, 2)
+ valid_item = self.objects.filter(decimal__lt=search_decimal, date__gt=search_date)[0]
+ valid_item_data = self._serialize_object(valid_item)
+ response = self.client.get(
+ '{url}'.format(url=self._get_url(valid_item)), {
+ 'decimal': '{decimal}'.format(decimal=search_decimal),
+ 'date': '{date}'.format(date=search_date)
+ })
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, valid_item_data)
+
+
+class SearchFilterModel(models.Model):
+ title = models.CharField(max_length=20)
+ text = models.CharField(max_length=100)
+
+
+class SearchFilterSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = SearchFilterModel
+
+
+class SearchFilterTests(TestCase):
+ def setUp(self):
+ # Sequence of title/text is:
+ #
+ # z abc
+ # zz bcd
+ # zzz cde
+ # ...
+ for idx in range(10):
+ title = 'z' * (idx + 1)
+ text = (
+ chr(idx + ord('a')) +
+ chr(idx + ord('b')) +
+ chr(idx + ord('c'))
+ )
+ SearchFilterModel(title=title, text=text).save()
+
+ def test_search(self):
+ class SearchListView(generics.ListAPIView):
+ queryset = SearchFilterModel.objects.all()
+ serializer_class = SearchFilterSerializer
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('title', 'text')
+
+ view = SearchListView.as_view()
+ request = factory.get('/', {'search': 'b'})
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 1, 'title': 'z', 'text': 'abc'},
+ {'id': 2, 'title': 'zz', 'text': 'bcd'}
+ ]
+ )
+
+ def test_exact_search(self):
+ class SearchListView(generics.ListAPIView):
+ queryset = SearchFilterModel.objects.all()
+ serializer_class = SearchFilterSerializer
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('=title', 'text')
+
+ view = SearchListView.as_view()
+ request = factory.get('/', {'search': 'zzz'})
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'zzz', 'text': 'cde'}
+ ]
+ )
+
+ def test_startswith_search(self):
+ class SearchListView(generics.ListAPIView):
+ queryset = SearchFilterModel.objects.all()
+ serializer_class = SearchFilterSerializer
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('title', '^text')
+
+ view = SearchListView.as_view()
+ request = factory.get('/', {'search': 'b'})
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 2, 'title': 'zz', 'text': 'bcd'}
+ ]
+ )
+
+ def test_search_with_nonstandard_search_param(self):
+ with override_settings(REST_FRAMEWORK={'SEARCH_PARAM': 'query'}):
+ reload_module(filters)
+
+ class SearchListView(generics.ListAPIView):
+ queryset = SearchFilterModel.objects.all()
+ serializer_class = SearchFilterSerializer
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('title', 'text')
+
+ view = SearchListView.as_view()
+ request = factory.get('/', {'query': 'b'})
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 1, 'title': 'z', 'text': 'abc'},
+ {'id': 2, 'title': 'zz', 'text': 'bcd'}
+ ]
+ )
+
+ reload_module(filters)
+
+
+class AttributeModel(models.Model):
+ label = models.CharField(max_length=32)
+
+
+class SearchFilterModelM2M(models.Model):
+ title = models.CharField(max_length=20)
+ text = models.CharField(max_length=100)
+ attributes = models.ManyToManyField(AttributeModel)
+
+
+class SearchFilterM2MSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = SearchFilterModelM2M
+
+
+class SearchFilterM2MTests(TestCase):
+ def setUp(self):
+ # Sequence of title/text/attributes is:
+ #
+ # z abc [1, 2, 3]
+ # zz bcd [1, 2, 3]
+ # zzz cde [1, 2, 3]
+ # ...
+ for idx in range(3):
+ label = 'w' * (idx + 1)
+ AttributeModel(label=label)
+
+ for idx in range(10):
+ title = 'z' * (idx + 1)
+ text = (
+ chr(idx + ord('a')) +
+ chr(idx + ord('b')) +
+ chr(idx + ord('c'))
+ )
+ SearchFilterModelM2M(title=title, text=text).save()
+ SearchFilterModelM2M.objects.get(title='zz').attributes.add(1, 2, 3)
+
+ def test_m2m_search(self):
+ class SearchListView(generics.ListAPIView):
+ queryset = SearchFilterModelM2M.objects.all()
+ serializer_class = SearchFilterM2MSerializer
+ filter_backends = (filters.SearchFilter,)
+ search_fields = ('=title', 'text', 'attributes__label')
+
+ view = SearchListView.as_view()
+ request = factory.get('/', {'search': 'zz'})
+ response = view(request)
+ self.assertEqual(len(response.data), 1)
+
+
+class OrderingFilterModel(models.Model):
+ title = models.CharField(max_length=20)
+ text = models.CharField(max_length=100)
+
+
+class OrderingFilterRelatedModel(models.Model):
+ related_object = models.ForeignKey(OrderingFilterModel,
+ related_name="relateds")
+
+
+class OrderingFilterSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = OrderingFilterModel
+
+
+class DjangoFilterOrderingModel(models.Model):
+ date = models.DateField()
+ text = models.CharField(max_length=10)
+
+ class Meta:
+ ordering = ['-date']
+
+
+class DjangoFilterOrderingSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = DjangoFilterOrderingModel
+
+
+class DjangoFilterOrderingTests(TestCase):
+ def setUp(self):
+ data = [{
+ 'date': datetime.date(2012, 10, 8),
+ 'text': 'abc'
+ }, {
+ 'date': datetime.date(2013, 10, 8),
+ 'text': 'bcd'
+ }, {
+ 'date': datetime.date(2014, 10, 8),
+ 'text': 'cde'
+ }]
+
+ for d in data:
+ DjangoFilterOrderingModel.objects.create(**d)
+
+ @unittest.skipUnless(django_filters, 'django-filter not installed')
+ def test_default_ordering(self):
+ class DjangoFilterOrderingView(generics.ListAPIView):
+ serializer_class = DjangoFilterOrderingSerializer
+ queryset = DjangoFilterOrderingModel.objects.all()
+ filter_backends = (filters.DjangoFilterBackend,)
+ filter_fields = ['text']
+ ordering = ('-date',)
+
+ view = DjangoFilterOrderingView.as_view()
+ request = factory.get('/')
+ response = view(request)
+
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'date': '2014-10-08', 'text': 'cde'},
+ {'id': 2, 'date': '2013-10-08', 'text': 'bcd'},
+ {'id': 1, 'date': '2012-10-08', 'text': 'abc'}
+ ]
+ )
+
+
+class OrderingFilterTests(TestCase):
+ def setUp(self):
+ # Sequence of title/text is:
+ #
+ # zyx abc
+ # yxw bcd
+ # xwv cde
+ for idx in range(3):
+ title = (
+ chr(ord('z') - idx) +
+ chr(ord('y') - idx) +
+ chr(ord('x') - idx)
+ )
+ text = (
+ chr(idx + ord('a')) +
+ chr(idx + ord('b')) +
+ chr(idx + ord('c'))
+ )
+ OrderingFilterModel(title=title, text=text).save()
+
+ def test_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ queryset = OrderingFilterModel.objects.all()
+ serializer_class = OrderingFilterSerializer
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+ ordering_fields = ('text',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('/', {'ordering': 'text'})
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ ]
+ )
+
+ def test_reverse_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ queryset = OrderingFilterModel.objects.all()
+ serializer_class = OrderingFilterSerializer
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+ ordering_fields = ('text',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('/', {'ordering': '-text'})
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )
+
+ def test_incorrectfield_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ queryset = OrderingFilterModel.objects.all()
+ serializer_class = OrderingFilterSerializer
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+ ordering_fields = ('text',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('/', {'ordering': 'foobar'})
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )
+
+ def test_default_ordering(self):
+ class OrderingListView(generics.ListAPIView):
+ queryset = OrderingFilterModel.objects.all()
+ serializer_class = OrderingFilterSerializer
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+ oredering_fields = ('text',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )
+
+ def test_default_ordering_using_string(self):
+ class OrderingListView(generics.ListAPIView):
+ queryset = OrderingFilterModel.objects.all()
+ serializer_class = OrderingFilterSerializer
+ filter_backends = (filters.OrderingFilter,)
+ ordering = 'title'
+ ordering_fields = ('text',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('')
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ ]
+ )
+
+ def test_ordering_by_aggregate_field(self):
+ # create some related models to aggregate order by
+ num_objs = [2, 5, 3]
+ for obj, num_relateds in zip(OrderingFilterModel.objects.all(),
+ num_objs):
+ for _ in range(num_relateds):
+ new_related = OrderingFilterRelatedModel(
+ related_object=obj
+ )
+ new_related.save()
+
+ class OrderingListView(generics.ListAPIView):
+ serializer_class = OrderingFilterSerializer
+ filter_backends = (filters.OrderingFilter,)
+ ordering = 'title'
+ ordering_fields = '__all__'
+ queryset = OrderingFilterModel.objects.all().annotate(
+ models.Count("relateds"))
+
+ view = OrderingListView.as_view()
+ request = factory.get('/', {'ordering': 'relateds__count'})
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ ]
+ )
+
+ def test_ordering_with_nonstandard_ordering_param(self):
+ with override_settings(REST_FRAMEWORK={'ORDERING_PARAM': 'order'}):
+ reload_module(filters)
+
+ class OrderingListView(generics.ListAPIView):
+ queryset = OrderingFilterModel.objects.all()
+ serializer_class = OrderingFilterSerializer
+ filter_backends = (filters.OrderingFilter,)
+ ordering = ('title',)
+ ordering_fields = ('text',)
+
+ view = OrderingListView.as_view()
+ request = factory.get('/', {'order': 'text'})
+ response = view(request)
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ ]
+ )
+
+ reload_module(filters)
+
+
+class SensitiveOrderingFilterModel(models.Model):
+ username = models.CharField(max_length=20)
+ password = models.CharField(max_length=100)
+
+
+# Three different styles of serializer.
+# All should allow ordering by username, but not by password.
+class SensitiveDataSerializer1(serializers.ModelSerializer):
+ username = serializers.CharField()
+
+ class Meta:
+ model = SensitiveOrderingFilterModel
+ fields = ('id', 'username')
+
+
+class SensitiveDataSerializer2(serializers.ModelSerializer):
+ username = serializers.CharField()
+ password = serializers.CharField(write_only=True)
+
+ class Meta:
+ model = SensitiveOrderingFilterModel
+ fields = ('id', 'username', 'password')
+
+
+class SensitiveDataSerializer3(serializers.ModelSerializer):
+ user = serializers.CharField(source='username')
+
+ class Meta:
+ model = SensitiveOrderingFilterModel
+ fields = ('id', 'user')
+
+
+class SensitiveOrderingFilterTests(TestCase):
+ def setUp(self):
+ for idx in range(3):
+ username = {0: 'userA', 1: 'userB', 2: 'userC'}[idx]
+ password = {0: 'passA', 1: 'passC', 2: 'passB'}[idx]
+ SensitiveOrderingFilterModel(username=username, password=password).save()
+
+ def test_order_by_serializer_fields(self):
+ for serializer_cls in [
+ SensitiveDataSerializer1,
+ SensitiveDataSerializer2,
+ SensitiveDataSerializer3
+ ]:
+ class OrderingListView(generics.ListAPIView):
+ queryset = SensitiveOrderingFilterModel.objects.all().order_by('username')
+ filter_backends = (filters.OrderingFilter,)
+ serializer_class = serializer_cls
+
+ view = OrderingListView.as_view()
+ request = factory.get('/', {'ordering': '-username'})
+ response = view(request)
+
+ if serializer_cls == SensitiveDataSerializer3:
+ username_field = 'user'
+ else:
+ username_field = 'username'
+
+ # Note: Inverse username ordering correctly applied.
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 3, username_field: 'userC'},
+ {'id': 2, username_field: 'userB'},
+ {'id': 1, username_field: 'userA'},
+ ]
+ )
+
+ def test_cannot_order_by_non_serializer_fields(self):
+ for serializer_cls in [
+ SensitiveDataSerializer1,
+ SensitiveDataSerializer2,
+ SensitiveDataSerializer3
+ ]:
+ class OrderingListView(generics.ListAPIView):
+ queryset = SensitiveOrderingFilterModel.objects.all().order_by('username')
+ filter_backends = (filters.OrderingFilter,)
+ serializer_class = serializer_cls
+
+ view = OrderingListView.as_view()
+ request = factory.get('/', {'ordering': 'password'})
+ response = view(request)
+
+ if serializer_cls == SensitiveDataSerializer3:
+ username_field = 'user'
+ else:
+ username_field = 'username'
+
+ # Note: The passwords are not in order. Default ordering is used.
+ self.assertEqual(
+ response.data,
+ [
+ {'id': 1, username_field: 'userA'}, # PassB
+ {'id': 2, username_field: 'userB'}, # PassC
+ {'id': 3, username_field: 'userC'}, # PassA
+ ]
+ )
diff --git a/tests/test_generics.py b/tests/test_generics.py
new file mode 100644
index 00000000..88e792ce
--- /dev/null
+++ b/tests/test_generics.py
@@ -0,0 +1,506 @@
+from __future__ import unicode_literals
+import django
+from django.db import models
+from django.shortcuts import get_object_or_404
+from django.test import TestCase
+from django.utils import six
+from rest_framework import generics, renderers, serializers, status
+from rest_framework.test import APIRequestFactory
+from tests.models import BasicModel, RESTFrameworkModel
+from tests.models import ForeignKeySource, ForeignKeyTarget
+
+factory = APIRequestFactory()
+
+
+# Models
+class SlugBasedModel(RESTFrameworkModel):
+ text = models.CharField(max_length=100)
+ slug = models.SlugField(max_length=32)
+
+
+# Model for regression test for #285
+class Comment(RESTFrameworkModel):
+ email = models.EmailField()
+ content = models.CharField(max_length=200)
+ created = models.DateTimeField(auto_now_add=True)
+
+
+# Serializers
+class BasicSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BasicModel
+
+
+class ForeignKeySerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ForeignKeySource
+
+
+class SlugSerializer(serializers.ModelSerializer):
+ slug = serializers.ReadOnlyField()
+
+ class Meta:
+ model = SlugBasedModel
+ fields = ('text', 'slug')
+
+
+# Views
+class RootView(generics.ListCreateAPIView):
+ queryset = BasicModel.objects.all()
+ serializer_class = BasicSerializer
+
+
+class InstanceView(generics.RetrieveUpdateDestroyAPIView):
+ queryset = BasicModel.objects.exclude(text='filtered out')
+ serializer_class = BasicSerializer
+
+
+class FKInstanceView(generics.RetrieveUpdateDestroyAPIView):
+ queryset = ForeignKeySource.objects.all()
+ serializer_class = ForeignKeySerializer
+
+
+class SlugBasedInstanceView(InstanceView):
+ """
+ A model with a slug-field.
+ """
+ queryset = SlugBasedModel.objects.all()
+ serializer_class = SlugSerializer
+ lookup_field = 'slug'
+
+
+# Tests
+class TestRootView(TestCase):
+ def setUp(self):
+ """
+ Create 3 BasicModel instances.
+ """
+ items = ['foo', 'bar', 'baz']
+ for item in items:
+ BasicModel(text=item).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+ self.view = RootView.as_view()
+
+ def test_get_root_view(self):
+ """
+ GET requests to ListCreateAPIView should return list of objects.
+ """
+ request = factory.get('/')
+ with self.assertNumQueries(1):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data)
+
+ def test_post_root_view(self):
+ """
+ POST requests to ListCreateAPIView should create a new object.
+ """
+ data = {'text': 'foobar'}
+ request = factory.post('/', data, format='json')
+ with self.assertNumQueries(1):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ self.assertEqual(response.data, {'id': 4, 'text': 'foobar'})
+ created = self.objects.get(id=4)
+ self.assertEqual(created.text, 'foobar')
+
+ def test_put_root_view(self):
+ """
+ PUT requests to ListCreateAPIView should not be allowed
+ """
+ data = {'text': 'foobar'}
+ request = factory.put('/', data, format='json')
+ with self.assertNumQueries(0):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
+ self.assertEqual(response.data, {"detail": 'Method "PUT" not allowed.'})
+
+ def test_delete_root_view(self):
+ """
+ DELETE requests to ListCreateAPIView should not be allowed
+ """
+ request = factory.delete('/')
+ with self.assertNumQueries(0):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
+ self.assertEqual(response.data, {"detail": 'Method "DELETE" not allowed.'})
+
+ def test_post_cannot_set_id(self):
+ """
+ POST requests to create a new object should not be able to set the id.
+ """
+ data = {'id': 999, 'text': 'foobar'}
+ request = factory.post('/', data, format='json')
+ with self.assertNumQueries(1):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ self.assertEqual(response.data, {'id': 4, 'text': 'foobar'})
+ created = self.objects.get(id=4)
+ self.assertEqual(created.text, 'foobar')
+
+
+EXPECTED_QUERIES_FOR_PUT = 3 if django.VERSION < (1, 6) else 2
+
+
+class TestInstanceView(TestCase):
+ def setUp(self):
+ """
+ Create 3 BasicModel instances.
+ """
+ items = ['foo', 'bar', 'baz', 'filtered out']
+ for item in items:
+ BasicModel(text=item).save()
+ self.objects = BasicModel.objects.exclude(text='filtered out')
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+ self.view = InstanceView.as_view()
+ self.slug_based_view = SlugBasedInstanceView.as_view()
+
+ def test_get_instance_view(self):
+ """
+ GET requests to RetrieveUpdateDestroyAPIView should return a single object.
+ """
+ request = factory.get('/1')
+ with self.assertNumQueries(1):
+ response = self.view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data[0])
+
+ def test_post_instance_view(self):
+ """
+ POST requests to RetrieveUpdateDestroyAPIView should not be allowed
+ """
+ data = {'text': 'foobar'}
+ request = factory.post('/', data, format='json')
+ with self.assertNumQueries(0):
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
+ self.assertEqual(response.data, {"detail": 'Method "POST" not allowed.'})
+
+ def test_put_instance_view(self):
+ """
+ PUT requests to RetrieveUpdateDestroyAPIView should update an object.
+ """
+ data = {'text': 'foobar'}
+ request = factory.put('/1', data, format='json')
+ with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT):
+ response = self.view(request, pk='1').render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(dict(response.data), {'id': 1, 'text': 'foobar'})
+ updated = self.objects.get(id=1)
+ self.assertEqual(updated.text, 'foobar')
+
+ def test_patch_instance_view(self):
+ """
+ PATCH requests to RetrieveUpdateDestroyAPIView should update an object.
+ """
+ data = {'text': 'foobar'}
+ request = factory.patch('/1', data, format='json')
+
+ with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT):
+ response = self.view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
+ updated = self.objects.get(id=1)
+ self.assertEqual(updated.text, 'foobar')
+
+ def test_delete_instance_view(self):
+ """
+ DELETE requests to RetrieveUpdateDestroyAPIView should delete an object.
+ """
+ request = factory.delete('/1')
+ with self.assertNumQueries(2):
+ response = self.view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
+ self.assertEqual(response.content, six.b(''))
+ ids = [obj.id for obj in self.objects.all()]
+ self.assertEqual(ids, [2, 3])
+
+ def test_get_instance_view_incorrect_arg(self):
+ """
+ GET requests with an incorrect pk type, should raise 404, not 500.
+ Regression test for #890.
+ """
+ request = factory.get('/a')
+ with self.assertNumQueries(0):
+ response = self.view(request, pk='a').render()
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+
+ def test_put_cannot_set_id(self):
+ """
+ PUT requests to create a new object should not be able to set the id.
+ """
+ data = {'id': 999, 'text': 'foobar'}
+ request = factory.put('/1', data, format='json')
+ with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT):
+ response = self.view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
+ updated = self.objects.get(id=1)
+ self.assertEqual(updated.text, 'foobar')
+
+ def test_put_to_deleted_instance(self):
+ """
+ PUT requests to RetrieveUpdateDestroyAPIView should return 404 if
+ an object does not currently exist.
+ """
+ self.objects.get(id=1).delete()
+ data = {'text': 'foobar'}
+ request = factory.put('/1', data, format='json')
+ with self.assertNumQueries(1):
+ response = self.view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+
+ def test_put_to_filtered_out_instance(self):
+ """
+ PUT requests to an URL of instance which is filtered out should not be
+ able to create new objects.
+ """
+ data = {'text': 'foo'}
+ filtered_out_pk = BasicModel.objects.filter(text='filtered out')[0].pk
+ request = factory.put('/{0}'.format(filtered_out_pk), data, format='json')
+ response = self.view(request, pk=filtered_out_pk).render()
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+
+ def test_patch_cannot_create_an_object(self):
+ """
+ PATCH requests should not be able to create objects.
+ """
+ data = {'text': 'foobar'}
+ request = factory.patch('/999', data, format='json')
+ with self.assertNumQueries(1):
+ response = self.view(request, pk=999).render()
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+ self.assertFalse(self.objects.filter(id=999).exists())
+
+
+class TestFKInstanceView(TestCase):
+ def setUp(self):
+ """
+ Create 3 BasicModel instances.
+ """
+ items = ['foo', 'bar', 'baz']
+ for item in items:
+ t = ForeignKeyTarget(name=item)
+ t.save()
+ ForeignKeySource(name='source_' + item, target=t).save()
+
+ self.objects = ForeignKeySource.objects
+ self.data = [
+ {'id': obj.id, 'name': obj.name}
+ for obj in self.objects.all()
+ ]
+ self.view = FKInstanceView.as_view()
+
+
+class TestOverriddenGetObject(TestCase):
+ """
+ Test cases for a RetrieveUpdateDestroyAPIView that does NOT use the
+ queryset/model mechanism but instead overrides get_object()
+ """
+
+ def setUp(self):
+ """
+ Create 3 BasicModel instances.
+ """
+ items = ['foo', 'bar', 'baz']
+ for item in items:
+ BasicModel(text=item).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+
+ class OverriddenGetObjectView(generics.RetrieveUpdateDestroyAPIView):
+ """
+ Example detail view for override of get_object().
+ """
+ serializer_class = BasicSerializer
+
+ def get_object(self):
+ pk = int(self.kwargs['pk'])
+ return get_object_or_404(BasicModel.objects.all(), id=pk)
+
+ self.view = OverriddenGetObjectView.as_view()
+
+ def test_overridden_get_object_view(self):
+ """
+ GET requests to RetrieveUpdateDestroyAPIView should return a single object.
+ """
+ request = factory.get('/1')
+ with self.assertNumQueries(1):
+ response = self.view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, self.data[0])
+
+
+# Regression test for #285
+
+class CommentSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = Comment
+ exclude = ('created',)
+
+
+class CommentView(generics.ListCreateAPIView):
+ serializer_class = CommentSerializer
+ model = Comment
+
+
+class TestCreateModelWithAutoNowAddField(TestCase):
+ def setUp(self):
+ self.objects = Comment.objects
+ self.view = CommentView.as_view()
+
+ def test_create_model_with_auto_now_add_field(self):
+ """
+ Regression test for #285
+
+ https://github.com/tomchristie/django-rest-framework/issues/285
+ """
+ data = {'email': 'foobar@example.com', 'content': 'foobar'}
+ request = factory.post('/', data, format='json')
+ response = self.view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+ created = self.objects.get(id=1)
+ self.assertEqual(created.content, 'foobar')
+
+
+# Test for particularly ugly regression with m2m in browsable API
+class ClassB(models.Model):
+ name = models.CharField(max_length=255)
+
+
+class ClassA(models.Model):
+ name = models.CharField(max_length=255)
+ children = models.ManyToManyField(ClassB, blank=True, null=True)
+
+
+class ClassASerializer(serializers.ModelSerializer):
+ children = serializers.PrimaryKeyRelatedField(
+ many=True, queryset=ClassB.objects.all()
+ )
+
+ class Meta:
+ model = ClassA
+
+
+class ExampleView(generics.ListCreateAPIView):
+ serializer_class = ClassASerializer
+ queryset = ClassA.objects.all()
+
+
+class TestM2MBrowsableAPI(TestCase):
+ def test_m2m_in_browsable_api(self):
+ """
+ Test for particularly ugly regression with m2m in browsable API
+ """
+ request = factory.get('/', HTTP_ACCEPT='text/html')
+ view = ExampleView().as_view()
+ response = view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+
+class InclusiveFilterBackend(object):
+ def filter_queryset(self, request, queryset, view):
+ return queryset.filter(text='foo')
+
+
+class ExclusiveFilterBackend(object):
+ def filter_queryset(self, request, queryset, view):
+ return queryset.filter(text='other')
+
+
+class TwoFieldModel(models.Model):
+ field_a = models.CharField(max_length=100)
+ field_b = models.CharField(max_length=100)
+
+
+class DynamicSerializerView(generics.ListCreateAPIView):
+ queryset = TwoFieldModel.objects.all()
+ renderer_classes = (renderers.BrowsableAPIRenderer, renderers.JSONRenderer)
+
+ def get_serializer_class(self):
+ if self.request.method == 'POST':
+ class DynamicSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = TwoFieldModel
+ fields = ('field_b',)
+ else:
+ class DynamicSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = TwoFieldModel
+ return DynamicSerializer
+
+
+class TestFilterBackendAppliedToViews(TestCase):
+ def setUp(self):
+ """
+ Create 3 BasicModel instances to filter on.
+ """
+ items = ['foo', 'bar', 'baz']
+ for item in items:
+ BasicModel(text=item).save()
+ self.objects = BasicModel.objects
+ self.data = [
+ {'id': obj.id, 'text': obj.text}
+ for obj in self.objects.all()
+ ]
+
+ def test_get_root_view_filters_by_name_with_filter_backend(self):
+ """
+ GET requests to ListCreateAPIView should return filtered list.
+ """
+ root_view = RootView.as_view(filter_backends=(InclusiveFilterBackend,))
+ request = factory.get('/')
+ response = root_view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(len(response.data), 1)
+ self.assertEqual(response.data, [{'id': 1, 'text': 'foo'}])
+
+ def test_get_root_view_filters_out_all_models_with_exclusive_filter_backend(self):
+ """
+ GET requests to ListCreateAPIView should return empty list when all models are filtered out.
+ """
+ root_view = RootView.as_view(filter_backends=(ExclusiveFilterBackend,))
+ request = factory.get('/')
+ response = root_view(request).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, [])
+
+ def test_get_instance_view_filters_out_name_with_filter_backend(self):
+ """
+ GET requests to RetrieveUpdateDestroyAPIView should raise 404 when model filtered out.
+ """
+ instance_view = InstanceView.as_view(filter_backends=(ExclusiveFilterBackend,))
+ request = factory.get('/1')
+ response = instance_view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+ self.assertEqual(response.data, {'detail': 'Not found.'})
+
+ def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(self):
+ """
+ GET requests to RetrieveUpdateDestroyAPIView should return a single object when not excluded
+ """
+ instance_view = InstanceView.as_view(filter_backends=(InclusiveFilterBackend,))
+ request = factory.get('/1')
+ response = instance_view(request, pk=1).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, {'id': 1, 'text': 'foo'})
+
+ def test_dynamic_serializer_form_in_browsable_api(self):
+ """
+ GET requests to ListCreateAPIView should return filtered list.
+ """
+ view = DynamicSerializerView.as_view()
+ request = factory.get('/')
+ response = view(request).render()
+ self.assertContains(response, 'field_b')
+ self.assertNotContains(response, 'field_a')
diff --git a/tests/test_htmlrenderer.py b/tests/test_htmlrenderer.py
new file mode 100644
index 00000000..a33b832f
--- /dev/null
+++ b/tests/test_htmlrenderer.py
@@ -0,0 +1,127 @@
+from __future__ import unicode_literals
+from django.core.exceptions import PermissionDenied
+from django.conf.urls import patterns, url
+from django.http import Http404
+from django.test import TestCase
+from django.template import TemplateDoesNotExist, Template
+from django.utils import six
+from rest_framework import status
+from rest_framework.decorators import api_view, renderer_classes
+from rest_framework.renderers import TemplateHTMLRenderer
+from rest_framework.response import Response
+import django.template.loader
+
+
+@api_view(('GET',))
+@renderer_classes((TemplateHTMLRenderer,))
+def example(request):
+ """
+ A view that can returns an HTML representation.
+ """
+ data = {'object': 'foobar'}
+ return Response(data, template_name='example.html')
+
+
+@api_view(('GET',))
+@renderer_classes((TemplateHTMLRenderer,))
+def permission_denied(request):
+ raise PermissionDenied()
+
+
+@api_view(('GET',))
+@renderer_classes((TemplateHTMLRenderer,))
+def not_found(request):
+ raise Http404()
+
+
+urlpatterns = patterns(
+ '',
+ url(r'^$', example),
+ url(r'^permission_denied$', permission_denied),
+ url(r'^not_found$', not_found),
+)
+
+
+class TemplateHTMLRendererTests(TestCase):
+ urls = 'tests.test_htmlrenderer'
+
+ def setUp(self):
+ """
+ Monkeypatch get_template
+ """
+ self.get_template = django.template.loader.get_template
+
+ def get_template(template_name, dirs=None):
+ if template_name == 'example.html':
+ return Template("example: {{ object }}")
+ raise TemplateDoesNotExist(template_name)
+
+ def select_template(template_name_list, dirs=None, using=None):
+ if template_name_list == ['example.html']:
+ return Template("example: {{ object }}")
+ raise TemplateDoesNotExist(template_name_list[0])
+
+ django.template.loader.get_template = get_template
+ django.template.loader.select_template = select_template
+
+ def tearDown(self):
+ """
+ Revert monkeypatching
+ """
+ django.template.loader.get_template = self.get_template
+
+ def test_simple_html_view(self):
+ response = self.client.get('/')
+ self.assertContains(response, "example: foobar")
+ self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
+
+ def test_not_found_html_view(self):
+ response = self.client.get('/not_found')
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+ self.assertEqual(response.content, six.b("404 Not Found"))
+ self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
+
+ def test_permission_denied_html_view(self):
+ response = self.client.get('/permission_denied')
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+ self.assertEqual(response.content, six.b("403 Forbidden"))
+ self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
+
+
+class TemplateHTMLRendererExceptionTests(TestCase):
+ urls = 'tests.test_htmlrenderer'
+
+ def setUp(self):
+ """
+ Monkeypatch get_template
+ """
+ self.get_template = django.template.loader.get_template
+
+ def get_template(template_name):
+ if template_name == '404.html':
+ return Template("404: {{ detail }}")
+ if template_name == '403.html':
+ return Template("403: {{ detail }}")
+ raise TemplateDoesNotExist(template_name)
+
+ django.template.loader.get_template = get_template
+
+ def tearDown(self):
+ """
+ Revert monkeypatching
+ """
+ django.template.loader.get_template = self.get_template
+
+ def test_not_found_html_view_with_template(self):
+ response = self.client.get('/not_found')
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+ self.assertTrue(response.content in (
+ six.b("404: Not found"), six.b("404 Not Found")))
+ self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
+
+ def test_permission_denied_html_view_with_template(self):
+ response = self.client.get('/permission_denied')
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+ self.assertTrue(response.content in (
+ six.b("403: Permission denied"), six.b("403 Forbidden")))
+ self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
diff --git a/tests/test_metadata.py b/tests/test_metadata.py
new file mode 100644
index 00000000..3a435f02
--- /dev/null
+++ b/tests/test_metadata.py
@@ -0,0 +1,209 @@
+from __future__ import unicode_literals
+from rest_framework import exceptions, serializers, status, views, versioning
+from rest_framework.request import Request
+from rest_framework.renderers import BrowsableAPIRenderer
+from rest_framework.test import APIRequestFactory
+
+request = Request(APIRequestFactory().options('/'))
+
+
+class TestMetadata:
+ def test_metadata(self):
+ """
+ OPTIONS requests to views should return a valid 200 response.
+ """
+ class ExampleView(views.APIView):
+ """Example view."""
+ pass
+
+ view = ExampleView.as_view()
+ response = view(request=request)
+ expected = {
+ 'name': 'Example',
+ 'description': 'Example view.',
+ 'renders': [
+ 'application/json',
+ 'text/html'
+ ],
+ 'parses': [
+ 'application/json',
+ 'application/x-www-form-urlencoded',
+ 'multipart/form-data'
+ ]
+ }
+ assert response.status_code == status.HTTP_200_OK
+ assert response.data == expected
+
+ def test_none_metadata(self):
+ """
+ OPTIONS requests to views where `metadata_class = None` should raise
+ a MethodNotAllowed exception, which will result in an HTTP 405 response.
+ """
+ class ExampleView(views.APIView):
+ metadata_class = None
+
+ view = ExampleView.as_view()
+ response = view(request=request)
+ assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
+ assert response.data == {'detail': 'Method "OPTIONS" not allowed.'}
+
+ def test_actions(self):
+ """
+ On generic views OPTIONS should return an 'actions' key with metadata
+ on the fields that may be supplied to PUT and POST requests.
+ """
+ class ExampleSerializer(serializers.Serializer):
+ choice_field = serializers.ChoiceField(['red', 'green', 'blue'])
+ integer_field = serializers.IntegerField(
+ min_value=1, max_value=1000
+ )
+ char_field = serializers.CharField(
+ required=False, min_length=3, max_length=40
+ )
+
+ class ExampleView(views.APIView):
+ """Example view."""
+ def post(self, request):
+ pass
+
+ def get_serializer(self):
+ return ExampleSerializer()
+
+ view = ExampleView.as_view()
+ response = view(request=request)
+ expected = {
+ 'name': 'Example',
+ 'description': 'Example view.',
+ 'renders': [
+ 'application/json',
+ 'text/html'
+ ],
+ 'parses': [
+ 'application/json',
+ 'application/x-www-form-urlencoded',
+ 'multipart/form-data'
+ ],
+ 'actions': {
+ 'POST': {
+ 'choice_field': {
+ 'type': 'choice',
+ 'required': True,
+ 'read_only': False,
+ 'label': 'Choice field',
+ 'choices': [
+ {'display_name': 'red', 'value': 'red'},
+ {'display_name': 'green', 'value': 'green'},
+ {'display_name': 'blue', 'value': 'blue'}
+ ]
+ },
+ 'integer_field': {
+ 'type': 'integer',
+ 'required': True,
+ 'read_only': False,
+ 'label': 'Integer field',
+ 'min_value': 1,
+ 'max_value': 1000,
+
+ },
+ 'char_field': {
+ 'type': 'string',
+ 'required': False,
+ 'read_only': False,
+ 'label': 'Char field',
+ 'min_length': 3,
+ 'max_length': 40
+ }
+ }
+ }
+ }
+ assert response.status_code == status.HTTP_200_OK
+ assert response.data == expected
+
+ def test_global_permissions(self):
+ """
+ If a user does not have global permissions on an action, then any
+ metadata associated with it should not be included in OPTION responses.
+ """
+ class ExampleSerializer(serializers.Serializer):
+ choice_field = serializers.ChoiceField(['red', 'green', 'blue'])
+ integer_field = serializers.IntegerField(max_value=10)
+ char_field = serializers.CharField(required=False)
+
+ class ExampleView(views.APIView):
+ """Example view."""
+ def post(self, request):
+ pass
+
+ def put(self, request):
+ pass
+
+ def get_serializer(self):
+ return ExampleSerializer()
+
+ def check_permissions(self, request):
+ if request.method == 'POST':
+ raise exceptions.PermissionDenied()
+
+ view = ExampleView.as_view()
+ response = view(request=request)
+ assert response.status_code == status.HTTP_200_OK
+ assert list(response.data['actions'].keys()) == ['PUT']
+
+ def test_object_permissions(self):
+ """
+ If a user does not have object permissions on an action, then any
+ metadata associated with it should not be included in OPTION responses.
+ """
+ class ExampleSerializer(serializers.Serializer):
+ choice_field = serializers.ChoiceField(['red', 'green', 'blue'])
+ integer_field = serializers.IntegerField(max_value=10)
+ char_field = serializers.CharField(required=False)
+
+ class ExampleView(views.APIView):
+ """Example view."""
+ def post(self, request):
+ pass
+
+ def put(self, request):
+ pass
+
+ def get_serializer(self):
+ return ExampleSerializer()
+
+ def get_object(self):
+ if self.request.method == 'PUT':
+ raise exceptions.PermissionDenied()
+
+ view = ExampleView.as_view()
+ response = view(request=request)
+ assert response.status_code == status.HTTP_200_OK
+ assert list(response.data['actions'].keys()) == ['POST']
+
+ def test_bug_2455_clone_request(self):
+ class ExampleView(views.APIView):
+ renderer_classes = (BrowsableAPIRenderer,)
+
+ def post(self, request):
+ pass
+
+ def get_serializer(self):
+ assert hasattr(self.request, 'version')
+ return serializers.Serializer()
+
+ view = ExampleView.as_view()
+ view(request=request)
+
+ def test_bug_2477_clone_request(self):
+ class ExampleView(views.APIView):
+ renderer_classes = (BrowsableAPIRenderer,)
+
+ def post(self, request):
+ pass
+
+ def get_serializer(self):
+ assert hasattr(self.request, 'versioning_scheme')
+ return serializers.Serializer()
+
+ scheme = versioning.QueryParameterVersioning
+ view = ExampleView.as_view(versioning_class=scheme)
+ view(request=request)
diff --git a/tests/test_middleware.py b/tests/test_middleware.py
new file mode 100644
index 00000000..4c099fca
--- /dev/null
+++ b/tests/test_middleware.py
@@ -0,0 +1,37 @@
+
+from django.conf.urls import patterns, url
+from django.contrib.auth.models import User
+from rest_framework.authentication import TokenAuthentication
+from rest_framework.authtoken.models import Token
+from rest_framework.test import APITestCase
+from rest_framework.views import APIView
+
+
+urlpatterns = patterns(
+ '',
+ url(r'^$', APIView.as_view(authentication_classes=(TokenAuthentication,))),
+)
+
+
+class MyMiddleware(object):
+
+ def process_response(self, request, response):
+ assert hasattr(request, 'user'), '`user` is not set on request'
+ assert request.user.is_authenticated(), '`user` is not authenticated'
+ return response
+
+
+class TestMiddleware(APITestCase):
+
+ urls = 'tests.test_middleware'
+
+ def test_middleware_can_access_user_when_processing_response(self):
+ user = User.objects.create_user('john', 'john@example.com', 'password')
+ key = 'abcd1234'
+ Token.objects.create(key=key, user=user)
+
+ with self.settings(
+ MIDDLEWARE_CLASSES=('tests.test_middleware.MyMiddleware',)
+ ):
+ auth = 'Token ' + key
+ self.client.get('/', HTTP_AUTHORIZATION=auth)
diff --git a/tests/test_model_serializer.py b/tests/test_model_serializer.py
new file mode 100644
index 00000000..bce2008a
--- /dev/null
+++ b/tests/test_model_serializer.py
@@ -0,0 +1,641 @@
+"""
+The `ModelSerializer` and `HyperlinkedModelSerializer` classes are essentially
+shortcuts for automatically creating serializers based on a given model class.
+
+These tests deal with ensuring that we correctly map the model fields onto
+an appropriate set of serializer fields for each case.
+"""
+from __future__ import unicode_literals
+from django.core.exceptions import ImproperlyConfigured
+from django.core.validators import MaxValueValidator, MinValueValidator, MinLengthValidator
+from django.db import models
+from django.test import TestCase
+from django.utils import six
+from rest_framework import serializers
+from rest_framework.compat import unicode_repr
+
+
+def dedent(blocktext):
+ return '\n'.join([line[12:] for line in blocktext.splitlines()[1:-1]])
+
+
+# Tests for regular field mappings.
+# ---------------------------------
+
+class CustomField(models.Field):
+ """
+ A custom model field simply for testing purposes.
+ """
+ pass
+
+
+class OneFieldModel(models.Model):
+ char_field = models.CharField(max_length=100)
+
+
+class RegularFieldsModel(models.Model):
+ """
+ A model class for testing regular flat fields.
+ """
+ auto_field = models.AutoField(primary_key=True)
+ big_integer_field = models.BigIntegerField()
+ boolean_field = models.BooleanField(default=False)
+ char_field = models.CharField(max_length=100)
+ comma_separated_integer_field = models.CommaSeparatedIntegerField(max_length=100)
+ date_field = models.DateField()
+ datetime_field = models.DateTimeField()
+ decimal_field = models.DecimalField(max_digits=3, decimal_places=1)
+ email_field = models.EmailField(max_length=100)
+ float_field = models.FloatField()
+ integer_field = models.IntegerField()
+ null_boolean_field = models.NullBooleanField()
+ positive_integer_field = models.PositiveIntegerField()
+ positive_small_integer_field = models.PositiveSmallIntegerField()
+ slug_field = models.SlugField(max_length=100)
+ small_integer_field = models.SmallIntegerField()
+ text_field = models.TextField()
+ time_field = models.TimeField()
+ url_field = models.URLField(max_length=100)
+ custom_field = CustomField()
+
+ def method(self):
+ return 'method'
+
+
+COLOR_CHOICES = (('red', 'Red'), ('blue', 'Blue'), ('green', 'Green'))
+
+
+class FieldOptionsModel(models.Model):
+ value_limit_field = models.IntegerField(validators=[MinValueValidator(1), MaxValueValidator(10)])
+ length_limit_field = models.CharField(validators=[MinLengthValidator(3)], max_length=12)
+ blank_field = models.CharField(blank=True, max_length=10)
+ null_field = models.IntegerField(null=True)
+ default_field = models.IntegerField(default=0)
+ descriptive_field = models.IntegerField(help_text='Some help text', verbose_name='A label')
+ choices_field = models.CharField(max_length=100, choices=COLOR_CHOICES)
+
+
+class TestModelSerializer(TestCase):
+ def test_create_method(self):
+ class TestSerializer(serializers.ModelSerializer):
+ non_model_field = serializers.CharField()
+
+ class Meta:
+ model = OneFieldModel
+ fields = ('char_field', 'non_model_field')
+
+ serializer = TestSerializer(data={
+ 'char_field': 'foo',
+ 'non_model_field': 'bar',
+ })
+ serializer.is_valid()
+ with self.assertRaises(TypeError) as excinfo:
+ serializer.save()
+ msginitial = 'Got a `TypeError` when calling `OneFieldModel.objects.create()`.'
+ assert str(excinfo.exception).startswith(msginitial)
+
+
+class TestRegularFieldMappings(TestCase):
+ def test_regular_fields(self):
+ """
+ Model fields should map to their equivelent serializer fields.
+ """
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = RegularFieldsModel
+
+ expected = dedent("""
+ TestSerializer():
+ auto_field = IntegerField(read_only=True)
+ big_integer_field = IntegerField()
+ boolean_field = BooleanField(required=False)
+ char_field = CharField(max_length=100)
+ comma_separated_integer_field = CharField(max_length=100, validators=[<django.core.validators.RegexValidator object>])
+ date_field = DateField()
+ datetime_field = DateTimeField()
+ decimal_field = DecimalField(decimal_places=1, max_digits=3)
+ email_field = EmailField(max_length=100)
+ float_field = FloatField()
+ integer_field = IntegerField()
+ null_boolean_field = NullBooleanField(required=False)
+ positive_integer_field = IntegerField()
+ positive_small_integer_field = IntegerField()
+ slug_field = SlugField(max_length=100)
+ small_integer_field = IntegerField()
+ text_field = CharField(style={'base_template': 'textarea.html'})
+ time_field = TimeField()
+ url_field = URLField(max_length=100)
+ custom_field = ModelField(model_field=<tests.test_model_serializer.CustomField: custom_field>)
+ """)
+ self.assertEqual(unicode_repr(TestSerializer()), expected)
+
+ def test_field_options(self):
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = FieldOptionsModel
+
+ expected = dedent("""
+ TestSerializer():
+ id = IntegerField(label='ID', read_only=True)
+ value_limit_field = IntegerField(max_value=10, min_value=1)
+ length_limit_field = CharField(max_length=12, min_length=3)
+ blank_field = CharField(allow_blank=True, max_length=10, required=False)
+ null_field = IntegerField(allow_null=True, required=False)
+ default_field = IntegerField(required=False)
+ descriptive_field = IntegerField(help_text='Some help text', label='A label')
+ choices_field = ChoiceField(choices=[('red', 'Red'), ('blue', 'Blue'), ('green', 'Green')])
+ """)
+ if six.PY2:
+ # This particular case is too awkward to resolve fully across
+ # both py2 and py3.
+ expected = expected.replace(
+ "('red', 'Red'), ('blue', 'Blue'), ('green', 'Green')",
+ "(u'red', u'Red'), (u'blue', u'Blue'), (u'green', u'Green')"
+ )
+ self.assertEqual(unicode_repr(TestSerializer()), expected)
+
+ def test_method_field(self):
+ """
+ Properties and methods on the model should be allowed as `Meta.fields`
+ values, and should map to `ReadOnlyField`.
+ """
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = RegularFieldsModel
+ fields = ('auto_field', 'method')
+
+ expected = dedent("""
+ TestSerializer():
+ auto_field = IntegerField(read_only=True)
+ method = ReadOnlyField()
+ """)
+ self.assertEqual(repr(TestSerializer()), expected)
+
+ def test_pk_fields(self):
+ """
+ Both `pk` and the actual primary key name are valid in `Meta.fields`.
+ """
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = RegularFieldsModel
+ fields = ('pk', 'auto_field')
+
+ expected = dedent("""
+ TestSerializer():
+ pk = IntegerField(label='Auto field', read_only=True)
+ auto_field = IntegerField(read_only=True)
+ """)
+ self.assertEqual(repr(TestSerializer()), expected)
+
+ def test_extra_field_kwargs(self):
+ """
+ Ensure `extra_kwargs` are passed to generated fields.
+ """
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = RegularFieldsModel
+ fields = ('auto_field', 'char_field')
+ extra_kwargs = {'char_field': {'default': 'extra'}}
+
+ expected = dedent("""
+ TestSerializer():
+ auto_field = IntegerField(read_only=True)
+ char_field = CharField(default='extra', max_length=100)
+ """)
+ self.assertEqual(repr(TestSerializer()), expected)
+
+ def test_invalid_field(self):
+ """
+ Field names that do not map to a model field or relationship should
+ raise a configuration errror.
+ """
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = RegularFieldsModel
+ fields = ('auto_field', 'invalid')
+
+ with self.assertRaises(ImproperlyConfigured) as excinfo:
+ TestSerializer().fields
+ expected = 'Field name `invalid` is not valid for model `RegularFieldsModel`.'
+ assert str(excinfo.exception) == expected
+
+ def test_missing_field(self):
+ """
+ Fields that have been declared on the serializer class must be included
+ in the `Meta.fields` if it exists.
+ """
+ class TestSerializer(serializers.ModelSerializer):
+ missing = serializers.ReadOnlyField()
+
+ class Meta:
+ model = RegularFieldsModel
+ fields = ('auto_field',)
+
+ with self.assertRaises(AssertionError) as excinfo:
+ TestSerializer().fields
+ expected = (
+ "The field 'missing' was declared on serializer TestSerializer, "
+ "but has not been included in the 'fields' option."
+ )
+ assert str(excinfo.exception) == expected
+
+ def test_missing_superclass_field(self):
+ """
+ Fields that have been declared on a parent of the serializer class may
+ be excluded from the `Meta.fields` option.
+ """
+ class TestSerializer(serializers.ModelSerializer):
+ missing = serializers.ReadOnlyField()
+
+ class Meta:
+ model = RegularFieldsModel
+
+ class ChildSerializer(TestSerializer):
+ missing = serializers.ReadOnlyField()
+
+ class Meta:
+ model = RegularFieldsModel
+ fields = ('auto_field',)
+
+ ChildSerializer().fields
+
+
+# Tests for relational field mappings.
+# ------------------------------------
+
+class ForeignKeyTargetModel(models.Model):
+ name = models.CharField(max_length=100)
+
+
+class ManyToManyTargetModel(models.Model):
+ name = models.CharField(max_length=100)
+
+
+class OneToOneTargetModel(models.Model):
+ name = models.CharField(max_length=100)
+
+
+class ThroughTargetModel(models.Model):
+ name = models.CharField(max_length=100)
+
+
+class Supplementary(models.Model):
+ extra = models.IntegerField()
+ forwards = models.ForeignKey('ThroughTargetModel')
+ backwards = models.ForeignKey('RelationalModel')
+
+
+class RelationalModel(models.Model):
+ foreign_key = models.ForeignKey(ForeignKeyTargetModel, related_name='reverse_foreign_key')
+ many_to_many = models.ManyToManyField(ManyToManyTargetModel, related_name='reverse_many_to_many')
+ one_to_one = models.OneToOneField(OneToOneTargetModel, related_name='reverse_one_to_one')
+ through = models.ManyToManyField(ThroughTargetModel, through=Supplementary, related_name='reverse_through')
+
+
+class TestRelationalFieldMappings(TestCase):
+ def test_pk_relations(self):
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = RelationalModel
+
+ expected = dedent("""
+ TestSerializer():
+ id = IntegerField(label='ID', read_only=True)
+ foreign_key = PrimaryKeyRelatedField(queryset=ForeignKeyTargetModel.objects.all())
+ one_to_one = PrimaryKeyRelatedField(queryset=OneToOneTargetModel.objects.all(), validators=[<UniqueValidator(queryset=RelationalModel.objects.all())>])
+ many_to_many = PrimaryKeyRelatedField(many=True, queryset=ManyToManyTargetModel.objects.all())
+ through = PrimaryKeyRelatedField(many=True, read_only=True)
+ """)
+ self.assertEqual(unicode_repr(TestSerializer()), expected)
+
+ def test_nested_relations(self):
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = RelationalModel
+ depth = 1
+
+ expected = dedent("""
+ TestSerializer():
+ id = IntegerField(label='ID', read_only=True)
+ foreign_key = NestedSerializer(read_only=True):
+ id = IntegerField(label='ID', read_only=True)
+ name = CharField(max_length=100)
+ one_to_one = NestedSerializer(read_only=True):
+ id = IntegerField(label='ID', read_only=True)
+ name = CharField(max_length=100)
+ many_to_many = NestedSerializer(many=True, read_only=True):
+ id = IntegerField(label='ID', read_only=True)
+ name = CharField(max_length=100)
+ through = NestedSerializer(many=True, read_only=True):
+ id = IntegerField(label='ID', read_only=True)
+ name = CharField(max_length=100)
+ """)
+ self.assertEqual(unicode_repr(TestSerializer()), expected)
+
+ def test_hyperlinked_relations(self):
+ class TestSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = RelationalModel
+
+ expected = dedent("""
+ TestSerializer():
+ url = HyperlinkedIdentityField(view_name='relationalmodel-detail')
+ foreign_key = HyperlinkedRelatedField(queryset=ForeignKeyTargetModel.objects.all(), view_name='foreignkeytargetmodel-detail')
+ one_to_one = HyperlinkedRelatedField(queryset=OneToOneTargetModel.objects.all(), validators=[<UniqueValidator(queryset=RelationalModel.objects.all())>], view_name='onetoonetargetmodel-detail')
+ many_to_many = HyperlinkedRelatedField(many=True, queryset=ManyToManyTargetModel.objects.all(), view_name='manytomanytargetmodel-detail')
+ through = HyperlinkedRelatedField(many=True, read_only=True, view_name='throughtargetmodel-detail')
+ """)
+ self.assertEqual(unicode_repr(TestSerializer()), expected)
+
+ def test_nested_hyperlinked_relations(self):
+ class TestSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = RelationalModel
+ depth = 1
+
+ expected = dedent("""
+ TestSerializer():
+ url = HyperlinkedIdentityField(view_name='relationalmodel-detail')
+ foreign_key = NestedSerializer(read_only=True):
+ url = HyperlinkedIdentityField(view_name='foreignkeytargetmodel-detail')
+ name = CharField(max_length=100)
+ one_to_one = NestedSerializer(read_only=True):
+ url = HyperlinkedIdentityField(view_name='onetoonetargetmodel-detail')
+ name = CharField(max_length=100)
+ many_to_many = NestedSerializer(many=True, read_only=True):
+ url = HyperlinkedIdentityField(view_name='manytomanytargetmodel-detail')
+ name = CharField(max_length=100)
+ through = NestedSerializer(many=True, read_only=True):
+ url = HyperlinkedIdentityField(view_name='throughtargetmodel-detail')
+ name = CharField(max_length=100)
+ """)
+ self.assertEqual(unicode_repr(TestSerializer()), expected)
+
+ def test_pk_reverse_foreign_key(self):
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ForeignKeyTargetModel
+ fields = ('id', 'name', 'reverse_foreign_key')
+
+ expected = dedent("""
+ TestSerializer():
+ id = IntegerField(label='ID', read_only=True)
+ name = CharField(max_length=100)
+ reverse_foreign_key = PrimaryKeyRelatedField(many=True, queryset=RelationalModel.objects.all())
+ """)
+ self.assertEqual(unicode_repr(TestSerializer()), expected)
+
+ def test_pk_reverse_one_to_one(self):
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = OneToOneTargetModel
+ fields = ('id', 'name', 'reverse_one_to_one')
+
+ expected = dedent("""
+ TestSerializer():
+ id = IntegerField(label='ID', read_only=True)
+ name = CharField(max_length=100)
+ reverse_one_to_one = PrimaryKeyRelatedField(queryset=RelationalModel.objects.all())
+ """)
+ self.assertEqual(unicode_repr(TestSerializer()), expected)
+
+ def test_pk_reverse_many_to_many(self):
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ManyToManyTargetModel
+ fields = ('id', 'name', 'reverse_many_to_many')
+
+ expected = dedent("""
+ TestSerializer():
+ id = IntegerField(label='ID', read_only=True)
+ name = CharField(max_length=100)
+ reverse_many_to_many = PrimaryKeyRelatedField(many=True, queryset=RelationalModel.objects.all())
+ """)
+ self.assertEqual(unicode_repr(TestSerializer()), expected)
+
+ def test_pk_reverse_through(self):
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ThroughTargetModel
+ fields = ('id', 'name', 'reverse_through')
+
+ expected = dedent("""
+ TestSerializer():
+ id = IntegerField(label='ID', read_only=True)
+ name = CharField(max_length=100)
+ reverse_through = PrimaryKeyRelatedField(many=True, read_only=True)
+ """)
+ self.assertEqual(unicode_repr(TestSerializer()), expected)
+
+
+class TestIntegration(TestCase):
+ def setUp(self):
+ self.foreign_key_target = ForeignKeyTargetModel.objects.create(
+ name='foreign_key'
+ )
+ self.one_to_one_target = OneToOneTargetModel.objects.create(
+ name='one_to_one'
+ )
+ self.many_to_many_targets = [
+ ManyToManyTargetModel.objects.create(
+ name='many_to_many (%d)' % idx
+ ) for idx in range(3)
+ ]
+ self.instance = RelationalModel.objects.create(
+ foreign_key=self.foreign_key_target,
+ one_to_one=self.one_to_one_target,
+ )
+ self.instance.many_to_many = self.many_to_many_targets
+ self.instance.save()
+
+ def test_pk_retrival(self):
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = RelationalModel
+
+ serializer = TestSerializer(self.instance)
+ expected = {
+ 'id': self.instance.pk,
+ 'foreign_key': self.foreign_key_target.pk,
+ 'one_to_one': self.one_to_one_target.pk,
+ 'many_to_many': [item.pk for item in self.many_to_many_targets],
+ 'through': []
+ }
+ self.assertEqual(serializer.data, expected)
+
+ def test_pk_create(self):
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = RelationalModel
+
+ new_foreign_key = ForeignKeyTargetModel.objects.create(
+ name='foreign_key'
+ )
+ new_one_to_one = OneToOneTargetModel.objects.create(
+ name='one_to_one'
+ )
+ new_many_to_many = [
+ ManyToManyTargetModel.objects.create(
+ name='new many_to_many (%d)' % idx
+ ) for idx in range(3)
+ ]
+ data = {
+ 'foreign_key': new_foreign_key.pk,
+ 'one_to_one': new_one_to_one.pk,
+ 'many_to_many': [item.pk for item in new_many_to_many],
+ }
+
+ # Serializer should validate okay.
+ serializer = TestSerializer(data=data)
+ assert serializer.is_valid()
+
+ # Creating the instance, relationship attributes should be set.
+ instance = serializer.save()
+ assert instance.foreign_key.pk == new_foreign_key.pk
+ assert instance.one_to_one.pk == new_one_to_one.pk
+ assert [
+ item.pk for item in instance.many_to_many.all()
+ ] == [
+ item.pk for item in new_many_to_many
+ ]
+ assert list(instance.through.all()) == []
+
+ # Representation should be correct.
+ expected = {
+ 'id': instance.pk,
+ 'foreign_key': new_foreign_key.pk,
+ 'one_to_one': new_one_to_one.pk,
+ 'many_to_many': [item.pk for item in new_many_to_many],
+ 'through': []
+ }
+ self.assertEqual(serializer.data, expected)
+
+ def test_pk_update(self):
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = RelationalModel
+
+ new_foreign_key = ForeignKeyTargetModel.objects.create(
+ name='foreign_key'
+ )
+ new_one_to_one = OneToOneTargetModel.objects.create(
+ name='one_to_one'
+ )
+ new_many_to_many = [
+ ManyToManyTargetModel.objects.create(
+ name='new many_to_many (%d)' % idx
+ ) for idx in range(3)
+ ]
+ data = {
+ 'foreign_key': new_foreign_key.pk,
+ 'one_to_one': new_one_to_one.pk,
+ 'many_to_many': [item.pk for item in new_many_to_many],
+ }
+
+ # Serializer should validate okay.
+ serializer = TestSerializer(self.instance, data=data)
+ assert serializer.is_valid()
+
+ # Creating the instance, relationship attributes should be set.
+ instance = serializer.save()
+ assert instance.foreign_key.pk == new_foreign_key.pk
+ assert instance.one_to_one.pk == new_one_to_one.pk
+ assert [
+ item.pk for item in instance.many_to_many.all()
+ ] == [
+ item.pk for item in new_many_to_many
+ ]
+ assert list(instance.through.all()) == []
+
+ # Representation should be correct.
+ expected = {
+ 'id': self.instance.pk,
+ 'foreign_key': new_foreign_key.pk,
+ 'one_to_one': new_one_to_one.pk,
+ 'many_to_many': [item.pk for item in new_many_to_many],
+ 'through': []
+ }
+ self.assertEqual(serializer.data, expected)
+
+
+# Tests for bulk create using `ListSerializer`.
+
+class BulkCreateModel(models.Model):
+ name = models.CharField(max_length=10)
+
+
+class TestBulkCreate(TestCase):
+ def test_bulk_create(self):
+ class BasicModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BulkCreateModel
+ fields = ('name',)
+
+ class BulkCreateSerializer(serializers.ListSerializer):
+ child = BasicModelSerializer()
+
+ data = [{'name': 'a'}, {'name': 'b'}, {'name': 'c'}]
+ serializer = BulkCreateSerializer(data=data)
+ assert serializer.is_valid()
+
+ # Objects are returned by save().
+ instances = serializer.save()
+ assert len(instances) == 3
+ assert [item.name for item in instances] == ['a', 'b', 'c']
+
+ # Objects have been created in the database.
+ assert BulkCreateModel.objects.count() == 3
+ assert list(BulkCreateModel.objects.values_list('name', flat=True)) == ['a', 'b', 'c']
+
+ # Serializer returns correct data.
+ assert serializer.data == data
+
+
+class TestMetaClassModel(models.Model):
+ text = models.CharField(max_length=100)
+
+
+class TestSerializerMetaClass(TestCase):
+ def test_meta_class_fields_option(self):
+ class ExampleSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = TestMetaClassModel
+ fields = 'text'
+
+ with self.assertRaises(TypeError) as result:
+ ExampleSerializer().fields
+
+ exception = result.exception
+ assert str(exception).startswith(
+ "The `fields` option must be a list or tuple"
+ )
+
+ def test_meta_class_exclude_option(self):
+ class ExampleSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = TestMetaClassModel
+ exclude = 'text'
+
+ with self.assertRaises(TypeError) as result:
+ ExampleSerializer().fields
+
+ exception = result.exception
+ assert str(exception).startswith(
+ "The `exclude` option must be a list or tuple"
+ )
+
+ def test_meta_class_fields_and_exclude_options(self):
+ class ExampleSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = TestMetaClassModel
+ fields = ('text',)
+ exclude = ('text',)
+
+ with self.assertRaises(AssertionError) as result:
+ ExampleSerializer().fields
+
+ exception = result.exception
+ self.assertEqual(
+ str(exception),
+ "Cannot set both 'fields' and 'exclude' options on serializer ExampleSerializer."
+ )
diff --git a/tests/test_multitable_inheritance.py b/tests/test_multitable_inheritance.py
new file mode 100644
index 00000000..15627e1d
--- /dev/null
+++ b/tests/test_multitable_inheritance.py
@@ -0,0 +1,67 @@
+from __future__ import unicode_literals
+from django.db import models
+from django.test import TestCase
+from rest_framework import serializers
+from tests.models import RESTFrameworkModel
+
+
+# Models
+class ParentModel(RESTFrameworkModel):
+ name1 = models.CharField(max_length=100)
+
+
+class ChildModel(ParentModel):
+ name2 = models.CharField(max_length=100)
+
+
+class AssociatedModel(RESTFrameworkModel):
+ ref = models.OneToOneField(ParentModel, primary_key=True)
+ name = models.CharField(max_length=100)
+
+
+# Serializers
+class DerivedModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ChildModel
+
+
+class AssociatedModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = AssociatedModel
+
+
+# Tests
+class InheritedModelSerializationTests(TestCase):
+
+ def test_multitable_inherited_model_fields_as_expected(self):
+ """
+ Assert that the parent pointer field is not included in the fields
+ serialized fields
+ """
+ child = ChildModel(name1='parent name', name2='child name')
+ serializer = DerivedModelSerializer(child)
+ self.assertEqual(set(serializer.data.keys()),
+ set(['name1', 'name2', 'id']))
+
+ def test_onetoone_primary_key_model_fields_as_expected(self):
+ """
+ Assert that a model with a onetoone field that is the primary key is
+ not treated like a derived model
+ """
+ parent = ParentModel.objects.create(name1='parent name')
+ associate = AssociatedModel.objects.create(name='hello', ref=parent)
+ serializer = AssociatedModelSerializer(associate)
+ self.assertEqual(set(serializer.data.keys()),
+ set(['name', 'ref']))
+
+ def test_data_is_valid_without_parent_ptr(self):
+ """
+ Assert that the pointer to the parent table is not a required field
+ for input data
+ """
+ data = {
+ 'name1': 'parent name',
+ 'name2': 'child name',
+ }
+ serializer = DerivedModelSerializer(data=data)
+ self.assertEqual(serializer.is_valid(), True)
diff --git a/tests/test_negotiation.py b/tests/test_negotiation.py
new file mode 100644
index 00000000..04b89eb6
--- /dev/null
+++ b/tests/test_negotiation.py
@@ -0,0 +1,45 @@
+from __future__ import unicode_literals
+from django.test import TestCase
+from rest_framework.negotiation import DefaultContentNegotiation
+from rest_framework.request import Request
+from rest_framework.renderers import BaseRenderer
+from rest_framework.test import APIRequestFactory
+
+
+factory = APIRequestFactory()
+
+
+class MockJSONRenderer(BaseRenderer):
+ media_type = 'application/json'
+
+
+class MockHTMLRenderer(BaseRenderer):
+ media_type = 'text/html'
+
+
+class NoCharsetSpecifiedRenderer(BaseRenderer):
+ media_type = 'my/media'
+
+
+class TestAcceptedMediaType(TestCase):
+ def setUp(self):
+ self.renderers = [MockJSONRenderer(), MockHTMLRenderer()]
+ self.negotiator = DefaultContentNegotiation()
+
+ def select_renderer(self, request):
+ return self.negotiator.select_renderer(request, self.renderers)
+
+ def test_client_without_accept_use_renderer(self):
+ request = Request(factory.get('/'))
+ accepted_renderer, accepted_media_type = self.select_renderer(request)
+ self.assertEqual(accepted_media_type, 'application/json')
+
+ def test_client_underspecifies_accept_use_renderer(self):
+ request = Request(factory.get('/', HTTP_ACCEPT='*/*'))
+ accepted_renderer, accepted_media_type = self.select_renderer(request)
+ self.assertEqual(accepted_media_type, 'application/json')
+
+ def test_client_overspecifies_accept_use_client(self):
+ request = Request(factory.get('/', HTTP_ACCEPT='application/json; indent=8'))
+ accepted_renderer, accepted_media_type = self.select_renderer(request)
+ self.assertEqual(accepted_media_type, 'application/json; indent=8')
diff --git a/tests/test_pagination.py b/tests/test_pagination.py
new file mode 100644
index 00000000..6b39a6f2
--- /dev/null
+++ b/tests/test_pagination.py
@@ -0,0 +1,671 @@
+# coding: utf-8
+from __future__ import unicode_literals
+from rest_framework import exceptions, generics, pagination, serializers, status, filters
+from rest_framework.request import Request
+from rest_framework.pagination import PageLink, PAGE_BREAK
+from rest_framework.test import APIRequestFactory
+import pytest
+
+factory = APIRequestFactory()
+
+
+class TestPaginationIntegration:
+ """
+ Integration tests.
+ """
+
+ def setup(self):
+ class PassThroughSerializer(serializers.BaseSerializer):
+ def to_representation(self, item):
+ return item
+
+ class EvenItemsOnly(filters.BaseFilterBackend):
+ def filter_queryset(self, request, queryset, view):
+ return [item for item in queryset if item % 2 == 0]
+
+ class BasicPagination(pagination.PageNumberPagination):
+ page_size = 5
+ page_size_query_param = 'page_size'
+ max_page_size = 20
+
+ self.view = generics.ListAPIView.as_view(
+ serializer_class=PassThroughSerializer,
+ queryset=range(1, 101),
+ filter_backends=[EvenItemsOnly],
+ pagination_class=BasicPagination
+ )
+
+ def test_filtered_items_are_paginated(self):
+ request = factory.get('/', {'page': 2})
+ response = self.view(request)
+ assert response.status_code == status.HTTP_200_OK
+ assert response.data == {
+ 'results': [12, 14, 16, 18, 20],
+ 'previous': 'http://testserver/',
+ 'next': 'http://testserver/?page=3',
+ 'count': 50
+ }
+
+ def test_setting_page_size(self):
+ """
+ When 'paginate_by_param' is set, the client may choose a page size.
+ """
+ request = factory.get('/', {'page_size': 10})
+ response = self.view(request)
+ assert response.status_code == status.HTTP_200_OK
+ assert response.data == {
+ 'results': [2, 4, 6, 8, 10, 12, 14, 16, 18, 20],
+ 'previous': None,
+ 'next': 'http://testserver/?page=2&page_size=10',
+ 'count': 50
+ }
+
+ def test_setting_page_size_over_maximum(self):
+ """
+ When page_size parameter exceeds maxiumum allowable,
+ then it should be capped to the maxiumum.
+ """
+ request = factory.get('/', {'page_size': 1000})
+ response = self.view(request)
+ assert response.status_code == status.HTTP_200_OK
+ assert response.data == {
+ 'results': [
+ 2, 4, 6, 8, 10, 12, 14, 16, 18, 20,
+ 22, 24, 26, 28, 30, 32, 34, 36, 38, 40
+ ],
+ 'previous': None,
+ 'next': 'http://testserver/?page=2&page_size=1000',
+ 'count': 50
+ }
+
+ def test_setting_page_size_to_zero(self):
+ """
+ When page_size parameter is invalid it should return to the default.
+ """
+ request = factory.get('/', {'page_size': 0})
+ response = self.view(request)
+ assert response.status_code == status.HTTP_200_OK
+ assert response.data == {
+ 'results': [2, 4, 6, 8, 10],
+ 'previous': None,
+ 'next': 'http://testserver/?page=2&page_size=0',
+ 'count': 50
+ }
+
+ def test_additional_query_params_are_preserved(self):
+ request = factory.get('/', {'page': 2, 'filter': 'even'})
+ response = self.view(request)
+ assert response.status_code == status.HTTP_200_OK
+ assert response.data == {
+ 'results': [12, 14, 16, 18, 20],
+ 'previous': 'http://testserver/?filter=even',
+ 'next': 'http://testserver/?filter=even&page=3',
+ 'count': 50
+ }
+
+ def test_404_not_found_for_zero_page(self):
+ request = factory.get('/', {'page': '0'})
+ response = self.view(request)
+ assert response.status_code == status.HTTP_404_NOT_FOUND
+ assert response.data == {
+ 'detail': 'Invalid page "0": That page number is less than 1.'
+ }
+
+ def test_404_not_found_for_invalid_page(self):
+ request = factory.get('/', {'page': 'invalid'})
+ response = self.view(request)
+ assert response.status_code == status.HTTP_404_NOT_FOUND
+ assert response.data == {
+ 'detail': 'Invalid page "invalid": That page number is not an integer.'
+ }
+
+
+class TestPaginationDisabledIntegration:
+ """
+ Integration tests for disabled pagination.
+ """
+
+ def setup(self):
+ class PassThroughSerializer(serializers.BaseSerializer):
+ def to_representation(self, item):
+ return item
+
+ self.view = generics.ListAPIView.as_view(
+ serializer_class=PassThroughSerializer,
+ queryset=range(1, 101),
+ pagination_class=None
+ )
+
+ def test_unpaginated_list(self):
+ request = factory.get('/', {'page': 2})
+ response = self.view(request)
+ assert response.status_code == status.HTTP_200_OK
+ assert response.data == list(range(1, 101))
+
+
+class TestDeprecatedStylePagination:
+ """
+ Integration tests for deprecated style of setting pagination
+ attributes on the view.
+ """
+
+ def setup(self):
+ class PassThroughSerializer(serializers.BaseSerializer):
+ def to_representation(self, item):
+ return item
+
+ class ExampleView(generics.ListAPIView):
+ serializer_class = PassThroughSerializer
+ queryset = range(1, 101)
+ pagination_class = pagination.PageNumberPagination
+ paginate_by = 20
+ page_query_param = 'page_number'
+
+ self.view = ExampleView.as_view()
+
+ def test_paginate_by_attribute_on_view(self):
+ request = factory.get('/?page_number=2')
+ response = self.view(request)
+ assert response.status_code == status.HTTP_200_OK
+ assert response.data == {
+ 'results': [
+ 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
+ 31, 32, 33, 34, 35, 36, 37, 38, 39, 40
+ ],
+ 'previous': 'http://testserver/',
+ 'next': 'http://testserver/?page_number=3',
+ 'count': 100
+ }
+
+
+class TestPageNumberPagination:
+ """
+ Unit tests for `pagination.PageNumberPagination`.
+ """
+
+ def setup(self):
+ class ExamplePagination(pagination.PageNumberPagination):
+ page_size = 5
+ self.pagination = ExamplePagination()
+ self.queryset = range(1, 101)
+
+ def paginate_queryset(self, request):
+ return list(self.pagination.paginate_queryset(self.queryset, request))
+
+ def get_paginated_content(self, queryset):
+ response = self.pagination.get_paginated_response(queryset)
+ return response.data
+
+ def get_html_context(self):
+ return self.pagination.get_html_context()
+
+ def test_no_page_number(self):
+ request = Request(factory.get('/'))
+ queryset = self.paginate_queryset(request)
+ content = self.get_paginated_content(queryset)
+ context = self.get_html_context()
+ assert queryset == [1, 2, 3, 4, 5]
+ assert content == {
+ 'results': [1, 2, 3, 4, 5],
+ 'previous': None,
+ 'next': 'http://testserver/?page=2',
+ 'count': 100
+ }
+ assert context == {
+ 'previous_url': None,
+ 'next_url': 'http://testserver/?page=2',
+ 'page_links': [
+ PageLink('http://testserver/', 1, True, False),
+ PageLink('http://testserver/?page=2', 2, False, False),
+ PageLink('http://testserver/?page=3', 3, False, False),
+ PAGE_BREAK,
+ PageLink('http://testserver/?page=20', 20, False, False),
+ ]
+ }
+ assert self.pagination.display_page_controls
+ assert isinstance(self.pagination.to_html(), type(''))
+
+ def test_second_page(self):
+ request = Request(factory.get('/', {'page': 2}))
+ queryset = self.paginate_queryset(request)
+ content = self.get_paginated_content(queryset)
+ context = self.get_html_context()
+ assert queryset == [6, 7, 8, 9, 10]
+ assert content == {
+ 'results': [6, 7, 8, 9, 10],
+ 'previous': 'http://testserver/',
+ 'next': 'http://testserver/?page=3',
+ 'count': 100
+ }
+ assert context == {
+ 'previous_url': 'http://testserver/',
+ 'next_url': 'http://testserver/?page=3',
+ 'page_links': [
+ PageLink('http://testserver/', 1, False, False),
+ PageLink('http://testserver/?page=2', 2, True, False),
+ PageLink('http://testserver/?page=3', 3, False, False),
+ PAGE_BREAK,
+ PageLink('http://testserver/?page=20', 20, False, False),
+ ]
+ }
+
+ def test_last_page(self):
+ request = Request(factory.get('/', {'page': 'last'}))
+ queryset = self.paginate_queryset(request)
+ content = self.get_paginated_content(queryset)
+ context = self.get_html_context()
+ assert queryset == [96, 97, 98, 99, 100]
+ assert content == {
+ 'results': [96, 97, 98, 99, 100],
+ 'previous': 'http://testserver/?page=19',
+ 'next': None,
+ 'count': 100
+ }
+ assert context == {
+ 'previous_url': 'http://testserver/?page=19',
+ 'next_url': None,
+ 'page_links': [
+ PageLink('http://testserver/', 1, False, False),
+ PAGE_BREAK,
+ PageLink('http://testserver/?page=18', 18, False, False),
+ PageLink('http://testserver/?page=19', 19, False, False),
+ PageLink('http://testserver/?page=20', 20, True, False),
+ ]
+ }
+
+ def test_invalid_page(self):
+ request = Request(factory.get('/', {'page': 'invalid'}))
+ with pytest.raises(exceptions.NotFound):
+ self.paginate_queryset(request)
+
+
+class TestLimitOffset:
+ """
+ Unit tests for `pagination.LimitOffsetPagination`.
+ """
+
+ def setup(self):
+ class ExamplePagination(pagination.LimitOffsetPagination):
+ default_limit = 10
+ self.pagination = ExamplePagination()
+ self.queryset = range(1, 101)
+
+ def paginate_queryset(self, request):
+ return list(self.pagination.paginate_queryset(self.queryset, request))
+
+ def get_paginated_content(self, queryset):
+ response = self.pagination.get_paginated_response(queryset)
+ return response.data
+
+ def get_html_context(self):
+ return self.pagination.get_html_context()
+
+ def test_no_offset(self):
+ request = Request(factory.get('/', {'limit': 5}))
+ queryset = self.paginate_queryset(request)
+ content = self.get_paginated_content(queryset)
+ context = self.get_html_context()
+ assert queryset == [1, 2, 3, 4, 5]
+ assert content == {
+ 'results': [1, 2, 3, 4, 5],
+ 'previous': None,
+ 'next': 'http://testserver/?limit=5&offset=5',
+ 'count': 100
+ }
+ assert context == {
+ 'previous_url': None,
+ 'next_url': 'http://testserver/?limit=5&offset=5',
+ 'page_links': [
+ PageLink('http://testserver/?limit=5', 1, True, False),
+ PageLink('http://testserver/?limit=5&offset=5', 2, False, False),
+ PageLink('http://testserver/?limit=5&offset=10', 3, False, False),
+ PAGE_BREAK,
+ PageLink('http://testserver/?limit=5&offset=95', 20, False, False),
+ ]
+ }
+ assert self.pagination.display_page_controls
+ assert isinstance(self.pagination.to_html(), type(''))
+
+ def test_single_offset(self):
+ """
+ When the offset is not a multiple of the limit we get some edge cases:
+ * The first page should still be offset zero.
+ * We may end up displaying an extra page in the pagination control.
+ """
+ request = Request(factory.get('/', {'limit': 5, 'offset': 1}))
+ queryset = self.paginate_queryset(request)
+ content = self.get_paginated_content(queryset)
+ context = self.get_html_context()
+ assert queryset == [2, 3, 4, 5, 6]
+ assert content == {
+ 'results': [2, 3, 4, 5, 6],
+ 'previous': 'http://testserver/?limit=5',
+ 'next': 'http://testserver/?limit=5&offset=6',
+ 'count': 100
+ }
+ assert context == {
+ 'previous_url': 'http://testserver/?limit=5',
+ 'next_url': 'http://testserver/?limit=5&offset=6',
+ 'page_links': [
+ PageLink('http://testserver/?limit=5', 1, False, False),
+ PageLink('http://testserver/?limit=5&offset=1', 2, True, False),
+ PageLink('http://testserver/?limit=5&offset=6', 3, False, False),
+ PAGE_BREAK,
+ PageLink('http://testserver/?limit=5&offset=96', 21, False, False),
+ ]
+ }
+
+ def test_first_offset(self):
+ request = Request(factory.get('/', {'limit': 5, 'offset': 5}))
+ queryset = self.paginate_queryset(request)
+ content = self.get_paginated_content(queryset)
+ context = self.get_html_context()
+ assert queryset == [6, 7, 8, 9, 10]
+ assert content == {
+ 'results': [6, 7, 8, 9, 10],
+ 'previous': 'http://testserver/?limit=5',
+ 'next': 'http://testserver/?limit=5&offset=10',
+ 'count': 100
+ }
+ assert context == {
+ 'previous_url': 'http://testserver/?limit=5',
+ 'next_url': 'http://testserver/?limit=5&offset=10',
+ 'page_links': [
+ PageLink('http://testserver/?limit=5', 1, False, False),
+ PageLink('http://testserver/?limit=5&offset=5', 2, True, False),
+ PageLink('http://testserver/?limit=5&offset=10', 3, False, False),
+ PAGE_BREAK,
+ PageLink('http://testserver/?limit=5&offset=95', 20, False, False),
+ ]
+ }
+
+ def test_middle_offset(self):
+ request = Request(factory.get('/', {'limit': 5, 'offset': 10}))
+ queryset = self.paginate_queryset(request)
+ content = self.get_paginated_content(queryset)
+ context = self.get_html_context()
+ assert queryset == [11, 12, 13, 14, 15]
+ assert content == {
+ 'results': [11, 12, 13, 14, 15],
+ 'previous': 'http://testserver/?limit=5&offset=5',
+ 'next': 'http://testserver/?limit=5&offset=15',
+ 'count': 100
+ }
+ assert context == {
+ 'previous_url': 'http://testserver/?limit=5&offset=5',
+ 'next_url': 'http://testserver/?limit=5&offset=15',
+ 'page_links': [
+ PageLink('http://testserver/?limit=5', 1, False, False),
+ PageLink('http://testserver/?limit=5&offset=5', 2, False, False),
+ PageLink('http://testserver/?limit=5&offset=10', 3, True, False),
+ PageLink('http://testserver/?limit=5&offset=15', 4, False, False),
+ PAGE_BREAK,
+ PageLink('http://testserver/?limit=5&offset=95', 20, False, False),
+ ]
+ }
+
+ def test_ending_offset(self):
+ request = Request(factory.get('/', {'limit': 5, 'offset': 95}))
+ queryset = self.paginate_queryset(request)
+ content = self.get_paginated_content(queryset)
+ context = self.get_html_context()
+ assert queryset == [96, 97, 98, 99, 100]
+ assert content == {
+ 'results': [96, 97, 98, 99, 100],
+ 'previous': 'http://testserver/?limit=5&offset=90',
+ 'next': None,
+ 'count': 100
+ }
+ assert context == {
+ 'previous_url': 'http://testserver/?limit=5&offset=90',
+ 'next_url': None,
+ 'page_links': [
+ PageLink('http://testserver/?limit=5', 1, False, False),
+ PAGE_BREAK,
+ PageLink('http://testserver/?limit=5&offset=85', 18, False, False),
+ PageLink('http://testserver/?limit=5&offset=90', 19, False, False),
+ PageLink('http://testserver/?limit=5&offset=95', 20, True, False),
+ ]
+ }
+
+ def test_invalid_offset(self):
+ """
+ An invalid offset query param should be treated as 0.
+ """
+ request = Request(factory.get('/', {'limit': 5, 'offset': 'invalid'}))
+ queryset = self.paginate_queryset(request)
+ assert queryset == [1, 2, 3, 4, 5]
+
+ def test_invalid_limit(self):
+ """
+ An invalid limit query param should be ignored in favor of the default.
+ """
+ request = Request(factory.get('/', {'limit': 'invalid', 'offset': 0}))
+ queryset = self.paginate_queryset(request)
+ assert queryset == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+
+
+class TestCursorPagination:
+ """
+ Unit tests for `pagination.CursorPagination`.
+ """
+
+ def setup(self):
+ class MockObject(object):
+ def __init__(self, idx):
+ self.created = idx
+
+ class MockQuerySet(object):
+ def __init__(self, items):
+ self.items = items
+
+ def filter(self, created__gt=None, created__lt=None):
+ if created__gt is not None:
+ return MockQuerySet([
+ item for item in self.items
+ if item.created > int(created__gt)
+ ])
+
+ assert created__lt is not None
+ return MockQuerySet([
+ item for item in self.items
+ if item.created < int(created__lt)
+ ])
+
+ def order_by(self, *ordering):
+ if ordering[0].startswith('-'):
+ return MockQuerySet(list(reversed(self.items)))
+ return self
+
+ def __getitem__(self, sliced):
+ return self.items[sliced]
+
+ class ExamplePagination(pagination.CursorPagination):
+ page_size = 5
+ ordering = 'created'
+
+ self.pagination = ExamplePagination()
+ self.queryset = MockQuerySet([
+ MockObject(idx) for idx in [
+ 1, 1, 1, 1, 1,
+ 1, 2, 3, 4, 4,
+ 4, 4, 5, 6, 7,
+ 7, 7, 7, 7, 7,
+ 7, 7, 7, 8, 9,
+ 9, 9, 9, 9, 9
+ ]
+ ])
+
+ def get_pages(self, url):
+ """
+ Given a URL return a tuple of:
+
+ (previous page, current page, next page, previous url, next url)
+ """
+ request = Request(factory.get(url))
+ queryset = self.pagination.paginate_queryset(self.queryset, request)
+ current = [item.created for item in queryset]
+
+ next_url = self.pagination.get_next_link()
+ previous_url = self.pagination.get_previous_link()
+
+ if next_url is not None:
+ request = Request(factory.get(next_url))
+ queryset = self.pagination.paginate_queryset(self.queryset, request)
+ next = [item.created for item in queryset]
+ else:
+ next = None
+
+ if previous_url is not None:
+ request = Request(factory.get(previous_url))
+ queryset = self.pagination.paginate_queryset(self.queryset, request)
+ previous = [item.created for item in queryset]
+ else:
+ previous = None
+
+ return (previous, current, next, previous_url, next_url)
+
+ def test_invalid_cursor(self):
+ request = Request(factory.get('/', {'cursor': '123'}))
+ with pytest.raises(exceptions.NotFound):
+ self.pagination.paginate_queryset(self.queryset, request)
+
+ def test_use_with_ordering_filter(self):
+ class MockView:
+ filter_backends = (filters.OrderingFilter,)
+ ordering_fields = ['username', 'created']
+ ordering = 'created'
+
+ request = Request(factory.get('/', {'ordering': 'username'}))
+ ordering = self.pagination.get_ordering(request, [], MockView())
+ assert ordering == ('username',)
+
+ request = Request(factory.get('/', {'ordering': '-username'}))
+ ordering = self.pagination.get_ordering(request, [], MockView())
+ assert ordering == ('-username',)
+
+ request = Request(factory.get('/', {'ordering': 'invalid'}))
+ ordering = self.pagination.get_ordering(request, [], MockView())
+ assert ordering == ('created',)
+
+ def test_cursor_pagination(self):
+ (previous, current, next, previous_url, next_url) = self.get_pages('/')
+
+ assert previous is None
+ assert current == [1, 1, 1, 1, 1]
+ assert next == [1, 2, 3, 4, 4]
+
+ (previous, current, next, previous_url, next_url) = self.get_pages(next_url)
+
+ assert previous == [1, 1, 1, 1, 1]
+ assert current == [1, 2, 3, 4, 4]
+ assert next == [4, 4, 5, 6, 7]
+
+ (previous, current, next, previous_url, next_url) = self.get_pages(next_url)
+
+ assert previous == [1, 2, 3, 4, 4]
+ assert current == [4, 4, 5, 6, 7]
+ assert next == [7, 7, 7, 7, 7]
+
+ (previous, current, next, previous_url, next_url) = self.get_pages(next_url)
+
+ assert previous == [4, 4, 4, 5, 6] # Paging artifact
+ assert current == [7, 7, 7, 7, 7]
+ assert next == [7, 7, 7, 8, 9]
+
+ (previous, current, next, previous_url, next_url) = self.get_pages(next_url)
+
+ assert previous == [7, 7, 7, 7, 7]
+ assert current == [7, 7, 7, 8, 9]
+ assert next == [9, 9, 9, 9, 9]
+
+ (previous, current, next, previous_url, next_url) = self.get_pages(next_url)
+
+ assert previous == [7, 7, 7, 8, 9]
+ assert current == [9, 9, 9, 9, 9]
+ assert next is None
+
+ (previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
+
+ assert previous == [7, 7, 7, 7, 7]
+ assert current == [7, 7, 7, 8, 9]
+ assert next == [9, 9, 9, 9, 9]
+
+ (previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
+
+ assert previous == [4, 4, 5, 6, 7]
+ assert current == [7, 7, 7, 7, 7]
+ assert next == [8, 9, 9, 9, 9] # Paging artifact
+
+ (previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
+
+ assert previous == [1, 2, 3, 4, 4]
+ assert current == [4, 4, 5, 6, 7]
+ assert next == [7, 7, 7, 7, 7]
+
+ (previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
+
+ assert previous == [1, 1, 1, 1, 1]
+ assert current == [1, 2, 3, 4, 4]
+ assert next == [4, 4, 5, 6, 7]
+
+ (previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
+
+ assert previous is None
+ assert current == [1, 1, 1, 1, 1]
+ assert next == [1, 2, 3, 4, 4]
+
+ assert isinstance(self.pagination.to_html(), type(''))
+
+
+def test_get_displayed_page_numbers():
+ """
+ Test our contextual page display function.
+
+ This determines which pages to display in a pagination control,
+ given the current page and the last page.
+ """
+ displayed_page_numbers = pagination._get_displayed_page_numbers
+
+ # At five pages or less, all pages are displayed, always.
+ assert displayed_page_numbers(1, 5) == [1, 2, 3, 4, 5]
+ assert displayed_page_numbers(2, 5) == [1, 2, 3, 4, 5]
+ assert displayed_page_numbers(3, 5) == [1, 2, 3, 4, 5]
+ assert displayed_page_numbers(4, 5) == [1, 2, 3, 4, 5]
+ assert displayed_page_numbers(5, 5) == [1, 2, 3, 4, 5]
+
+ # Between six and either pages we may have a single page break.
+ assert displayed_page_numbers(1, 6) == [1, 2, 3, None, 6]
+ assert displayed_page_numbers(2, 6) == [1, 2, 3, None, 6]
+ assert displayed_page_numbers(3, 6) == [1, 2, 3, 4, 5, 6]
+ assert displayed_page_numbers(4, 6) == [1, 2, 3, 4, 5, 6]
+ assert displayed_page_numbers(5, 6) == [1, None, 4, 5, 6]
+ assert displayed_page_numbers(6, 6) == [1, None, 4, 5, 6]
+
+ assert displayed_page_numbers(1, 7) == [1, 2, 3, None, 7]
+ assert displayed_page_numbers(2, 7) == [1, 2, 3, None, 7]
+ assert displayed_page_numbers(3, 7) == [1, 2, 3, 4, None, 7]
+ assert displayed_page_numbers(4, 7) == [1, 2, 3, 4, 5, 6, 7]
+ assert displayed_page_numbers(5, 7) == [1, None, 4, 5, 6, 7]
+ assert displayed_page_numbers(6, 7) == [1, None, 5, 6, 7]
+ assert displayed_page_numbers(7, 7) == [1, None, 5, 6, 7]
+
+ assert displayed_page_numbers(1, 8) == [1, 2, 3, None, 8]
+ assert displayed_page_numbers(2, 8) == [1, 2, 3, None, 8]
+ assert displayed_page_numbers(3, 8) == [1, 2, 3, 4, None, 8]
+ assert displayed_page_numbers(4, 8) == [1, 2, 3, 4, 5, None, 8]
+ assert displayed_page_numbers(5, 8) == [1, None, 4, 5, 6, 7, 8]
+ assert displayed_page_numbers(6, 8) == [1, None, 5, 6, 7, 8]
+ assert displayed_page_numbers(7, 8) == [1, None, 6, 7, 8]
+ assert displayed_page_numbers(8, 8) == [1, None, 6, 7, 8]
+
+ # At nine or more pages we may have two page breaks, one on each side.
+ assert displayed_page_numbers(1, 9) == [1, 2, 3, None, 9]
+ assert displayed_page_numbers(2, 9) == [1, 2, 3, None, 9]
+ assert displayed_page_numbers(3, 9) == [1, 2, 3, 4, None, 9]
+ assert displayed_page_numbers(4, 9) == [1, 2, 3, 4, 5, None, 9]
+ assert displayed_page_numbers(5, 9) == [1, None, 4, 5, 6, None, 9]
+ assert displayed_page_numbers(6, 9) == [1, None, 5, 6, 7, 8, 9]
+ assert displayed_page_numbers(7, 9) == [1, None, 6, 7, 8, 9]
+ assert displayed_page_numbers(8, 9) == [1, None, 7, 8, 9]
+ assert displayed_page_numbers(9, 9) == [1, None, 7, 8, 9]
diff --git a/tests/test_parsers.py b/tests/test_parsers.py
new file mode 100644
index 00000000..fe6aec19
--- /dev/null
+++ b/tests/test_parsers.py
@@ -0,0 +1,103 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import unicode_literals
+from django import forms
+from django.core.files.uploadhandler import MemoryFileUploadHandler
+from django.test import TestCase
+from django.utils.six.moves import StringIO
+from rest_framework.exceptions import ParseError
+from rest_framework.parsers import FormParser, FileUploadParser
+
+
+class Form(forms.Form):
+ field1 = forms.CharField(max_length=3)
+ field2 = forms.CharField()
+
+
+class TestFormParser(TestCase):
+ def setUp(self):
+ self.string = "field1=abc&field2=defghijk"
+
+ def test_parse(self):
+ """ Make sure the `QueryDict` works OK """
+ parser = FormParser()
+
+ stream = StringIO(self.string)
+ data = parser.parse(stream)
+
+ self.assertEqual(Form(data).is_valid(), True)
+
+
+class TestFileUploadParser(TestCase):
+ def setUp(self):
+ class MockRequest(object):
+ pass
+ from io import BytesIO
+ self.stream = BytesIO(
+ "Test text file".encode('utf-8')
+ )
+ request = MockRequest()
+ request.upload_handlers = (MemoryFileUploadHandler(),)
+ request.META = {
+ 'HTTP_CONTENT_DISPOSITION': 'Content-Disposition: inline; filename=file.txt',
+ 'HTTP_CONTENT_LENGTH': 14,
+ }
+ self.parser_context = {'request': request, 'kwargs': {}}
+
+ def test_parse(self):
+ """
+ Parse raw file upload.
+ """
+ parser = FileUploadParser()
+ self.stream.seek(0)
+ data_and_files = parser.parse(self.stream, None, self.parser_context)
+ file_obj = data_and_files.files['file']
+ self.assertEqual(file_obj._size, 14)
+
+ def test_parse_missing_filename(self):
+ """
+ Parse raw file upload when filename is missing.
+ """
+ parser = FileUploadParser()
+ self.stream.seek(0)
+ self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = ''
+ with self.assertRaises(ParseError):
+ parser.parse(self.stream, None, self.parser_context)
+
+ def test_parse_missing_filename_multiple_upload_handlers(self):
+ """
+ Parse raw file upload with multiple handlers when filename is missing.
+ Regression test for #2109.
+ """
+ parser = FileUploadParser()
+ self.stream.seek(0)
+ self.parser_context['request'].upload_handlers = (
+ MemoryFileUploadHandler(),
+ MemoryFileUploadHandler()
+ )
+ self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = ''
+ with self.assertRaises(ParseError):
+ parser.parse(self.stream, None, self.parser_context)
+
+ def test_get_filename(self):
+ parser = FileUploadParser()
+ filename = parser.get_filename(self.stream, None, self.parser_context)
+ self.assertEqual(filename, 'file.txt')
+
+ def test_get_encoded_filename(self):
+ parser = FileUploadParser()
+
+ self.__replace_content_disposition('inline; filename*=utf-8\'\'ÀĥƦ.txt')
+ filename = parser.get_filename(self.stream, None, self.parser_context)
+ self.assertEqual(filename, 'ÀĥƦ.txt')
+
+ self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8\'\'ÀĥƦ.txt')
+ filename = parser.get_filename(self.stream, None, self.parser_context)
+ self.assertEqual(filename, 'ÀĥƦ.txt')
+
+ self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8\'en-us\'ÀĥƦ.txt')
+ filename = parser.get_filename(self.stream, None, self.parser_context)
+ self.assertEqual(filename, 'ÀĥƦ.txt')
+
+ def __replace_content_disposition(self, disposition):
+ self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = disposition
diff --git a/tests/test_permissions.py b/tests/test_permissions.py
new file mode 100644
index 00000000..97bac33d
--- /dev/null
+++ b/tests/test_permissions.py
@@ -0,0 +1,312 @@
+from __future__ import unicode_literals
+from django.contrib.auth.models import User, Permission, Group
+from django.db import models
+from django.test import TestCase
+from django.utils import unittest
+from rest_framework import generics, serializers, status, permissions, authentication, HTTP_HEADER_ENCODING
+from rest_framework.compat import guardian, get_model_name
+from rest_framework.filters import DjangoObjectPermissionsFilter
+from rest_framework.test import APIRequestFactory
+from tests.models import BasicModel
+import base64
+
+factory = APIRequestFactory()
+
+
+class BasicSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BasicModel
+
+
+class RootView(generics.ListCreateAPIView):
+ queryset = BasicModel.objects.all()
+ serializer_class = BasicSerializer
+ authentication_classes = [authentication.BasicAuthentication]
+ permission_classes = [permissions.DjangoModelPermissions]
+
+
+class InstanceView(generics.RetrieveUpdateDestroyAPIView):
+ queryset = BasicModel.objects.all()
+ serializer_class = BasicSerializer
+ authentication_classes = [authentication.BasicAuthentication]
+ permission_classes = [permissions.DjangoModelPermissions]
+
+root_view = RootView.as_view()
+instance_view = InstanceView.as_view()
+
+
+def basic_auth_header(username, password):
+ credentials = ('%s:%s' % (username, password))
+ base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING)
+ return 'Basic %s' % base64_credentials
+
+
+class ModelPermissionsIntegrationTests(TestCase):
+ def setUp(self):
+ User.objects.create_user('disallowed', 'disallowed@example.com', 'password')
+ user = User.objects.create_user('permitted', 'permitted@example.com', 'password')
+ user.user_permissions = [
+ Permission.objects.get(codename='add_basicmodel'),
+ Permission.objects.get(codename='change_basicmodel'),
+ Permission.objects.get(codename='delete_basicmodel')
+ ]
+ user = User.objects.create_user('updateonly', 'updateonly@example.com', 'password')
+ user.user_permissions = [
+ Permission.objects.get(codename='change_basicmodel'),
+ ]
+
+ self.permitted_credentials = basic_auth_header('permitted', 'password')
+ self.disallowed_credentials = basic_auth_header('disallowed', 'password')
+ self.updateonly_credentials = basic_auth_header('updateonly', 'password')
+
+ BasicModel(text='foo').save()
+
+ def test_has_create_permissions(self):
+ request = factory.post('/', {'text': 'foobar'}, format='json',
+ HTTP_AUTHORIZATION=self.permitted_credentials)
+ response = root_view(request, pk=1)
+ self.assertEqual(response.status_code, status.HTTP_201_CREATED)
+
+ def test_has_put_permissions(self):
+ request = factory.put('/1', {'text': 'foobar'}, format='json',
+ HTTP_AUTHORIZATION=self.permitted_credentials)
+ response = instance_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ def test_has_delete_permissions(self):
+ request = factory.delete('/1', HTTP_AUTHORIZATION=self.permitted_credentials)
+ response = instance_view(request, pk=1)
+ self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
+
+ def test_does_not_have_create_permissions(self):
+ request = factory.post('/', {'text': 'foobar'}, format='json',
+ HTTP_AUTHORIZATION=self.disallowed_credentials)
+ response = root_view(request, pk=1)
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+ def test_does_not_have_put_permissions(self):
+ request = factory.put('/1', {'text': 'foobar'}, format='json',
+ HTTP_AUTHORIZATION=self.disallowed_credentials)
+ response = instance_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+ def test_does_not_have_delete_permissions(self):
+ request = factory.delete('/1', HTTP_AUTHORIZATION=self.disallowed_credentials)
+ response = instance_view(request, pk=1)
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+ def test_options_permitted(self):
+ request = factory.options(
+ '/',
+ HTTP_AUTHORIZATION=self.permitted_credentials
+ )
+ response = root_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertIn('actions', response.data)
+ self.assertEqual(list(response.data['actions'].keys()), ['POST'])
+
+ request = factory.options(
+ '/1',
+ HTTP_AUTHORIZATION=self.permitted_credentials
+ )
+ response = instance_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertIn('actions', response.data)
+ self.assertEqual(list(response.data['actions'].keys()), ['PUT'])
+
+ def test_options_disallowed(self):
+ request = factory.options(
+ '/',
+ HTTP_AUTHORIZATION=self.disallowed_credentials
+ )
+ response = root_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertNotIn('actions', response.data)
+
+ request = factory.options(
+ '/1',
+ HTTP_AUTHORIZATION=self.disallowed_credentials
+ )
+ response = instance_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertNotIn('actions', response.data)
+
+ def test_options_updateonly(self):
+ request = factory.options(
+ '/',
+ HTTP_AUTHORIZATION=self.updateonly_credentials
+ )
+ response = root_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertNotIn('actions', response.data)
+
+ request = factory.options(
+ '/1',
+ HTTP_AUTHORIZATION=self.updateonly_credentials
+ )
+ response = instance_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertIn('actions', response.data)
+ self.assertEqual(list(response.data['actions'].keys()), ['PUT'])
+
+
+class BasicPermModel(models.Model):
+ text = models.CharField(max_length=100)
+
+ class Meta:
+ app_label = 'tests'
+ permissions = (
+ ('view_basicpermmodel', 'Can view basic perm model'),
+ # add, change, delete built in to django
+ )
+
+
+class BasicPermSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BasicPermModel
+
+
+# Custom object-level permission, that includes 'view' permissions
+class ViewObjectPermissions(permissions.DjangoObjectPermissions):
+ perms_map = {
+ 'GET': ['%(app_label)s.view_%(model_name)s'],
+ 'OPTIONS': ['%(app_label)s.view_%(model_name)s'],
+ 'HEAD': ['%(app_label)s.view_%(model_name)s'],
+ 'POST': ['%(app_label)s.add_%(model_name)s'],
+ 'PUT': ['%(app_label)s.change_%(model_name)s'],
+ 'PATCH': ['%(app_label)s.change_%(model_name)s'],
+ 'DELETE': ['%(app_label)s.delete_%(model_name)s'],
+ }
+
+
+class ObjectPermissionInstanceView(generics.RetrieveUpdateDestroyAPIView):
+ queryset = BasicPermModel.objects.all()
+ serializer_class = BasicPermSerializer
+ authentication_classes = [authentication.BasicAuthentication]
+ permission_classes = [ViewObjectPermissions]
+
+object_permissions_view = ObjectPermissionInstanceView.as_view()
+
+
+class ObjectPermissionListView(generics.ListAPIView):
+ queryset = BasicPermModel.objects.all()
+ serializer_class = BasicPermSerializer
+ authentication_classes = [authentication.BasicAuthentication]
+ permission_classes = [ViewObjectPermissions]
+
+object_permissions_list_view = ObjectPermissionListView.as_view()
+
+
+@unittest.skipUnless(guardian, 'django-guardian not installed')
+class ObjectPermissionsIntegrationTests(TestCase):
+ """
+ Integration tests for the object level permissions API.
+ """
+ def setUp(self):
+ from guardian.shortcuts import assign_perm
+
+ # create users
+ create = User.objects.create_user
+ users = {
+ 'fullaccess': create('fullaccess', 'fullaccess@example.com', 'password'),
+ 'readonly': create('readonly', 'readonly@example.com', 'password'),
+ 'writeonly': create('writeonly', 'writeonly@example.com', 'password'),
+ 'deleteonly': create('deleteonly', 'deleteonly@example.com', 'password'),
+ }
+
+ # give everyone model level permissions, as we are not testing those
+ everyone = Group.objects.create(name='everyone')
+ model_name = get_model_name(BasicPermModel)
+ app_label = BasicPermModel._meta.app_label
+ f = '{0}_{1}'.format
+ perms = {
+ 'view': f('view', model_name),
+ 'change': f('change', model_name),
+ 'delete': f('delete', model_name)
+ }
+ for perm in perms.values():
+ perm = '{0}.{1}'.format(app_label, perm)
+ assign_perm(perm, everyone)
+ everyone.user_set.add(*users.values())
+
+ # appropriate object level permissions
+ readers = Group.objects.create(name='readers')
+ writers = Group.objects.create(name='writers')
+ deleters = Group.objects.create(name='deleters')
+
+ model = BasicPermModel.objects.create(text='foo')
+
+ assign_perm(perms['view'], readers, model)
+ assign_perm(perms['change'], writers, model)
+ assign_perm(perms['delete'], deleters, model)
+
+ readers.user_set.add(users['fullaccess'], users['readonly'])
+ writers.user_set.add(users['fullaccess'], users['writeonly'])
+ deleters.user_set.add(users['fullaccess'], users['deleteonly'])
+
+ self.credentials = {}
+ for user in users.values():
+ self.credentials[user.username] = basic_auth_header(user.username, 'password')
+
+ # Delete
+ def test_can_delete_permissions(self):
+ request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['deleteonly'])
+ response = object_permissions_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
+
+ def test_cannot_delete_permissions(self):
+ request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['readonly'])
+ response = object_permissions_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+
+ # Update
+ def test_can_update_permissions(self):
+ request = factory.patch(
+ '/1', {'text': 'foobar'}, format='json',
+ HTTP_AUTHORIZATION=self.credentials['writeonly']
+ )
+ response = object_permissions_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data.get('text'), 'foobar')
+
+ def test_cannot_update_permissions(self):
+ request = factory.patch(
+ '/1', {'text': 'foobar'}, format='json',
+ HTTP_AUTHORIZATION=self.credentials['deleteonly']
+ )
+ response = object_permissions_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+
+ def test_cannot_update_permissions_non_existing(self):
+ request = factory.patch(
+ '/999', {'text': 'foobar'}, format='json',
+ HTTP_AUTHORIZATION=self.credentials['deleteonly']
+ )
+ response = object_permissions_view(request, pk='999')
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+
+ # Read
+ def test_can_read_permissions(self):
+ request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['readonly'])
+ response = object_permissions_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ def test_cannot_read_permissions(self):
+ request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['writeonly'])
+ response = object_permissions_view(request, pk='1')
+ self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
+
+ # Read list
+ def test_can_read_list_permissions(self):
+ request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['readonly'])
+ object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,)
+ response = object_permissions_list_view(request)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data[0].get('id'), 1)
+
+ def test_cannot_read_list_permissions(self):
+ request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['writeonly'])
+ object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,)
+ response = object_permissions_list_view(request)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertListEqual(response.data, [])
diff --git a/tests/test_relations.py b/tests/test_relations.py
new file mode 100644
index 00000000..fbe176e2
--- /dev/null
+++ b/tests/test_relations.py
@@ -0,0 +1,169 @@
+from .utils import mock_reverse, fail_reverse, BadType, MockObject, MockQueryset
+from django.core.exceptions import ImproperlyConfigured
+from django.utils.datastructures import MultiValueDict
+from rest_framework import serializers
+from rest_framework.fields import empty
+from rest_framework.test import APISimpleTestCase
+import pytest
+
+
+class TestStringRelatedField(APISimpleTestCase):
+ def setUp(self):
+ self.instance = MockObject(pk=1, name='foo')
+ self.field = serializers.StringRelatedField()
+
+ def test_string_related_representation(self):
+ representation = self.field.to_representation(self.instance)
+ assert representation == '<MockObject name=foo, pk=1>'
+
+
+class TestPrimaryKeyRelatedField(APISimpleTestCase):
+ def setUp(self):
+ self.queryset = MockQueryset([
+ MockObject(pk=1, name='foo'),
+ MockObject(pk=2, name='bar'),
+ MockObject(pk=3, name='baz')
+ ])
+ self.instance = self.queryset.items[2]
+ self.field = serializers.PrimaryKeyRelatedField(queryset=self.queryset)
+
+ def test_pk_related_lookup_exists(self):
+ instance = self.field.to_internal_value(self.instance.pk)
+ assert instance is self.instance
+
+ def test_pk_related_lookup_does_not_exist(self):
+ with pytest.raises(serializers.ValidationError) as excinfo:
+ self.field.to_internal_value(4)
+ msg = excinfo.value.detail[0]
+ assert msg == 'Invalid pk "4" - object does not exist.'
+
+ def test_pk_related_lookup_invalid_type(self):
+ with pytest.raises(serializers.ValidationError) as excinfo:
+ self.field.to_internal_value(BadType())
+ msg = excinfo.value.detail[0]
+ assert msg == 'Incorrect type. Expected pk value, received BadType.'
+
+ def test_pk_representation(self):
+ representation = self.field.to_representation(self.instance)
+ assert representation == self.instance.pk
+
+
+class TestHyperlinkedIdentityField(APISimpleTestCase):
+ def setUp(self):
+ self.instance = MockObject(pk=1, name='foo')
+ self.field = serializers.HyperlinkedIdentityField(view_name='example')
+ self.field.reverse = mock_reverse
+ self.field._context = {'request': True}
+
+ def test_representation(self):
+ representation = self.field.to_representation(self.instance)
+ assert representation == 'http://example.org/example/1/'
+
+ def test_representation_unsaved_object(self):
+ representation = self.field.to_representation(MockObject(pk=None))
+ assert representation is None
+
+ def test_representation_with_format(self):
+ self.field._context['format'] = 'xml'
+ representation = self.field.to_representation(self.instance)
+ assert representation == 'http://example.org/example/1.xml/'
+
+ def test_improperly_configured(self):
+ """
+ If a matching view cannot be reversed with the given instance,
+ the the user has misconfigured something, as the URL conf and the
+ hyperlinked field do not match.
+ """
+ self.field.reverse = fail_reverse
+ with pytest.raises(ImproperlyConfigured):
+ self.field.to_representation(self.instance)
+
+
+class TestHyperlinkedIdentityFieldWithFormat(APISimpleTestCase):
+ """
+ Tests for a hyperlinked identity field that has a `format` set,
+ which enforces that alternate formats are never linked too.
+
+ Eg. If your API includes some endpoints that accept both `.xml` and `.json`,
+ but other endpoints that only accept `.json`, we allow for hyperlinked
+ relationships that enforce only a single suffix type.
+ """
+
+ def setUp(self):
+ self.instance = MockObject(pk=1, name='foo')
+ self.field = serializers.HyperlinkedIdentityField(view_name='example', format='json')
+ self.field.reverse = mock_reverse
+ self.field._context = {'request': True}
+
+ def test_representation(self):
+ representation = self.field.to_representation(self.instance)
+ assert representation == 'http://example.org/example/1/'
+
+ def test_representation_with_format(self):
+ self.field._context['format'] = 'xml'
+ representation = self.field.to_representation(self.instance)
+ assert representation == 'http://example.org/example/1.json/'
+
+
+class TestSlugRelatedField(APISimpleTestCase):
+ def setUp(self):
+ self.queryset = MockQueryset([
+ MockObject(pk=1, name='foo'),
+ MockObject(pk=2, name='bar'),
+ MockObject(pk=3, name='baz')
+ ])
+ self.instance = self.queryset.items[2]
+ self.field = serializers.SlugRelatedField(
+ slug_field='name', queryset=self.queryset
+ )
+
+ def test_slug_related_lookup_exists(self):
+ instance = self.field.to_internal_value(self.instance.name)
+ assert instance is self.instance
+
+ def test_slug_related_lookup_does_not_exist(self):
+ with pytest.raises(serializers.ValidationError) as excinfo:
+ self.field.to_internal_value('doesnotexist')
+ msg = excinfo.value.detail[0]
+ assert msg == 'Object with name=doesnotexist does not exist.'
+
+ def test_slug_related_lookup_invalid_type(self):
+ with pytest.raises(serializers.ValidationError) as excinfo:
+ self.field.to_internal_value(BadType())
+ msg = excinfo.value.detail[0]
+ assert msg == 'Invalid value.'
+
+ def test_representation(self):
+ representation = self.field.to_representation(self.instance)
+ assert representation == self.instance.name
+
+
+class TestManyRelatedField(APISimpleTestCase):
+ def setUp(self):
+ self.instance = MockObject(pk=1, name='foo')
+ self.field = serializers.StringRelatedField(many=True)
+ self.field.field_name = 'foo'
+
+ def test_get_value_regular_dictionary_full(self):
+ assert 'bar' == self.field.get_value({'foo': 'bar'})
+ assert empty == self.field.get_value({'baz': 'bar'})
+
+ def test_get_value_regular_dictionary_partial(self):
+ setattr(self.field.root, 'partial', True)
+ assert 'bar' == self.field.get_value({'foo': 'bar'})
+ assert empty == self.field.get_value({'baz': 'bar'})
+
+ def test_get_value_multi_dictionary_full(self):
+ mvd = MultiValueDict({'foo': ['bar1', 'bar2']})
+ assert ['bar1', 'bar2'] == self.field.get_value(mvd)
+
+ mvd = MultiValueDict({'baz': ['bar1', 'bar2']})
+ assert [] == self.field.get_value(mvd)
+
+ def test_get_value_multi_dictionary_partial(self):
+ setattr(self.field.root, 'partial', True)
+ mvd = MultiValueDict({'foo': ['bar1', 'bar2']})
+ assert ['bar1', 'bar2'] == self.field.get_value(mvd)
+
+ mvd = MultiValueDict({'baz': ['bar1', 'bar2']})
+ assert empty == self.field.get_value(mvd)
diff --git a/tests/test_relations_generic.py b/tests/test_relations_generic.py
new file mode 100644
index 00000000..b600b333
--- /dev/null
+++ b/tests/test_relations_generic.py
@@ -0,0 +1,104 @@
+from __future__ import unicode_literals
+from django.contrib.contenttypes.models import ContentType
+from django.contrib.contenttypes.generic import GenericRelation, GenericForeignKey
+from django.db import models
+from django.test import TestCase
+from django.utils.encoding import python_2_unicode_compatible
+from rest_framework import serializers
+
+
+@python_2_unicode_compatible
+class Tag(models.Model):
+ """
+ Tags have a descriptive slug, and are attached to an arbitrary object.
+ """
+ tag = models.SlugField()
+ content_type = models.ForeignKey(ContentType)
+ object_id = models.PositiveIntegerField()
+ tagged_item = GenericForeignKey('content_type', 'object_id')
+
+ def __str__(self):
+ return self.tag
+
+
+@python_2_unicode_compatible
+class Bookmark(models.Model):
+ """
+ A URL bookmark that may have multiple tags attached.
+ """
+ url = models.URLField()
+ tags = GenericRelation(Tag)
+
+ def __str__(self):
+ return 'Bookmark: %s' % self.url
+
+
+@python_2_unicode_compatible
+class Note(models.Model):
+ """
+ A textual note that may have multiple tags attached.
+ """
+ text = models.TextField()
+ tags = GenericRelation(Tag)
+
+ def __str__(self):
+ return 'Note: %s' % self.text
+
+
+class TestGenericRelations(TestCase):
+ def setUp(self):
+ self.bookmark = Bookmark.objects.create(url='https://www.djangoproject.com/')
+ Tag.objects.create(tagged_item=self.bookmark, tag='django')
+ Tag.objects.create(tagged_item=self.bookmark, tag='python')
+ self.note = Note.objects.create(text='Remember the milk')
+ Tag.objects.create(tagged_item=self.note, tag='reminder')
+
+ def test_generic_relation(self):
+ """
+ Test a relationship that spans a GenericRelation field.
+ IE. A reverse generic relationship.
+ """
+
+ class BookmarkSerializer(serializers.ModelSerializer):
+ tags = serializers.StringRelatedField(many=True)
+
+ class Meta:
+ model = Bookmark
+ fields = ('tags', 'url')
+
+ serializer = BookmarkSerializer(self.bookmark)
+ expected = {
+ 'tags': ['django', 'python'],
+ 'url': 'https://www.djangoproject.com/'
+ }
+ self.assertEqual(serializer.data, expected)
+
+ def test_generic_fk(self):
+ """
+ Test a relationship that spans a GenericForeignKey field.
+ IE. A forward generic relationship.
+ """
+
+ class TagSerializer(serializers.ModelSerializer):
+ tagged_item = serializers.StringRelatedField()
+
+ class Meta:
+ model = Tag
+ fields = ('tag', 'tagged_item')
+
+ serializer = TagSerializer(Tag.objects.all(), many=True)
+ expected = [
+ {
+ 'tag': 'django',
+ 'tagged_item': 'Bookmark: https://www.djangoproject.com/'
+ },
+ {
+ 'tag': 'python',
+ 'tagged_item': 'Bookmark: https://www.djangoproject.com/'
+ },
+ {
+ 'tag': 'reminder',
+ 'tagged_item': 'Note: Remember the milk'
+ }
+ ]
+ self.assertEqual(serializer.data, expected)
diff --git a/tests/test_relations_hyperlink.py b/tests/test_relations_hyperlink.py
new file mode 100644
index 00000000..33b09713
--- /dev/null
+++ b/tests/test_relations_hyperlink.py
@@ -0,0 +1,444 @@
+from __future__ import unicode_literals
+from django.conf.urls import url
+from django.test import TestCase
+from rest_framework import serializers
+from rest_framework.test import APIRequestFactory
+from tests.models import (
+ ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource,
+ NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource
+)
+
+factory = APIRequestFactory()
+request = factory.get('/') # Just to ensure we have a request in the serializer context
+
+
+def dummy_view(request, pk):
+ pass
+
+
+urlpatterns = [
+ url(r'^dummyurl/(?P<pk>[0-9]+)/$', dummy_view, name='dummy-url'),
+ url(r'^manytomanysource/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanysource-detail'),
+ url(r'^manytomanytarget/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanytarget-detail'),
+ url(r'^foreignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeysource-detail'),
+ url(r'^foreignkeytarget/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeytarget-detail'),
+ url(r'^nullableforeignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableforeignkeysource-detail'),
+ url(r'^onetoonetarget/(?P<pk>[0-9]+)/$', dummy_view, name='onetoonetarget-detail'),
+ url(r'^nullableonetoonesource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableonetoonesource-detail'),
+]
+
+
+# ManyToMany
+class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = ManyToManyTarget
+ fields = ('url', 'name', 'sources')
+
+
+class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = ManyToManySource
+ fields = ('url', 'name', 'targets')
+
+
+# ForeignKey
+class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = ForeignKeyTarget
+ fields = ('url', 'name', 'sources')
+
+
+class ForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = ForeignKeySource
+ fields = ('url', 'name', 'target')
+
+
+# Nullable ForeignKey
+class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = NullableForeignKeySource
+ fields = ('url', 'name', 'target')
+
+
+# Nullable OneToOne
+class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = OneToOneTarget
+ fields = ('url', 'name', 'nullable_source')
+
+
+# TODO: Add test that .data cannot be accessed prior to .is_valid
+
+class HyperlinkedManyToManyTests(TestCase):
+ urls = 'tests.test_relations_hyperlink'
+
+ def setUp(self):
+ for idx in range(1, 4):
+ target = ManyToManyTarget(name='target-%d' % idx)
+ target.save()
+ source = ManyToManySource(name='source-%d' % idx)
+ source.save()
+ for target in ManyToManyTarget.objects.all():
+ source.targets.add(target)
+
+ def test_many_to_many_retrieve(self):
+ queryset = ManyToManySource.objects.all()
+ serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']},
+ {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
+ {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
+ ]
+ with self.assertNumQueries(4):
+ self.assertEqual(serializer.data, expected)
+
+ def test_many_to_many_retrieve_prefetch_related(self):
+ queryset = ManyToManySource.objects.all().prefetch_related('targets')
+ serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
+ with self.assertNumQueries(2):
+ serializer.data
+
+ def test_reverse_many_to_many_retrieve(self):
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
+ {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
+ {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}
+ ]
+ with self.assertNumQueries(4):
+ self.assertEqual(serializer.data, expected)
+
+ def test_many_to_many_update(self):
+ data = {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
+ instance = ManyToManySource.objects.get(pk=1)
+ serializer = ManyToManySourceSerializer(instance, data=data, context={'request': request})
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ self.assertEqual(serializer.data, data)
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = ManyToManySource.objects.all()
+ serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']},
+ {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
+ {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_reverse_many_to_many_update(self):
+ data = {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']}
+ instance = ManyToManyTarget.objects.get(pk=1)
+ serializer = ManyToManyTargetSerializer(instance, data=data, context={'request': request})
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ self.assertEqual(serializer.data, data)
+
+ # Ensure target 1 is updated, and everything else is as expected
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']},
+ {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
+ {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}
+
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_many_to_many_create(self):
+ data = {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']}
+ serializer = ManyToManySourceSerializer(data=data, context={'request': request})
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure source 4 is added, and everything else is as expected
+ queryset = ManyToManySource.objects.all()
+ serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']},
+ {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
+ {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']},
+ {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_reverse_many_to_many_create(self):
+ data = {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']}
+ serializer = ManyToManyTargetSerializer(data=data, context={'request': request})
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-4')
+
+ # Ensure target 4 is added, and everything else is as expected
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
+ {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
+ {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']},
+ {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+
+class HyperlinkedForeignKeyTests(TestCase):
+ urls = 'tests.test_relations_hyperlink'
+
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ new_target = ForeignKeyTarget(name='target-2')
+ new_target.save()
+ for idx in range(1, 4):
+ source = ForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_retrieve(self):
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}
+ ]
+ with self.assertNumQueries(1):
+ self.assertEqual(serializer.data, expected)
+
+ def test_reverse_foreign_key_retrieve(self):
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']},
+ {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
+ ]
+ with self.assertNumQueries(3):
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update(self):
+ data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ self.assertEqual(serializer.data, data)
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'},
+ {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_incorrect_type(self):
+ data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 2}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'target': ['Incorrect type. Expected URL string, received int.']})
+
+ def test_reverse_foreign_key_update(self):
+ data = {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}
+ instance = ForeignKeyTarget.objects.get(pk=2)
+ serializer = ForeignKeyTargetSerializer(instance, data=data, context={'request': request})
+ self.assertTrue(serializer.is_valid())
+ # We shouldn't have saved anything to the db yet since save
+ # hasn't been called.
+ queryset = ForeignKeyTarget.objects.all()
+ new_serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']},
+ {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
+ ]
+ self.assertEqual(new_serializer.data, expected)
+
+ serializer.save()
+ self.assertEqual(serializer.data, data)
+
+ # Ensure target 2 is update, and everything else is as expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']},
+ {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_create(self):
+ data = {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'}
+ serializer = ForeignKeySourceSerializer(data=data, context={'request': request})
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_reverse_foreign_key_create(self):
+ data = {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}
+ serializer = ForeignKeyTargetSerializer(data=data, context={'request': request})
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-3')
+
+ # Ensure target 4 is added, and everything else is as expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']},
+ {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
+ {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_with_invalid_null(self):
+ data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': None}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'target': ['This field may not be null.']})
+
+
+class HyperlinkedNullableForeignKeyTests(TestCase):
+ urls = 'tests.test_relations_hyperlink'
+
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ for idx in range(1, 4):
+ if idx == 3:
+ target = None
+ source = NullableForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_retrieve_with_null(self):
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_create_with_valid_null(self):
+ data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request})
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure source 4 is created, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
+ {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_create_with_valid_emptystring(self):
+ """
+ The emptystring should be interpreted as null in the context
+ of relationships.
+ """
+ data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': ''}
+ expected_data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request})
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, expected_data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure source 4 is created, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
+ {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_with_valid_null(self):
+ data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}
+ instance = NullableForeignKeySource.objects.get(pk=1)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request})
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ self.assertEqual(serializer.data, data)
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None},
+ {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_with_valid_emptystring(self):
+ """
+ The emptystring should be interpreted as null in the context
+ of relationships.
+ """
+ data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': ''}
+ expected_data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}
+ instance = NullableForeignKeySource.objects.get(pk=1)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request})
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ self.assertEqual(serializer.data, expected_data)
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None},
+ {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
+ {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+
+class HyperlinkedNullableOneToOneTests(TestCase):
+ urls = 'tests.test_relations_hyperlink'
+
+ def setUp(self):
+ target = OneToOneTarget(name='target-1')
+ target.save()
+ new_target = OneToOneTarget(name='target-2')
+ new_target.save()
+ source = NullableOneToOneSource(name='source-1', target=target)
+ source.save()
+
+ def test_reverse_foreign_key_retrieve_with_null(self):
+ queryset = OneToOneTarget.objects.all()
+ serializer = NullableOneToOneTargetSerializer(queryset, many=True, context={'request': request})
+ expected = [
+ {'url': 'http://testserver/onetoonetarget/1/', 'name': 'target-1', 'nullable_source': 'http://testserver/nullableonetoonesource/1/'},
+ {'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None},
+ ]
+ self.assertEqual(serializer.data, expected)
diff --git a/tests/test_relations_pk.py b/tests/test_relations_pk.py
new file mode 100644
index 00000000..ca43272b
--- /dev/null
+++ b/tests/test_relations_pk.py
@@ -0,0 +1,450 @@
+from __future__ import unicode_literals
+from django.test import TestCase
+from django.utils import six
+from rest_framework import serializers
+from tests.models import (
+ ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource,
+ NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource,
+)
+
+
+# ManyToMany
+class ManyToManyTargetSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ManyToManyTarget
+ fields = ('id', 'name', 'sources')
+
+
+class ManyToManySourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ManyToManySource
+ fields = ('id', 'name', 'targets')
+
+
+# ForeignKey
+class ForeignKeyTargetSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ForeignKeyTarget
+ fields = ('id', 'name', 'sources')
+
+
+class ForeignKeySourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ForeignKeySource
+ fields = ('id', 'name', 'target')
+
+
+# Nullable ForeignKey
+class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = NullableForeignKeySource
+ fields = ('id', 'name', 'target')
+
+
+# Nullable OneToOne
+class NullableOneToOneTargetSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = OneToOneTarget
+ fields = ('id', 'name', 'nullable_source')
+
+
+# TODO: Add test that .data cannot be accessed prior to .is_valid
+
+class PKManyToManyTests(TestCase):
+ def setUp(self):
+ for idx in range(1, 4):
+ target = ManyToManyTarget(name='target-%d' % idx)
+ target.save()
+ source = ManyToManySource(name='source-%d' % idx)
+ source.save()
+ for target in ManyToManyTarget.objects.all():
+ source.targets.add(target)
+
+ def test_many_to_many_retrieve(self):
+ queryset = ManyToManySource.objects.all()
+ serializer = ManyToManySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'targets': [1]},
+ {'id': 2, 'name': 'source-2', 'targets': [1, 2]},
+ {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}
+ ]
+ with self.assertNumQueries(4):
+ self.assertEqual(serializer.data, expected)
+
+ def test_many_to_many_retrieve_prefetch_related(self):
+ queryset = ManyToManySource.objects.all().prefetch_related('targets')
+ serializer = ManyToManySourceSerializer(queryset, many=True)
+ with self.assertNumQueries(2):
+ serializer.data
+
+ def test_reverse_many_to_many_retrieve(self):
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
+ {'id': 2, 'name': 'target-2', 'sources': [2, 3]},
+ {'id': 3, 'name': 'target-3', 'sources': [3]}
+ ]
+ with self.assertNumQueries(4):
+ self.assertEqual(serializer.data, expected)
+
+ def test_many_to_many_update(self):
+ data = {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]}
+ instance = ManyToManySource.objects.get(pk=1)
+ serializer = ManyToManySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ self.assertEqual(serializer.data, data)
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = ManyToManySource.objects.all()
+ serializer = ManyToManySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]},
+ {'id': 2, 'name': 'source-2', 'targets': [1, 2]},
+ {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_reverse_many_to_many_update(self):
+ data = {'id': 1, 'name': 'target-1', 'sources': [1]}
+ instance = ManyToManyTarget.objects.get(pk=1)
+ serializer = ManyToManyTargetSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ self.assertEqual(serializer.data, data)
+
+ # Ensure target 1 is updated, and everything else is as expected
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': [1]},
+ {'id': 2, 'name': 'target-2', 'sources': [2, 3]},
+ {'id': 3, 'name': 'target-3', 'sources': [3]}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_many_to_many_create(self):
+ data = {'id': 4, 'name': 'source-4', 'targets': [1, 3]}
+ serializer = ManyToManySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure source 4 is added, and everything else is as expected
+ queryset = ManyToManySource.objects.all()
+ serializer = ManyToManySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'targets': [1]},
+ {'id': 2, 'name': 'source-2', 'targets': [1, 2]},
+ {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]},
+ {'id': 4, 'name': 'source-4', 'targets': [1, 3]},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_many_to_many_unsaved(self):
+ source = ManyToManySource(name='source-unsaved')
+
+ serializer = ManyToManySourceSerializer(source)
+
+ expected = {'id': None, 'name': 'source-unsaved', 'targets': []}
+ # no query if source hasn't been created yet
+ with self.assertNumQueries(0):
+ self.assertEqual(serializer.data, expected)
+
+ def test_reverse_many_to_many_create(self):
+ data = {'id': 4, 'name': 'target-4', 'sources': [1, 3]}
+ serializer = ManyToManyTargetSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-4')
+
+ # Ensure target 4 is added, and everything else is as expected
+ queryset = ManyToManyTarget.objects.all()
+ serializer = ManyToManyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
+ {'id': 2, 'name': 'target-2', 'sources': [2, 3]},
+ {'id': 3, 'name': 'target-3', 'sources': [3]},
+ {'id': 4, 'name': 'target-4', 'sources': [1, 3]}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+
+class PKForeignKeyTests(TestCase):
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ new_target = ForeignKeyTarget(name='target-2')
+ new_target.save()
+ for idx in range(1, 4):
+ source = ForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_retrieve(self):
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 1},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': 1}
+ ]
+ with self.assertNumQueries(1):
+ self.assertEqual(serializer.data, expected)
+
+ def test_reverse_foreign_key_retrieve(self):
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
+ {'id': 2, 'name': 'target-2', 'sources': []},
+ ]
+ with self.assertNumQueries(3):
+ self.assertEqual(serializer.data, expected)
+
+ def test_reverse_foreign_key_retrieve_prefetch_related(self):
+ queryset = ForeignKeyTarget.objects.all().prefetch_related('sources')
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ with self.assertNumQueries(2):
+ serializer.data
+
+ def test_foreign_key_update(self):
+ data = {'id': 1, 'name': 'source-1', 'target': 2}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ self.assertEqual(serializer.data, data)
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 2},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': 1}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_incorrect_type(self):
+ data = {'id': 1, 'name': 'source-1', 'target': 'foo'}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'target': ['Incorrect type. Expected pk value, received %s.' % six.text_type.__name__]})
+
+ def test_reverse_foreign_key_update(self):
+ data = {'id': 2, 'name': 'target-2', 'sources': [1, 3]}
+ instance = ForeignKeyTarget.objects.get(pk=2)
+ serializer = ForeignKeyTargetSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ # We shouldn't have saved anything to the db yet since save
+ # hasn't been called.
+ queryset = ForeignKeyTarget.objects.all()
+ new_serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
+ {'id': 2, 'name': 'target-2', 'sources': []},
+ ]
+ self.assertEqual(new_serializer.data, expected)
+
+ serializer.save()
+ self.assertEqual(serializer.data, data)
+
+ # Ensure target 2 is update, and everything else is as expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': [2]},
+ {'id': 2, 'name': 'target-2', 'sources': [1, 3]},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_create(self):
+ data = {'id': 4, 'name': 'source-4', 'target': 2}
+ serializer = ForeignKeySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure source 4 is added, and everything else is as expected
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 1},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': 1},
+ {'id': 4, 'name': 'source-4', 'target': 2},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_reverse_foreign_key_create(self):
+ data = {'id': 3, 'name': 'target-3', 'sources': [1, 3]}
+ serializer = ForeignKeyTargetSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-3')
+
+ # Ensure target 3 is added, and everything else is as expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': [2]},
+ {'id': 2, 'name': 'target-2', 'sources': []},
+ {'id': 3, 'name': 'target-3', 'sources': [1, 3]},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_with_invalid_null(self):
+ data = {'id': 1, 'name': 'source-1', 'target': None}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'target': ['This field may not be null.']})
+
+ def test_foreign_key_with_unsaved(self):
+ source = ForeignKeySource(name='source-unsaved')
+ expected = {'id': None, 'name': 'source-unsaved', 'target': None}
+
+ serializer = ForeignKeySourceSerializer(source)
+
+ # no query if source hasn't been created yet
+ with self.assertNumQueries(0):
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_with_empty(self):
+ """
+ Regression test for #1072
+
+ https://github.com/tomchristie/django-rest-framework/issues/1072
+ """
+ serializer = NullableForeignKeySourceSerializer()
+ self.assertEqual(serializer.data['target'], None)
+
+
+class PKNullableForeignKeyTests(TestCase):
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ for idx in range(1, 4):
+ if idx == 3:
+ target = None
+ source = NullableForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_retrieve_with_null(self):
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 1},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': None},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_create_with_valid_null(self):
+ data = {'id': 4, 'name': 'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure source 4 is created, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 1},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': None},
+ {'id': 4, 'name': 'source-4', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_create_with_valid_emptystring(self):
+ """
+ The emptystring should be interpreted as null in the context
+ of relationships.
+ """
+ data = {'id': 4, 'name': 'source-4', 'target': ''}
+ expected_data = {'id': 4, 'name': 'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, expected_data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure source 4 is created, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 1},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': None},
+ {'id': 4, 'name': 'source-4', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_with_valid_null(self):
+ data = {'id': 1, 'name': 'source-1', 'target': None}
+ instance = NullableForeignKeySource.objects.get(pk=1)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ self.assertEqual(serializer.data, data)
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': None},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_with_valid_emptystring(self):
+ """
+ The emptystring should be interpreted as null in the context
+ of relationships.
+ """
+ data = {'id': 1, 'name': 'source-1', 'target': ''}
+ expected_data = {'id': 1, 'name': 'source-1', 'target': None}
+ instance = NullableForeignKeySource.objects.get(pk=1)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ self.assertEqual(serializer.data, expected_data)
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': None},
+ {'id': 2, 'name': 'source-2', 'target': 1},
+ {'id': 3, 'name': 'source-3', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+
+class PKNullableOneToOneTests(TestCase):
+ def setUp(self):
+ target = OneToOneTarget(name='target-1')
+ target.save()
+ new_target = OneToOneTarget(name='target-2')
+ new_target.save()
+ source = NullableOneToOneSource(name='source-1', target=new_target)
+ source.save()
+
+ def test_reverse_foreign_key_retrieve_with_null(self):
+ queryset = OneToOneTarget.objects.all()
+ serializer = NullableOneToOneTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'nullable_source': None},
+ {'id': 2, 'name': 'target-2', 'nullable_source': 1},
+ ]
+ self.assertEqual(serializer.data, expected)
diff --git a/tests/test_relations_slug.py b/tests/test_relations_slug.py
new file mode 100644
index 00000000..cd2cb1ed
--- /dev/null
+++ b/tests/test_relations_slug.py
@@ -0,0 +1,281 @@
+from django.test import TestCase
+from rest_framework import serializers
+from tests.models import NullableForeignKeySource, ForeignKeySource, ForeignKeyTarget
+
+
+class ForeignKeyTargetSerializer(serializers.ModelSerializer):
+ sources = serializers.SlugRelatedField(
+ slug_field='name',
+ queryset=ForeignKeySource.objects.all(),
+ many=True
+ )
+
+ class Meta:
+ model = ForeignKeyTarget
+
+
+class ForeignKeySourceSerializer(serializers.ModelSerializer):
+ target = serializers.SlugRelatedField(
+ slug_field='name',
+ queryset=ForeignKeyTarget.objects.all()
+ )
+
+ class Meta:
+ model = ForeignKeySource
+
+
+class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
+ target = serializers.SlugRelatedField(
+ slug_field='name',
+ queryset=ForeignKeyTarget.objects.all(),
+ allow_null=True
+ )
+
+ class Meta:
+ model = NullableForeignKeySource
+
+
+# TODO: M2M Tests, FKTests (Non-nullable), One2One
+class SlugForeignKeyTests(TestCase):
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ new_target = ForeignKeyTarget(name='target-2')
+ new_target.save()
+ for idx in range(1, 4):
+ source = ForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_retrieve(self):
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 'target-1'},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': 'target-1'}
+ ]
+ with self.assertNumQueries(4):
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_retrieve_select_related(self):
+ queryset = ForeignKeySource.objects.all().select_related('target')
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
+ with self.assertNumQueries(1):
+ serializer.data
+
+ def test_reverse_foreign_key_retrieve(self):
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']},
+ {'id': 2, 'name': 'target-2', 'sources': []},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_reverse_foreign_key_retrieve_prefetch_related(self):
+ queryset = ForeignKeyTarget.objects.all().prefetch_related('sources')
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ with self.assertNumQueries(2):
+ serializer.data
+
+ def test_foreign_key_update(self):
+ data = {'id': 1, 'name': 'source-1', 'target': 'target-2'}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ self.assertEqual(serializer.data, data)
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 'target-2'},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': 'target-1'}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_incorrect_type(self):
+ data = {'id': 1, 'name': 'source-1', 'target': 123}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'target': ['Object with name=123 does not exist.']})
+
+ def test_reverse_foreign_key_update(self):
+ data = {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']}
+ instance = ForeignKeyTarget.objects.get(pk=2)
+ serializer = ForeignKeyTargetSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ # We shouldn't have saved anything to the db yet since save
+ # hasn't been called.
+ queryset = ForeignKeyTarget.objects.all()
+ new_serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']},
+ {'id': 2, 'name': 'target-2', 'sources': []},
+ ]
+ self.assertEqual(new_serializer.data, expected)
+
+ serializer.save()
+ self.assertEqual(serializer.data, data)
+
+ # Ensure target 2 is update, and everything else is as expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': ['source-2']},
+ {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_create(self):
+ data = {'id': 4, 'name': 'source-4', 'target': 'target-2'}
+ serializer = ForeignKeySourceSerializer(data=data)
+ serializer.is_valid()
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure source 4 is added, and everything else is as expected
+ queryset = ForeignKeySource.objects.all()
+ serializer = ForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 'target-1'},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': 'target-1'},
+ {'id': 4, 'name': 'source-4', 'target': 'target-2'},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_reverse_foreign_key_create(self):
+ data = {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']}
+ serializer = ForeignKeyTargetSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'target-3')
+
+ # Ensure target 3 is added, and everything else is as expected
+ queryset = ForeignKeyTarget.objects.all()
+ serializer = ForeignKeyTargetSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'target-1', 'sources': ['source-2']},
+ {'id': 2, 'name': 'target-2', 'sources': []},
+ {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_with_invalid_null(self):
+ data = {'id': 1, 'name': 'source-1', 'target': None}
+ instance = ForeignKeySource.objects.get(pk=1)
+ serializer = ForeignKeySourceSerializer(instance, data=data)
+ self.assertFalse(serializer.is_valid())
+ self.assertEqual(serializer.errors, {'target': ['This field may not be null.']})
+
+
+class SlugNullableForeignKeyTests(TestCase):
+ def setUp(self):
+ target = ForeignKeyTarget(name='target-1')
+ target.save()
+ for idx in range(1, 4):
+ if idx == 3:
+ target = None
+ source = NullableForeignKeySource(name='source-%d' % idx, target=target)
+ source.save()
+
+ def test_foreign_key_retrieve_with_null(self):
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 'target-1'},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': None},
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_create_with_valid_null(self):
+ data = {'id': 4, 'name': 'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure source 4 is created, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 'target-1'},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': None},
+ {'id': 4, 'name': 'source-4', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_create_with_valid_emptystring(self):
+ """
+ The emptystring should be interpreted as null in the context
+ of relationships.
+ """
+ data = {'id': 4, 'name': 'source-4', 'target': ''}
+ expected_data = {'id': 4, 'name': 'source-4', 'target': None}
+ serializer = NullableForeignKeySourceSerializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ obj = serializer.save()
+ self.assertEqual(serializer.data, expected_data)
+ self.assertEqual(obj.name, 'source-4')
+
+ # Ensure source 4 is created, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': 'target-1'},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': None},
+ {'id': 4, 'name': 'source-4', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_with_valid_null(self):
+ data = {'id': 1, 'name': 'source-1', 'target': None}
+ instance = NullableForeignKeySource.objects.get(pk=1)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ self.assertEqual(serializer.data, data)
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': None},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
+
+ def test_foreign_key_update_with_valid_emptystring(self):
+ """
+ The emptystring should be interpreted as null in the context
+ of relationships.
+ """
+ data = {'id': 1, 'name': 'source-1', 'target': ''}
+ expected_data = {'id': 1, 'name': 'source-1', 'target': None}
+ instance = NullableForeignKeySource.objects.get(pk=1)
+ serializer = NullableForeignKeySourceSerializer(instance, data=data)
+ self.assertTrue(serializer.is_valid())
+ serializer.save()
+ self.assertEqual(serializer.data, expected_data)
+
+ # Ensure source 1 is updated, and everything else is as expected
+ queryset = NullableForeignKeySource.objects.all()
+ serializer = NullableForeignKeySourceSerializer(queryset, many=True)
+ expected = [
+ {'id': 1, 'name': 'source-1', 'target': None},
+ {'id': 2, 'name': 'source-2', 'target': 'target-1'},
+ {'id': 3, 'name': 'source-3', 'target': None}
+ ]
+ self.assertEqual(serializer.data, expected)
diff --git a/tests/test_renderers.py b/tests/test_renderers.py
new file mode 100644
index 00000000..cb76f683
--- /dev/null
+++ b/tests/test_renderers.py
@@ -0,0 +1,473 @@
+# -*- coding: utf-8 -*-
+from __future__ import unicode_literals
+from django.conf.urls import patterns, url, include
+from django.core.cache import cache
+from django.db import models
+from django.test import TestCase
+from django.utils import six
+from django.utils.translation import ugettext_lazy as _
+from rest_framework import status, permissions
+from rest_framework.compat import OrderedDict
+from rest_framework.response import Response
+from rest_framework.views import APIView
+from rest_framework import serializers
+from rest_framework.renderers import (
+ BaseRenderer, JSONRenderer, BrowsableAPIRenderer, HTMLFormRenderer
+)
+from rest_framework.settings import api_settings
+from rest_framework.test import APIRequestFactory
+from collections import MutableMapping
+import json
+import re
+
+
+DUMMYSTATUS = status.HTTP_200_OK
+DUMMYCONTENT = 'dummycontent'
+
+
+def RENDERER_A_SERIALIZER(x):
+ return ('Renderer A: %s' % x).encode('ascii')
+
+
+def RENDERER_B_SERIALIZER(x):
+ return ('Renderer B: %s' % x).encode('ascii')
+
+
+expected_results = [
+ ((elem for elem in [1, 2, 3]), JSONRenderer, b'[1,2,3]') # Generator
+]
+
+
+class DummyTestModel(models.Model):
+ name = models.CharField(max_length=42, default='')
+
+
+class BasicRendererTests(TestCase):
+ def test_expected_results(self):
+ for value, renderer_cls, expected in expected_results:
+ output = renderer_cls().render(value)
+ self.assertEqual(output, expected)
+
+
+class RendererA(BaseRenderer):
+ media_type = 'mock/renderera'
+ format = "formata"
+
+ def render(self, data, media_type=None, renderer_context=None):
+ return RENDERER_A_SERIALIZER(data)
+
+
+class RendererB(BaseRenderer):
+ media_type = 'mock/rendererb'
+ format = "formatb"
+
+ def render(self, data, media_type=None, renderer_context=None):
+ return RENDERER_B_SERIALIZER(data)
+
+
+class MockView(APIView):
+ renderer_classes = (RendererA, RendererB)
+
+ def get(self, request, **kwargs):
+ response = Response(DUMMYCONTENT, status=DUMMYSTATUS)
+ return response
+
+
+class MockGETView(APIView):
+ def get(self, request, **kwargs):
+ return Response({'foo': ['bar', 'baz']})
+
+
+class MockPOSTView(APIView):
+ def post(self, request, **kwargs):
+ return Response({'foo': request.DATA})
+
+
+class EmptyGETView(APIView):
+ renderer_classes = (JSONRenderer,)
+
+ def get(self, request, **kwargs):
+ return Response(status=status.HTTP_204_NO_CONTENT)
+
+
+class HTMLView(APIView):
+ renderer_classes = (BrowsableAPIRenderer, )
+
+ def get(self, request, **kwargs):
+ return Response('text')
+
+
+class HTMLView1(APIView):
+ renderer_classes = (BrowsableAPIRenderer, JSONRenderer)
+
+ def get(self, request, **kwargs):
+ return Response('text')
+
+urlpatterns = patterns(
+ '',
+ url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB])),
+ url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB])),
+ url(r'^cache$', MockGETView.as_view()),
+ url(r'^parseerror$', MockPOSTView.as_view(renderer_classes=[JSONRenderer, BrowsableAPIRenderer])),
+ url(r'^html$', HTMLView.as_view()),
+ url(r'^html1$', HTMLView1.as_view()),
+ url(r'^empty$', EmptyGETView.as_view()),
+ url(r'^api', include('rest_framework.urls', namespace='rest_framework'))
+)
+
+
+class POSTDeniedPermission(permissions.BasePermission):
+ def has_permission(self, request, view):
+ return request.method != 'POST'
+
+
+class POSTDeniedView(APIView):
+ renderer_classes = (BrowsableAPIRenderer,)
+ permission_classes = (POSTDeniedPermission,)
+
+ def get(self, request):
+ return Response()
+
+ def post(self, request):
+ return Response()
+
+ def put(self, request):
+ return Response()
+
+ def patch(self, request):
+ return Response()
+
+
+class DocumentingRendererTests(TestCase):
+ def test_only_permitted_forms_are_displayed(self):
+ view = POSTDeniedView.as_view()
+ request = APIRequestFactory().get('/')
+ response = view(request).render()
+ self.assertNotContains(response, '>POST<')
+ self.assertContains(response, '>PUT<')
+ self.assertContains(response, '>PATCH<')
+
+
+class RendererEndToEndTests(TestCase):
+ """
+ End-to-end testing of renderers using an RendererMixin on a generic view.
+ """
+
+ urls = 'tests.test_renderers'
+
+ def test_default_renderer_serializes_content(self):
+ """If the Accept header is not set the default renderer should serialize the response."""
+ resp = self.client.get('/')
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_head_method_serializes_no_content(self):
+ """No response must be included in HEAD requests."""
+ resp = self.client.head('/')
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, six.b(''))
+
+ def test_default_renderer_serializes_content_on_accept_any(self):
+ """If the Accept header is set to */* the default renderer should serialize the response."""
+ resp = self.client.get('/', HTTP_ACCEPT='*/*')
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_serializes_content_default_case(self):
+ """If the Accept header is set the specified renderer should serialize the response.
+ (In this case we check that works for the default renderer)"""
+ resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_serializes_content_non_default_case(self):
+ """If the Accept header is set the specified renderer should serialize the response.
+ (In this case we check that works for a non-default renderer)"""
+ resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_serializes_content_on_accept_query(self):
+ """The '_accept' query string should behave in the same way as the Accept header."""
+ param = '?%s=%s' % (
+ api_settings.URL_ACCEPT_OVERRIDE,
+ RendererB.media_type
+ )
+ resp = self.client.get('/' + param)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_unsatisfiable_accept_header_on_request_returns_406_status(self):
+ """If the Accept header is unsatisfiable we should return a 406 Not Acceptable response."""
+ resp = self.client.get('/', HTTP_ACCEPT='foo/bar')
+ self.assertEqual(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE)
+
+ def test_specified_renderer_serializes_content_on_format_query(self):
+ """If a 'format' query is specified, the renderer with the matching
+ format attribute should serialize the response."""
+ param = '?%s=%s' % (
+ api_settings.URL_FORMAT_OVERRIDE,
+ RendererB.format
+ )
+ resp = self.client.get('/' + param)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_serializes_content_on_format_kwargs(self):
+ """If a 'format' keyword arg is specified, the renderer with the matching
+ format attribute should serialize the response."""
+ resp = self.client.get('/something.formatb')
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_is_used_on_format_query_with_matching_accept(self):
+ """If both a 'format' query and a matching Accept header specified,
+ the renderer with the matching format attribute should serialize the response."""
+ param = '?%s=%s' % (
+ api_settings.URL_FORMAT_OVERRIDE,
+ RendererB.format
+ )
+ resp = self.client.get('/' + param,
+ HTTP_ACCEPT=RendererB.media_type)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_parse_error_renderers_browsable_api(self):
+ """Invalid data should still render the browsable API correctly."""
+ resp = self.client.post('/parseerror', data='foobar', content_type='application/json', HTTP_ACCEPT='text/html')
+ self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
+ self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST)
+
+ def test_204_no_content_responses_have_no_content_type_set(self):
+ """
+ Regression test for #1196
+
+ https://github.com/tomchristie/django-rest-framework/issues/1196
+ """
+ resp = self.client.get('/empty')
+ self.assertEqual(resp.get('Content-Type', None), None)
+ self.assertEqual(resp.status_code, status.HTTP_204_NO_CONTENT)
+
+ def test_contains_headers_of_api_response(self):
+ """
+ Issue #1437
+
+ Test we display the headers of the API response and not those from the
+ HTML response
+ """
+ resp = self.client.get('/html1')
+ self.assertContains(resp, '>GET, HEAD, OPTIONS<')
+ self.assertContains(resp, '>application/json<')
+ self.assertNotContains(resp, '>text/html; charset=utf-8<')
+
+
+_flat_repr = '{"foo":["bar","baz"]}'
+_indented_repr = '{\n "foo": [\n "bar",\n "baz"\n ]\n}'
+
+
+def strip_trailing_whitespace(content):
+ """
+ Seems to be some inconsistencies re. trailing whitespace with
+ different versions of the json lib.
+ """
+ return re.sub(' +\n', '\n', content)
+
+
+class JSONRendererTests(TestCase):
+ """
+ Tests specific to the JSON Renderer
+ """
+
+ def test_render_lazy_strings(self):
+ """
+ JSONRenderer should deal with lazy translated strings.
+ """
+ ret = JSONRenderer().render(_('test'))
+ self.assertEqual(ret, b'"test"')
+
+ def test_render_queryset_values(self):
+ o = DummyTestModel.objects.create(name='dummy')
+ qs = DummyTestModel.objects.values('id', 'name')
+ ret = JSONRenderer().render(qs)
+ data = json.loads(ret.decode('utf-8'))
+ self.assertEquals(data, [{'id': o.id, 'name': o.name}])
+
+ def test_render_queryset_values_list(self):
+ o = DummyTestModel.objects.create(name='dummy')
+ qs = DummyTestModel.objects.values_list('id', 'name')
+ ret = JSONRenderer().render(qs)
+ data = json.loads(ret.decode('utf-8'))
+ self.assertEquals(data, [[o.id, o.name]])
+
+ def test_render_dict_abc_obj(self):
+ class Dict(MutableMapping):
+ def __init__(self):
+ self._dict = dict()
+
+ def __getitem__(self, key):
+ return self._dict.__getitem__(key)
+
+ def __setitem__(self, key, value):
+ return self._dict.__setitem__(key, value)
+
+ def __delitem__(self, key):
+ return self._dict.__delitem__(key)
+
+ def __iter__(self):
+ return self._dict.__iter__()
+
+ def __len__(self):
+ return self._dict.__len__()
+
+ def keys(self):
+ return self._dict.keys()
+
+ x = Dict()
+ x['key'] = 'string value'
+ x[2] = 3
+ ret = JSONRenderer().render(x)
+ data = json.loads(ret.decode('utf-8'))
+ self.assertEquals(data, {'key': 'string value', '2': 3})
+
+ def test_render_obj_with_getitem(self):
+ class DictLike(object):
+ def __init__(self):
+ self._dict = {}
+
+ def set(self, value):
+ self._dict = dict(value)
+
+ def __getitem__(self, key):
+ return self._dict[key]
+
+ x = DictLike()
+ x.set({'a': 1, 'b': 'string'})
+ with self.assertRaises(TypeError):
+ JSONRenderer().render(x)
+
+ def test_without_content_type_args(self):
+ """
+ Test basic JSON rendering.
+ """
+ obj = {'foo': ['bar', 'baz']}
+ renderer = JSONRenderer()
+ content = renderer.render(obj, 'application/json')
+ # Fix failing test case which depends on version of JSON library.
+ self.assertEqual(content.decode('utf-8'), _flat_repr)
+
+ def test_with_content_type_args(self):
+ """
+ Test JSON rendering with additional content type arguments supplied.
+ """
+ obj = {'foo': ['bar', 'baz']}
+ renderer = JSONRenderer()
+ content = renderer.render(obj, 'application/json; indent=2')
+ self.assertEqual(strip_trailing_whitespace(content.decode('utf-8')), _indented_repr)
+
+
+class UnicodeJSONRendererTests(TestCase):
+ """
+ Tests specific for the Unicode JSON Renderer
+ """
+ def test_proper_encoding(self):
+ obj = {'countries': ['United Kingdom', 'France', 'España']}
+ renderer = JSONRenderer()
+ content = renderer.render(obj, 'application/json')
+ self.assertEqual(content, '{"countries":["United Kingdom","France","España"]}'.encode('utf-8'))
+
+ def test_u2028_u2029(self):
+ # The \u2028 and \u2029 characters should be escaped,
+ # even when the non-escaping unicode representation is used.
+ # Regression test for #2169
+ obj = {'should_escape': '\u2028\u2029'}
+ renderer = JSONRenderer()
+ content = renderer.render(obj, 'application/json')
+ self.assertEqual(content, '{"should_escape":"\\u2028\\u2029"}'.encode('utf-8'))
+
+
+class AsciiJSONRendererTests(TestCase):
+ """
+ Tests specific for the Unicode JSON Renderer
+ """
+ def test_proper_encoding(self):
+ class AsciiJSONRenderer(JSONRenderer):
+ ensure_ascii = True
+ obj = {'countries': ['United Kingdom', 'France', 'España']}
+ renderer = AsciiJSONRenderer()
+ content = renderer.render(obj, 'application/json')
+ self.assertEqual(content, '{"countries":["United Kingdom","France","Espa\\u00f1a"]}'.encode('utf-8'))
+
+
+# Tests for caching issue, #346
+class CacheRenderTest(TestCase):
+ """
+ Tests specific to caching responses
+ """
+
+ urls = 'tests.test_renderers'
+
+ def test_head_caching(self):
+ """
+ Test caching of HEAD requests
+ """
+ response = self.client.head('/cache')
+ cache.set('key', response)
+ cached_response = cache.get('key')
+ assert isinstance(cached_response, Response)
+ assert cached_response.content == response.content
+ assert cached_response.status_code == response.status_code
+
+ def test_get_caching(self):
+ """
+ Test caching of GET requests
+ """
+ response = self.client.get('/cache')
+ cache.set('key', response)
+ cached_response = cache.get('key')
+ assert isinstance(cached_response, Response)
+ assert cached_response.content == response.content
+ assert cached_response.status_code == response.status_code
+
+
+class TestJSONIndentationStyles:
+ def test_indented(self):
+ renderer = JSONRenderer()
+ data = OrderedDict([('a', 1), ('b', 2)])
+ assert renderer.render(data) == b'{"a":1,"b":2}'
+
+ def test_compact(self):
+ renderer = JSONRenderer()
+ data = OrderedDict([('a', 1), ('b', 2)])
+ context = {'indent': 4}
+ assert (
+ renderer.render(data, renderer_context=context) ==
+ b'{\n "a": 1,\n "b": 2\n}'
+ )
+
+ def test_long_form(self):
+ renderer = JSONRenderer()
+ renderer.compact = False
+ data = OrderedDict([('a', 1), ('b', 2)])
+ assert renderer.render(data) == b'{"a": 1, "b": 2}'
+
+
+class TestHiddenFieldHTMLFormRenderer(TestCase):
+ def test_hidden_field_rendering(self):
+ class TestSerializer(serializers.Serializer):
+ published = serializers.HiddenField(default=True)
+
+ serializer = TestSerializer(data={})
+ serializer.is_valid()
+ renderer = HTMLFormRenderer()
+ field = serializer['published']
+ rendered = renderer.render_field(field, {})
+ assert rendered == ''
diff --git a/tests/test_request.py b/tests/test_request.py
new file mode 100644
index 00000000..c274ab69
--- /dev/null
+++ b/tests/test_request.py
@@ -0,0 +1,278 @@
+"""
+Tests for content parsing, and form-overloaded content parsing.
+"""
+from __future__ import unicode_literals
+from django.conf.urls import patterns
+from django.contrib.auth.models import User
+from django.contrib.auth import authenticate, login, logout
+from django.contrib.sessions.middleware import SessionMiddleware
+from django.core.handlers.wsgi import WSGIRequest
+from django.test import TestCase
+from django.utils import six
+from rest_framework import status
+from rest_framework.authentication import SessionAuthentication
+from rest_framework.parsers import (
+ BaseParser,
+ FormParser,
+ MultiPartParser,
+ JSONParser
+)
+from rest_framework.request import Request, Empty
+from rest_framework.response import Response
+from rest_framework.settings import api_settings
+from rest_framework.test import APIRequestFactory, APIClient
+from rest_framework.views import APIView
+from io import BytesIO
+import json
+
+
+factory = APIRequestFactory()
+
+
+class PlainTextParser(BaseParser):
+ media_type = 'text/plain'
+
+ def parse(self, stream, media_type=None, parser_context=None):
+ """
+ Returns a 2-tuple of `(data, files)`.
+
+ `data` will simply be a string representing the body of the request.
+ `files` will always be `None`.
+ """
+ return stream.read()
+
+
+class TestMethodOverloading(TestCase):
+ def test_method(self):
+ """
+ Request methods should be same as underlying request.
+ """
+ request = Request(factory.get('/'))
+ self.assertEqual(request.method, 'GET')
+ request = Request(factory.post('/'))
+ self.assertEqual(request.method, 'POST')
+
+ def test_overloaded_method(self):
+ """
+ POST requests can be overloaded to another method by setting a
+ reserved form field
+ """
+ request = Request(factory.post('/', {api_settings.FORM_METHOD_OVERRIDE: 'DELETE'}))
+ self.assertEqual(request.method, 'DELETE')
+
+ def test_x_http_method_override_header(self):
+ """
+ POST requests can also be overloaded to another method by setting
+ the X-HTTP-Method-Override header.
+ """
+ request = Request(factory.post('/', {'foo': 'bar'}, HTTP_X_HTTP_METHOD_OVERRIDE='DELETE'))
+ self.assertEqual(request.method, 'DELETE')
+
+ request = Request(factory.get('/', {'foo': 'bar'}, HTTP_X_HTTP_METHOD_OVERRIDE='DELETE'))
+ self.assertEqual(request.method, 'DELETE')
+
+
+class TestContentParsing(TestCase):
+ def test_standard_behaviour_determines_no_content_GET(self):
+ """
+ Ensure request.DATA returns empty QueryDict for GET request.
+ """
+ request = Request(factory.get('/'))
+ self.assertEqual(request.DATA, {})
+
+ def test_standard_behaviour_determines_no_content_HEAD(self):
+ """
+ Ensure request.DATA returns empty QueryDict for HEAD request.
+ """
+ request = Request(factory.head('/'))
+ self.assertEqual(request.DATA, {})
+
+ def test_request_DATA_with_form_content(self):
+ """
+ Ensure request.DATA returns content for POST request with form content.
+ """
+ data = {'qwerty': 'uiop'}
+ request = Request(factory.post('/', data))
+ request.parsers = (FormParser(), MultiPartParser())
+ self.assertEqual(list(request.DATA.items()), list(data.items()))
+
+ def test_request_DATA_with_text_content(self):
+ """
+ Ensure request.DATA returns content for POST request with
+ non-form content.
+ """
+ content = six.b('qwerty')
+ content_type = 'text/plain'
+ request = Request(factory.post('/', content, content_type=content_type))
+ request.parsers = (PlainTextParser(),)
+ self.assertEqual(request.DATA, content)
+
+ def test_request_POST_with_form_content(self):
+ """
+ Ensure request.POST returns content for POST request with form content.
+ """
+ data = {'qwerty': 'uiop'}
+ request = Request(factory.post('/', data))
+ request.parsers = (FormParser(), MultiPartParser())
+ self.assertEqual(list(request.POST.items()), list(data.items()))
+
+ def test_standard_behaviour_determines_form_content_PUT(self):
+ """
+ Ensure request.DATA returns content for PUT request with form content.
+ """
+ data = {'qwerty': 'uiop'}
+ request = Request(factory.put('/', data))
+ request.parsers = (FormParser(), MultiPartParser())
+ self.assertEqual(list(request.DATA.items()), list(data.items()))
+
+ def test_standard_behaviour_determines_non_form_content_PUT(self):
+ """
+ Ensure request.DATA returns content for PUT request with
+ non-form content.
+ """
+ content = six.b('qwerty')
+ content_type = 'text/plain'
+ request = Request(factory.put('/', content, content_type=content_type))
+ request.parsers = (PlainTextParser(), )
+ self.assertEqual(request.DATA, content)
+
+ def test_overloaded_behaviour_allows_content_tunnelling(self):
+ """
+ Ensure request.DATA returns content for overloaded POST request.
+ """
+ json_data = {'foobar': 'qwerty'}
+ content = json.dumps(json_data)
+ content_type = 'application/json'
+ form_data = {
+ api_settings.FORM_CONTENT_OVERRIDE: content,
+ api_settings.FORM_CONTENTTYPE_OVERRIDE: content_type
+ }
+ request = Request(factory.post('/', form_data))
+ request.parsers = (JSONParser(), )
+ self.assertEqual(request.DATA, json_data)
+
+ def test_form_POST_unicode(self):
+ """
+ JSON POST via default web interface with unicode data
+ """
+ # Note: environ and other variables here have simplified content compared to real Request
+ CONTENT = b'_content_type=application%2Fjson&_content=%7B%22request%22%3A+4%2C+%22firm%22%3A+1%2C+%22text%22%3A+%22%D0%9F%D1%80%D0%B8%D0%B2%D0%B5%D1%82%21%22%7D'
+ environ = {
+ 'REQUEST_METHOD': 'POST',
+ 'CONTENT_TYPE': 'application/x-www-form-urlencoded',
+ 'CONTENT_LENGTH': len(CONTENT),
+ 'wsgi.input': BytesIO(CONTENT),
+ }
+ wsgi_request = WSGIRequest(environ=environ)
+ wsgi_request._load_post_and_files()
+ parsers = (JSONParser(), FormParser(), MultiPartParser())
+ parser_context = {
+ 'encoding': 'utf-8',
+ 'kwargs': {},
+ 'args': (),
+ }
+ request = Request(wsgi_request, parsers=parsers, parser_context=parser_context)
+ method = request.method
+ self.assertEqual(method, 'POST')
+ self.assertEqual(request._content_type, 'application/json')
+ self.assertEqual(request._stream.getvalue(), b'{"request": 4, "firm": 1, "text": "\xd0\x9f\xd1\x80\xd0\xb8\xd0\xb2\xd0\xb5\xd1\x82!"}')
+ self.assertEqual(request._data, Empty)
+ self.assertEqual(request._files, Empty)
+
+
+class MockView(APIView):
+ authentication_classes = (SessionAuthentication,)
+
+ def post(self, request):
+ if request.POST.get('example') is not None:
+ return Response(status=status.HTTP_200_OK)
+
+ return Response(status=status.HTTP_500_INTERNAL_SERVER_ERROR)
+
+urlpatterns = patterns(
+ '',
+ (r'^$', MockView.as_view()),
+)
+
+
+class TestContentParsingWithAuthentication(TestCase):
+ urls = 'tests.test_request'
+
+ def setUp(self):
+ self.csrf_client = APIClient(enforce_csrf_checks=True)
+ self.username = 'john'
+ self.email = 'lennon@thebeatles.com'
+ self.password = 'password'
+ self.user = User.objects.create_user(self.username, self.email, self.password)
+
+ def test_user_logged_in_authentication_has_POST_when_not_logged_in(self):
+ """
+ Ensures request.POST exists after SessionAuthentication when user
+ doesn't log in.
+ """
+ content = {'example': 'example'}
+
+ response = self.client.post('/', content)
+ self.assertEqual(status.HTTP_200_OK, response.status_code)
+
+ response = self.csrf_client.post('/', content)
+ self.assertEqual(status.HTTP_200_OK, response.status_code)
+
+
+class TestUserSetter(TestCase):
+
+ def setUp(self):
+ # Pass request object through session middleware so session is
+ # available to login and logout functions
+ self.wrapped_request = factory.get('/')
+ self.request = Request(self.wrapped_request)
+ SessionMiddleware().process_request(self.request)
+
+ User.objects.create_user('ringo', 'starr@thebeatles.com', 'yellow')
+ self.user = authenticate(username='ringo', password='yellow')
+
+ def test_user_can_be_set(self):
+ self.request.user = self.user
+ self.assertEqual(self.request.user, self.user)
+
+ def test_user_can_login(self):
+ login(self.request, self.user)
+ self.assertEqual(self.request.user, self.user)
+
+ def test_user_can_logout(self):
+ self.request.user = self.user
+ self.assertFalse(self.request.user.is_anonymous())
+ logout(self.request)
+ self.assertTrue(self.request.user.is_anonymous())
+
+ def test_logged_in_user_is_set_on_wrapped_request(self):
+ login(self.request, self.user)
+ self.assertEqual(self.wrapped_request.user, self.user)
+
+ def test_calling_user_fails_when_attribute_error_is_raised(self):
+ """
+ This proves that when an AttributeError is raised inside of the request.user
+ property, that we can handle this and report the true, underlying error.
+ """
+ class AuthRaisesAttributeError(object):
+ def authenticate(self, request):
+ import rest_framework
+ rest_framework.MISSPELLED_NAME_THAT_DOESNT_EXIST
+
+ self.request = Request(factory.get('/'), authenticators=(AuthRaisesAttributeError(),))
+ SessionMiddleware().process_request(self.request)
+
+ login(self.request, self.user)
+ try:
+ self.request.user
+ except AttributeError as error:
+ self.assertEqual(str(error), "'module' object has no attribute 'MISSPELLED_NAME_THAT_DOESNT_EXIST'")
+ else:
+ assert False, 'AttributeError not raised'
+
+
+class TestAuthSetter(TestCase):
+ def test_auth_can_be_set(self):
+ request = Request(factory.get('/'))
+ request.auth = 'DUMMY'
+ self.assertEqual(request.auth, 'DUMMY')
diff --git a/tests/test_response.py b/tests/test_response.py
new file mode 100644
index 00000000..4a9deaa2
--- /dev/null
+++ b/tests/test_response.py
@@ -0,0 +1,292 @@
+from __future__ import unicode_literals
+from django.conf.urls import patterns, url, include
+from django.test import TestCase
+from django.utils import six
+from tests.models import BasicModel
+from rest_framework.response import Response
+from rest_framework.views import APIView
+from rest_framework import generics
+from rest_framework import routers
+from rest_framework import serializers
+from rest_framework import status
+from rest_framework.renderers import (
+ BaseRenderer,
+ JSONRenderer,
+ BrowsableAPIRenderer
+)
+from rest_framework import viewsets
+from rest_framework.settings import api_settings
+
+
+# Serializer used to test BasicModel
+class BasicModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BasicModel
+
+
+class MockPickleRenderer(BaseRenderer):
+ media_type = 'application/pickle'
+
+
+class MockJsonRenderer(BaseRenderer):
+ media_type = 'application/json'
+
+
+class MockTextMediaRenderer(BaseRenderer):
+ media_type = 'text/html'
+
+DUMMYSTATUS = status.HTTP_200_OK
+DUMMYCONTENT = 'dummycontent'
+
+
+def RENDERER_A_SERIALIZER(x):
+ return ('Renderer A: %s' % x).encode('ascii')
+
+
+def RENDERER_B_SERIALIZER(x):
+ return ('Renderer B: %s' % x).encode('ascii')
+
+
+class RendererA(BaseRenderer):
+ media_type = 'mock/renderera'
+ format = "formata"
+
+ def render(self, data, media_type=None, renderer_context=None):
+ return RENDERER_A_SERIALIZER(data)
+
+
+class RendererB(BaseRenderer):
+ media_type = 'mock/rendererb'
+ format = "formatb"
+
+ def render(self, data, media_type=None, renderer_context=None):
+ return RENDERER_B_SERIALIZER(data)
+
+
+class RendererC(RendererB):
+ media_type = 'mock/rendererc'
+ format = 'formatc'
+ charset = "rendererc"
+
+
+class MockView(APIView):
+ renderer_classes = (RendererA, RendererB, RendererC)
+
+ def get(self, request, **kwargs):
+ return Response(DUMMYCONTENT, status=DUMMYSTATUS)
+
+
+class MockViewSettingContentType(APIView):
+ renderer_classes = (RendererA, RendererB, RendererC)
+
+ def get(self, request, **kwargs):
+ return Response(DUMMYCONTENT, status=DUMMYSTATUS, content_type='setbyview')
+
+
+class HTMLView(APIView):
+ renderer_classes = (BrowsableAPIRenderer, )
+
+ def get(self, request, **kwargs):
+ return Response('text')
+
+
+class HTMLView1(APIView):
+ renderer_classes = (BrowsableAPIRenderer, JSONRenderer)
+
+ def get(self, request, **kwargs):
+ return Response('text')
+
+
+class HTMLNewModelViewSet(viewsets.ModelViewSet):
+ serializer_class = BasicModelSerializer
+ queryset = BasicModel.objects.all()
+
+
+class HTMLNewModelView(generics.ListCreateAPIView):
+ renderer_classes = (BrowsableAPIRenderer,)
+ permission_classes = []
+ serializer_class = BasicModelSerializer
+ queryset = BasicModel.objects.all()
+
+
+new_model_viewset_router = routers.DefaultRouter()
+new_model_viewset_router.register(r'', HTMLNewModelViewSet)
+
+
+urlpatterns = patterns(
+ '',
+ url(r'^setbyview$', MockViewSettingContentType.as_view(renderer_classes=[RendererA, RendererB, RendererC])),
+ url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])),
+ url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])),
+ url(r'^html$', HTMLView.as_view()),
+ url(r'^html1$', HTMLView1.as_view()),
+ url(r'^html_new_model$', HTMLNewModelView.as_view()),
+ url(r'^html_new_model_viewset', include(new_model_viewset_router.urls)),
+ url(r'^restframework', include('rest_framework.urls', namespace='rest_framework'))
+)
+
+
+# TODO: Clean tests bellow - remove duplicates with above, better unit testing, ...
+class RendererIntegrationTests(TestCase):
+ """
+ End-to-end testing of renderers using an ResponseMixin on a generic view.
+ """
+
+ urls = 'tests.test_response'
+
+ def test_default_renderer_serializes_content(self):
+ """If the Accept header is not set the default renderer should serialize the response."""
+ resp = self.client.get('/')
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_head_method_serializes_no_content(self):
+ """No response must be included in HEAD requests."""
+ resp = self.client.head('/')
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, six.b(''))
+
+ def test_default_renderer_serializes_content_on_accept_any(self):
+ """If the Accept header is set to */* the default renderer should serialize the response."""
+ resp = self.client.get('/', HTTP_ACCEPT='*/*')
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_serializes_content_default_case(self):
+ """If the Accept header is set the specified renderer should serialize the response.
+ (In this case we check that works for the default renderer)"""
+ resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
+ self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_serializes_content_non_default_case(self):
+ """If the Accept header is set the specified renderer should serialize the response.
+ (In this case we check that works for a non-default renderer)"""
+ resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_serializes_content_on_accept_query(self):
+ """The '_accept' query string should behave in the same way as the Accept header."""
+ param = '?%s=%s' % (
+ api_settings.URL_ACCEPT_OVERRIDE,
+ RendererB.media_type
+ )
+ resp = self.client.get('/' + param)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_serializes_content_on_format_query(self):
+ """If a 'format' query is specified, the renderer with the matching
+ format attribute should serialize the response."""
+ resp = self.client.get('/?format=%s' % RendererB.format)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_serializes_content_on_format_kwargs(self):
+ """If a 'format' keyword arg is specified, the renderer with the matching
+ format attribute should serialize the response."""
+ resp = self.client.get('/something.formatb')
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+ def test_specified_renderer_is_used_on_format_query_with_matching_accept(self):
+ """If both a 'format' query and a matching Accept header specified,
+ the renderer with the matching format attribute should serialize the response."""
+ resp = self.client.get('/?format=%s' % RendererB.format,
+ HTTP_ACCEPT=RendererB.media_type)
+ self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
+ self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
+ self.assertEqual(resp.status_code, DUMMYSTATUS)
+
+
+class Issue122Tests(TestCase):
+ """
+ Tests that covers #122.
+ """
+ urls = 'tests.test_response'
+
+ def test_only_html_renderer(self):
+ """
+ Test if no infinite recursion occurs.
+ """
+ self.client.get('/html')
+
+ def test_html_renderer_is_first(self):
+ """
+ Test if no infinite recursion occurs.
+ """
+ self.client.get('/html1')
+
+
+class Issue467Tests(TestCase):
+ """
+ Tests for #467
+ """
+
+ urls = 'tests.test_response'
+
+ def test_form_has_label_and_help_text(self):
+ resp = self.client.get('/html_new_model')
+ self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
+ # self.assertContains(resp, 'Text comes here')
+ # self.assertContains(resp, 'Text description.')
+
+
+class Issue807Tests(TestCase):
+ """
+ Covers #807
+ """
+
+ urls = 'tests.test_response'
+
+ def test_does_not_append_charset_by_default(self):
+ """
+ Renderers don't include a charset unless set explicitly.
+ """
+ headers = {"HTTP_ACCEPT": RendererA.media_type}
+ resp = self.client.get('/', **headers)
+ expected = "{0}; charset={1}".format(RendererA.media_type, 'utf-8')
+ self.assertEqual(expected, resp['Content-Type'])
+
+ def test_if_there_is_charset_specified_on_renderer_it_gets_appended(self):
+ """
+ If renderer class has charset attribute declared, it gets appended
+ to Response's Content-Type
+ """
+ headers = {"HTTP_ACCEPT": RendererC.media_type}
+ resp = self.client.get('/', **headers)
+ expected = "{0}; charset={1}".format(RendererC.media_type, RendererC.charset)
+ self.assertEqual(expected, resp['Content-Type'])
+
+ def test_content_type_set_explicitly_on_response(self):
+ """
+ The content type may be set explicitly on the response.
+ """
+ headers = {"HTTP_ACCEPT": RendererC.media_type}
+ resp = self.client.get('/setbyview', **headers)
+ self.assertEqual('setbyview', resp['Content-Type'])
+
+ def test_viewset_label_help_text(self):
+ param = '?%s=%s' % (
+ api_settings.URL_ACCEPT_OVERRIDE,
+ 'text/html'
+ )
+ resp = self.client.get('/html_new_model_viewset/' + param)
+ self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
+ # self.assertContains(resp, 'Text comes here')
+ # self.assertContains(resp, 'Text description.')
+
+ def test_form_has_label_and_help_text(self):
+ resp = self.client.get('/html_new_model')
+ self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
+ # self.assertContains(resp, 'Text comes here')
+ # self.assertContains(resp, 'Text description.')
diff --git a/tests/test_reverse.py b/tests/test_reverse.py
new file mode 100644
index 00000000..675a9d5a
--- /dev/null
+++ b/tests/test_reverse.py
@@ -0,0 +1,28 @@
+from __future__ import unicode_literals
+from django.conf.urls import patterns, url
+from django.test import TestCase
+from rest_framework.reverse import reverse
+from rest_framework.test import APIRequestFactory
+
+factory = APIRequestFactory()
+
+
+def null_view(request):
+ pass
+
+urlpatterns = patterns(
+ '',
+ url(r'^view$', null_view, name='view'),
+)
+
+
+class ReverseTests(TestCase):
+ """
+ Tests for fully qualified URLs when using `reverse`.
+ """
+ urls = 'tests.test_reverse'
+
+ def test_reversed_urls_are_fully_qualified(self):
+ request = factory.get('/view')
+ url = reverse('view', request=request)
+ self.assertEqual(url, 'http://testserver/view')
diff --git a/tests/test_routers.py b/tests/test_routers.py
new file mode 100644
index 00000000..08c58ec7
--- /dev/null
+++ b/tests/test_routers.py
@@ -0,0 +1,348 @@
+from __future__ import unicode_literals
+from django.conf.urls import url, include
+from django.db import models
+from django.test import TestCase
+from django.core.exceptions import ImproperlyConfigured
+from rest_framework import serializers, viewsets, permissions
+from rest_framework.decorators import detail_route, list_route
+from rest_framework.response import Response
+from rest_framework.routers import SimpleRouter, DefaultRouter
+from rest_framework.test import APIRequestFactory
+from collections import namedtuple
+
+factory = APIRequestFactory()
+
+
+class RouterTestModel(models.Model):
+ uuid = models.CharField(max_length=20)
+ text = models.CharField(max_length=200)
+
+
+class NoteSerializer(serializers.HyperlinkedModelSerializer):
+ url = serializers.HyperlinkedIdentityField(view_name='routertestmodel-detail', lookup_field='uuid')
+
+ class Meta:
+ model = RouterTestModel
+ fields = ('url', 'uuid', 'text')
+
+
+class NoteViewSet(viewsets.ModelViewSet):
+ queryset = RouterTestModel.objects.all()
+ serializer_class = NoteSerializer
+ lookup_field = 'uuid'
+
+
+class MockViewSet(viewsets.ModelViewSet):
+ queryset = None
+ serializer_class = None
+
+
+notes_router = SimpleRouter()
+notes_router.register(r'notes', NoteViewSet)
+
+namespaced_router = DefaultRouter()
+namespaced_router.register(r'example', MockViewSet, base_name='example')
+
+urlpatterns = [
+ url(r'^non-namespaced/', include(namespaced_router.urls)),
+ url(r'^namespaced/', include(namespaced_router.urls, namespace='example')),
+ url(r'^example/', include(notes_router.urls)),
+]
+
+
+class BasicViewSet(viewsets.ViewSet):
+ def list(self, request, *args, **kwargs):
+ return Response({'method': 'list'})
+
+ @detail_route(methods=['post'])
+ def action1(self, request, *args, **kwargs):
+ return Response({'method': 'action1'})
+
+ @detail_route(methods=['post'])
+ def action2(self, request, *args, **kwargs):
+ return Response({'method': 'action2'})
+
+ @detail_route(methods=['post', 'delete'])
+ def action3(self, request, *args, **kwargs):
+ return Response({'method': 'action2'})
+
+ @detail_route()
+ def link1(self, request, *args, **kwargs):
+ return Response({'method': 'link1'})
+
+ @detail_route()
+ def link2(self, request, *args, **kwargs):
+ return Response({'method': 'link2'})
+
+
+class TestSimpleRouter(TestCase):
+ def setUp(self):
+ self.router = SimpleRouter()
+
+ def test_link_and_action_decorator(self):
+ routes = self.router.get_routes(BasicViewSet)
+ decorator_routes = routes[2:]
+ # Make sure all these endpoints exist and none have been clobbered
+ for i, endpoint in enumerate(['action1', 'action2', 'action3', 'link1', 'link2']):
+ route = decorator_routes[i]
+ # check url listing
+ self.assertEqual(route.url,
+ '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(endpoint))
+ # check method to function mapping
+ if endpoint == 'action3':
+ methods_map = ['post', 'delete']
+ elif endpoint.startswith('action'):
+ methods_map = ['post']
+ else:
+ methods_map = ['get']
+ for method in methods_map:
+ self.assertEqual(route.mapping[method], endpoint)
+
+
+class TestRootView(TestCase):
+ urls = 'tests.test_routers'
+
+ def test_retrieve_namespaced_root(self):
+ response = self.client.get('/namespaced/')
+ self.assertEqual(
+ response.data,
+ {
+ "example": "http://testserver/namespaced/example/",
+ }
+ )
+
+ def test_retrieve_non_namespaced_root(self):
+ response = self.client.get('/non-namespaced/')
+ self.assertEqual(
+ response.data,
+ {
+ "example": "http://testserver/non-namespaced/example/",
+ }
+ )
+
+
+class TestCustomLookupFields(TestCase):
+ """
+ Ensure that custom lookup fields are correctly routed.
+ """
+ urls = 'tests.test_routers'
+
+ def setUp(self):
+ RouterTestModel.objects.create(uuid='123', text='foo bar')
+
+ def test_custom_lookup_field_route(self):
+ detail_route = notes_router.urls[-1]
+ detail_url_pattern = detail_route.regex.pattern
+ self.assertIn('<uuid>', detail_url_pattern)
+
+ def test_retrieve_lookup_field_list_view(self):
+ response = self.client.get('/example/notes/')
+ self.assertEqual(
+ response.data,
+ [{
+ "url": "http://testserver/example/notes/123/",
+ "uuid": "123", "text": "foo bar"
+ }]
+ )
+
+ def test_retrieve_lookup_field_detail_view(self):
+ response = self.client.get('/example/notes/123/')
+ self.assertEqual(
+ response.data,
+ {
+ "url": "http://testserver/example/notes/123/",
+ "uuid": "123", "text": "foo bar"
+ }
+ )
+
+
+class TestLookupValueRegex(TestCase):
+ """
+ Ensure the router honors lookup_value_regex when applied
+ to the viewset.
+ """
+ def setUp(self):
+ class NoteViewSet(viewsets.ModelViewSet):
+ queryset = RouterTestModel.objects.all()
+ lookup_field = 'uuid'
+ lookup_value_regex = '[0-9a-f]{32}'
+
+ self.router = SimpleRouter()
+ self.router.register(r'notes', NoteViewSet)
+ self.urls = self.router.urls
+
+ def test_urls_limited_by_lookup_value_regex(self):
+ expected = ['^notes/$', '^notes/(?P<uuid>[0-9a-f]{32})/$']
+ for idx in range(len(expected)):
+ self.assertEqual(expected[idx], self.urls[idx].regex.pattern)
+
+
+class TestTrailingSlashIncluded(TestCase):
+ def setUp(self):
+ class NoteViewSet(viewsets.ModelViewSet):
+ queryset = RouterTestModel.objects.all()
+
+ self.router = SimpleRouter()
+ self.router.register(r'notes', NoteViewSet)
+ self.urls = self.router.urls
+
+ def test_urls_have_trailing_slash_by_default(self):
+ expected = ['^notes/$', '^notes/(?P<pk>[^/.]+)/$']
+ for idx in range(len(expected)):
+ self.assertEqual(expected[idx], self.urls[idx].regex.pattern)
+
+
+class TestTrailingSlashRemoved(TestCase):
+ def setUp(self):
+ class NoteViewSet(viewsets.ModelViewSet):
+ queryset = RouterTestModel.objects.all()
+
+ self.router = SimpleRouter(trailing_slash=False)
+ self.router.register(r'notes', NoteViewSet)
+ self.urls = self.router.urls
+
+ def test_urls_can_have_trailing_slash_removed(self):
+ expected = ['^notes$', '^notes/(?P<pk>[^/.]+)$']
+ for idx in range(len(expected)):
+ self.assertEqual(expected[idx], self.urls[idx].regex.pattern)
+
+
+class TestNameableRoot(TestCase):
+ def setUp(self):
+ class NoteViewSet(viewsets.ModelViewSet):
+ queryset = RouterTestModel.objects.all()
+
+ self.router = DefaultRouter()
+ self.router.root_view_name = 'nameable-root'
+ self.router.register(r'notes', NoteViewSet)
+ self.urls = self.router.urls
+
+ def test_router_has_custom_name(self):
+ expected = 'nameable-root'
+ self.assertEqual(expected, self.urls[0].name)
+
+
+class TestActionKeywordArgs(TestCase):
+ """
+ Ensure keyword arguments passed in the `@action` decorator
+ are properly handled. Refs #940.
+ """
+
+ def setUp(self):
+ class TestViewSet(viewsets.ModelViewSet):
+ permission_classes = []
+
+ @detail_route(methods=['post'], permission_classes=[permissions.AllowAny])
+ def custom(self, request, *args, **kwargs):
+ return Response({
+ 'permission_classes': self.permission_classes
+ })
+
+ self.router = SimpleRouter()
+ self.router.register(r'test', TestViewSet, base_name='test')
+ self.view = self.router.urls[-1].callback
+
+ def test_action_kwargs(self):
+ request = factory.post('/test/0/custom/')
+ response = self.view(request)
+ self.assertEqual(
+ response.data,
+ {'permission_classes': [permissions.AllowAny]}
+ )
+
+
+class TestActionAppliedToExistingRoute(TestCase):
+ """
+ Ensure `@detail_route` decorator raises an except when applied
+ to an existing route
+ """
+
+ def test_exception_raised_when_action_applied_to_existing_route(self):
+ class TestViewSet(viewsets.ModelViewSet):
+
+ @detail_route(methods=['post'])
+ def retrieve(self, request, *args, **kwargs):
+ return Response({
+ 'hello': 'world'
+ })
+
+ self.router = SimpleRouter()
+ self.router.register(r'test', TestViewSet, base_name='test')
+
+ with self.assertRaises(ImproperlyConfigured):
+ self.router.urls
+
+
+class DynamicListAndDetailViewSet(viewsets.ViewSet):
+ def list(self, request, *args, **kwargs):
+ return Response({'method': 'list'})
+
+ @list_route(methods=['post'])
+ def list_route_post(self, request, *args, **kwargs):
+ return Response({'method': 'action1'})
+
+ @detail_route(methods=['post'])
+ def detail_route_post(self, request, *args, **kwargs):
+ return Response({'method': 'action2'})
+
+ @list_route()
+ def list_route_get(self, request, *args, **kwargs):
+ return Response({'method': 'link1'})
+
+ @detail_route()
+ def detail_route_get(self, request, *args, **kwargs):
+ return Response({'method': 'link2'})
+
+ @list_route(url_path="list_custom-route")
+ def list_custom_route_get(self, request, *args, **kwargs):
+ return Response({'method': 'link1'})
+
+ @detail_route(url_path="detail_custom-route")
+ def detail_custom_route_get(self, request, *args, **kwargs):
+ return Response({'method': 'link2'})
+
+
+class SubDynamicListAndDetailViewSet(DynamicListAndDetailViewSet):
+ pass
+
+
+class TestDynamicListAndDetailRouter(TestCase):
+ def setUp(self):
+ self.router = SimpleRouter()
+
+ def _test_list_and_detail_route_decorators(self, viewset):
+ routes = self.router.get_routes(viewset)
+ decorator_routes = [r for r in routes if not (r.name.endswith('-list') or r.name.endswith('-detail'))]
+
+ MethodNamesMap = namedtuple('MethodNamesMap', 'method_name url_path')
+ # Make sure all these endpoints exist and none have been clobbered
+ for i, endpoint in enumerate([MethodNamesMap('list_custom_route_get', 'list_custom-route'),
+ MethodNamesMap('list_route_get', 'list_route_get'),
+ MethodNamesMap('list_route_post', 'list_route_post'),
+ MethodNamesMap('detail_custom_route_get', 'detail_custom-route'),
+ MethodNamesMap('detail_route_get', 'detail_route_get'),
+ MethodNamesMap('detail_route_post', 'detail_route_post')
+ ]):
+ route = decorator_routes[i]
+ # check url listing
+ method_name = endpoint.method_name
+ url_path = endpoint.url_path
+
+ if method_name.startswith('list_'):
+ self.assertEqual(route.url,
+ '^{{prefix}}/{0}{{trailing_slash}}$'.format(url_path))
+ else:
+ self.assertEqual(route.url,
+ '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(url_path))
+ # check method to function mapping
+ if method_name.endswith('_post'):
+ method_map = 'post'
+ else:
+ method_map = 'get'
+ self.assertEqual(route.mapping[method_map], method_name)
+
+ def test_list_and_detail_route_decorators(self):
+ self._test_list_and_detail_route_decorators(DynamicListAndDetailViewSet)
+
+ def test_inherited_list_and_detail_route_decorators(self):
+ self._test_list_and_detail_route_decorators(SubDynamicListAndDetailViewSet)
diff --git a/tests/test_serializer.py b/tests/test_serializer.py
new file mode 100644
index 00000000..b7a0484b
--- /dev/null
+++ b/tests/test_serializer.py
@@ -0,0 +1,297 @@
+# coding: utf-8
+from __future__ import unicode_literals
+from .utils import MockObject
+from rest_framework import serializers
+from rest_framework.compat import unicode_repr
+import pickle
+import pytest
+
+
+# Tests for core functionality.
+# -----------------------------
+
+class TestSerializer:
+ def setup(self):
+ class ExampleSerializer(serializers.Serializer):
+ char = serializers.CharField()
+ integer = serializers.IntegerField()
+ self.Serializer = ExampleSerializer
+
+ def test_valid_serializer(self):
+ serializer = self.Serializer(data={'char': 'abc', 'integer': 123})
+ assert serializer.is_valid()
+ assert serializer.validated_data == {'char': 'abc', 'integer': 123}
+ assert serializer.errors == {}
+
+ def test_invalid_serializer(self):
+ serializer = self.Serializer(data={'char': 'abc'})
+ assert not serializer.is_valid()
+ assert serializer.validated_data == {}
+ assert serializer.errors == {'integer': ['This field is required.']}
+
+ def test_partial_validation(self):
+ serializer = self.Serializer(data={'char': 'abc'}, partial=True)
+ assert serializer.is_valid()
+ assert serializer.validated_data == {'char': 'abc'}
+ assert serializer.errors == {}
+
+ def test_empty_serializer(self):
+ serializer = self.Serializer()
+ assert serializer.data == {'char': '', 'integer': None}
+
+ def test_missing_attribute_during_serialization(self):
+ class MissingAttributes:
+ pass
+ instance = MissingAttributes()
+ serializer = self.Serializer(instance)
+ with pytest.raises(AttributeError):
+ serializer.data
+
+
+class TestValidateMethod:
+ def test_non_field_error_validate_method(self):
+ class ExampleSerializer(serializers.Serializer):
+ char = serializers.CharField()
+ integer = serializers.IntegerField()
+
+ def validate(self, attrs):
+ raise serializers.ValidationError('Non field error')
+
+ serializer = ExampleSerializer(data={'char': 'abc', 'integer': 123})
+ assert not serializer.is_valid()
+ assert serializer.errors == {'non_field_errors': ['Non field error']}
+
+ def test_field_error_validate_method(self):
+ class ExampleSerializer(serializers.Serializer):
+ char = serializers.CharField()
+ integer = serializers.IntegerField()
+
+ def validate(self, attrs):
+ raise serializers.ValidationError({'char': 'Field error'})
+
+ serializer = ExampleSerializer(data={'char': 'abc', 'integer': 123})
+ assert not serializer.is_valid()
+ assert serializer.errors == {'char': ['Field error']}
+
+
+class TestBaseSerializer:
+ def setup(self):
+ class ExampleSerializer(serializers.BaseSerializer):
+ def to_representation(self, obj):
+ return {
+ 'id': obj['id'],
+ 'email': obj['name'] + '@' + obj['domain']
+ }
+
+ def to_internal_value(self, data):
+ name, domain = str(data['email']).split('@')
+ return {
+ 'id': int(data['id']),
+ 'name': name,
+ 'domain': domain,
+ }
+
+ self.Serializer = ExampleSerializer
+
+ def test_serialize_instance(self):
+ instance = {'id': 1, 'name': 'tom', 'domain': 'example.com'}
+ serializer = self.Serializer(instance)
+ assert serializer.data == {'id': 1, 'email': 'tom@example.com'}
+
+ def test_serialize_list(self):
+ instances = [
+ {'id': 1, 'name': 'tom', 'domain': 'example.com'},
+ {'id': 2, 'name': 'ann', 'domain': 'example.com'},
+ ]
+ serializer = self.Serializer(instances, many=True)
+ assert serializer.data == [
+ {'id': 1, 'email': 'tom@example.com'},
+ {'id': 2, 'email': 'ann@example.com'}
+ ]
+
+ def test_validate_data(self):
+ data = {'id': 1, 'email': 'tom@example.com'}
+ serializer = self.Serializer(data=data)
+ assert serializer.is_valid()
+ assert serializer.validated_data == {
+ 'id': 1,
+ 'name': 'tom',
+ 'domain': 'example.com'
+ }
+
+ def test_validate_list(self):
+ data = [
+ {'id': 1, 'email': 'tom@example.com'},
+ {'id': 2, 'email': 'ann@example.com'},
+ ]
+ serializer = self.Serializer(data=data, many=True)
+ assert serializer.is_valid()
+ assert serializer.validated_data == [
+ {'id': 1, 'name': 'tom', 'domain': 'example.com'},
+ {'id': 2, 'name': 'ann', 'domain': 'example.com'}
+ ]
+
+
+class TestStarredSource:
+ """
+ Tests for `source='*'` argument, which is used for nested representations.
+
+ For example:
+
+ nested_field = NestedField(source='*')
+ """
+ data = {
+ 'nested1': {'a': 1, 'b': 2},
+ 'nested2': {'c': 3, 'd': 4}
+ }
+
+ def setup(self):
+ class NestedSerializer1(serializers.Serializer):
+ a = serializers.IntegerField()
+ b = serializers.IntegerField()
+
+ class NestedSerializer2(serializers.Serializer):
+ c = serializers.IntegerField()
+ d = serializers.IntegerField()
+
+ class TestSerializer(serializers.Serializer):
+ nested1 = NestedSerializer1(source='*')
+ nested2 = NestedSerializer2(source='*')
+
+ self.Serializer = TestSerializer
+
+ def test_nested_validate(self):
+ """
+ A nested representation is validated into a flat internal object.
+ """
+ serializer = self.Serializer(data=self.data)
+ assert serializer.is_valid()
+ assert serializer.validated_data == {
+ 'a': 1,
+ 'b': 2,
+ 'c': 3,
+ 'd': 4
+ }
+
+ def test_nested_serialize(self):
+ """
+ An object can be serialized into a nested representation.
+ """
+ instance = {'a': 1, 'b': 2, 'c': 3, 'd': 4}
+ serializer = self.Serializer(instance)
+ assert serializer.data == self.data
+
+
+class TestIncorrectlyConfigured:
+ def test_incorrect_field_name(self):
+ class ExampleSerializer(serializers.Serializer):
+ incorrect_name = serializers.IntegerField()
+
+ class ExampleObject:
+ def __init__(self):
+ self.correct_name = 123
+
+ instance = ExampleObject()
+ serializer = ExampleSerializer(instance)
+ with pytest.raises(AttributeError) as exc_info:
+ serializer.data
+ msg = str(exc_info.value)
+ assert msg.startswith(
+ "Got AttributeError when attempting to get a value for field `incorrect_name` on serializer `ExampleSerializer`.\n"
+ "The serializer field might be named incorrectly and not match any attribute or key on the `ExampleObject` instance.\n"
+ "Original exception text was:"
+ )
+
+
+class TestUnicodeRepr:
+ def test_unicode_repr(self):
+ class ExampleSerializer(serializers.Serializer):
+ example = serializers.CharField()
+
+ class ExampleObject:
+ def __init__(self):
+ self.example = '한국'
+
+ def __repr__(self):
+ return unicode_repr(self.example)
+
+ instance = ExampleObject()
+ serializer = ExampleSerializer(instance)
+ repr(serializer) # Should not error.
+
+
+class TestNotRequiredOutput:
+ def test_not_required_output_for_dict(self):
+ """
+ 'required=False' should allow a dictionary key to be missing in output.
+ """
+ class ExampleSerializer(serializers.Serializer):
+ omitted = serializers.CharField(required=False)
+ included = serializers.CharField()
+
+ serializer = ExampleSerializer(data={'included': 'abc'})
+ serializer.is_valid()
+ assert serializer.data == {'included': 'abc'}
+
+ def test_not_required_output_for_object(self):
+ """
+ 'required=False' should allow an object attribute to be missing in output.
+ """
+ class ExampleSerializer(serializers.Serializer):
+ omitted = serializers.CharField(required=False)
+ included = serializers.CharField()
+
+ def create(self, validated_data):
+ return MockObject(**validated_data)
+
+ serializer = ExampleSerializer(data={'included': 'abc'})
+ serializer.is_valid()
+ serializer.save()
+ assert serializer.data == {'included': 'abc'}
+
+ def test_default_required_output_for_dict(self):
+ """
+ 'default="something"' should require dictionary key.
+
+ We need to handle this as the field will have an implicit
+ 'required=False', but it should still have a value.
+ """
+ class ExampleSerializer(serializers.Serializer):
+ omitted = serializers.CharField(default='abc')
+ included = serializers.CharField()
+
+ serializer = ExampleSerializer({'included': 'abc'})
+ with pytest.raises(KeyError):
+ serializer.data
+
+ def test_default_required_output_for_object(self):
+ """
+ 'default="something"' should require object attribute.
+
+ We need to handle this as the field will have an implicit
+ 'required=False', but it should still have a value.
+ """
+ class ExampleSerializer(serializers.Serializer):
+ omitted = serializers.CharField(default='abc')
+ included = serializers.CharField()
+
+ instance = MockObject(included='abc')
+ serializer = ExampleSerializer(instance)
+ with pytest.raises(AttributeError):
+ serializer.data
+
+
+class TestCacheSerializerData:
+ def test_cache_serializer_data(self):
+ """
+ Caching serializer data with pickle will drop the serializer info,
+ but does preserve the data itself.
+ """
+ class ExampleSerializer(serializers.Serializer):
+ field1 = serializers.CharField()
+ field2 = serializers.CharField()
+
+ serializer = ExampleSerializer({'field1': 'a', 'field2': 'b'})
+ pickled = pickle.dumps(serializer.data)
+ data = pickle.loads(pickled)
+ assert data == {'field1': 'a', 'field2': 'b'}
diff --git a/tests/test_serializer_bulk_update.py b/tests/test_serializer_bulk_update.py
new file mode 100644
index 00000000..bc955b2e
--- /dev/null
+++ b/tests/test_serializer_bulk_update.py
@@ -0,0 +1,123 @@
+"""
+Tests to cover bulk create and update using serializers.
+"""
+from __future__ import unicode_literals
+from django.test import TestCase
+from django.utils import six
+from rest_framework import serializers
+
+
+class BulkCreateSerializerTests(TestCase):
+ """
+ Creating multiple instances using serializers.
+ """
+
+ def setUp(self):
+ class BookSerializer(serializers.Serializer):
+ id = serializers.IntegerField()
+ title = serializers.CharField(max_length=100)
+ author = serializers.CharField(max_length=100)
+
+ self.BookSerializer = BookSerializer
+
+ def test_bulk_create_success(self):
+ """
+ Correct bulk update serialization should return the input data.
+ """
+
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 1,
+ 'title': 'If this is a man',
+ 'author': 'Primo Levi'
+ }, {
+ 'id': 2,
+ 'title': 'The wind-up bird chronicle',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+
+ serializer = self.BookSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), True)
+ self.assertEqual(serializer.validated_data, data)
+
+ def test_bulk_create_errors(self):
+ """
+ Incorrect bulk create serialization should return errors.
+ """
+
+ data = [
+ {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }, {
+ 'id': 1,
+ 'title': 'If this is a man',
+ 'author': 'Primo Levi'
+ }, {
+ 'id': 'foo',
+ 'title': 'The wind-up bird chronicle',
+ 'author': 'Haruki Murakami'
+ }
+ ]
+ expected_errors = [
+ {},
+ {},
+ {'id': ['A valid integer is required.']}
+ ]
+
+ serializer = self.BookSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertEqual(serializer.errors, expected_errors)
+
+ def test_invalid_list_datatype(self):
+ """
+ Data containing list of incorrect data type should return errors.
+ """
+ data = ['foo', 'bar', 'baz']
+ serializer = self.BookSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+
+ text_type_string = six.text_type.__name__
+ message = 'Invalid data. Expected a dictionary, but got %s.' % text_type_string
+ expected_errors = [
+ {'non_field_errors': [message]},
+ {'non_field_errors': [message]},
+ {'non_field_errors': [message]}
+ ]
+
+ self.assertEqual(serializer.errors, expected_errors)
+
+ def test_invalid_single_datatype(self):
+ """
+ Data containing a single incorrect data type should return errors.
+ """
+ data = 123
+ serializer = self.BookSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+
+ expected_errors = {'non_field_errors': ['Expected a list of items but got type "int".']}
+
+ self.assertEqual(serializer.errors, expected_errors)
+
+ def test_invalid_single_object(self):
+ """
+ Data containing only a single object, instead of a list of objects
+ should return errors.
+ """
+ data = {
+ 'id': 0,
+ 'title': 'The electric kool-aid acid test',
+ 'author': 'Tom Wolfe'
+ }
+ serializer = self.BookSerializer(data=data, many=True)
+ self.assertEqual(serializer.is_valid(), False)
+
+ expected_errors = {'non_field_errors': ['Expected a list of items but got type "dict".']}
+
+ self.assertEqual(serializer.errors, expected_errors)
diff --git a/tests/test_serializer_lists.py b/tests/test_serializer_lists.py
new file mode 100644
index 00000000..35b68ae7
--- /dev/null
+++ b/tests/test_serializer_lists.py
@@ -0,0 +1,290 @@
+from rest_framework import serializers
+from django.utils.datastructures import MultiValueDict
+
+
+class BasicObject:
+ """
+ A mock object for testing serializer save behavior.
+ """
+ def __init__(self, **kwargs):
+ self._data = kwargs
+ for key, value in kwargs.items():
+ setattr(self, key, value)
+
+ def __eq__(self, other):
+ if self._data.keys() != other._data.keys():
+ return False
+ for key in self._data.keys():
+ if self._data[key] != other._data[key]:
+ return False
+ return True
+
+
+class TestListSerializer:
+ """
+ Tests for using a ListSerializer as a top-level serializer.
+ Note that this is in contrast to using ListSerializer as a field.
+ """
+
+ def setup(self):
+ class IntegerListSerializer(serializers.ListSerializer):
+ child = serializers.IntegerField()
+ self.Serializer = IntegerListSerializer
+
+ def test_validate(self):
+ """
+ Validating a list of items should return a list of validated items.
+ """
+ input_data = ["123", "456"]
+ expected_output = [123, 456]
+ serializer = self.Serializer(data=input_data)
+ assert serializer.is_valid()
+ assert serializer.validated_data == expected_output
+
+ def test_validate_html_input(self):
+ """
+ HTML input should be able to mock list structures using [x] style ids.
+ """
+ input_data = MultiValueDict({"[0]": ["123"], "[1]": ["456"]})
+ expected_output = [123, 456]
+ serializer = self.Serializer(data=input_data)
+ assert serializer.is_valid()
+ assert serializer.validated_data == expected_output
+
+
+class TestListSerializerContainingNestedSerializer:
+ """
+ Tests for using a ListSerializer containing another serializer.
+ """
+
+ def setup(self):
+ class TestSerializer(serializers.Serializer):
+ integer = serializers.IntegerField()
+ boolean = serializers.BooleanField()
+
+ def create(self, validated_data):
+ return BasicObject(**validated_data)
+
+ class ObjectListSerializer(serializers.ListSerializer):
+ child = TestSerializer()
+
+ self.Serializer = ObjectListSerializer
+
+ def test_validate(self):
+ """
+ Validating a list of dictionaries should return a list of
+ validated dictionaries.
+ """
+ input_data = [
+ {"integer": "123", "boolean": "true"},
+ {"integer": "456", "boolean": "false"}
+ ]
+ expected_output = [
+ {"integer": 123, "boolean": True},
+ {"integer": 456, "boolean": False}
+ ]
+ serializer = self.Serializer(data=input_data)
+ assert serializer.is_valid()
+ assert serializer.validated_data == expected_output
+
+ def test_create(self):
+ """
+ Creating from a list of dictionaries should return a list of objects.
+ """
+ input_data = [
+ {"integer": "123", "boolean": "true"},
+ {"integer": "456", "boolean": "false"}
+ ]
+ expected_output = [
+ BasicObject(integer=123, boolean=True),
+ BasicObject(integer=456, boolean=False),
+ ]
+ serializer = self.Serializer(data=input_data)
+ assert serializer.is_valid()
+ assert serializer.save() == expected_output
+
+ def test_serialize(self):
+ """
+ Serialization of a list of objects should return a list of dictionaries.
+ """
+ input_objects = [
+ BasicObject(integer=123, boolean=True),
+ BasicObject(integer=456, boolean=False)
+ ]
+ expected_output = [
+ {"integer": 123, "boolean": True},
+ {"integer": 456, "boolean": False}
+ ]
+ serializer = self.Serializer(input_objects)
+ assert serializer.data == expected_output
+
+ def test_validate_html_input(self):
+ """
+ HTML input should be able to mock list structures using [x]
+ style prefixes.
+ """
+ input_data = MultiValueDict({
+ "[0]integer": ["123"],
+ "[0]boolean": ["true"],
+ "[1]integer": ["456"],
+ "[1]boolean": ["false"]
+ })
+ expected_output = [
+ {"integer": 123, "boolean": True},
+ {"integer": 456, "boolean": False}
+ ]
+ serializer = self.Serializer(data=input_data)
+ assert serializer.is_valid()
+ assert serializer.validated_data == expected_output
+
+
+class TestNestedListSerializer:
+ """
+ Tests for using a ListSerializer as a field.
+ """
+
+ def setup(self):
+ class TestSerializer(serializers.Serializer):
+ integers = serializers.ListSerializer(child=serializers.IntegerField())
+ booleans = serializers.ListSerializer(child=serializers.BooleanField())
+
+ def create(self, validated_data):
+ return BasicObject(**validated_data)
+
+ self.Serializer = TestSerializer
+
+ def test_validate(self):
+ """
+ Validating a list of items should return a list of validated items.
+ """
+ input_data = {
+ "integers": ["123", "456"],
+ "booleans": ["true", "false"]
+ }
+ expected_output = {
+ "integers": [123, 456],
+ "booleans": [True, False]
+ }
+ serializer = self.Serializer(data=input_data)
+ assert serializer.is_valid()
+ assert serializer.validated_data == expected_output
+
+ def test_create(self):
+ """
+ Creation with a list of items return an object with an attribute that
+ is a list of items.
+ """
+ input_data = {
+ "integers": ["123", "456"],
+ "booleans": ["true", "false"]
+ }
+ expected_output = BasicObject(
+ integers=[123, 456],
+ booleans=[True, False]
+ )
+ serializer = self.Serializer(data=input_data)
+ assert serializer.is_valid()
+ assert serializer.save() == expected_output
+
+ def test_serialize(self):
+ """
+ Serialization of a list of items should return a list of items.
+ """
+ input_object = BasicObject(
+ integers=[123, 456],
+ booleans=[True, False]
+ )
+ expected_output = {
+ "integers": [123, 456],
+ "booleans": [True, False]
+ }
+ serializer = self.Serializer(input_object)
+ assert serializer.data == expected_output
+
+ def test_validate_html_input(self):
+ """
+ HTML input should be able to mock list structures using [x]
+ style prefixes.
+ """
+ input_data = MultiValueDict({
+ "integers[0]": ["123"],
+ "integers[1]": ["456"],
+ "booleans[0]": ["true"],
+ "booleans[1]": ["false"]
+ })
+ expected_output = {
+ "integers": [123, 456],
+ "booleans": [True, False]
+ }
+ serializer = self.Serializer(data=input_data)
+ assert serializer.is_valid()
+ assert serializer.validated_data == expected_output
+
+
+class TestNestedListOfListsSerializer:
+ def setup(self):
+ class TestSerializer(serializers.Serializer):
+ integers = serializers.ListSerializer(
+ child=serializers.ListSerializer(
+ child=serializers.IntegerField()
+ )
+ )
+ booleans = serializers.ListSerializer(
+ child=serializers.ListSerializer(
+ child=serializers.BooleanField()
+ )
+ )
+
+ self.Serializer = TestSerializer
+
+ def test_validate(self):
+ input_data = {
+ 'integers': [['123', '456'], ['789', '0']],
+ 'booleans': [['true', 'true'], ['false', 'true']]
+ }
+ expected_output = {
+ "integers": [[123, 456], [789, 0]],
+ "booleans": [[True, True], [False, True]]
+ }
+ serializer = self.Serializer(data=input_data)
+ assert serializer.is_valid()
+ assert serializer.validated_data == expected_output
+
+ def test_validate_html_input(self):
+ """
+ HTML input should be able to mock lists of lists using [x][y]
+ style prefixes.
+ """
+ input_data = MultiValueDict({
+ "integers[0][0]": ["123"],
+ "integers[0][1]": ["456"],
+ "integers[1][0]": ["789"],
+ "integers[1][1]": ["000"],
+ "booleans[0][0]": ["true"],
+ "booleans[0][1]": ["true"],
+ "booleans[1][0]": ["false"],
+ "booleans[1][1]": ["true"]
+ })
+ expected_output = {
+ "integers": [[123, 456], [789, 0]],
+ "booleans": [[True, True], [False, True]]
+ }
+ serializer = self.Serializer(data=input_data)
+ assert serializer.is_valid()
+ assert serializer.validated_data == expected_output
+
+
+class TestListSerializerClass:
+ """Tests for a custom list_serializer_class."""
+ def test_list_serializer_class_validate(self):
+ class CustomListSerializer(serializers.ListSerializer):
+ def validate(self, attrs):
+ raise serializers.ValidationError('Non field error')
+
+ class TestSerializer(serializers.Serializer):
+ class Meta:
+ list_serializer_class = CustomListSerializer
+
+ serializer = TestSerializer(data=[], many=True)
+ assert not serializer.is_valid()
+ assert serializer.errors == {'non_field_errors': ['Non field error']}
diff --git a/tests/test_serializer_nested.py b/tests/test_serializer_nested.py
new file mode 100644
index 00000000..f5e4b26a
--- /dev/null
+++ b/tests/test_serializer_nested.py
@@ -0,0 +1,40 @@
+from rest_framework import serializers
+
+
+class TestNestedSerializer:
+ def setup(self):
+ class NestedSerializer(serializers.Serializer):
+ one = serializers.IntegerField(max_value=10)
+ two = serializers.IntegerField(max_value=10)
+
+ class TestSerializer(serializers.Serializer):
+ nested = NestedSerializer()
+
+ self.Serializer = TestSerializer
+
+ def test_nested_validate(self):
+ input_data = {
+ 'nested': {
+ 'one': '1',
+ 'two': '2',
+ }
+ }
+ expected_data = {
+ 'nested': {
+ 'one': 1,
+ 'two': 2,
+ }
+ }
+ serializer = self.Serializer(data=input_data)
+ assert serializer.is_valid()
+ assert serializer.validated_data == expected_data
+
+ def test_nested_serialize_empty(self):
+ expected_data = {
+ 'nested': {
+ 'one': None,
+ 'two': None
+ }
+ }
+ serializer = self.Serializer()
+ assert serializer.data == expected_data
diff --git a/tests/test_settings.py b/tests/test_settings.py
new file mode 100644
index 00000000..f2ff4ca1
--- /dev/null
+++ b/tests/test_settings.py
@@ -0,0 +1,17 @@
+from __future__ import unicode_literals
+from django.test import TestCase
+from rest_framework.settings import APISettings
+
+
+class TestSettings(TestCase):
+ def test_import_error_message_maintained(self):
+ """
+ Make sure import errors are captured and raised sensibly.
+ """
+ settings = APISettings({
+ 'DEFAULT_RENDERER_CLASSES': [
+ 'tests.invalid_module.InvalidClassName'
+ ]
+ })
+ with self.assertRaises(ImportError):
+ settings.DEFAULT_RENDERER_CLASSES
diff --git a/tests/test_status.py b/tests/test_status.py
new file mode 100644
index 00000000..721a6e30
--- /dev/null
+++ b/tests/test_status.py
@@ -0,0 +1,33 @@
+from __future__ import unicode_literals
+from django.test import TestCase
+from rest_framework.status import (
+ is_informational, is_success, is_redirect, is_client_error, is_server_error
+)
+
+
+class TestStatus(TestCase):
+ def test_status_categories(self):
+ self.assertFalse(is_informational(99))
+ self.assertTrue(is_informational(100))
+ self.assertTrue(is_informational(199))
+ self.assertFalse(is_informational(200))
+
+ self.assertFalse(is_success(199))
+ self.assertTrue(is_success(200))
+ self.assertTrue(is_success(299))
+ self.assertFalse(is_success(300))
+
+ self.assertFalse(is_redirect(299))
+ self.assertTrue(is_redirect(300))
+ self.assertTrue(is_redirect(399))
+ self.assertFalse(is_redirect(400))
+
+ self.assertFalse(is_client_error(399))
+ self.assertTrue(is_client_error(400))
+ self.assertTrue(is_client_error(499))
+ self.assertFalse(is_client_error(500))
+
+ self.assertFalse(is_server_error(499))
+ self.assertTrue(is_server_error(500))
+ self.assertTrue(is_server_error(599))
+ self.assertFalse(is_server_error(600))
diff --git a/tests/test_templatetags.py b/tests/test_templatetags.py
new file mode 100644
index 00000000..0cee91f1
--- /dev/null
+++ b/tests/test_templatetags.py
@@ -0,0 +1,75 @@
+# encoding: utf-8
+from __future__ import unicode_literals
+from django.test import TestCase
+from rest_framework.test import APIRequestFactory
+from rest_framework.templatetags.rest_framework import add_query_param, urlize_quoted_links
+
+
+factory = APIRequestFactory()
+
+
+class TemplateTagTests(TestCase):
+
+ def test_add_query_param_with_non_latin_charactor(self):
+ # Ensure we don't double-escape non-latin characters
+ # that are present in the querystring.
+ # See #1314.
+ request = factory.get("/", {'q': '查询'})
+ json_url = add_query_param(request, "format", "json")
+ self.assertIn("q=%E6%9F%A5%E8%AF%A2", json_url)
+ self.assertIn("format=json", json_url)
+
+
+class Issue1386Tests(TestCase):
+ """
+ Covers #1386
+ """
+
+ def test_issue_1386(self):
+ """
+ Test function urlize_quoted_links with different args
+ """
+ correct_urls = [
+ "asdf.com",
+ "asdf.net",
+ "www.as_df.org",
+ "as.d8f.ghj8.gov",
+ ]
+ for i in correct_urls:
+ res = urlize_quoted_links(i)
+ self.assertNotEqual(res, i)
+ self.assertIn(i, res)
+
+ incorrect_urls = [
+ "mailto://asdf@fdf.com",
+ "asdf.netnet",
+ ]
+ for i in incorrect_urls:
+ res = urlize_quoted_links(i)
+ self.assertEqual(i, res)
+
+ # example from issue #1386, this shouldn't raise an exception
+ urlize_quoted_links("asdf:[/p]zxcv.com")
+
+
+class URLizerTests(TestCase):
+ """
+ Test if JSON URLs are transformed into links well
+ """
+ def _urlize_dict_check(self, data):
+ """
+ For all items in dict test assert that the value is urlized key
+ """
+ for original, urlized in data.items():
+ assert urlize_quoted_links(original, nofollow=False) == urlized
+
+ def test_json_with_url(self):
+ """
+ Test if JSON URLs are transformed into links well
+ """
+ data = {}
+ data['"url": "http://api/users/1/", '] = \
+ '&quot;url&quot;: &quot;<a href="http://api/users/1/">http://api/users/1/</a>&quot;, '
+ data['"foo_set": [\n "http://api/foos/1/"\n], '] = \
+ '&quot;foo_set&quot;: [\n &quot;<a href="http://api/foos/1/">http://api/foos/1/</a>&quot;\n], '
+ self._urlize_dict_check(data)
diff --git a/tests/test_testing.py b/tests/test_testing.py
new file mode 100644
index 00000000..87d2b61f
--- /dev/null
+++ b/tests/test_testing.py
@@ -0,0 +1,234 @@
+# encoding: utf-8
+from __future__ import unicode_literals
+from django.conf.urls import patterns, url
+from django.contrib.auth.models import User
+from django.shortcuts import redirect
+from django.test import TestCase
+from rest_framework.decorators import api_view
+from rest_framework.response import Response
+from rest_framework.test import APIClient, APIRequestFactory, force_authenticate
+from io import BytesIO
+
+
+@api_view(['GET', 'POST'])
+def view(request):
+ return Response({
+ 'auth': request.META.get('HTTP_AUTHORIZATION', b''),
+ 'user': request.user.username
+ })
+
+
+@api_view(['GET', 'POST'])
+def session_view(request):
+ active_session = request.session.get('active_session', False)
+ request.session['active_session'] = True
+ return Response({
+ 'active_session': active_session
+ })
+
+
+@api_view(['GET', 'POST', 'PUT', 'PATCH', 'DELETE', 'OPTIONS'])
+def redirect_view(request):
+ return redirect('/view/')
+
+
+urlpatterns = patterns(
+ '',
+ url(r'^view/$', view),
+ url(r'^session-view/$', session_view),
+ url(r'^redirect-view/$', redirect_view),
+)
+
+
+class TestAPITestClient(TestCase):
+ urls = 'tests.test_testing'
+
+ def setUp(self):
+ self.client = APIClient()
+
+ def test_credentials(self):
+ """
+ Setting `.credentials()` adds the required headers to each request.
+ """
+ self.client.credentials(HTTP_AUTHORIZATION='example')
+ for _ in range(0, 3):
+ response = self.client.get('/view/')
+ self.assertEqual(response.data['auth'], 'example')
+
+ def test_force_authenticate(self):
+ """
+ Setting `.force_authenticate()` forcibly authenticates each request.
+ """
+ user = User.objects.create_user('example', 'example@example.com')
+ self.client.force_authenticate(user)
+ response = self.client.get('/view/')
+ self.assertEqual(response.data['user'], 'example')
+
+ def test_force_authenticate_with_sessions(self):
+ """
+ Setting `.force_authenticate()` forcibly authenticates each request.
+ """
+ user = User.objects.create_user('example', 'example@example.com')
+ self.client.force_authenticate(user)
+
+ # First request does not yet have an active session
+ response = self.client.get('/session-view/')
+ self.assertEqual(response.data['active_session'], False)
+
+ # Subsequant requests have an active session
+ response = self.client.get('/session-view/')
+ self.assertEqual(response.data['active_session'], True)
+
+ # Force authenticating as `None` should also logout the user session.
+ self.client.force_authenticate(None)
+ response = self.client.get('/session-view/')
+ self.assertEqual(response.data['active_session'], False)
+
+ def test_csrf_exempt_by_default(self):
+ """
+ By default, the test client is CSRF exempt.
+ """
+ User.objects.create_user('example', 'example@example.com', 'password')
+ self.client.login(username='example', password='password')
+ response = self.client.post('/view/')
+ self.assertEqual(response.status_code, 200)
+
+ def test_explicitly_enforce_csrf_checks(self):
+ """
+ The test client can enforce CSRF checks.
+ """
+ client = APIClient(enforce_csrf_checks=True)
+ User.objects.create_user('example', 'example@example.com', 'password')
+ client.login(username='example', password='password')
+ response = client.post('/view/')
+ expected = {'detail': 'CSRF Failed: CSRF cookie not set.'}
+ self.assertEqual(response.status_code, 403)
+ self.assertEqual(response.data, expected)
+
+ def test_can_logout(self):
+ """
+ `logout()` resets stored credentials
+ """
+ self.client.credentials(HTTP_AUTHORIZATION='example')
+ response = self.client.get('/view/')
+ self.assertEqual(response.data['auth'], 'example')
+ self.client.logout()
+ response = self.client.get('/view/')
+ self.assertEqual(response.data['auth'], b'')
+
+ def test_logout_resets_force_authenticate(self):
+ """
+ `logout()` resets any `force_authenticate`
+ """
+ user = User.objects.create_user('example', 'example@example.com', 'password')
+ self.client.force_authenticate(user)
+ response = self.client.get('/view/')
+ self.assertEqual(response.data['user'], 'example')
+ self.client.logout()
+ response = self.client.get('/view/')
+ self.assertEqual(response.data['user'], '')
+
+ def test_follow_redirect(self):
+ """
+ Follow redirect by setting follow argument.
+ """
+ response = self.client.get('/redirect-view/')
+ self.assertEqual(response.status_code, 302)
+ response = self.client.get('/redirect-view/', follow=True)
+ self.assertIsNotNone(response.redirect_chain)
+ self.assertEqual(response.status_code, 200)
+
+ response = self.client.post('/redirect-view/')
+ self.assertEqual(response.status_code, 302)
+ response = self.client.post('/redirect-view/', follow=True)
+ self.assertIsNotNone(response.redirect_chain)
+ self.assertEqual(response.status_code, 200)
+
+ response = self.client.put('/redirect-view/')
+ self.assertEqual(response.status_code, 302)
+ response = self.client.put('/redirect-view/', follow=True)
+ self.assertIsNotNone(response.redirect_chain)
+ self.assertEqual(response.status_code, 200)
+
+ response = self.client.patch('/redirect-view/')
+ self.assertEqual(response.status_code, 302)
+ response = self.client.patch('/redirect-view/', follow=True)
+ self.assertIsNotNone(response.redirect_chain)
+ self.assertEqual(response.status_code, 200)
+
+ response = self.client.delete('/redirect-view/')
+ self.assertEqual(response.status_code, 302)
+ response = self.client.delete('/redirect-view/', follow=True)
+ self.assertIsNotNone(response.redirect_chain)
+ self.assertEqual(response.status_code, 200)
+
+ response = self.client.options('/redirect-view/')
+ self.assertEqual(response.status_code, 302)
+ response = self.client.options('/redirect-view/', follow=True)
+ self.assertIsNotNone(response.redirect_chain)
+ self.assertEqual(response.status_code, 200)
+
+
+class TestAPIRequestFactory(TestCase):
+ def test_csrf_exempt_by_default(self):
+ """
+ By default, the test client is CSRF exempt.
+ """
+ user = User.objects.create_user('example', 'example@example.com', 'password')
+ factory = APIRequestFactory()
+ request = factory.post('/view/')
+ request.user = user
+ response = view(request)
+ self.assertEqual(response.status_code, 200)
+
+ def test_explicitly_enforce_csrf_checks(self):
+ """
+ The test client can enforce CSRF checks.
+ """
+ user = User.objects.create_user('example', 'example@example.com', 'password')
+ factory = APIRequestFactory(enforce_csrf_checks=True)
+ request = factory.post('/view/')
+ request.user = user
+ response = view(request)
+ expected = {'detail': 'CSRF Failed: CSRF cookie not set.'}
+ self.assertEqual(response.status_code, 403)
+ self.assertEqual(response.data, expected)
+
+ def test_invalid_format(self):
+ """
+ Attempting to use a format that is not configured will raise an
+ assertion error.
+ """
+ factory = APIRequestFactory()
+ self.assertRaises(
+ AssertionError, factory.post,
+ path='/view/', data={'example': 1}, format='xml'
+ )
+
+ def test_force_authenticate(self):
+ """
+ Setting `force_authenticate()` forcibly authenticates the request.
+ """
+ user = User.objects.create_user('example', 'example@example.com')
+ factory = APIRequestFactory()
+ request = factory.get('/view')
+ force_authenticate(request, user=user)
+ response = view(request)
+ self.assertEqual(response.data['user'], 'example')
+
+ def test_upload_file(self):
+ # This is a 1x1 black png
+ simple_png = BytesIO(b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x06\x00\x00\x00\x1f\x15\xc4\x89\x00\x00\x00\rIDATx\x9cc````\x00\x00\x00\x05\x00\x01\xa5\xf6E@\x00\x00\x00\x00IEND\xaeB`\x82')
+ simple_png.name = 'test.png'
+ factory = APIRequestFactory()
+ factory.post('/', data={'image': simple_png})
+
+ def test_request_factory_url_arguments(self):
+ """
+ This is a non regression test against #1461
+ """
+ factory = APIRequestFactory()
+ request = factory.get('/view/?demo=test')
+ self.assertEqual(dict(request.GET), {'demo': ['test']})
+ request = factory.get('/view/', {'demo': 'test'})
+ self.assertEqual(dict(request.GET), {'demo': ['test']})
diff --git a/tests/test_throttling.py b/tests/test_throttling.py
new file mode 100644
index 00000000..50a53b3e
--- /dev/null
+++ b/tests/test_throttling.py
@@ -0,0 +1,353 @@
+"""
+Tests for the throttling implementations in the permissions module.
+"""
+from __future__ import unicode_literals
+from django.test import TestCase
+from django.contrib.auth.models import User
+from django.core.cache import cache
+from rest_framework.settings import api_settings
+from rest_framework.test import APIRequestFactory
+from rest_framework.views import APIView
+from rest_framework.throttling import BaseThrottle, UserRateThrottle, ScopedRateThrottle
+from rest_framework.response import Response
+
+
+class User3SecRateThrottle(UserRateThrottle):
+ rate = '3/sec'
+ scope = 'seconds'
+
+
+class User3MinRateThrottle(UserRateThrottle):
+ rate = '3/min'
+ scope = 'minutes'
+
+
+class NonTimeThrottle(BaseThrottle):
+ def allow_request(self, request, view):
+ if not hasattr(self.__class__, 'called'):
+ self.__class__.called = True
+ return True
+ return False
+
+
+class MockView(APIView):
+ throttle_classes = (User3SecRateThrottle,)
+
+ def get(self, request):
+ return Response('foo')
+
+
+class MockView_MinuteThrottling(APIView):
+ throttle_classes = (User3MinRateThrottle,)
+
+ def get(self, request):
+ return Response('foo')
+
+
+class MockView_NonTimeThrottling(APIView):
+ throttle_classes = (NonTimeThrottle,)
+
+ def get(self, request):
+ return Response('foo')
+
+
+class ThrottlingTests(TestCase):
+ def setUp(self):
+ """
+ Reset the cache so that no throttles will be active
+ """
+ cache.clear()
+ self.factory = APIRequestFactory()
+
+ def test_requests_are_throttled(self):
+ """
+ Ensure request rate is limited
+ """
+ request = self.factory.get('/')
+ for dummy in range(4):
+ response = MockView.as_view()(request)
+ self.assertEqual(429, response.status_code)
+
+ def set_throttle_timer(self, view, value):
+ """
+ Explicitly set the timer, overriding time.time()
+ """
+ view.throttle_classes[0].timer = lambda self: value
+
+ def test_request_throttling_expires(self):
+ """
+ Ensure request rate is limited for a limited duration only
+ """
+ self.set_throttle_timer(MockView, 0)
+
+ request = self.factory.get('/')
+ for dummy in range(4):
+ response = MockView.as_view()(request)
+ self.assertEqual(429, response.status_code)
+
+ # Advance the timer by one second
+ self.set_throttle_timer(MockView, 1)
+
+ response = MockView.as_view()(request)
+ self.assertEqual(200, response.status_code)
+
+ def ensure_is_throttled(self, view, expect):
+ request = self.factory.get('/')
+ request.user = User.objects.create(username='a')
+ for dummy in range(3):
+ view.as_view()(request)
+ request.user = User.objects.create(username='b')
+ response = view.as_view()(request)
+ self.assertEqual(expect, response.status_code)
+
+ def test_request_throttling_is_per_user(self):
+ """
+ Ensure request rate is only limited per user, not globally for
+ PerUserThrottles
+ """
+ self.ensure_is_throttled(MockView, 200)
+
+ def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers):
+ """
+ Ensure the response returns an Retry-After field with status and next attributes
+ set properly.
+ """
+ request = self.factory.get('/')
+ for timer, expect in expected_headers:
+ self.set_throttle_timer(view, timer)
+ response = view.as_view()(request)
+ if expect is not None:
+ self.assertEqual(response['Retry-After'], expect)
+ else:
+ self.assertFalse('Retry-After' in response)
+
+ def test_seconds_fields(self):
+ """
+ Ensure for second based throttles.
+ """
+ self.ensure_response_header_contains_proper_throttle_field(
+ MockView, (
+ (0, None),
+ (0, None),
+ (0, None),
+ (0, '1')
+ )
+ )
+
+ def test_minutes_fields(self):
+ """
+ Ensure for minute based throttles.
+ """
+ self.ensure_response_header_contains_proper_throttle_field(
+ MockView_MinuteThrottling, (
+ (0, None),
+ (0, None),
+ (0, None),
+ (0, '60')
+ )
+ )
+
+ def test_next_rate_remains_constant_if_followed(self):
+ """
+ If a client follows the recommended next request rate,
+ the throttling rate should stay constant.
+ """
+ self.ensure_response_header_contains_proper_throttle_field(
+ MockView_MinuteThrottling, (
+ (0, None),
+ (20, None),
+ (40, None),
+ (60, None),
+ (80, None)
+ )
+ )
+
+ def test_non_time_throttle(self):
+ """
+ Ensure for second based throttles.
+ """
+ request = self.factory.get('/')
+
+ self.assertFalse(hasattr(MockView_NonTimeThrottling.throttle_classes[0], 'called'))
+
+ response = MockView_NonTimeThrottling.as_view()(request)
+ self.assertFalse('Retry-After' in response)
+
+ self.assertTrue(MockView_NonTimeThrottling.throttle_classes[0].called)
+
+ response = MockView_NonTimeThrottling.as_view()(request)
+ self.assertFalse('Retry-After' in response)
+
+
+class ScopedRateThrottleTests(TestCase):
+ """
+ Tests for ScopedRateThrottle.
+ """
+
+ def setUp(self):
+ class XYScopedRateThrottle(ScopedRateThrottle):
+ TIMER_SECONDS = 0
+ THROTTLE_RATES = {'x': '3/min', 'y': '1/min'}
+
+ def timer(self):
+ return self.TIMER_SECONDS
+
+ class XView(APIView):
+ throttle_classes = (XYScopedRateThrottle,)
+ throttle_scope = 'x'
+
+ def get(self, request):
+ return Response('x')
+
+ class YView(APIView):
+ throttle_classes = (XYScopedRateThrottle,)
+ throttle_scope = 'y'
+
+ def get(self, request):
+ return Response('y')
+
+ class UnscopedView(APIView):
+ throttle_classes = (XYScopedRateThrottle,)
+
+ def get(self, request):
+ return Response('y')
+
+ self.throttle_class = XYScopedRateThrottle
+ self.factory = APIRequestFactory()
+ self.x_view = XView.as_view()
+ self.y_view = YView.as_view()
+ self.unscoped_view = UnscopedView.as_view()
+
+ def increment_timer(self, seconds=1):
+ self.throttle_class.TIMER_SECONDS += seconds
+
+ def test_scoped_rate_throttle(self):
+ request = self.factory.get('/')
+
+ # Should be able to hit x view 3 times per minute.
+ response = self.x_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.x_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.x_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.x_view(request)
+ self.assertEqual(429, response.status_code)
+
+ # Should be able to hit y view 1 time per minute.
+ self.increment_timer()
+ response = self.y_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.y_view(request)
+ self.assertEqual(429, response.status_code)
+
+ # Ensure throttles properly reset by advancing the rest of the minute
+ self.increment_timer(55)
+
+ # Should still be able to hit x view 3 times per minute.
+ response = self.x_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.x_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.x_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.x_view(request)
+ self.assertEqual(429, response.status_code)
+
+ # Should still be able to hit y view 1 time per minute.
+ self.increment_timer()
+ response = self.y_view(request)
+ self.assertEqual(200, response.status_code)
+
+ self.increment_timer()
+ response = self.y_view(request)
+ self.assertEqual(429, response.status_code)
+
+ def test_unscoped_view_not_throttled(self):
+ request = self.factory.get('/')
+
+ for idx in range(10):
+ self.increment_timer()
+ response = self.unscoped_view(request)
+ self.assertEqual(200, response.status_code)
+
+
+class XffTestingBase(TestCase):
+ def setUp(self):
+
+ class Throttle(ScopedRateThrottle):
+ THROTTLE_RATES = {'test_limit': '1/day'}
+ TIMER_SECONDS = 0
+
+ def timer(self):
+ return self.TIMER_SECONDS
+
+ class View(APIView):
+ throttle_classes = (Throttle,)
+ throttle_scope = 'test_limit'
+
+ def get(self, request):
+ return Response('test_limit')
+
+ cache.clear()
+ self.throttle = Throttle()
+ self.view = View.as_view()
+ self.request = APIRequestFactory().get('/some_uri')
+ self.request.META['REMOTE_ADDR'] = '3.3.3.3'
+ self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 1.1.1.1, 2.2.2.2'
+
+ def config_proxy(self, num_proxies):
+ setattr(api_settings, 'NUM_PROXIES', num_proxies)
+
+
+class IdWithXffBasicTests(XffTestingBase):
+ def test_accepts_request_under_limit(self):
+ self.config_proxy(0)
+ self.assertEqual(200, self.view(self.request).status_code)
+
+ def test_denies_request_over_limit(self):
+ self.config_proxy(0)
+ self.view(self.request)
+ self.assertEqual(429, self.view(self.request).status_code)
+
+
+class XffSpoofingTests(XffTestingBase):
+ def test_xff_spoofing_doesnt_change_machine_id_with_one_app_proxy(self):
+ self.config_proxy(1)
+ self.view(self.request)
+ self.request.META['HTTP_X_FORWARDED_FOR'] = '4.4.4.4, 5.5.5.5, 2.2.2.2'
+ self.assertEqual(429, self.view(self.request).status_code)
+
+ def test_xff_spoofing_doesnt_change_machine_id_with_two_app_proxies(self):
+ self.config_proxy(2)
+ self.view(self.request)
+ self.request.META['HTTP_X_FORWARDED_FOR'] = '4.4.4.4, 1.1.1.1, 2.2.2.2'
+ self.assertEqual(429, self.view(self.request).status_code)
+
+
+class XffUniqueMachinesTest(XffTestingBase):
+ def test_unique_clients_are_counted_independently_with_one_proxy(self):
+ self.config_proxy(1)
+ self.view(self.request)
+ self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 1.1.1.1, 7.7.7.7'
+ self.assertEqual(200, self.view(self.request).status_code)
+
+ def test_unique_clients_are_counted_independently_with_two_proxies(self):
+ self.config_proxy(2)
+ self.view(self.request)
+ self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 7.7.7.7, 2.2.2.2'
+ self.assertEqual(200, self.view(self.request).status_code)
diff --git a/tests/test_urlpatterns.py b/tests/test_urlpatterns.py
new file mode 100644
index 00000000..e0060e69
--- /dev/null
+++ b/tests/test_urlpatterns.py
@@ -0,0 +1,76 @@
+from __future__ import unicode_literals
+from collections import namedtuple
+from django.conf.urls import patterns, url, include
+from django.core import urlresolvers
+from django.test import TestCase
+from rest_framework.test import APIRequestFactory
+from rest_framework.urlpatterns import format_suffix_patterns
+
+
+# A container class for test paths for the test case
+URLTestPath = namedtuple('URLTestPath', ['path', 'args', 'kwargs'])
+
+
+def dummy_view(request, *args, **kwargs):
+ pass
+
+
+class FormatSuffixTests(TestCase):
+ """
+ Tests `format_suffix_patterns` against different URLPatterns to ensure the URLs still resolve properly, including any captured parameters.
+ """
+ def _resolve_urlpatterns(self, urlpatterns, test_paths):
+ factory = APIRequestFactory()
+ try:
+ urlpatterns = format_suffix_patterns(urlpatterns)
+ except Exception:
+ self.fail("Failed to apply `format_suffix_patterns` on the supplied urlpatterns")
+ resolver = urlresolvers.RegexURLResolver(r'^/', urlpatterns)
+ for test_path in test_paths:
+ request = factory.get(test_path.path)
+ try:
+ callback, callback_args, callback_kwargs = resolver.resolve(request.path_info)
+ except Exception:
+ self.fail("Failed to resolve URL: %s" % request.path_info)
+ self.assertEqual(callback_args, test_path.args)
+ self.assertEqual(callback_kwargs, test_path.kwargs)
+
+ def test_format_suffix(self):
+ urlpatterns = patterns(
+ '',
+ url(r'^test$', dummy_view),
+ )
+ test_paths = [
+ URLTestPath('/test', (), {}),
+ URLTestPath('/test.api', (), {'format': 'api'}),
+ URLTestPath('/test.asdf', (), {'format': 'asdf'}),
+ ]
+ self._resolve_urlpatterns(urlpatterns, test_paths)
+
+ def test_default_args(self):
+ urlpatterns = patterns(
+ '',
+ url(r'^test$', dummy_view, {'foo': 'bar'}),
+ )
+ test_paths = [
+ URLTestPath('/test', (), {'foo': 'bar', }),
+ URLTestPath('/test.api', (), {'foo': 'bar', 'format': 'api'}),
+ URLTestPath('/test.asdf', (), {'foo': 'bar', 'format': 'asdf'}),
+ ]
+ self._resolve_urlpatterns(urlpatterns, test_paths)
+
+ def test_included_urls(self):
+ nested_patterns = patterns(
+ '',
+ url(r'^path$', dummy_view)
+ )
+ urlpatterns = patterns(
+ '',
+ url(r'^test/', include(nested_patterns), {'foo': 'bar'}),
+ )
+ test_paths = [
+ URLTestPath('/test/path', (), {'foo': 'bar', }),
+ URLTestPath('/test/path.api', (), {'foo': 'bar', 'format': 'api'}),
+ URLTestPath('/test/path.asdf', (), {'foo': 'bar', 'format': 'asdf'}),
+ ]
+ self._resolve_urlpatterns(urlpatterns, test_paths)
diff --git a/tests/test_utils.py b/tests/test_utils.py
new file mode 100644
index 00000000..8c286ea4
--- /dev/null
+++ b/tests/test_utils.py
@@ -0,0 +1,166 @@
+from __future__ import unicode_literals
+from django.core.exceptions import ImproperlyConfigured
+from django.conf.urls import patterns, url
+from django.test import TestCase
+from django.utils import six
+from rest_framework.utils.model_meta import _resolve_model
+from rest_framework.utils.breadcrumbs import get_breadcrumbs
+from rest_framework.views import APIView
+from tests.models import BasicModel
+
+import rest_framework.utils.model_meta
+
+
+class Root(APIView):
+ pass
+
+
+class ResourceRoot(APIView):
+ pass
+
+
+class ResourceInstance(APIView):
+ pass
+
+
+class NestedResourceRoot(APIView):
+ pass
+
+
+class NestedResourceInstance(APIView):
+ pass
+
+
+urlpatterns = patterns(
+ '',
+ url(r'^$', Root.as_view()),
+ url(r'^resource/$', ResourceRoot.as_view()),
+ url(r'^resource/(?P<key>[0-9]+)$', ResourceInstance.as_view()),
+ url(r'^resource/(?P<key>[0-9]+)/$', NestedResourceRoot.as_view()),
+ url(r'^resource/(?P<key>[0-9]+)/(?P<other>[A-Za-z]+)$', NestedResourceInstance.as_view()),
+)
+
+
+class BreadcrumbTests(TestCase):
+ """
+ Tests the breadcrumb functionality used by the HTML renderer.
+ """
+ urls = 'tests.test_utils'
+
+ def test_root_breadcrumbs(self):
+ url = '/'
+ self.assertEqual(
+ get_breadcrumbs(url),
+ [('Root', '/')]
+ )
+
+ def test_resource_root_breadcrumbs(self):
+ url = '/resource/'
+ self.assertEqual(
+ get_breadcrumbs(url),
+ [
+ ('Root', '/'),
+ ('Resource Root', '/resource/')
+ ]
+ )
+
+ def test_resource_instance_breadcrumbs(self):
+ url = '/resource/123'
+ self.assertEqual(
+ get_breadcrumbs(url),
+ [
+ ('Root', '/'),
+ ('Resource Root', '/resource/'),
+ ('Resource Instance', '/resource/123')
+ ]
+ )
+
+ def test_nested_resource_breadcrumbs(self):
+ url = '/resource/123/'
+ self.assertEqual(
+ get_breadcrumbs(url),
+ [
+ ('Root', '/'),
+ ('Resource Root', '/resource/'),
+ ('Resource Instance', '/resource/123'),
+ ('Nested Resource Root', '/resource/123/')
+ ]
+ )
+
+ def test_nested_resource_instance_breadcrumbs(self):
+ url = '/resource/123/abc'
+ self.assertEqual(
+ get_breadcrumbs(url),
+ [
+ ('Root', '/'),
+ ('Resource Root', '/resource/'),
+ ('Resource Instance', '/resource/123'),
+ ('Nested Resource Root', '/resource/123/'),
+ ('Nested Resource Instance', '/resource/123/abc')
+ ]
+ )
+
+ def test_broken_url_breadcrumbs_handled_gracefully(self):
+ url = '/foobar'
+ self.assertEqual(
+ get_breadcrumbs(url),
+ [('Root', '/')]
+ )
+
+
+class ResolveModelTests(TestCase):
+ """
+ `_resolve_model` should return a Django model class given the
+ provided argument is a Django model class itself, or a properly
+ formatted string representation of one.
+ """
+ def test_resolve_django_model(self):
+ resolved_model = _resolve_model(BasicModel)
+ self.assertEqual(resolved_model, BasicModel)
+
+ def test_resolve_string_representation(self):
+ resolved_model = _resolve_model('tests.BasicModel')
+ self.assertEqual(resolved_model, BasicModel)
+
+ def test_resolve_unicode_representation(self):
+ resolved_model = _resolve_model(six.text_type('tests.BasicModel'))
+ self.assertEqual(resolved_model, BasicModel)
+
+ def test_resolve_non_django_model(self):
+ with self.assertRaises(ValueError):
+ _resolve_model(TestCase)
+
+ def test_resolve_improper_string_representation(self):
+ with self.assertRaises(ValueError):
+ _resolve_model('BasicModel')
+
+
+class ResolveModelWithPatchedDjangoTests(TestCase):
+ """
+ Test coverage for when Django's `get_model` returns `None`.
+
+ Under certain circumstances Django may return `None` with `get_model`:
+ http://git.io/get-model-source
+
+ It usually happens with circular imports so it is important that DRF
+ excepts early, otherwise fault happens downstream and is much more
+ difficult to debug.
+
+ """
+
+ def setUp(self):
+ """Monkeypatch get_model."""
+ self.get_model = rest_framework.utils.model_meta.models.get_model
+
+ def get_model(app_label, model_name):
+ return None
+
+ rest_framework.utils.model_meta.models.get_model = get_model
+
+ def tearDown(self):
+ """Revert monkeypatching."""
+ rest_framework.utils.model_meta.models.get_model = self.get_model
+
+ def test_blows_up_if_model_does_not_resolve(self):
+ with self.assertRaises(ImproperlyConfigured):
+ _resolve_model('tests.BasicModel')
diff --git a/tests/test_validation.py b/tests/test_validation.py
new file mode 100644
index 00000000..4234efd3
--- /dev/null
+++ b/tests/test_validation.py
@@ -0,0 +1,183 @@
+from __future__ import unicode_literals
+from django.core.validators import RegexValidator, MaxValueValidator
+from django.db import models
+from django.test import TestCase
+from rest_framework import generics, serializers, status
+from rest_framework.test import APIRequestFactory
+import re
+
+factory = APIRequestFactory()
+
+
+# Regression for #666
+
+class ValidationModel(models.Model):
+ blank_validated_field = models.CharField(max_length=255)
+
+
+class ValidationModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ValidationModel
+ fields = ('blank_validated_field',)
+ read_only_fields = ('blank_validated_field',)
+
+
+class UpdateValidationModel(generics.RetrieveUpdateDestroyAPIView):
+ queryset = ValidationModel.objects.all()
+ serializer_class = ValidationModelSerializer
+
+
+# Regression for #653
+
+class ShouldValidateModel(models.Model):
+ should_validate_field = models.CharField(max_length=255)
+
+
+class ShouldValidateModelSerializer(serializers.ModelSerializer):
+ renamed = serializers.CharField(source='should_validate_field', required=False)
+
+ def validate_renamed(self, value):
+ if len(value) < 3:
+ raise serializers.ValidationError('Minimum 3 characters.')
+ return value
+
+ class Meta:
+ model = ShouldValidateModel
+ fields = ('renamed',)
+
+
+class TestPreSaveValidationExclusionsSerializer(TestCase):
+ def test_renamed_fields_are_model_validated(self):
+ """
+ Ensure fields with 'source' applied do get still get model validation.
+ """
+ # We've set `required=False` on the serializer, but the model
+ # does not have `blank=True`, so this serializer should not validate.
+ serializer = ShouldValidateModelSerializer(data={'renamed': ''})
+ self.assertEqual(serializer.is_valid(), False)
+ self.assertIn('renamed', serializer.errors)
+ self.assertNotIn('should_validate_field', serializer.errors)
+
+
+class TestCustomValidationMethods(TestCase):
+ def test_custom_validation_method_is_executed(self):
+ serializer = ShouldValidateModelSerializer(data={'renamed': 'fo'})
+ self.assertFalse(serializer.is_valid())
+ self.assertIn('renamed', serializer.errors)
+
+ def test_custom_validation_method_passing(self):
+ serializer = ShouldValidateModelSerializer(data={'renamed': 'foo'})
+ self.assertTrue(serializer.is_valid())
+
+
+class ValidationSerializer(serializers.Serializer):
+ foo = serializers.CharField()
+
+ def validate_foo(self, attrs, source):
+ raise serializers.ValidationError("foo invalid")
+
+ def validate(self, attrs):
+ raise serializers.ValidationError("serializer invalid")
+
+
+class TestAvoidValidation(TestCase):
+ """
+ If serializer was initialized with invalid data (None or non dict-like), it
+ should avoid validation layer (validate_<field> and validate methods)
+ """
+ def test_serializer_errors_has_only_invalid_data_error(self):
+ serializer = ValidationSerializer(data='invalid data')
+ self.assertFalse(serializer.is_valid())
+ self.assertDictEqual(serializer.errors, {
+ 'non_field_errors': [
+ 'Invalid data. Expected a dictionary, but got %s.' % type('').__name__
+ ]
+ })
+
+
+# regression tests for issue: 1493
+
+class ValidationMaxValueValidatorModel(models.Model):
+ number_value = models.PositiveIntegerField(validators=[MaxValueValidator(100)])
+
+
+class ValidationMaxValueValidatorModelSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ValidationMaxValueValidatorModel
+
+
+class UpdateMaxValueValidationModel(generics.RetrieveUpdateDestroyAPIView):
+ queryset = ValidationMaxValueValidatorModel.objects.all()
+ serializer_class = ValidationMaxValueValidatorModelSerializer
+
+
+class TestMaxValueValidatorValidation(TestCase):
+
+ def test_max_value_validation_serializer_success(self):
+ serializer = ValidationMaxValueValidatorModelSerializer(data={'number_value': 99})
+ self.assertTrue(serializer.is_valid())
+
+ def test_max_value_validation_serializer_fails(self):
+ serializer = ValidationMaxValueValidatorModelSerializer(data={'number_value': 101})
+ self.assertFalse(serializer.is_valid())
+ self.assertDictEqual({'number_value': ['Ensure this value is less than or equal to 100.']}, serializer.errors)
+
+ def test_max_value_validation_success(self):
+ obj = ValidationMaxValueValidatorModel.objects.create(number_value=100)
+ request = factory.patch('/{0}'.format(obj.pk), {'number_value': 98}, format='json')
+ view = UpdateMaxValueValidationModel().as_view()
+ response = view(request, pk=obj.pk).render()
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+
+ def test_max_value_validation_fail(self):
+ obj = ValidationMaxValueValidatorModel.objects.create(number_value=100)
+ request = factory.patch('/{0}'.format(obj.pk), {'number_value': 101}, format='json')
+ view = UpdateMaxValueValidationModel().as_view()
+ response = view(request, pk=obj.pk).render()
+ self.assertEqual(response.content, b'{"number_value":["Ensure this value is less than or equal to 100."]}')
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+
+
+class TestChoiceFieldChoicesValidate(TestCase):
+ CHOICES = [
+ (0, 'Small'),
+ (1, 'Medium'),
+ (2, 'Large'),
+ ]
+
+ CHOICES_NESTED = [
+ ('Category', (
+ (1, 'First'),
+ (2, 'Second'),
+ (3, 'Third'),
+ )),
+ (4, 'Fourth'),
+ ]
+
+ def test_choices(self):
+ """
+ Make sure a value for choices works as expected.
+ """
+ f = serializers.ChoiceField(choices=self.CHOICES)
+ value = self.CHOICES[0][0]
+ try:
+ f.to_internal_value(value)
+ except serializers.ValidationError:
+ self.fail("Value %s does not validate" % str(value))
+
+
+class RegexSerializer(serializers.Serializer):
+ pin = serializers.CharField(
+ validators=[RegexValidator(regex=re.compile('^[0-9]{4,6}$'),
+ message='A PIN is 4-6 digits')])
+
+expected_repr = """
+RegexSerializer():
+ pin = CharField(validators=[<django.core.validators.RegexValidator object>])
+""".strip()
+
+
+class TestRegexSerializer(TestCase):
+ def test_regex_repr(self):
+ serializer_repr = repr(RegexSerializer())
+ assert serializer_repr == expected_repr
diff --git a/tests/test_validators.py b/tests/test_validators.py
new file mode 100644
index 00000000..127ec6f8
--- /dev/null
+++ b/tests/test_validators.py
@@ -0,0 +1,347 @@
+from django.db import models
+from django.test import TestCase
+from rest_framework import serializers
+import datetime
+
+
+def dedent(blocktext):
+ return '\n'.join([line[12:] for line in blocktext.splitlines()[1:-1]])
+
+
+# Tests for `UniqueValidator`
+# ---------------------------
+
+class UniquenessModel(models.Model):
+ username = models.CharField(unique=True, max_length=100)
+
+
+class UniquenessSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = UniquenessModel
+
+
+class AnotherUniquenessModel(models.Model):
+ code = models.IntegerField(unique=True)
+
+
+class AnotherUniquenessSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = AnotherUniquenessModel
+
+
+class TestUniquenessValidation(TestCase):
+ def setUp(self):
+ self.instance = UniquenessModel.objects.create(username='existing')
+
+ def test_repr(self):
+ serializer = UniquenessSerializer()
+ expected = dedent("""
+ UniquenessSerializer():
+ id = IntegerField(label='ID', read_only=True)
+ username = CharField(max_length=100, validators=[<UniqueValidator(queryset=UniquenessModel.objects.all())>])
+ """)
+ assert repr(serializer) == expected
+
+ def test_is_not_unique(self):
+ data = {'username': 'existing'}
+ serializer = UniquenessSerializer(data=data)
+ assert not serializer.is_valid()
+ assert serializer.errors == {'username': ['This field must be unique.']}
+
+ def test_is_unique(self):
+ data = {'username': 'other'}
+ serializer = UniquenessSerializer(data=data)
+ assert serializer.is_valid()
+ assert serializer.validated_data == {'username': 'other'}
+
+ def test_updated_instance_excluded(self):
+ data = {'username': 'existing'}
+ serializer = UniquenessSerializer(self.instance, data=data)
+ assert serializer.is_valid()
+ assert serializer.validated_data == {'username': 'existing'}
+
+ def test_doesnt_pollute_model(self):
+ instance = AnotherUniquenessModel.objects.create(code='100')
+ serializer = AnotherUniquenessSerializer(instance)
+ self.assertEqual(
+ AnotherUniquenessModel._meta.get_field('code').validators, [])
+
+ # Accessing data shouldn't effect validators on the model
+ serializer.data
+ self.assertEqual(
+ AnotherUniquenessModel._meta.get_field('code').validators, [])
+
+
+# Tests for `UniqueTogetherValidator`
+# -----------------------------------
+
+class UniquenessTogetherModel(models.Model):
+ race_name = models.CharField(max_length=100)
+ position = models.IntegerField()
+
+ class Meta:
+ unique_together = ('race_name', 'position')
+
+
+class NullUniquenessTogetherModel(models.Model):
+ """
+ Used to ensure that null values are not included when checking
+ unique_together constraints.
+
+ Ignoring items which have a null in any of the validated fields is the same
+ behavior that database backends will use when they have the
+ unique_together constraint added.
+
+ Example case: a null position could indicate a non-finisher in the race,
+ there could be many non-finishers in a race, but all non-NULL
+ values *should* be unique against the given `race_name`.
+ """
+ date_of_birth = models.DateField(null=True) # Not part of the uniqueness constraint
+ race_name = models.CharField(max_length=100)
+ position = models.IntegerField(null=True)
+
+ class Meta:
+ unique_together = ('race_name', 'position')
+
+
+class UniquenessTogetherSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = UniquenessTogetherModel
+
+
+class NullUniquenessTogetherSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = NullUniquenessTogetherModel
+
+
+class TestUniquenessTogetherValidation(TestCase):
+ def setUp(self):
+ self.instance = UniquenessTogetherModel.objects.create(
+ race_name='example',
+ position=1
+ )
+ UniquenessTogetherModel.objects.create(
+ race_name='example',
+ position=2
+ )
+ UniquenessTogetherModel.objects.create(
+ race_name='other',
+ position=1
+ )
+
+ def test_repr(self):
+ serializer = UniquenessTogetherSerializer()
+ expected = dedent("""
+ UniquenessTogetherSerializer():
+ id = IntegerField(label='ID', read_only=True)
+ race_name = CharField(max_length=100, required=True)
+ position = IntegerField(required=True)
+ class Meta:
+ validators = [<UniqueTogetherValidator(queryset=UniquenessTogetherModel.objects.all(), fields=('race_name', 'position'))>]
+ """)
+ assert repr(serializer) == expected
+
+ def test_is_not_unique_together(self):
+ """
+ Failing unique together validation should result in non field errors.
+ """
+ data = {'race_name': 'example', 'position': 2}
+ serializer = UniquenessTogetherSerializer(data=data)
+ assert not serializer.is_valid()
+ assert serializer.errors == {
+ 'non_field_errors': [
+ 'The fields race_name, position must make a unique set.'
+ ]
+ }
+
+ def test_is_unique_together(self):
+ """
+ In a unique together validation, one field may be non-unique
+ so long as the set as a whole is unique.
+ """
+ data = {'race_name': 'other', 'position': 2}
+ serializer = UniquenessTogetherSerializer(data=data)
+ assert serializer.is_valid()
+ assert serializer.validated_data == {
+ 'race_name': 'other',
+ 'position': 2
+ }
+
+ def test_updated_instance_excluded_from_unique_together(self):
+ """
+ When performing an update, the existing instance does not count
+ as a match against uniqueness.
+ """
+ data = {'race_name': 'example', 'position': 1}
+ serializer = UniquenessTogetherSerializer(self.instance, data=data)
+ assert serializer.is_valid()
+ assert serializer.validated_data == {
+ 'race_name': 'example',
+ 'position': 1
+ }
+
+ def test_unique_together_is_required(self):
+ """
+ In a unique together validation, all fields are required.
+ """
+ data = {'position': 2}
+ serializer = UniquenessTogetherSerializer(data=data, partial=True)
+ assert not serializer.is_valid()
+ assert serializer.errors == {
+ 'race_name': ['This field is required.']
+ }
+
+ def test_ignore_excluded_fields(self):
+ """
+ When model fields are not included in a serializer, then uniqueness
+ validators should not be added for that field.
+ """
+ class ExcludedFieldSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = UniquenessTogetherModel
+ fields = ('id', 'race_name',)
+ serializer = ExcludedFieldSerializer()
+ expected = dedent("""
+ ExcludedFieldSerializer():
+ id = IntegerField(label='ID', read_only=True)
+ race_name = CharField(max_length=100)
+ """)
+ assert repr(serializer) == expected
+
+ def test_ignore_validation_for_null_fields(self):
+ # None values that are on fields which are part of the uniqueness
+ # constraint cause the instance to ignore uniqueness validation.
+ NullUniquenessTogetherModel.objects.create(
+ date_of_birth=datetime.date(2000, 1, 1),
+ race_name='Paris Marathon',
+ position=None
+ )
+ data = {
+ 'date': datetime.date(2000, 1, 1),
+ 'race_name': 'Paris Marathon',
+ 'position': None
+ }
+ serializer = NullUniquenessTogetherSerializer(data=data)
+ assert serializer.is_valid()
+
+ def test_do_not_ignore_validation_for_null_fields(self):
+ # None values that are not on fields part of the uniqueness constraint
+ # do not cause the instance to skip validation.
+ NullUniquenessTogetherModel.objects.create(
+ date_of_birth=datetime.date(2000, 1, 1),
+ race_name='Paris Marathon',
+ position=1
+ )
+ data = {'date': None, 'race_name': 'Paris Marathon', 'position': 1}
+ serializer = NullUniquenessTogetherSerializer(data=data)
+ assert not serializer.is_valid()
+
+
+# Tests for `UniqueForDateValidator`
+# ----------------------------------
+
+class UniqueForDateModel(models.Model):
+ slug = models.CharField(max_length=100, unique_for_date='published')
+ published = models.DateField()
+
+
+class UniqueForDateSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = UniqueForDateModel
+
+
+class TestUniquenessForDateValidation(TestCase):
+ def setUp(self):
+ self.instance = UniqueForDateModel.objects.create(
+ slug='existing',
+ published='2000-01-01'
+ )
+
+ def test_repr(self):
+ serializer = UniqueForDateSerializer()
+ expected = dedent("""
+ UniqueForDateSerializer():
+ id = IntegerField(label='ID', read_only=True)
+ slug = CharField(max_length=100)
+ published = DateField(required=True)
+ class Meta:
+ validators = [<UniqueForDateValidator(queryset=UniqueForDateModel.objects.all(), field='slug', date_field='published')>]
+ """)
+ assert repr(serializer) == expected
+
+ def test_is_not_unique_for_date(self):
+ """
+ Failing unique for date validation should result in field error.
+ """
+ data = {'slug': 'existing', 'published': '2000-01-01'}
+ serializer = UniqueForDateSerializer(data=data)
+ assert not serializer.is_valid()
+ assert serializer.errors == {
+ 'slug': ['This field must be unique for the "published" date.']
+ }
+
+ def test_is_unique_for_date(self):
+ """
+ Passing unique for date validation.
+ """
+ data = {'slug': 'existing', 'published': '2000-01-02'}
+ serializer = UniqueForDateSerializer(data=data)
+ assert serializer.is_valid()
+ assert serializer.validated_data == {
+ 'slug': 'existing',
+ 'published': datetime.date(2000, 1, 2)
+ }
+
+ def test_updated_instance_excluded_from_unique_for_date(self):
+ """
+ When performing an update, the existing instance does not count
+ as a match against unique_for_date.
+ """
+ data = {'slug': 'existing', 'published': '2000-01-01'}
+ serializer = UniqueForDateSerializer(instance=self.instance, data=data)
+ assert serializer.is_valid()
+ assert serializer.validated_data == {
+ 'slug': 'existing',
+ 'published': datetime.date(2000, 1, 1)
+ }
+
+
+class HiddenFieldUniqueForDateModel(models.Model):
+ slug = models.CharField(max_length=100, unique_for_date='published')
+ published = models.DateTimeField(auto_now_add=True)
+
+
+class TestHiddenFieldUniquenessForDateValidation(TestCase):
+ def test_repr_date_field_not_included(self):
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = HiddenFieldUniqueForDateModel
+ fields = ('id', 'slug')
+
+ serializer = TestSerializer()
+ expected = dedent("""
+ TestSerializer():
+ id = IntegerField(label='ID', read_only=True)
+ slug = CharField(max_length=100)
+ published = HiddenField(default=CreateOnlyDefault(<function now>))
+ class Meta:
+ validators = [<UniqueForDateValidator(queryset=HiddenFieldUniqueForDateModel.objects.all(), field='slug', date_field='published')>]
+ """)
+ assert repr(serializer) == expected
+
+ def test_repr_date_field_included(self):
+ class TestSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = HiddenFieldUniqueForDateModel
+ fields = ('id', 'slug', 'published')
+
+ serializer = TestSerializer()
+ expected = dedent("""
+ TestSerializer():
+ id = IntegerField(label='ID', read_only=True)
+ slug = CharField(max_length=100)
+ published = DateTimeField(default=CreateOnlyDefault(<function now>), read_only=True)
+ class Meta:
+ validators = [<UniqueForDateValidator(queryset=HiddenFieldUniqueForDateModel.objects.all(), field='slug', date_field='published')>]
+ """)
+ assert repr(serializer) == expected
diff --git a/tests/test_versioning.py b/tests/test_versioning.py
new file mode 100644
index 00000000..90ad8afd
--- /dev/null
+++ b/tests/test_versioning.py
@@ -0,0 +1,264 @@
+from .utils import UsingURLPatterns
+from django.conf.urls import include, url
+from rest_framework import serializers
+from rest_framework import status, versioning
+from rest_framework.decorators import APIView
+from rest_framework.response import Response
+from rest_framework.reverse import reverse
+from rest_framework.test import APIRequestFactory, APITestCase
+from rest_framework.versioning import NamespaceVersioning
+import pytest
+
+
+class RequestVersionView(APIView):
+ def get(self, request, *args, **kwargs):
+ return Response({'version': request.version})
+
+
+class ReverseView(APIView):
+ def get(self, request, *args, **kwargs):
+ return Response({'url': reverse('another', request=request)})
+
+
+class RequestInvalidVersionView(APIView):
+ def determine_version(self, request, *args, **kwargs):
+ scheme = self.versioning_class()
+ scheme.allowed_versions = ('v1', 'v2')
+ return (scheme.determine_version(request, *args, **kwargs), scheme)
+
+ def get(self, request, *args, **kwargs):
+ return Response({'version': request.version})
+
+
+factory = APIRequestFactory()
+
+
+def dummy_view(request):
+ pass
+
+
+def dummy_pk_view(request, pk):
+ pass
+
+
+class TestRequestVersion:
+ def test_unversioned(self):
+ view = RequestVersionView.as_view()
+
+ request = factory.get('/endpoint/')
+ response = view(request)
+ assert response.data == {'version': None}
+
+ def test_query_param_versioning(self):
+ scheme = versioning.QueryParameterVersioning
+ view = RequestVersionView.as_view(versioning_class=scheme)
+
+ request = factory.get('/endpoint/?version=1.2.3')
+ response = view(request)
+ assert response.data == {'version': '1.2.3'}
+
+ request = factory.get('/endpoint/')
+ response = view(request)
+ assert response.data == {'version': None}
+
+ def test_host_name_versioning(self):
+ scheme = versioning.HostNameVersioning
+ view = RequestVersionView.as_view(versioning_class=scheme)
+
+ request = factory.get('/endpoint/', HTTP_HOST='v1.example.org')
+ response = view(request)
+ assert response.data == {'version': 'v1'}
+
+ request = factory.get('/endpoint/')
+ response = view(request)
+ assert response.data == {'version': None}
+
+ def test_accept_header_versioning(self):
+ scheme = versioning.AcceptHeaderVersioning
+ view = RequestVersionView.as_view(versioning_class=scheme)
+
+ request = factory.get('/endpoint/', HTTP_ACCEPT='application/json; version=1.2.3')
+ response = view(request)
+ assert response.data == {'version': '1.2.3'}
+
+ request = factory.get('/endpoint/', HTTP_ACCEPT='application/json')
+ response = view(request)
+ assert response.data == {'version': None}
+
+ def test_url_path_versioning(self):
+ scheme = versioning.URLPathVersioning
+ view = RequestVersionView.as_view(versioning_class=scheme)
+
+ request = factory.get('/1.2.3/endpoint/')
+ response = view(request, version='1.2.3')
+ assert response.data == {'version': '1.2.3'}
+
+ request = factory.get('/endpoint/')
+ response = view(request)
+ assert response.data == {'version': None}
+
+ def test_namespace_versioning(self):
+ class FakeResolverMatch:
+ namespace = 'v1'
+
+ scheme = versioning.NamespaceVersioning
+ view = RequestVersionView.as_view(versioning_class=scheme)
+
+ request = factory.get('/v1/endpoint/')
+ request.resolver_match = FakeResolverMatch
+ response = view(request, version='v1')
+ assert response.data == {'version': 'v1'}
+
+ request = factory.get('/endpoint/')
+ response = view(request)
+ assert response.data == {'version': None}
+
+
+class TestURLReversing(UsingURLPatterns, APITestCase):
+ included = [
+ url(r'^namespaced/$', dummy_view, name='another'),
+ url(r'^example/(?P<pk>\d+)/$', dummy_pk_view, name='example-detail')
+ ]
+
+ urlpatterns = [
+ url(r'^v1/', include(included, namespace='v1')),
+ url(r'^another/$', dummy_view, name='another'),
+ url(r'^(?P<version>[^/]+)/another/$', dummy_view, name='another'),
+ ]
+
+ def test_reverse_unversioned(self):
+ view = ReverseView.as_view()
+
+ request = factory.get('/endpoint/')
+ response = view(request)
+ assert response.data == {'url': 'http://testserver/another/'}
+
+ def test_reverse_query_param_versioning(self):
+ scheme = versioning.QueryParameterVersioning
+ view = ReverseView.as_view(versioning_class=scheme)
+
+ request = factory.get('/endpoint/?version=v1')
+ response = view(request)
+ assert response.data == {'url': 'http://testserver/another/?version=v1'}
+
+ request = factory.get('/endpoint/')
+ response = view(request)
+ assert response.data == {'url': 'http://testserver/another/'}
+
+ def test_reverse_host_name_versioning(self):
+ scheme = versioning.HostNameVersioning
+ view = ReverseView.as_view(versioning_class=scheme)
+
+ request = factory.get('/endpoint/', HTTP_HOST='v1.example.org')
+ response = view(request)
+ assert response.data == {'url': 'http://v1.example.org/another/'}
+
+ request = factory.get('/endpoint/')
+ response = view(request)
+ assert response.data == {'url': 'http://testserver/another/'}
+
+ def test_reverse_url_path_versioning(self):
+ scheme = versioning.URLPathVersioning
+ view = ReverseView.as_view(versioning_class=scheme)
+
+ request = factory.get('/v1/endpoint/')
+ response = view(request, version='v1')
+ assert response.data == {'url': 'http://testserver/v1/another/'}
+
+ request = factory.get('/endpoint/')
+ response = view(request)
+ assert response.data == {'url': 'http://testserver/another/'}
+
+ def test_reverse_namespace_versioning(self):
+ class FakeResolverMatch:
+ namespace = 'v1'
+
+ scheme = versioning.NamespaceVersioning
+ view = ReverseView.as_view(versioning_class=scheme)
+
+ request = factory.get('/v1/endpoint/')
+ request.resolver_match = FakeResolverMatch
+ response = view(request, version='v1')
+ assert response.data == {'url': 'http://testserver/v1/namespaced/'}
+
+ request = factory.get('/endpoint/')
+ response = view(request)
+ assert response.data == {'url': 'http://testserver/another/'}
+
+
+class TestInvalidVersion:
+ def test_invalid_query_param_versioning(self):
+ scheme = versioning.QueryParameterVersioning
+ view = RequestInvalidVersionView.as_view(versioning_class=scheme)
+
+ request = factory.get('/endpoint/?version=v3')
+ response = view(request)
+ assert response.status_code == status.HTTP_404_NOT_FOUND
+
+ def test_invalid_host_name_versioning(self):
+ scheme = versioning.HostNameVersioning
+ view = RequestInvalidVersionView.as_view(versioning_class=scheme)
+
+ request = factory.get('/endpoint/', HTTP_HOST='v3.example.org')
+ response = view(request)
+ assert response.status_code == status.HTTP_404_NOT_FOUND
+
+ def test_invalid_accept_header_versioning(self):
+ scheme = versioning.AcceptHeaderVersioning
+ view = RequestInvalidVersionView.as_view(versioning_class=scheme)
+
+ request = factory.get('/endpoint/', HTTP_ACCEPT='application/json; version=v3')
+ response = view(request)
+ assert response.status_code == status.HTTP_406_NOT_ACCEPTABLE
+
+ def test_invalid_url_path_versioning(self):
+ scheme = versioning.URLPathVersioning
+ view = RequestInvalidVersionView.as_view(versioning_class=scheme)
+
+ request = factory.get('/v3/endpoint/')
+ response = view(request, version='v3')
+ assert response.status_code == status.HTTP_404_NOT_FOUND
+
+ def test_invalid_namespace_versioning(self):
+ class FakeResolverMatch:
+ namespace = 'v3'
+
+ scheme = versioning.NamespaceVersioning
+ view = RequestInvalidVersionView.as_view(versioning_class=scheme)
+
+ request = factory.get('/v3/endpoint/')
+ request.resolver_match = FakeResolverMatch
+ response = view(request, version='v3')
+ assert response.status_code == status.HTTP_404_NOT_FOUND
+
+
+class TestHyperlinkedRelatedField(UsingURLPatterns, APITestCase):
+ included = [
+ url(r'^namespaced/(?P<pk>\d+)/$', dummy_view, name='namespaced'),
+ ]
+
+ urlpatterns = [
+ url(r'^v1/', include(included, namespace='v1')),
+ url(r'^v2/', include(included, namespace='v2'))
+ ]
+
+ def setUp(self):
+ super(TestHyperlinkedRelatedField, self).setUp()
+
+ class MockQueryset(object):
+ def get(self, pk):
+ return 'object %s' % pk
+
+ self.field = serializers.HyperlinkedRelatedField(
+ view_name='namespaced',
+ queryset=MockQueryset()
+ )
+ request = factory.get('/')
+ request.versioning_scheme = NamespaceVersioning()
+ request.version = 'v1'
+ self.field._context = {'request': request}
+
+ def test_bug_2489(self):
+ assert self.field.to_internal_value('/v1/namespaced/3/') == 'object 3'
+ with pytest.raises(serializers.ValidationError):
+ self.field.to_internal_value('/v2/namespaced/3/')
diff --git a/tests/test_views.py b/tests/test_views.py
new file mode 100644
index 00000000..77b113ee
--- /dev/null
+++ b/tests/test_views.py
@@ -0,0 +1,148 @@
+from __future__ import unicode_literals
+
+import sys
+import copy
+from django.test import TestCase
+from rest_framework import status
+from rest_framework.decorators import api_view
+from rest_framework.response import Response
+from rest_framework.settings import api_settings
+from rest_framework.test import APIRequestFactory
+from rest_framework.views import APIView
+
+factory = APIRequestFactory()
+
+if sys.version_info[:2] >= (3, 4):
+ JSON_ERROR = 'JSON parse error - Expecting value:'
+else:
+ JSON_ERROR = 'JSON parse error - No JSON object could be decoded'
+
+
+class BasicView(APIView):
+ def get(self, request, *args, **kwargs):
+ return Response({'method': 'GET'})
+
+ def post(self, request, *args, **kwargs):
+ return Response({'method': 'POST', 'data': request.DATA})
+
+
+@api_view(['GET', 'POST', 'PUT', 'PATCH'])
+def basic_view(request):
+ if request.method == 'GET':
+ return {'method': 'GET'}
+ elif request.method == 'POST':
+ return {'method': 'POST', 'data': request.DATA}
+ elif request.method == 'PUT':
+ return {'method': 'PUT', 'data': request.DATA}
+ elif request.method == 'PATCH':
+ return {'method': 'PATCH', 'data': request.DATA}
+
+
+class ErrorView(APIView):
+ def get(self, request, *args, **kwargs):
+ raise Exception
+
+
+@api_view(['GET'])
+def error_view(request):
+ raise Exception
+
+
+def sanitise_json_error(error_dict):
+ """
+ Exact contents of JSON error messages depend on the installed version
+ of json.
+ """
+ ret = copy.copy(error_dict)
+ chop = len(JSON_ERROR)
+ ret['detail'] = ret['detail'][:chop]
+ return ret
+
+
+class ClassBasedViewIntegrationTests(TestCase):
+ def setUp(self):
+ self.view = BasicView.as_view()
+
+ def test_400_parse_error(self):
+ request = factory.post('/', 'f00bar', content_type='application/json')
+ response = self.view(request)
+ expected = {
+ 'detail': JSON_ERROR
+ }
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+ self.assertEqual(sanitise_json_error(response.data), expected)
+
+ def test_400_parse_error_tunneled_content(self):
+ content = 'f00bar'
+ content_type = 'application/json'
+ form_data = {
+ api_settings.FORM_CONTENT_OVERRIDE: content,
+ api_settings.FORM_CONTENTTYPE_OVERRIDE: content_type
+ }
+ request = factory.post('/', form_data)
+ response = self.view(request)
+ expected = {
+ 'detail': JSON_ERROR
+ }
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+ self.assertEqual(sanitise_json_error(response.data), expected)
+
+
+class FunctionBasedViewIntegrationTests(TestCase):
+ def setUp(self):
+ self.view = basic_view
+
+ def test_400_parse_error(self):
+ request = factory.post('/', 'f00bar', content_type='application/json')
+ response = self.view(request)
+ expected = {
+ 'detail': JSON_ERROR
+ }
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+ self.assertEqual(sanitise_json_error(response.data), expected)
+
+ def test_400_parse_error_tunneled_content(self):
+ content = 'f00bar'
+ content_type = 'application/json'
+ form_data = {
+ api_settings.FORM_CONTENT_OVERRIDE: content,
+ api_settings.FORM_CONTENTTYPE_OVERRIDE: content_type
+ }
+ request = factory.post('/', form_data)
+ response = self.view(request)
+ expected = {
+ 'detail': JSON_ERROR
+ }
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+ self.assertEqual(sanitise_json_error(response.data), expected)
+
+
+class TestCustomExceptionHandler(TestCase):
+ def setUp(self):
+ self.DEFAULT_HANDLER = api_settings.EXCEPTION_HANDLER
+
+ def exception_handler(exc):
+ return Response('Error!', status=status.HTTP_400_BAD_REQUEST)
+
+ api_settings.EXCEPTION_HANDLER = exception_handler
+
+ def tearDown(self):
+ api_settings.EXCEPTION_HANDLER = self.DEFAULT_HANDLER
+
+ def test_class_based_view_exception_handler(self):
+ view = ErrorView.as_view()
+
+ request = factory.get('/', content_type='application/json')
+ response = view(request)
+ expected = 'Error!'
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+ self.assertEqual(response.data, expected)
+
+ def test_function_based_view_exception_handler(self):
+ view = error_view
+
+ request = factory.get('/', content_type='application/json')
+ response = view(request)
+ expected = 'Error!'
+ self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
+ self.assertEqual(response.data, expected)
diff --git a/tests/test_viewsets.py b/tests/test_viewsets.py
new file mode 100644
index 00000000..4d18a955
--- /dev/null
+++ b/tests/test_viewsets.py
@@ -0,0 +1,35 @@
+from django.test import TestCase
+from rest_framework import status
+from rest_framework.response import Response
+from rest_framework.test import APIRequestFactory
+from rest_framework.viewsets import GenericViewSet
+
+
+factory = APIRequestFactory()
+
+
+class BasicViewSet(GenericViewSet):
+ def list(self, request, *args, **kwargs):
+ return Response({'ACTION': 'LIST'})
+
+
+class InitializeViewSetsTestCase(TestCase):
+ def test_initialize_view_set_with_actions(self):
+ request = factory.get('/', '', content_type='application/json')
+ my_view = BasicViewSet.as_view(actions={
+ 'get': 'list',
+ })
+
+ response = my_view(request)
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertEqual(response.data, {'ACTION': 'LIST'})
+
+ def test_initialize_view_set_with_empty_actions(self):
+ try:
+ BasicViewSet.as_view()
+ except TypeError as e:
+ self.assertEqual(str(e), "The `actions` argument must be provided "
+ "when calling `.as_view()` on a ViewSet. "
+ "For example `.as_view({'get': 'list'})`")
+ else:
+ self.fail("actions must not be empty.")
diff --git a/tests/test_write_only_fields.py b/tests/test_write_only_fields.py
new file mode 100644
index 00000000..dd3bbd6e
--- /dev/null
+++ b/tests/test_write_only_fields.py
@@ -0,0 +1,31 @@
+from django.test import TestCase
+from rest_framework import serializers
+
+
+class WriteOnlyFieldTests(TestCase):
+ def setUp(self):
+ class ExampleSerializer(serializers.Serializer):
+ email = serializers.EmailField()
+ password = serializers.CharField(write_only=True)
+
+ def create(self, attrs):
+ return attrs
+
+ self.Serializer = ExampleSerializer
+
+ def write_only_fields_are_present_on_input(self):
+ data = {
+ 'email': 'foo@example.com',
+ 'password': '123'
+ }
+ serializer = self.Serializer(data=data)
+ self.assertTrue(serializer.is_valid())
+ self.assertEquals(serializer.validated_data, data)
+
+ def write_only_fields_are_not_present_on_output(self):
+ instance = {
+ 'email': 'foo@example.com',
+ 'password': '123'
+ }
+ serializer = self.Serializer(instance)
+ self.assertEquals(serializer.data, {'email': 'foo@example.com'})
diff --git a/tests/urls.py b/tests/urls.py
new file mode 100644
index 00000000..41f527df
--- /dev/null
+++ b/tests/urls.py
@@ -0,0 +1,6 @@
+"""
+Blank URLConf just to keep the test suite happy
+"""
+from django.conf.urls import patterns
+
+urlpatterns = patterns('')
diff --git a/tests/utils.py b/tests/utils.py
new file mode 100644
index 00000000..b9034996
--- /dev/null
+++ b/tests/utils.py
@@ -0,0 +1,77 @@
+from django.core.exceptions import ObjectDoesNotExist
+from django.core.urlresolvers import NoReverseMatch
+
+
+class UsingURLPatterns(object):
+ """
+ Isolates URL patterns used during testing on the test class itself.
+ For example:
+
+ class MyTestCase(UsingURLPatterns, TestCase):
+ urlpatterns = [
+ ...
+ ]
+
+ def test_something(self):
+ ...
+ """
+ urls = __name__
+
+ def setUp(self):
+ global urlpatterns
+ urlpatterns = self.urlpatterns
+
+ def tearDown(self):
+ global urlpatterns
+ urlpatterns = []
+
+
+class MockObject(object):
+ def __init__(self, **kwargs):
+ self._kwargs = kwargs
+ for key, val in kwargs.items():
+ setattr(self, key, val)
+
+ def __str__(self):
+ kwargs_str = ', '.join([
+ '%s=%s' % (key, value)
+ for key, value in sorted(self._kwargs.items())
+ ])
+ return '<MockObject %s>' % kwargs_str
+
+
+class MockQueryset(object):
+ def __init__(self, iterable):
+ self.items = iterable
+
+ def get(self, **lookup):
+ for item in self.items:
+ if all([
+ getattr(item, key, None) == value
+ for key, value in lookup.items()
+ ]):
+ return item
+ raise ObjectDoesNotExist()
+
+
+class BadType(object):
+ """
+ When used as a lookup with a `MockQueryset`, these objects
+ will raise a `TypeError`, as occurs in Django when making
+ queryset lookups with an incorrect type for the lookup value.
+ """
+ def __eq__(self):
+ raise TypeError()
+
+
+def mock_reverse(view_name, args=None, kwargs=None, request=None, format=None):
+ args = args or []
+ kwargs = kwargs or {}
+ value = (args + list(kwargs.values()) + ['-'])[0]
+ prefix = 'http://example.org' if request else ''
+ suffix = ('.' + format) if (format is not None) else ''
+ return '%s/%s/%s%s/' % (prefix, view_name, value, suffix)
+
+
+def fail_reverse(view_name, args=None, kwargs=None, request=None, format=None):
+ raise NoReverseMatch()